Skip to content

Commit ee2ced2

Browse files
committed
Start generation of other operations
1 parent e83777a commit ee2ced2

File tree

10 files changed

+133
-67
lines changed

10 files changed

+133
-67
lines changed

NetFabric.Hyperlinq.SourceGenerator.UnitTests/GenerateSourceTests.cs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,18 @@ public static TheoryData<string[], string> GeneratorSources
7474
new[] { "TestData/Source/AsValueEnumerable.TestValueEnumerable.cs" },
7575
"TestData/Results/AsValueEnumerable.TestValueEnumerable.cs"
7676
},
77-
//{
78-
// new[] { "TestData/Source/Count.Array.cs" },
79-
// "TestData/Results/Count.Array.cs"
80-
//},
77+
{
78+
new[] { "TestData/Source/Count.Array.cs" },
79+
"TestData/Results/Count.Array.cs"
80+
},
8181
//{
8282
// new[] { "TestData/Source/Count.Span.cs" },
8383
// "TestData/Results/Count.Span.cs"
8484
//},
85+
{
86+
new[] { "TestData/Source/Count.TestEnumerableWithValueTypeEnumerator.cs" },
87+
"TestData/Results/Count.TestEnumerableWithValueTypeEnumerator.cs"
88+
},
8589
};
8690

