|
4 | 4 | using System.IO; |
5 | 5 | using System.Linq; |
6 | 6 | using Microsoft.CodeAnalysis; |
| 7 | +using Semmle.Util; |
7 | 8 | using Semmle.Extraction.CSharp.Entities; |
8 | 9 |
|
9 | 10 | namespace Semmle.Extraction.CSharp |
@@ -652,15 +653,33 @@ public static bool TryGetExtensionMethod(this IMethodSymbol method, out IMethodS |
652 | 653 | .FirstOrDefault(m => SymbolEqualityComparer.Default.Equals(m.AssociatedExtensionImplementation, method.ConstructedFrom)); |
653 | 654 |
|
654 | 655 | var isFullyConstructed = method.IsBoundGenericMethod(); |
| 656 | + // TODO: We also need to handle generic methods in non-generic extension types. |
655 | 657 | if (isFullyConstructed && unboundDeclaration?.ContainingType is INamedTypeSymbol extensionType && extensionType.IsGenericType) |
656 | 658 | { |
657 | | - // Use the type arguments from the constructed extension method to construct the extension type. |
658 | | - var arguments = method.TypeArguments.ToArray(); |
659 | | - var boundExtensionType = extensionType.Construct(arguments); |
660 | | - var boundDeclaration = boundExtensionType.GetMembers() |
661 | | - .OfType<IMethodSymbol>() |
662 | | - .FirstOrDefault(c => SymbolEqualityComparer.Default.Equals(c.OriginalDefinition, unboundDeclaration)); |
663 | | - declaration = boundDeclaration; |
| 659 | + try |
| 660 | + { |
| 661 | + // Use the type arguments from the constructed extension method to construct the extension type. |
| 662 | + var arguments = method.TypeArguments.ToArray(); |
| 663 | + var (extensionTypeArguments, extensionMethodArguments) = arguments.SplitAt(extensionType.TypeParameters.Length); |
| 664 | + |
| 665 | + // Construct the extension type. |
| 666 | + var boundExtensionType = extensionType.Construct(extensionTypeArguments.ToArray()); |
| 667 | + |
| 668 | + // Find the extension method declaration within the constructed extension type. |
| 669 | + var extensionDeclaration = boundExtensionType.GetMembers() |
| 670 | + .OfType<IMethodSymbol>() |
| 671 | + .First(c => SymbolEqualityComparer.Default.Equals(c.OriginalDefinition, unboundDeclaration)); |
| 672 | + |
| 673 | + // If the extension declaration is unbound apply the remaning type arguments and construct it. |
| 674 | + declaration = extensionDeclaration.IsUnboundGenericMethod() |
| 675 | + ? extensionDeclaration.Construct(extensionMethodArguments.ToArray()) |
| 676 | + : extensionDeclaration; |
| 677 | + } |
| 678 | + catch (Exception) |
| 679 | + { |
| 680 | + // If anything goes wrong, fall back to the unbound declaration. |
| 681 | + declaration = unboundDeclaration; |
| 682 | + } |
664 | 683 | } |
665 | 684 | else |
666 | 685 | { |
|
0 commit comments