Skip to content

Commit 81814ca

Browse files
committed
C#: Support replacing invocations of generic methods in generic extensions.
1 parent 7f32d50 commit 81814ca

File tree

2 files changed

+45
-7
lines changed

2 files changed

+45
-7
lines changed

csharp/extractor/Semmle.Extraction.CSharp/CodeAnalysisExtensions/SymbolExtensions.cs

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.IO;
55
using System.Linq;
66
using Microsoft.CodeAnalysis;
7+
using Semmle.Util;
78
using Semmle.Extraction.CSharp.Entities;
89

910
namespace Semmle.Extraction.CSharp
@@ -652,15 +653,33 @@ public static bool TryGetExtensionMethod(this IMethodSymbol method, out IMethodS
652653
.FirstOrDefault(m => SymbolEqualityComparer.Default.Equals(m.AssociatedExtensionImplementation, method.ConstructedFrom));
653654

654655
var isFullyConstructed = method.IsBoundGenericMethod();
656+
// TODO: We also need to handle generic methods in non-generic extension types.
655657
if (isFullyConstructed && unboundDeclaration?.ContainingType is INamedTypeSymbol extensionType && extensionType.IsGenericType)
656658
{
657-
// Use the type arguments from the constructed extension method to construct the extension type.
658-
var arguments = method.TypeArguments.ToArray();
659-
var boundExtensionType = extensionType.Construct(arguments);
660-
var boundDeclaration = boundExtensionType.GetMembers()
661-
.OfType<IMethodSymbol>()
662-
.FirstOrDefault(c => SymbolEqualityComparer.Default.Equals(c.OriginalDefinition, unboundDeclaration));
663-
declaration = boundDeclaration;
659+
try
660+
{
661+
// Use the type arguments from the constructed extension method to construct the extension type.
662+
var arguments = method.TypeArguments.ToArray();
663+
var (extensionTypeArguments, extensionMethodArguments) = arguments.SplitAt(extensionType.TypeParameters.Length);
664+
665+
// Construct the extension type.
666+
var boundExtensionType = extensionType.Construct(extensionTypeArguments.ToArray());
667+
668+
// Find the extension method declaration within the constructed extension type.
669+
var extensionDeclaration = boundExtensionType.GetMembers()
670+
.OfType<IMethodSymbol>()
671+
.First(c => SymbolEqualityComparer.Default.Equals(c.OriginalDefinition, unboundDeclaration));
672+
673+
// If the extension declaration is unbound apply the remaning type arguments and construct it.
674+
declaration = extensionDeclaration.IsUnboundGenericMethod()
675+
? extensionDeclaration.Construct(extensionMethodArguments.ToArray())
676+
: extensionDeclaration;
677+
}
678+
catch (Exception)
679+
{
680+
// If anything goes wrong, fall back to the unbound declaration.
681+
declaration = unboundDeclaration;
682+
}
664683
}
665684
else
666685
{

csharp/extractor/Semmle.Util/IEnumerableExtensions.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,5 +119,24 @@ public static int SequenceHash<T>(this IEnumerable<T> items) where T : notnull
119119
/// </summary>
120120
public static IEnumerable<T> WhereNotNull<T>(this IEnumerable<T?> items) where T : class =>
121121
items.Where(i => i is not null)!;
122+
123+
/// <summary>
124+
/// Splits the sequence at the given index.
125+
/// </summary>
126+
public static (IEnumerable<T>, IEnumerable<T>) SplitAt<T>(this IEnumerable<T> items, int index)
127+
{
128+
var left = new List<T>();
129+
var right = new List<T>();
130+
var i = 0;
131+
foreach (var item in items)
132+
{
133+
if (i < index)
134+
left.Add(item);
135+
else
136+
right.Add(item);
137+
i++;
138+
}
139+
return (left, right);
140+
}
122141
}
123142
}

0 commit comments

Comments
 (0)