Skip to content

Commit 66bb700

Browse files
committed
Fixes
1 parent 463ce41 commit 66bb700

File tree

12 files changed

+210
-100
lines changed

12 files changed

+210
-100
lines changed

.github/workflows/dotnetcore.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ on: [push]
44

55
jobs:
66
windows:
7-
runs-on: windows-latest
7+
runs-on: ubuntu-latest
88

99
steps:
1010
- name: Check out repository
1111
uses: actions/checkout@v2
1212
- name: Setup .NET Core
13-
uses: actions/setup-dotnet@v1.7.2
13+
uses: actions/setup-dotnet@v1
1414
with:
15-
dotnet-version: '6.0.100-preview.2.21155.3'
15+
dotnet-version: 6.0.x
1616
- name: Test source generator
1717
run: dotnet test ./NetFabric.Hyperlinq.SourceGenerator.UnitTests/NetFabric.Hyperlinq.SourceGenerator.UnitTests.csproj
1818
- name: Build solution

NetFabric.Hyperlinq.SourceGenerator.UnitTests/GenerateSourceTests.cs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Microsoft.CodeAnalysis;
22
using Microsoft.CodeAnalysis.CSharp.Syntax;
33
using NetFabric.Assertive;
4+
using System;
45
using System.Collections.Generic;
56
using System.IO;
67
using System.Linq;
@@ -51,19 +52,24 @@ public async Task GenerateSourceShouldGenerate(string[] paths, string expected)
5152
{
5253
// Arrange
5354
var sources = paths
54-
//.Concat(Directory.EnumerateFiles("TestData/Source/Common", "*.cs", SearchOption.AllDirectories))
55+
.Concat(Directory.EnumerateFiles("TestData/Source/Common", "*.cs", SearchOption.AllDirectories))
5556
.Select(path => File.ReadAllText(path));
5657
var project = Verifier.CreateProject(sources);
5758
var compilation = await project.GetCompilationAsync().ConfigureAwait(false)
5859
?? throw new System.Exception("Error getting compilation!");
60+
var errors = compilation
61+
.GetDiagnostics()
62+
.Where(diagnostic => diagnostic.Severity == DiagnosticSeverity.Error)
63+
.ToArray();
64+
_ = errors.Must().BeEqualTo(Array.Empty<Diagnostic>());
5965
var memberAccessExpressions = compilation.SyntaxTrees
6066
.SelectMany(tree => tree.GetRoot().DescendantNodes().OfType<MemberAccessExpressionSyntax>())
61-
.Where(memberAccess => ((ICollection<string>)OverloadsGenerator.methods).Contains(memberAccess.Name.Identifier.ValueText))
67+
.Where(memberAccess => Generator.methods.Contains(memberAccess.Name.Identifier.ValueText))
6268
.ToList();
6369

6470
// Act
6571
var builder = new CodeBuilder();
66-
OverloadsGenerator.GenerateSource(compilation, memberAccessExpressions, builder, CancellationToken.None);
72+
Generator.GenerateSource(compilation, memberAccessExpressions, builder, CancellationToken.None);
6773
var result = builder.ToString();
6874

6975
// Assert

NetFabric.Hyperlinq.SourceGenerator.UnitTests/NetFabric.Hyperlinq.SourceGenerator.UnitTests.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
<ItemGroup>
1414
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Workspaces" Version="3.10.0" />
1515
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.10.0" />
16-
<PackageReference Include="NetFabric.Assertive" Version="3.0.1" />
16+
<PackageReference Include="NetFabric.Assertive" Version="4.0.0" />
1717
<PackageReference Include="xunit" Version="2.4.1" />
1818
<PackageReference Include="xunit.runner.visualstudio" Version="2.4.3">
1919
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>

NetFabric.Hyperlinq.SourceGenerator.UnitTests/TestData/Source/Common/Enumerables.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public void Dispose()
5353
}
5454
}
5555