8791
[Theory]
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#nullable enable
2+
3+
using System;
4+
using System.Collections;
5+
using System.Collections.Generic;
6+
using System.Runtime.CompilerServices;
7+
8+
namespace NetFabric.Hyperlinq
9+
{
10+
static partial class GeneratedExtensionMethods
11+
{
12+
}
13+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using System;
2+
using System.Linq;
3+
using NetFabric.Hyperlinq;
4+
5+
partial class TestsSource
6+
{
7+
static void Count_TestEnumerableWithValueTypeEnumerator()
8+
{
9+
_ = new TestEnumerableWithValueTypeEnumerator<TestValueType>()
10+
.AsValueEnumerable()
11+
.AsEnumerable();
12+
}
13+
}

NetFabric.Hyperlinq.SourceGenerator/Generator.AsValueEnumerable.cs

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,21 @@
44
using NetFabric.CodeAnalysis;
55
using System;
66
using System.Collections.Generic;
7+
using System.Collections.Immutable;
78
using System.Threading;
89

910
namespace NetFabric.Hyperlinq.SourceGenerator
1011
{
1112
public partial class Generator
1213
{
13-
static bool HandleAsValueEnumerable(Compilation compilation, TypeSymbolsCache typeSymbolsCache, MemberAccessExpressionSyntax expressionSyntax, CodeBuilder builder, CancellationToken cancellationToken, bool isUnitTest)
14+
static ValueEnumerableType? GenerateAsValueEnumerable(Compilation compilation, SemanticModel semanticModel, TypeSymbolsCache typeSymbolsCache, MemberAccessExpressionSyntax expressionSyntax, CodeBuilder builder, HashSet<MethodSignature> generatedMethods, CancellationToken cancellationToken, bool isUnitTest)
1415
{
15-
var semanticModel = compilation.GetSemanticModel(expressionSyntax.SyntaxTree);
16-
17-
// Check if an extension method is defined for this type
18-
if (semanticModel.GetSymbolInfo(expressionSyntax).Symbol is not null)
19-
return false;
16+
// Check if the method is already defined in the project source
17+
if (semanticModel.GetSymbolInfo(expressionSyntax, cancellationToken).Symbol is IMethodSymbol methodSymbol and not null)
18+
return new ValueEnumerableType(Name: methodSymbol.ReturnType.Name);
2019

2120
// Get the type this operator is applied to
22-
var receiverTypeSymbol = semanticModel.GetTypeInfo(expressionSyntax.Expression).Type;
21+
var receiverTypeSymbol = semanticModel.GetTypeInfo(expressionSyntax.Expression, cancellationToken).Type;
2322

2423
// Check if NetFabric.Hyperlinq already contains specific overloads for this type
2524
// This is required for when the 'using NetFabric.Hyperlinq;' statement is missing
@@ -31,8 +30,9 @@ static bool HandleAsValueEnumerable(Compilation compilation, TypeSymbolsCache ty
3130
|| SymbolEqualityComparer.Default.Equals(receiverTypeSymbol.OriginalDefinition, typeSymbolsCache[typeof(Memory<>)])
3231
|| SymbolEqualityComparer.Default.Equals(receiverTypeSymbol.OriginalDefinition, typeSymbolsCache[typeof(ReadOnlyMemory<>)])
3332
|| SymbolEqualityComparer.Default.Equals(receiverTypeSymbol.OriginalDefinition, typeSymbolsCache[typeof(List<>)])
33+
|| SymbolEqualityComparer.Default.Equals(receiverTypeSymbol.OriginalDefinition, typeSymbolsCache[typeof(ImmutableArray<>)])
3434
)
35-
return false; // no need to generate an implementation
35+
return null; // no need to generate an implementation
3636

3737
var receiverTypeString = receiverTypeSymbol.ToDisplayString();
3838

@@ -47,10 +47,9 @@ static bool HandleAsValueEnumerable(Compilation compilation, TypeSymbolsCache ty
4747
.AppendLine($"public static {receiverTypeString} AsValueEnumerable(this {receiverTypeString} source)")
4848
.AppendIdentation().AppendLine($"=> source;");
4949

50-
return true;
50+
return new ValueEnumerableType(Name: receiverTypeString);
5151
}
5252

53-
5453
// Receiver type is an enumerable
5554

5655
if (receiverTypeSymbol.IsEnumerable(compilation, out var enumerableSymbols))
@@ -86,23 +85,16 @@ static bool HandleAsValueEnumerable(Compilation compilation, TypeSymbolsCache ty
8685
}
8786

8887
// Define what value enumerator type will be used
89-
string enumeratorTypeString;
9088
var enumeratorImplementsIEnumerator = getEnumeratorReturnType.ImplementsInterface(SpecialType.System_Collections_Generic_IEnumerator_T, out var _);
9189
var useConstraints = enumerableImplementsIEnumerable;
92-
if (enumeratorImplementsIEnumerator)
93-
{
94-
if (getEnumeratorReturnType.IsValueType)
95-
enumeratorTypeString = getEnumeratorReturnTypeString;
96-
else
97-
enumeratorTypeString = $"ValueEnumerator<{itemTypeString}>";
98-
}
99-
else
100-
{
101-
if (useConstraints)
102-
enumeratorTypeString = $"{enumerableTypeString}<TEnumerable>.Enumerator";
103-
else
104-
enumeratorTypeString = $"{enumerableTypeString}.Enumerator";
105-
}
90+
91+
var enumeratorTypeString = enumeratorImplementsIEnumerator
92+
? getEnumeratorReturnType.IsValueType
93+
? getEnumeratorReturnTypeString
94+
: $"ValueEnumerator<{itemTypeString}>"
95+
: useConstraints
96+
? $"{enumerableTypeString}<TEnumerable>.Enumerator"
97+
: $"{enumerableTypeString}.Enumerator";
10698

10799
// Generate the method
108100
_ = useConstraints
@@ -118,13 +110,14 @@ static bool HandleAsValueEnumerable(Compilation compilation, TypeSymbolsCache ty
118110
.AppendIdentation().AppendLine($"=> new(source);");
119111

120112
// Generate the value enumerable wrapper
121-
_ = useConstraints
122-
? builder
123-
.AppendLine()
124-
.AppendLine($"public readonly struct {enumerableTypeString}<TEnumerable>")
125-
: builder
113+
string valueEnumerableTypeName;
114+
valueEnumerableTypeName = useConstraints
115+
? $"{enumerableTypeString}<TEnumerable>"
116+
: enumerableTypeString;
117+
118+
_ = builder
126119
.AppendLine()
127-
.AppendLine($"public readonly struct {enumerableTypeString}");
120+
.AppendLine($"public readonly struct {valueEnumerableTypeName}");
128121

129122
// Define what interfaces the wrapper implements
130123
if (enumerableImplementsIList || enumerableImplementsIReadOnlyList)
@@ -378,11 +371,13 @@ static bool HandleAsValueEnumerable(Compilation compilation, TypeSymbolsCache ty
378371
}
379372
}
380373

381-
return true;
374+
// A new AsValueEnumerable method has been generated
375+
_ = generatedMethods.Add(new MethodSignature("AsValueEnumerable", receiverTypeString));
376+
return new ValueEnumerableType(Name: valueEnumerableTypeName);
382377
}
383378
}
384379

385-
return false;
380+
return null;
386381
}
387382
}
388383
}

