1+ using System ;
12using System . IO ;
23using System . Linq ;
34using System . Collections . Generic ;
5+ using System . Reflection ;
46using Mono . Cecil ;
57using Mono . Cecil . Cil ;
8+ using Mono . Cecil . Rocks ;
69using Unity . CompilationPipeline . Common . Diagnostics ;
710using Unity . CompilationPipeline . Common . ILPostProcessing ;
811using ILPPInterface = Unity . CompilationPipeline . Common . ILPostProcessing . ILPostProcessor ;
12+ using MethodAttributes = Mono . Cecil . MethodAttributes ;
913
1014namespace Unity . Netcode . Editor . CodeGen
1115{
@@ -14,7 +18,9 @@ internal sealed class INetworkMessageILPP : ILPPInterface
1418 {
1519 public override ILPPInterface GetInstance ( ) => this ;
1620
17- public override bool WillProcess ( ICompiledAssembly compiledAssembly ) => compiledAssembly . References . Any ( filePath => Path . GetFileNameWithoutExtension ( filePath ) == CodeGenHelpers . RuntimeAssemblyName ) ;
21+ public override bool WillProcess ( ICompiledAssembly compiledAssembly ) =>
22+ compiledAssembly . Name == CodeGenHelpers . RuntimeAssemblyName ||
23+ compiledAssembly . References . Any ( filePath => Path . GetFileNameWithoutExtension ( filePath ) == CodeGenHelpers . RuntimeAssemblyName ) ;
1824
1925 private readonly List < DiagnosticMessage > m_Diagnostics = new List < DiagnosticMessage > ( ) ;
2026
@@ -42,11 +48,24 @@ public override ILPostProcessResult Process(ICompiledAssembly compiledAssembly)
4248 {
4349 if ( ImportReferences ( mainModule ) )
4450 {
51+ var types = mainModule . GetTypes ( )
52+ . Where ( t => t . Resolve ( ) . HasInterface ( CodeGenHelpers . INetworkMessage_FullName ) && ! t . Resolve ( ) . IsAbstract )
53+ . ToList ( ) ;
4554 // process `INetworkMessage` types
46- mainModule . GetTypes ( )
47- . Where ( t => t . HasInterface ( CodeGenHelpers . INetworkMessage_FullName ) )
48- . ToList ( )
49- . ForEach ( b => ProcessINetworkMessage ( b ) ) ;
55+ if ( types . Count == 0 )
56+ {
57+ return null ;
58+ }
59+
60+ try
61+ {
62+ types . ForEach ( b => ProcessINetworkMessage ( b ) ) ;
63+ CreateModuleInitializer ( assemblyDefinition , types ) ;
64+ }
65+ catch ( Exception e )
66+ {
67+ m_Diagnostics . AddError ( ( e . ToString ( ) + e . StackTrace . ToString ( ) ) . Replace ( "\n " , "|" ) . Replace ( "\r " , "|" ) ) ;
68+ }
5069 }
5170 else
5271 {
@@ -58,6 +77,8 @@ public override ILPostProcessResult Process(ICompiledAssembly compiledAssembly)
5877 m_Diagnostics . AddError ( $ "Cannot get main module from assembly definition: { compiledAssembly . Name } ") ;
5978 }
6079
80+ mainModule . RemoveRecursiveReferences ( ) ;
81+
6182 // write
6283 var pe = new MemoryStream ( ) ;
6384 var pdb = new MemoryStream ( ) ;
@@ -77,12 +98,50 @@ public override ILPostProcessResult Process(ICompiledAssembly compiledAssembly)
7798
7899 private TypeReference m_FastBufferReader_TypeRef ;
79100 private TypeReference m_NetworkContext_TypeRef ;
101+ private FieldReference m_MessagingSystem___network_message_types_FieldRef ;
102+ private MethodReference m_Type_GetTypeFromHandle_MethodRef ;
103+
104+ private MethodReference m_List_Add_MethodRef ;
80105
81106 private bool ImportReferences ( ModuleDefinition moduleDefinition )
82107 {
83108 m_FastBufferReader_TypeRef = moduleDefinition . ImportReference ( typeof ( FastBufferReader ) ) ;
84109 m_NetworkContext_TypeRef = moduleDefinition . ImportReference ( typeof ( NetworkContext ) ) ;
85110
111+ var typeType = typeof ( Type ) ;
112+ foreach ( var methodInfo in typeType . GetMethods ( ) )
113+ {
114+ switch ( methodInfo . Name )
115+ {
116+ case nameof ( Type . GetTypeFromHandle ) :
117+ m_Type_GetTypeFromHandle_MethodRef = moduleDefinition . ImportReference ( methodInfo ) ;
118+ break ;
119+ }
120+ }
121+
122+ var messagingSystemType = typeof ( MessagingSystem ) ;
123+ foreach ( var fieldInfo in messagingSystemType . GetFields ( BindingFlags . Static | BindingFlags . NonPublic ) )
124+ {
125+ switch ( fieldInfo . Name )
126+ {
127+ case nameof ( MessagingSystem . __network_message_types ) :
128+ m_MessagingSystem___network_message_types_FieldRef = moduleDefinition . ImportReference ( fieldInfo ) ;
129+ break ;
130+ }
131+ }
132+
133+ var listType = typeof ( List < Type > ) ;
134+ foreach ( var methodInfo in listType . GetMethods ( ) )
135+ {
136+ switch ( methodInfo . Name )
137+ {
138+ case nameof ( List < Type > . Add ) :
139+ m_List_Add_MethodRef = moduleDefinition . ImportReference ( methodInfo ) ;
140+ break ;
141+ }
142+ }
143+
144+
86145 return true ;
87146 }
88147
@@ -98,6 +157,7 @@ private void ProcessINetworkMessage(TypeDefinition typeDefinition)
98157 {
99158 typeSequence = methodSequence ;
100159 }
160+
101161 if ( resolved . IsStatic && resolved . IsPublic && resolved . Name == "Receive" && resolved . Parameters . Count == 2
102162 && ! resolved . Parameters [ 0 ] . IsIn
103163 && ! resolved . Parameters [ 0 ] . ParameterType . IsByReference
@@ -118,5 +178,62 @@ private void ProcessINetworkMessage(TypeDefinition typeDefinition)
118178 m_Diagnostics . AddError ( typeSequence , $ "Class { typeDefinition . FullName } does not implement required function: `public static void Receive(FastBufferReader, in NetworkContext)`") ;
119179 }
120180 }
181+
182+ private MethodDefinition GetOrCreateStaticConstructor ( TypeDefinition typeDefinition )
183+ {
184+ var staticCtorMethodDef = typeDefinition . GetStaticConstructor ( ) ;
185+ if ( staticCtorMethodDef == null )
186+ {
187+ staticCtorMethodDef = new MethodDefinition (
188+ ".cctor" , // Static Constructor (constant-constructor)
189+ MethodAttributes . HideBySig |
190+ MethodAttributes . SpecialName |
191+ MethodAttributes . RTSpecialName |
192+ MethodAttributes . Static ,
193+ typeDefinition . Module . TypeSystem . Void ) ;
194+ staticCtorMethodDef . Body . Instructions . Add ( Instruction . Create ( OpCodes . Ret ) ) ;
195+ typeDefinition . Methods . Add ( staticCtorMethodDef ) ;
196+ }
197+
198+ return staticCtorMethodDef ;
199+ }
200+
201+ private void CreateInstructionsToRegisterType ( ILProcessor processor , List < Instruction > instructions , TypeReference type )
202+ {
203+ // MessagingSystem.__network_message_types.Add(typeof(type));
204+ instructions . Add ( processor . Create ( OpCodes . Ldsfld , m_MessagingSystem___network_message_types_FieldRef ) ) ;
205+ instructions . Add ( processor . Create ( OpCodes . Ldtoken , type ) ) ;
206+ instructions . Add ( processor . Create ( OpCodes . Call , m_Type_GetTypeFromHandle_MethodRef ) ) ;
207+ instructions . Add ( processor . Create ( OpCodes . Callvirt , m_List_Add_MethodRef ) ) ;
208+ }
209+
210+ // Creates a static module constructor (which is executed when the module is loaded) that registers all the
211+ // message types in the assembly with MessagingSystem.
212+ // This is the same behavior as annotating a static method with [ModuleInitializer] in standardized
213+ // C# (that attribute doesn't exist in Unity, but the static module constructor still works)
214+ // https://docs.microsoft.com/en-us/dotnet/api/system.runtime.compilerservices.moduleinitializerattribute?view=net-5.0
215+ // https://web.archive.org/web/20100212140402/http://blogs.msdn.com/junfeng/archive/2005/11/19/494914.aspx
216+ private void CreateModuleInitializer ( AssemblyDefinition assembly , List < TypeDefinition > networkMessageTypes )
217+ {
218+ foreach ( var typeDefinition in assembly . MainModule . Types )
219+ {
220+ if ( typeDefinition . FullName == "<Module>" )
221+ {
222+ var staticCtorMethodDef = GetOrCreateStaticConstructor ( typeDefinition ) ;
223+
224+ var processor = staticCtorMethodDef . Body . GetILProcessor ( ) ;
225+
226+ var instructions = new List < Instruction > ( ) ;
227+
228+ foreach ( var type in networkMessageTypes )
229+ {
230+ CreateInstructionsToRegisterType ( processor , instructions , type ) ;
231+ }
232+
233+ instructions . ForEach ( instruction => processor . Body . Instructions . Insert ( processor . Body . Instructions . Count - 1 , instruction ) ) ;
234+ break ;
235+ }
236+ }
237+ }
121238 }
122239}
0 commit comments