56-
class TestEnumerableWithValueTypeEnumerator
56+
public class TestEnumerableWithValueTypeEnumerator
5757
: IEnumerable<int>
5858
{
5959
public Enumerator GetEnumerator()

NetFabric.Hyperlinq.SourceGenerator.UnitTests/Verifier.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
using Microsoft.CodeAnalysis.Text;
44
using System;
55
using System.Buffers;
6+
using System.Collections;
67
using System.Collections.Generic;
78
using System.Collections.Immutable;
89
using System.IO;
910
using System.Linq;
10-
using System.Runtime.CompilerServices;
11+
using System.Reflection;
12+
using System.Threading.Tasks;
1113

1214
namespace NetFabric.Hyperlinq.SourceGenerator.UnitTests
1315
{
@@ -35,7 +37,8 @@ public static Project CreateProject(IEnumerable<string> sources)
3537
var solution = new AdhocWorkspace()
3638
.CurrentSolution
3739
.AddProject(projectId, testProjectName, testProjectName, LanguageNames.CSharp)
38-
.AddMetadataReferences(projectId, references);
40+
.AddMetadataReferences(projectId, references)
41+
.WithProjectCompilationOptions(projectId, new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
3942

4043
var count = 0;
4144
foreach (var source in sources)
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
using Microsoft.CodeAnalysis;
2+
using Microsoft.CodeAnalysis.CSharp;
3+
using Microsoft.CodeAnalysis.CSharp.Syntax;
4+
using NetFabric.CodeAnalysis;
5+
using System;
6+
using System.Threading;
7+
8+
namespace NetFabric.Hyperlinq.SourceGenerator
9+
{
10+
public partial class Generator
11+
{
12+
static bool HandleAsValueEnumerable(Compilation compilation, TypeSymbolCache typeSymbolCache, MemberAccessExpressionSyntax expressionSyntax, CodeBuilder builder, CancellationToken cancellationToken)
13+
{
14+
// Get the type this operator is applied to
15+
var semanticModel = compilation.GetSemanticModel(expressionSyntax.SyntaxTree);
16+
var receiverTypeSymbol = semanticModel.GetTypeInfo(expressionSyntax.Expression).Type;
17+
18+
// Check if NetFabric.Hyperlinq already contains specific overloads for this type
19+
if (receiverTypeSymbol is null
20+
or { TypeKind: TypeKind.Array } // is array
21+
|| SymbolEqualityComparer.Default.Equals(receiverTypeSymbol.OriginalDefinition, typeSymbolCache.GetTypeSymbol(typeof(ArraySegment<>)))
22+
|| SymbolEqualityComparer.Default.Equals(receiverTypeSymbol.OriginalDefinition, typeSymbolCache.GetTypeSymbol(typeof(Span<>)))
23+
|| SymbolEqualityComparer.Default.Equals(receiverTypeSymbol.OriginalDefinition, typeSymbolCache.GetTypeSymbol(typeof(ReadOnlySpan<>)))
24+
|| SymbolEqualityComparer.Default.Equals(receiverTypeSymbol.OriginalDefinition, typeSymbolCache.GetTypeSymbol(typeof(Memory<>)))
25+
|| SymbolEqualityComparer.Default.Equals(receiverTypeSymbol.OriginalDefinition, typeSymbolCache.GetTypeSymbol(typeof(ReadOnlyMemory<>)))
26+
)
27+
return false; // no need to generate an implementation
28+
29+
// Generate the method source depending on receiver type characteristics
30+
var receiverTypeString = receiverTypeSymbol.ToDisplayString();
31+
32+
// Receiver type implements IValueEnumerable<,>
33+
34+
var valueEnumerableInterface = typeSymbolCache.GetTypeSymbol("NetFabric.Hyperlinq.IValueEnumerable`2")!;
35+
if (receiverTypeSymbol.ImplementsInterface(valueEnumerableInterface, out var _))
36+
{
37+
// Receiver instance returns itself
38+
_ = builder
39+
.AppendLine()
40+
.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]")
41+
.AppendLine($"public static {receiverTypeString} AsValueEnumerable(this {receiverTypeString} source) => source;");
42+
43+
return true;
44+
}
45+
46+
47+
// Receiver type is an enumerable that is none of the above
48+
49+
if (receiverTypeSymbol.IsEnumerable(compilation, out var enumerableSymbols))
50+
{
51+
var enumerableTypeString = $"{receiverTypeString}__ValueEnumerable";
52+
var enumeratorTypeString = "Enumerator";
53+
var itemType = enumerableSymbols.EnumeratorSymbols.Current.Type;
54+
var itemTypeString = itemType.ToDisplayString();
55+
56+
// Check if the returned type by GetEnumerator() does not require a wrapper
57+
var useGetEnumeratorReturnType = false;
58+
var getEnumeratorReturnType = enumerableSymbols.GetEnumerator.ReturnType;
59+
var enumeratorImplementsInterface = getEnumeratorReturnType.ImplementsInterface(SpecialType.System_Collections_Generic_IEnumerator_T, out var _);
60+
if (enumeratorImplementsInterface)
61+
{
62+
if (getEnumeratorReturnType.IsValueType)
63+
{
64+
// Enumerator is value type and implements IEnumerator<>
65+
useGetEnumeratorReturnType = true;
66+
enumeratorTypeString = getEnumeratorReturnType.ToDisplayString();
67+
}
68+
else
69+
{
70+
enumeratorTypeString = $"ValueEnumerator<{itemTypeString}>";
71+
}
72+
}
73+
74+
_ = builder
75+
.AppendLine()
76+
.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]")
77+
.AppendLine($"public static {enumerableTypeString} AsValueEnumerable(this {receiverTypeString} source) => new(source);")
78+
.AppendLine();
79+
80+
using (builder.AppendBlock($"public readonly struct {enumerableTypeString} : IValueEnumerable<{itemType.ToDisplayString()}, {enumeratorTypeString}>"))
81+
{
82+
_ = builder
83+
.AppendLine($"readonly {receiverTypeString} source;") // source field
84+
.AppendLine()
85+
.AppendLine($"public {enumerableTypeString}({receiverTypeString} source) => this.source = source;"); // constructor
86+
87+
if (useGetEnumeratorReturnType)
88+
{
89+
// No wrapper required
90+
_ = builder
91+
.AppendLine()
92+
.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]")
93+
.AppendLine($"public {enumeratorTypeString} GetEnumerator() => source.GetEnumerator();")
94+
.AppendLine()
95+
.AppendLine($"IEnumerator<{itemTypeString}> IEnumerable<{itemTypeString}>.GetEnumerator() => source.GetEnumerator();")
96+
.AppendLine()
97+
.AppendLine($"IEnumerator IEnumerable.GetEnumerator() => source.GetEnumerator();");
98+
}
99+
else if(enumeratorImplementsInterface)
100+
{
101+
// Use the ValueEnumerator<> wrapper
102+
_ = builder
103+
.AppendLine()
104+
.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]")
105+
.AppendLine($"public {enumeratorTypeString} GetEnumerator() => new(source.GetEnumerator());")
106+
.AppendLine()
107+
.AppendLine($"IEnumerator<{itemTypeString}> IEnumerable<{itemTypeString}>.GetEnumerator() => source.GetEnumerator();")
108+
.AppendLine()
109+
.AppendLine($"IEnumerator IEnumerable.GetEnumerator() => source.GetEnumerator();");
110+
}
111+
else
112+
{
113+
// A custom wrapper is required
114+
_ = builder
115+
.AppendLine()
116+
.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]")
117+
.AppendLine($"public {enumeratorTypeString} GetEnumerator() => new(this);")
118+
.AppendLine()
119+
.AppendLine($"IEnumerator<{itemTypeString}> IEnumerable<{itemTypeString}>.GetEnumerator() => new {enumeratorTypeString}(this);")
120+
.AppendLine()
121+
.AppendLine($"IEnumerator IEnumerable.GetEnumerator() => new {enumeratorTypeString}(this);")
122+
.AppendLine();
123+
124+
using (builder.AppendBlock($"public struct {enumeratorTypeString}"))
125+
{
126+
127+
}
128+
}
129+
130+
return true;
131+
}
132+
}
133+
134+
return false;
135+
}
136+
}
137+
}