NetFabric.Hyperlinq.SourceGenerator/Generator.cs

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using System.Collections.Generic;
88
using System.Collections.Immutable;
99
using System.Diagnostics;
10+
using System.Linq;
1011
using System.Text;
1112
using System.Threading;
1213

@@ -24,7 +25,7 @@ public partial class Generator
2425
DiagnosticSeverity.Error,
2526
isEnabledByDefault: true);
2627

27-
internal static readonly ImmutableArray<string> methods = ImmutableArray.Create(new[]
28+
internal static readonly ImmutableHashSet<string> methods = ImmutableHashSet.Create(new[]
2829
{
2930
// aggregation
3031
"Count",
@@ -98,20 +99,20 @@ public void Initialize(GeneratorInitializationContext context)
9899

99100
public void Execute(GeneratorExecutionContext context)
100101
{
101-
var typeSymbolCache = new TypeSymbolsCache(context.Compilation);
102+
var typeSymbolsCache = new TypeSymbolsCache(context.Compilation);
102103

103104
// Check if NetFabric.Hyperlinq.Abstractions and NetFabric.Hyperlinq.Abstractions are referenced
104-
if (typeSymbolCache["NetFabric.Hyperlinq.IValueEnumerable`2"] is null
105-
|| typeSymbolCache["NetFabric.Hyperlinq.ValueEnumerableExtensions"] is null)
106-
return;
105+
if (typeSymbolsCache["NetFabric.Hyperlinq.IValueEnumerable`2"] is null
106+
|| typeSymbolsCache["NetFabric.Hyperlinq.ValueEnumerableExtensions"] is null)
107+
return; // TODO: return a Diagnostic?
107108

108109
if (context.SyntaxReceiver is not SyntaxReceiver receiver)
109110
return;
110111

111112
try
112113
{
113114
var builder = new CodeBuilder();
114-
GenerateSource(context.Compilation, typeSymbolCache, receiver.MemberAccessExpressions, builder, context.CancellationToken);
115+
GenerateSource(context.Compilation, typeSymbolsCache, receiver.MemberAccessExpressions, builder, context.CancellationToken);
115116
context.AddSource("ExtensionMethods.g.cs", SourceText.From(builder.ToString(), Encoding.UTF8));
116117
}
117118
catch (OperationCanceledException)
@@ -126,6 +127,8 @@ public void Execute(GeneratorExecutionContext context)
126127

127128
internal static void GenerateSource(Compilation compilation, TypeSymbolsCache typeSymbolsCache, List<MemberAccessExpressionSyntax> memberAccessExpressions, CodeBuilder builder, CancellationToken cancellationToken, bool isUnitTest = false)
128129
{
130+
var generatedMethods = new HashSet<MethodSignature>();
131+
129132
_ = builder
130133
.AppendLine("#nullable enable")
131134
.AppendLine()
@@ -140,24 +143,26 @@ internal static void GenerateSource(Compilation compilation, TypeSymbolsCache ty
140143
{
141144
foreach (var expressionSyntax in memberAccessExpressions)
142145
{
143-
cancellationToken.ThrowIfCancellationRequested();
146+
var semanticModel = compilation.GetSemanticModel(expressionSyntax.SyntaxTree);
144147

145-
_ = expressionSyntax.Name.ToString() switch
146-
{
147-
"AsValueEnumerable" => HandleAsValueEnumerable(compilation, typeSymbolsCache, expressionSyntax, builder, cancellationToken, isUnitTest),
148-
_ => HandleMethod(compilation, expressionSyntax, builder, cancellationToken, isUnitTest),
149-
};
148+
_ = GenerateSource(compilation, semanticModel, typeSymbolsCache, expressionSyntax, builder, generatedMethods, cancellationToken, isUnitTest);
150149
}
151150
}
152151
}
153152

154-
static bool HandleMethod(Compilation compilation, MemberAccessExpressionSyntax expressionSyntax, CodeBuilder builder, CancellationToken cancellationToken, bool isUnitTest)
153+
static ValueEnumerableType? GenerateSource(Compilation compilation, SemanticModel semanticModel, TypeSymbolsCache typeSymbolsCache, MemberAccessExpressionSyntax expressionSyntax, CodeBuilder builder, HashSet<MethodSignature> generatedMethods, CancellationToken cancellationToken, bool isUnitTest)
154+
=> expressionSyntax.Name.ToString() switch
155+
{
156+
"AsValueEnumerable" => GenerateAsValueEnumerable(compilation, semanticModel, typeSymbolsCache, expressionSyntax, builder, generatedMethods, cancellationToken, isUnitTest),
157+
_ => GenerateOperationSource(compilation, semanticModel, expressionSyntax, builder, generatedMethods, cancellationToken, isUnitTest),
158+
};
159+
160+
static ValueEnumerableType? GenerateOperationSource(Compilation compilation, SemanticModel semanticModel, MemberAccessExpressionSyntax expressionSyntax, CodeBuilder builder, HashSet<MethodSignature> generatedMethods, CancellationToken cancellationToken, bool isUnitTest)
155161
{
156-
var semanticModel = compilation.GetSemanticModel(expressionSyntax.SyntaxTree);
157-
if (semanticModel.GetSymbolInfo(expressionSyntax, cancellationToken).Symbol is not null) // method already exists
158-
return false;
162+
// Get the type this operator is applied to
163+
var receiverTypeSymbol = semanticModel.GetTypeInfo(expressionSyntax.Expression, cancellationToken).Type;
159164

160-
return false;
165+
return null;
161166
}
162167

163168
static bool IsValueEnumerable(ITypeSymbol symbol, TypeSymbolsCache typeSymbolsCache)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
using System;
2+
using System.Collections.Immutable;
3+
using System.Linq;
4+
5+
namespace NetFabric.Hyperlinq.SourceGenerator
6+
{
7+
readonly struct MethodSignature
8+
: IEquatable<MethodSignature>
9+
{
10+
public MethodSignature(string name, params string[] parameters)
11+
=> (Name, Parameters) = (name, ImmutableArray.Create(parameters));
12+
13+
readonly string Name { get; }
14+
readonly ImmutableArray<string> Parameters { get; }
15+
16+
public bool Equals(MethodSignature other)
17+
=> Name == other.Name && Parameters.SequenceEqual(other.Parameters);
18+
19+
public override bool Equals(object other)
20+
=> other is MethodSignature signature && Equals(signature);
21+
22+
public override int GetHashCode()
23+
{
24+
unchecked
25+
{
26+
const int HashingBase = (int)2166136261;
27+
const int HashingMultiplier = 16777619;
28+
29+
var hash = HashingBase;
30+
hash = (hash * HashingMultiplier) ^ Name.GetHashCode();
31+
foreach(var parameter in Parameters)
32+
hash = (hash * HashingMultiplier) ^ parameter.GetHashCode();
33+
return hash;
34+
}
35+
}
36+
}
37+
}

NetFabric.Hyperlinq.SourceGenerator/MethodsSet.cs

Lines changed: 0 additions & 11 deletions
This file was deleted.

NetFabric.Hyperlinq.SourceGenerator/NetFabric.Hyperlinq.SourceGenerator.csproj

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
<IncludeAssets>runtime; build; native; contentfiles; analyzers</IncludeAssets>
1616
</PackageReference>-->
1717
<PackageReference Include="Ben.TypeDictionary" Version="0.1.4" />
18+
<PackageReference Include="IsExternalInit" Version="1.0.1">
19+
<PrivateAssets>all</PrivateAssets>
20+
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
21+
</PackageReference>
1822
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="3.3.2" PrivateAssets="all" />
1923
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="3.10.0" />
2024
<PackageReference Include="NetFabric.CodeAnalysis" Version="4.0.2" />

NetFabric.Hyperlinq.SourceGenerator/SyntaxReceiver.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ namespace NetFabric.Hyperlinq.SourceGenerator
88
{
99
class SyntaxReceiver : ISyntaxReceiver
1010
{
11-
readonly ImmutableArray<string> memberAccessNames;
11+
readonly ImmutableHashSet<string> memberAccessNames;
1212

13-
public SyntaxReceiver(ImmutableArray<string> memberAccessNames)
13+
public SyntaxReceiver(ImmutableHashSet<string> memberAccessNames)
1414
=> this.memberAccessNames = memberAccessNames;
1515

1616
public List<MemberAccessExpressionSyntax> MemberAccessExpressions { get; } = new();
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
using System;
2+
3+
namespace NetFabric.Hyperlinq.SourceGenerator
4+
{
5+
record ValueEnumerableType(string Name);
6+
}

0 commit comments

Comments
 (0)