22using Microsoft . CodeAnalysis . CSharp ;
33using Microsoft . CodeAnalysis . CSharp . Syntax ;
44using Microsoft . CodeAnalysis . Text ;
5- using NetFabric . CodeAnalysis ;
65using System ;
76using System . Collections . Generic ;
87using System . Collections . Immutable ;
98using System . Diagnostics ;
10- using System . Linq ;
119using System . Text ;
1210using System . Threading ;
1311
1412namespace NetFabric . Hyperlinq . SourceGenerator
1513{
1614 [ Generator ]
17- public class OverloadsGenerator
15+ public partial class Generator
1816 : ISourceGenerator
1917 {
2018 static readonly DiagnosticDescriptor unhandledExceptionError = new (
@@ -25,7 +23,7 @@ public class OverloadsGenerator
2523 DiagnosticSeverity . Error ,
2624 isEnabledByDefault : true ) ;
2725
28- internal static readonly string [ ] methods = new [ ]
26+ internal static readonly ImmutableArray < string > methods = ImmutableArray . Create ( new [ ]
2927 {
3028 // aggregation
3129 "Count" ,
@@ -85,7 +83,7 @@ public class OverloadsGenerator
8583 // quantifier
8684 "Distinct" ,
8785 "DistinctAsync" ,
88- } ;
86+ } ) ;
8987
9088 public void Initialize ( GeneratorInitializationContext context )
9189 {
@@ -126,6 +124,8 @@ public void Execute(GeneratorExecutionContext context)
126124
127125 internal static void GenerateSource ( Compilation compilation , List < MemberAccessExpressionSyntax > memberAccessExpressions , CodeBuilder builder , CancellationToken cancellationToken )
128126 {
127+ var typeSymbolCache = new TypeSymbolCache ( compilation ) ;
128+
129129 _ = builder
130130 . AppendLine ( "#nullable enable" )
131131 . AppendLine ( )
@@ -144,76 +144,13 @@ internal static void GenerateSource(Compilation compilation, List<MemberAccessEx
144144
145145 _ = expressionSyntax . Name . ToString ( ) switch
146146 {
147- "AsValueEnumerable" => HandleAsValueEnumerable ( compilation , expressionSyntax , builder , cancellationToken ) ,
147+ "AsValueEnumerable" => HandleAsValueEnumerable ( compilation , typeSymbolCache , expressionSyntax , builder , cancellationToken ) ,
148148 _ => HandleMethod ( compilation , expressionSyntax , builder , cancellationToken ) ,
149149 } ;
150150 }
151151 }
152152 }
153153
154- static bool HandleAsValueEnumerable ( Compilation compilation , MemberAccessExpressionSyntax expressionSyntax , CodeBuilder builder , CancellationToken cancellationToken )
155- {
156- // Get the type this operator is applied to
157- var semanticModel = compilation . GetSemanticModel ( expressionSyntax . SyntaxTree ) ;
158- var receiverTypeSymbol = semanticModel . GetTypeInfo ( expressionSyntax . Expression ) . Type ;
159-
160- // Check if NetFabric.Hyperlinq already contains specific overloads for this type
161- if ( receiverTypeSymbol is null
162- or { TypeKind : TypeKind . Array } // is array
163- || SymbolEqualityComparer . Default . Equals ( receiverTypeSymbol . OriginalDefinition , compilation . GetTypeByMetadataName ( typeof ( ArraySegment < > ) . FullName ) )
164- || SymbolEqualityComparer . Default . Equals ( receiverTypeSymbol . OriginalDefinition , compilation . GetTypeByMetadataName ( typeof ( Span < > ) . FullName ) )
165- || SymbolEqualityComparer . Default . Equals ( receiverTypeSymbol . OriginalDefinition , compilation . GetTypeByMetadataName ( typeof ( ReadOnlySpan < > ) . FullName ) )
166- || SymbolEqualityComparer . Default . Equals ( receiverTypeSymbol . OriginalDefinition , compilation . GetTypeByMetadataName ( typeof ( Memory < > ) . FullName ) )
167- || SymbolEqualityComparer . Default . Equals ( receiverTypeSymbol . OriginalDefinition , compilation . GetTypeByMetadataName ( typeof ( ReadOnlyMemory < > ) . FullName ) )
168- )
169- return false ; // no need to generate an implementation
170-
171- // Generate the method source depending on receiver type characteristics
172- var receiverTypeString = receiverTypeSymbol . ToDisplayString ( ) ;
173-
174- // Receiver type implements IValueEnumerable<,>
175-
176- var valueEnumerableInterface = compilation . GetTypeByMetadataName ( "NetFabric.Hyperlinq.IValueEnumerable`2" ) ! ;
177- if ( receiverTypeSymbol . ImplementsInterface ( valueEnumerableInterface , out var _ ) )
178- {
179- // Receiver instance returns itself
180- _ = builder
181- . AppendLine ( )
182- . AppendLine ( "[MethodImpl(MethodImplOptions.AggressiveInlining)]" )
183- . AppendLine ( $ "public static { receiverTypeString } AsValueEnumerable(this { receiverTypeString } source) => source;") ;
184-
185- return true ;
186- }
187-
188-
189- // Receiver type is an enumerable that is none of the above
190-
191- if ( receiverTypeSymbol . IsEnumerable ( compilation , out var enumerableSymbols ) )
192- {
193- var wrapperName = receiverTypeString + "__ValueEnumerable" ;
194- _ = builder
195- . AppendLine ( )
196- . AppendLine ( "[MethodImpl(MethodImplOptions.AggressiveInlining)]" )
197- . AppendLine ( $ "public static { wrapperName } AsValueEnumerable(this { receiverTypeString } source) => new(source);")
198- . AppendLine ( ) ;
199-
200- var enumeratorType = enumerableSymbols . GetEnumerator . ReturnType ;
201- var itemType = enumerableSymbols . EnumeratorSymbols . Current . Type ;
202- using ( builder . AppendBlock ( $ "public readonly struct { wrapperName } : IValueEnumerable<{ itemType . ToDisplayString ( ) } , { enumeratorType . ToDisplayString ( ) } >") )
203- {
204- _ = builder
205- . AppendLine ( $ "readonly { receiverTypeString } source;") // source field
206- . AppendLine ( )
207- . AppendLine ( $ "public { wrapperName } ({ receiverTypeString } source) => this.source = source;") ; // constructor
208-
209- ImplementValueEnumerable ( enumerableSymbols , builder ) ;
210- return true ;
211- }
212- }
213-
214- return false ;
215- }
216-
217154 static bool HandleMethod ( Compilation compilation , MemberAccessExpressionSyntax expressionSyntax , CodeBuilder builder , CancellationToken cancellationToken )
218155 {
219156 var semanticModel = compilation . GetSemanticModel ( expressionSyntax . SyntaxTree ) ;
@@ -222,22 +159,5 @@ static bool HandleMethod(Compilation compilation, MemberAccessExpressionSyntax e
222159
223160 return false ;
224161 }
225-
226- static void ImplementValueEnumerable ( EnumerableSymbols enumerableSymbols , CodeBuilder builder )
227- {
228- var enumeratorType = enumerableSymbols . GetEnumerator . ReturnType ;
229- var itemType = enumerableSymbols . EnumeratorSymbols . Current . Type ;
230-
231- // Enumerator is value type and implements IEnumerator<>
232- // It can be returned directly by GetEnumerator()
233- if ( enumeratorType . IsValueType && enumeratorType . ImplementsInterface ( SpecialType . System_Collections_Generic_IEnumerable_T , out var _ ) )
234- {
235- _ = builder
236- . AppendLine ( )
237- . AppendLine ( $ "public { enumeratorType . ToDisplayString ( ) } GetEnumerator() => source.GetEnumerator();")
238- . AppendLine ( $ "IEnumerator<int> IEnumerable<int>.GetEnumerator() => source.GetEnumerator();")
239- . AppendLine ( $ "IEnumerator IEnumerable.GetEnumerator() => source.GetEnumerator();") ;
240- }
241- }
242162 }
243163}
0 commit comments