NetFabric.Hyperlinq.SourceGenerator/OverloadsGenerator.cs renamed to NetFabric.Hyperlinq.SourceGenerator/Generator.cs

Lines changed: 6 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,17 @@
22
using Microsoft.CodeAnalysis.CSharp;
33
using Microsoft.CodeAnalysis.CSharp.Syntax;
44
using Microsoft.CodeAnalysis.Text;
5-
using NetFabric.CodeAnalysis;
65
using System;
76
using System.Collections.Generic;
87
using System.Collections.Immutable;
98
using System.Diagnostics;
10-
using System.Linq;
119
using System.Text;
1210
using System.Threading;
1311

1412
namespace NetFabric.Hyperlinq.SourceGenerator
1513
{
1614
[Generator]
17-
public class OverloadsGenerator
15+
public partial class Generator
1816
: ISourceGenerator
1917
{
2018
static readonly DiagnosticDescriptor unhandledExceptionError = new(
@@ -25,7 +23,7 @@ public class OverloadsGenerator
2523
DiagnosticSeverity.Error,
2624
isEnabledByDefault: true);
2725

28-
internal static readonly string[] methods = new[]
26+
internal static readonly ImmutableArray<string> methods = ImmutableArray.Create(new[]
2927
{
3028
// aggregation
3129
"Count",
@@ -85,7 +83,7 @@ public class OverloadsGenerator
8583
// quantifier
8684
"Distinct",
8785
"DistinctAsync",
88-
};
86+
});
8987

9088
public void Initialize(GeneratorInitializationContext context)
9189
{
@@ -126,6 +124,8 @@ public void Execute(GeneratorExecutionContext context)
126124

127125
internal static void GenerateSource(Compilation compilation, List<MemberAccessExpressionSyntax> memberAccessExpressions, CodeBuilder builder, CancellationToken cancellationToken)
128126
{
127+
var typeSymbolCache = new TypeSymbolCache(compilation);
128+
129129
_ = builder
130130
.AppendLine("#nullable enable")
131131
.AppendLine()
@@ -144,76 +144,13 @@ internal static void GenerateSource(Compilation compilation, List<MemberAccessEx
144144

145145
_ = expressionSyntax.Name.ToString() switch
146146
{
147-
"AsValueEnumerable" => HandleAsValueEnumerable(compilation, expressionSyntax, builder, cancellationToken),
147+
"AsValueEnumerable" => HandleAsValueEnumerable(compilation, typeSymbolCache, expressionSyntax, builder, cancellationToken),
148148
_ => HandleMethod(compilation, expressionSyntax, builder, cancellationToken),
149149
};
150150
}
151151
}
152152
}
153153

154-
static bool HandleAsValueEnumerable(Compilation compilation, MemberAccessExpressionSyntax expressionSyntax, CodeBuilder builder, CancellationToken cancellationToken)
155-
{
156-
// Get the type this operator is applied to
157-
var semanticModel = compilation.GetSemanticModel(expressionSyntax.SyntaxTree);
158-
var receiverTypeSymbol = semanticModel.GetTypeInfo(expressionSyntax.Expression).Type;
159-
160-
// Check if NetFabric.Hyperlinq already contains specific overloads for this type
161-
if (receiverTypeSymbol is null
162-
or { TypeKind: TypeKind.Array } // is array
163-
|| SymbolEqualityComparer.Default.Equals(receiverTypeSymbol.OriginalDefinition, compilation.GetTypeByMetadataName(typeof(ArraySegment<>).FullName))
164-
|| SymbolEqualityComparer.Default.Equals(receiverTypeSymbol.OriginalDefinition, compilation.GetTypeByMetadataName(typeof(Span<>).FullName))
165-
|| SymbolEqualityComparer.Default.Equals(receiverTypeSymbol.OriginalDefinition, compilation.GetTypeByMetadataName(typeof(ReadOnlySpan<>).FullName))
166-
|| SymbolEqualityComparer.Default.Equals(receiverTypeSymbol.OriginalDefinition, compilation.GetTypeByMetadataName(typeof(Memory<>).FullName))
167-
|| SymbolEqualityComparer.Default.Equals(receiverTypeSymbol.OriginalDefinition, compilation.GetTypeByMetadataName(typeof(ReadOnlyMemory<>).FullName))
168-
)
169-
return false; // no need to generate an implementation
170-
171-
// Generate the method source depending on receiver type characteristics
172-
var receiverTypeString = receiverTypeSymbol.ToDisplayString();
173-
174-
// Receiver type implements IValueEnumerable<,>
175-
176-
var valueEnumerableInterface = compilation.GetTypeByMetadataName("NetFabric.Hyperlinq.IValueEnumerable`2")!;
177-
if (receiverTypeSymbol.ImplementsInterface(valueEnumerableInterface, out var _))
178-
{
179-
// Receiver instance returns itself
180-
_ = builder
181-
.AppendLine()
182-
.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]")
183-
.AppendLine($"public static {receiverTypeString} AsValueEnumerable(this {receiverTypeString} source) => source;");
184-
185-
return true;
186-
}
187-
188-
189-
// Receiver type is an enumerable that is none of the above
190-
191-
if (receiverTypeSymbol.IsEnumerable(compilation, out var enumerableSymbols))
192-
{
193-
var wrapperName = receiverTypeString + "__ValueEnumerable";
194-
_ = builder
195-
.AppendLine()
196-
.AppendLine("[MethodImpl(MethodImplOptions.AggressiveInlining)]")
197-
.AppendLine($"public static {wrapperName} AsValueEnumerable(this {receiverTypeString} source) => new(source);")
198-
.AppendLine();
199-
200-
var enumeratorType = enumerableSymbols.GetEnumerator.ReturnType;
201-
var itemType = enumerableSymbols.EnumeratorSymbols.Current.Type;
202-
using (builder.AppendBlock($"public readonly struct {wrapperName} : IValueEnumerable<{itemType.ToDisplayString()}, {enumeratorType.ToDisplayString()}>"))
203-
{
204-
_ = builder
205-
.AppendLine($"readonly {receiverTypeString} source;") // source field
206-
.AppendLine()
207-
.AppendLine($"public {wrapperName}({receiverTypeString} source) => this.source = source;"); // constructor
208-
209-
ImplementValueEnumerable(enumerableSymbols, builder);
210-
return true;
211-
}
212-
}
213-
214-
return false;
215-
}
216-
217154
static bool HandleMethod(Compilation compilation, MemberAccessExpressionSyntax expressionSyntax, CodeBuilder builder, CancellationToken cancellationToken)
218155
{
219156
var semanticModel = compilation.GetSemanticModel(expressionSyntax.SyntaxTree);
@@ -222,22 +159,5 @@ static bool HandleMethod(Compilation compilation, MemberAccessExpressionSyntax e
222159

223160
return false;
224161
}
225-
226-
static void ImplementValueEnumerable(EnumerableSymbols enumerableSymbols, CodeBuilder builder)
227-
{
228-
var enumeratorType = enumerableSymbols.GetEnumerator.ReturnType;
229-
var itemType = enumerableSymbols.EnumeratorSymbols.Current.Type;
230-
231-
// Enumerator is value type and implements IEnumerator<>
232-
// It can be returned directly by GetEnumerator()
233-
if (enumeratorType.IsValueType && enumeratorType.ImplementsInterface(SpecialType.System_Collections_Generic_IEnumerable_T, out var _))
234-
{
235-
_ = builder
236-
.AppendLine()
237-
.AppendLine($"public {enumeratorType.ToDisplayString()} GetEnumerator() => source.GetEnumerator();")
238-
.AppendLine($"IEnumerator<int> IEnumerable<int>.GetEnumerator() => source.GetEnumerator();")
239-
.AppendLine($"IEnumerator IEnumerable.GetEnumerator() => source.GetEnumerator();");
240-
}
241-
}
242162
}
243163
}

0 commit comments

Comments
 (0)