|
5 | 5 | using System.Linq; |
6 | 6 | using Microsoft.CodeAnalysis; |
7 | 7 | using Semmle.Extraction.CSharp.Entities; |
8 | | -using Semmle.Extraction.CSharp.Entities.Statements; |
9 | 8 |
|
10 | 9 | namespace Semmle.Extraction.CSharp |
11 | 10 | { |
@@ -647,14 +646,36 @@ public static bool TryGetExtensionMethod(this IMethodSymbol method, out IMethodS |
647 | 646 | var extensions = containingType.GetMembers() |
648 | 647 | .OfType<INamedTypeSymbol>() |
649 | 648 | .Where(t => t.IsExtension); |
650 | | - // Find the original extension method that maps to this implementation (if any). |
651 | | - declaration = extensions.SelectMany(e => e.GetMembers()) |
| 649 | + // Find the (possibly unbound) original extension method that maps to this implementation (if any). |
| 650 | + var unboundDeclaration = extensions.SelectMany(e => e.GetMembers()) |
652 | 651 | .OfType<IMethodSymbol>() |
653 | | - .FirstOrDefault(m => SymbolEqualityComparer.Default.Equals(m.AssociatedExtensionImplementation, method)); |
654 | | - return declaration is not null; |
| 652 | + .FirstOrDefault(m => SymbolEqualityComparer.Default.Equals(m.AssociatedExtensionImplementation, method.ConstructedFrom)); |
| 653 | + |
| 654 | + var isFullyConstructed = method.IsBoundGenericMethod(); |
| 655 | + if (isFullyConstructed && unboundDeclaration?.ContainingType is INamedTypeSymbol extensionType && extensionType.IsGenericType) |
| 656 | + { |
| 657 | + // Use the type arguments from the constructed extension method to construct the extension type. |
| 658 | + var arguments = method.TypeArguments.ToArray(); |
| 659 | + var boundExtensionType = extensionType.Construct(arguments); |
| 660 | + var boundDeclaration = boundExtensionType.GetMembers() |
| 661 | + .OfType<IMethodSymbol>() |
| 662 | + .FirstOrDefault(c => SymbolEqualityComparer.Default.Equals(c.OriginalDefinition, unboundDeclaration)); |
| 663 | + declaration = boundDeclaration; |
| 664 | + } |
| 665 | + else |
| 666 | + { |
| 667 | + declaration = unboundDeclaration; |
| 668 | + } |
| 669 | + |
655 | 670 | } |
656 | | - return false; |
| 671 | + return declaration is not null; |
657 | 672 | } |
| 673 | + |
| 674 | + public static bool IsUnboundGenericMethod(this IMethodSymbol method) => |
| 675 | + method.IsGenericMethod && SymbolEqualityComparer.Default.Equals(method.ConstructedFrom, method); |
| 676 | + |
| 677 | + public static bool IsBoundGenericMethod(this IMethodSymbol method) => method.IsGenericMethod && !IsUnboundGenericMethod(method); |
| 678 | + |
658 | 679 | /// <summary> |
659 | 680 | /// Gets the base type of `symbol`. Unlike `symbol.BaseType`, this excludes effective base |
660 | 681 | /// types of type parameters as well as `object` base types. |
|
0 commit comments