Rasagar/Library/PackageCache/com.unity.burst/Runtime/Editor/BurstReflection.cs

644 lines
26 KiB
C#
Raw Normal View History

2024-08-26 13:07:20 -07:00
#if UNITY_EDITOR
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using Unity.Jobs.LowLevel.Unsafe;
using UnityEditor;
using UnityEditor.Compilation;
using Debug = UnityEngine.Debug;
[assembly: InternalsVisibleTo("Unity.Burst.Editor.Tests")]
namespace Unity.Burst.Editor
{
using static BurstCompilerOptions;
internal static class BurstReflection
{
// The TypeCache API was added in 2019.2. So there are two versions of FindExecuteMethods,
// one that uses TypeCache and one that doesn't.
public static FindExecuteMethodsResult FindExecuteMethods(List<System.Reflection.Assembly> assemblyList, BurstReflectionAssemblyOptions options)
{
var methodsToCompile = new List<BurstCompileTarget>();
var methodsToCompileSet = new HashSet<MethodInfo>();
var logMessages = new List<LogMessage>();
var interfaceToProducer = new Dictionary<Type, Type>();
var assemblySet = new HashSet<System.Reflection.Assembly>(assemblyList);
void AddTarget(BurstCompileTarget target)
{
if (target.Method.Name.EndsWith("$BurstManaged")) return;
// We will not try to record more than once a method in the methods to compile
// This can happen if a job interface is inheriting from another job interface which are using in the end the same
// job producer type
if (!target.IsStaticMethod && !methodsToCompileSet.Add(target.Method))
{
return;
}
if (options.HasFlag(BurstReflectionAssemblyOptions.ExcludeTestAssemblies) &&
target.JobType.Assembly.GetReferencedAssemblies().Any(x => IsNUnitDll(x.Name)))
{
return;
}
methodsToCompile.Add(target);
}
var staticMethodTypes = new HashSet<Type>();
// -------------------------------------------
// Find job structs using TypeCache.
// -------------------------------------------
var jobProducerImplementations = TypeCache.GetTypesWithAttribute<JobProducerTypeAttribute>();
foreach (var jobProducerImplementation in jobProducerImplementations)
{
var attrs = jobProducerImplementation.GetCustomAttributes(typeof(JobProducerTypeAttribute), false);
if (attrs.Length == 0)
{
continue;
}
staticMethodTypes.Add(jobProducerImplementation);
var attr = (JobProducerTypeAttribute)attrs[0];
interfaceToProducer.Add(jobProducerImplementation, attr.ProducerType);
}
foreach (var jobProducerImplementation in jobProducerImplementations)
{
if (!jobProducerImplementation.IsInterface)
{
continue;
}
var jobTypes = TypeCache.GetTypesDerivedFrom(jobProducerImplementation);
foreach (var jobType in jobTypes)
{
if (jobType.IsGenericType || !jobType.IsValueType)
{
continue;
}
ScanJobType(jobType, interfaceToProducer, logMessages, AddTarget);
}
}
// -------------------------------------------
// Find static methods using TypeCache.
// -------------------------------------------
void AddStaticMethods(TypeCache.MethodCollection methods)
{
foreach (var method in methods)
{
if (HasBurstCompileAttribute(method.DeclaringType))
{
staticMethodTypes.Add(method.DeclaringType);
// NOTE: Make sure that we don't use a value type generic definition (e.g `class Outer<T> { struct Inner { } }`)
// We are only working on plain type or generic type instance!
if (!method.DeclaringType.IsGenericTypeDefinition &&
method.IsStatic &&
!method.ContainsGenericParameters)
{
AddTarget(new BurstCompileTarget(method, method.DeclaringType, null, true));
}
}
}
}
// Add [BurstCompile] static methods.
AddStaticMethods(TypeCache.GetMethodsWithAttribute<BurstCompileAttribute>());
// Add [TestCompiler] static methods.
if (!options.HasFlag(BurstReflectionAssemblyOptions.ExcludeTestAssemblies))
{
var testCompilerAttributeType = Type.GetType("Burst.Compiler.IL.Tests.TestCompilerAttribute, Unity.Burst.Tests.UnitTests, Version=0.0.0.0, Culture=neutral, PublicKeyToken=null");
if (testCompilerAttributeType != null)
{
AddStaticMethods(TypeCache.GetMethodsWithAttribute(testCompilerAttributeType));
}
}
// -------------------------------------------
// Find job types and static methods based on
// generic instances types. These will not be
// found by the TypeCache scanning above.
// -------------------------------------------
FindExecuteMethodsForGenericInstances(
assemblySet,
staticMethodTypes,
interfaceToProducer,
AddTarget,
logMessages);
return new FindExecuteMethodsResult(methodsToCompile, logMessages);
}
private static void ScanJobType(
Type jobType,
Dictionary<Type, Type> interfaceToProducer,
List<LogMessage> logMessages,
Action<BurstCompileTarget> addTarget)
{
foreach (var interfaceType in jobType.GetInterfaces())
{
var genericLessInterface = interfaceType;
if (interfaceType.IsGenericType)
{
genericLessInterface = interfaceType.GetGenericTypeDefinition();
}
if (interfaceToProducer.TryGetValue(genericLessInterface, out var foundProducer))
{
var genericParams = new List<Type> { jobType };
if (interfaceType.IsGenericType)
{
genericParams.AddRange(interfaceType.GenericTypeArguments);
}
try
{
var executeType = foundProducer.MakeGenericType(genericParams.ToArray());
var executeMethod = executeType.GetMethod("Execute", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static);
if (executeMethod == null)
{
throw new InvalidOperationException($"Burst reflection error. The type `{executeType}` does not contain an `Execute` method");
}
addTarget(new BurstCompileTarget(executeMethod, jobType, interfaceType, false));
}
catch (Exception ex)
{
logMessages.Add(new LogMessage(ex));
}
}
}
}
private static void FindExecuteMethodsForGenericInstances(
HashSet<System.Reflection.Assembly> assemblyList,
HashSet<Type> staticMethodTypes,
Dictionary<Type, Type> interfaceToProducer,
Action<BurstCompileTarget> addTarget,
List<LogMessage> logMessages)
{
var valueTypes = new List<TypeToVisit>();
//Debug.Log("Filtered Assembly List: " + string.Join(", ", assemblyList.Select(assembly => assembly.GetName().Name)));
// Find all ways to execute job types (via producer attributes)
var typesVisited = new HashSet<string>();
var typesToVisit = new HashSet<string>();
var allTypesAssembliesCollected = new HashSet<Type>();
foreach (var assembly in assemblyList)
{
var types = new List<Type>();
try
{
// Collect all generic type instances (excluding indirect instances)
CollectGenericTypeInstances(
assembly,
x => assemblyList.Contains(x.Assembly),
types,
allTypesAssembliesCollected);
}
catch (Exception ex)
{
logMessages.Add(new LogMessage(LogType.Warning, "Unexpected exception while collecting types in assembly `" + assembly.FullName + "` Exception: " + ex));
}
for (var i = 0; i < types.Count; i++)
{
var t = types[i];
if (typesToVisit.Add(t.AssemblyQualifiedName))
{
// Because the list of types returned by CollectGenericTypeInstances does not detect nested generic classes that are not
// used explicitly, we need to create them if a declaring type is actually used
// so for example if we have:
// class MyClass<T> { class MyNestedClass { } }
// class MyDerived : MyClass<int> { }
// The CollectGenericTypeInstances will return typically the type MyClass<int>, but will not list MyClass<int>.MyNestedClass
// So the following code is correcting this in order to fully query the full graph of generic instance types, including indirect types
var nestedTypes = t.GetNestedTypes(BindingFlags.Public | BindingFlags.NonPublic);
foreach (var nestedType in nestedTypes)
{
if (t.IsGenericType && !t.IsGenericTypeDefinition)
{
var parentGenericTypeArguments = t.GetGenericArguments();
// Only create nested types that are closed generic types (full generic instance types)
// It happens if for example the parent class is `class MClass<T> { class MyNestedGeneric<T1> {} }`
// In that case, MyNestedGeneric<T1> is opened in the context of MClass<int>, so we don't process them
if (nestedType.GetGenericArguments().Length == parentGenericTypeArguments.Length)
{
try
{
var instanceNestedType = nestedType.MakeGenericType(parentGenericTypeArguments);
types.Add(instanceNestedType);
}
catch (Exception ex)
{
var error = $"Unexpected Burst Inspector error. Invalid generic type instance. Trying to instantiate the generic type {nestedType.FullName} with the generic arguments <{string.Join(", ", parentGenericTypeArguments.Select(x => x.FullName))}> is not supported: {ex}";
logMessages.Add(new LogMessage(LogType.Warning, error));
}
}
}
else
{
types.Add(nestedType);
}
}
}
}
foreach (var t in types)
{
// If the type has been already visited, don't try to visit it
if (!typesVisited.Add(t.AssemblyQualifiedName) || (t.IsGenericTypeDefinition && !t.IsInterface))
{
continue;
}
try
{
// collect methods with types having a [BurstCompile] attribute
var staticMethodDeclaringType = t;
if (t.IsGenericType)
{
staticMethodDeclaringType = t.GetGenericTypeDefinition();
}
bool visitStaticMethods = staticMethodTypes.Contains(staticMethodDeclaringType);
bool isValueType = false;
if (t.IsValueType)
{
// NOTE: Make sure that we don't use a value type generic definition (e.g `class Outer<T> { struct Inner { } }`)
// We are only working on plain type or generic type instance!
if (!t.IsGenericTypeDefinition)
isValueType = true;
}
if (isValueType || visitStaticMethods)
{
valueTypes.Add(new TypeToVisit(t, visitStaticMethods));
}
}
catch (Exception ex)
{
logMessages.Add(new LogMessage(LogType.Warning,
"Unexpected exception while inspecting type `" + t +
"` IsConstructedGenericType: " + t.IsConstructedGenericType +
" IsGenericTypeDef: " + t.IsGenericTypeDefinition +
" IsGenericParam: " + t.IsGenericParameter +
" Exception: " + ex));
}
}
}
// Revisit all types to find things that are compilable using the above producers.
foreach (var typePair in valueTypes)
{
var type = typePair.Type;
// collect static [BurstCompile] methods
if (typePair.CollectStaticMethods)
{
try
{
var methods = type.GetMethods(BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic);
foreach (var method in methods)
{
if (HasBurstCompileAttribute(method))
{
addTarget(new BurstCompileTarget(method, type, null, true));
}
}
}
catch (Exception ex)
{
logMessages.Add(new LogMessage(ex));
}
}
// If the type is not a value type, we don't need to proceed with struct Jobs
if (!type.IsValueType)
{
continue;
}
ScanJobType(type, interfaceToProducer, logMessages, addTarget);
}
}
public sealed class FindExecuteMethodsResult
{
public readonly List<BurstCompileTarget> CompileTargets;
public readonly List<LogMessage> LogMessages;
public FindExecuteMethodsResult(List<BurstCompileTarget> compileTargets, List<LogMessage> logMessages)
{
CompileTargets = compileTargets;
LogMessages = logMessages;
}
}
public sealed class LogMessage
{
public readonly LogType LogType;
public readonly string Message;
public readonly Exception Exception;
public LogMessage(LogType logType, string message)
{
LogType = logType;
Message = message;
}
public LogMessage(Exception exception)
{
LogType = LogType.Exception;
Exception = exception;
}
}
public enum LogType
{
Warning,
Exception,
}
/// <summary>
/// This method exists solely to ensure that the static constructor has been called.
/// </summary>
public static void EnsureInitialized() { }
public static readonly List<System.Reflection.Assembly> EditorAssembliesThatCanPossiblyContainJobs;
public static readonly List<System.Reflection.Assembly> EditorAssembliesThatCanPossiblyContainJobsExcludingTestAssemblies;
/// <summary>
/// Collects (and caches) all editor assemblies - transitively.
/// </summary>
static BurstReflection()
{
EditorAssembliesThatCanPossiblyContainJobs = new List<System.Reflection.Assembly>();
EditorAssembliesThatCanPossiblyContainJobsExcludingTestAssemblies = new List<System.Reflection.Assembly>();
// TODO: Not sure there is a better way to match assemblies returned by CompilationPipeline.GetAssemblies
// with runtime assemblies contained in the AppDomain.CurrentDomain.GetAssemblies()
// Filter the assemblies
var assemblyList = CompilationPipeline.GetAssemblies(AssembliesType.Editor);
var assemblyNames = new HashSet<string>();
foreach (var assembly in assemblyList)
{
CollectAssemblyNames(assembly, assemblyNames);
}
var allAssemblies = new HashSet<System.Reflection.Assembly>();
foreach (var assembly in AppDomain.CurrentDomain.GetAssemblies())
{
if (!assemblyNames.Contains(assembly.GetName().Name))
{
continue;
}
CollectAssembly(assembly, allAssemblies);
}
}
// For an assembly to contain something "interesting" when we're scanning for things to compile,
// it needs to either:
// (a) be one of these assemblies, or
// (b) reference one of these assemblies
private static readonly string[] ScanMarkerAssemblies = new[]
{
// Contains [BurstCompile] attribute
"Unity.Burst",
// Contains [JobProducerType] attribute
"UnityEngine.CoreModule"
};
private static void CollectAssembly(System.Reflection.Assembly assembly, HashSet<System.Reflection.Assembly> collect)
{
if (!collect.Add(assembly))
{
return;
}
var referencedAssemblies = assembly.GetReferencedAssemblies();
var shouldCollectReferences = false;
var name = assembly.GetName().Name;
if (ScanMarkerAssemblies.Contains(name) || referencedAssemblies.Any(x => ScanMarkerAssemblies.Contains(x.Name)))
{
EditorAssembliesThatCanPossiblyContainJobs.Add(assembly);
shouldCollectReferences = true;
if (!assembly.GetReferencedAssemblies().Any(x => IsNUnitDll(x.Name)))
{
EditorAssembliesThatCanPossiblyContainJobsExcludingTestAssemblies.Add(assembly);
}
}
if (!shouldCollectReferences)
{
return;
}
foreach (var assemblyName in referencedAssemblies)
{
try
{
CollectAssembly(System.Reflection.Assembly.Load(assemblyName), collect);
}
catch (Exception)
{
if (BurstLoader.IsDebugging)
{
Debug.LogWarning("Could not load assembly " + assemblyName);
}
}
}
}
private static bool IsNUnitDll(string value)
{
return CultureInfo.InvariantCulture.CompareInfo.IndexOf(value, "nunit.framework") >= 0;
}
private static void CollectAssemblyNames(UnityEditor.Compilation.Assembly assembly, HashSet<string> collect)
{
if (assembly == null || assembly.name == null) return;
if (!collect.Add(assembly.name))
{
return;
}
foreach (var assemblyRef in assembly.assemblyReferences)
{
CollectAssemblyNames(assemblyRef, collect);
}
}
/// <summary>
/// Gets the list of concrete generic type instances used in an assembly.
/// See remarks
/// </summary>
/// <param name="assembly">The assembly</param>
/// <param name="types"></param>
/// <returns>The list of generic type instances</returns>
/// <remarks>
/// Note that this method fetchs only direct type instances but
/// cannot fetch transitive generic type instances.
/// </remarks>
private static void CollectGenericTypeInstances(
System.Reflection.Assembly assembly,
Func<Type, bool> typeFilter,
List<Type> types,
HashSet<Type> visited)
{
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
// WARNING: THIS CODE HAS TO BE MAINTAINED IN SYNC WITH BclApp.cs
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
// From: https://gist.github.com/xoofx/710aaf86e0e8c81649d1261b1ef9590e
if (assembly == null) throw new ArgumentNullException(nameof(assembly));
const int mdMaxCount = 1 << 24;
foreach (var module in assembly.Modules)
{
for (int i = 1; i < mdMaxCount; i++)
{
try
{
// Token base id for TypeSpec
const int mdTypeSpec = 0x1B000000;
var type = module.ResolveType(mdTypeSpec | i);
CollectGenericTypeInstances(type, types, visited, typeFilter);
}
catch (ArgumentOutOfRangeException)
{
break;
}
catch (ArgumentException)
{
// Can happen on ResolveType on certain generic types, so we continue
}
}
for (int i = 1; i < mdMaxCount; i++)
{
try
{
// Token base id for MethodSpec
const int mdMethodSpec = 0x2B000000;
var method = module.ResolveMethod(mdMethodSpec | i);
var genericArgs = method.GetGenericArguments();
foreach (var genArgType in genericArgs)
{
CollectGenericTypeInstances(genArgType, types, visited, typeFilter);
}
}
catch (ArgumentOutOfRangeException)
{
break;
}
catch (ArgumentException)
{
// Can happen on ResolveType on certain generic types, so we continue
}
}
for (int i = 1; i < mdMaxCount; i++)
{
try
{
// Token base id for Field
const int mdField = 0x04000000;
var field = module.ResolveField(mdField | i);
CollectGenericTypeInstances(field.FieldType, types, visited, typeFilter);
}
catch (ArgumentOutOfRangeException)
{
break;
}
catch (ArgumentException)
{
// Can happen on ResolveType on certain generic types, so we continue
}
}
}
// Scan for types used in constructor arguments to assembly-level attributes,
// such as [RegisterGenericJobType(typeof(...))].
foreach (var customAttribute in assembly.CustomAttributes)
{
foreach (var argument in customAttribute.ConstructorArguments)
{
if (argument.ArgumentType == typeof(Type))
{
CollectGenericTypeInstances((Type)argument.Value, types, visited, typeFilter);
}
}
}
}
private static void CollectGenericTypeInstances(
Type type,
List<Type> types,
HashSet<Type> visited,
Func<Type, bool> typeFilter)
{
if (type.IsPrimitive) return;
if (!visited.Add(type)) return;
// Add only concrete types
if (type.IsConstructedGenericType && !type.ContainsGenericParameters && typeFilter(type))
{
types.Add(type);
}
// Collect recursively generic type arguments
var genericTypeArguments = type.GenericTypeArguments;
foreach (var genericTypeArgument in genericTypeArguments)
{
if (!genericTypeArgument.IsPrimitive)
{
CollectGenericTypeInstances(genericTypeArgument, types, visited, typeFilter);
}
}
}
[DebuggerDisplay("{Type} (static methods: {CollectStaticMethods})")]
private struct TypeToVisit
{
public TypeToVisit(Type type, bool collectStaticMethods)
{
Type = type;
CollectStaticMethods = collectStaticMethods;
}
public readonly Type Type;
public readonly bool CollectStaticMethods;
}
}
[Flags]
internal enum BurstReflectionAssemblyOptions
{
None = 0,
ExcludeTestAssemblies = 1,
}
}
#endif