Skip to content

Commit e5ebce5

Browse files
committed
C#: Extract extension types and members. Replacing invocations to static generated methods with invocation of extension type member.
1 parent 5075de5 commit e5ebce5

File tree

8 files changed

+256
-20
lines changed

8 files changed

+256
-20
lines changed

csharp/extractor/Semmle.Extraction.CSharp/CodeAnalysisExtensions/SymbolExtensions.cs

Lines changed: 130 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.IO;
55
using System.Linq;
66
using Microsoft.CodeAnalysis;
7+
using Semmle.Util;
78
using Semmle.Extraction.CSharp.Entities;
89

910
namespace Semmle.Extraction.CSharp
@@ -164,6 +165,7 @@ public static void BuildTypeId(this ITypeSymbol type, Context cx, EscapingTextWr
164165
case TypeKind.Enum:
165166
case TypeKind.Delegate:
166167
case TypeKind.Error:
168+
case TypeKind.Extension:
167169
var named = (INamedTypeSymbol)type;
168170
named.BuildNamedTypeId(cx, trapFile, symbolBeingDefined, constructUnderlyingTupleType);
169171
return;
@@ -275,6 +277,20 @@ private static void BuildFunctionPointerTypeId(this IFunctionPointerTypeSymbol f
275277
public static IEnumerable<IFieldSymbol?> GetTupleElementsMaybeNull(this INamedTypeSymbol type) =>
276278
type.TupleElements;
277279

280+
private static void BuildExtensionTypeId(this INamedTypeSymbol named, Context cx, EscapingTextWriter trapFile)
281+
{
282+
trapFile.Write("extension(");
283+
if (named.ExtensionMarkerName is not null)
284+
{
285+
trapFile.Write(named.ExtensionMarkerName);
286+
}
287+
else
288+
{
289+
trapFile.Write("unknown");
290+
}
291+
trapFile.Write(")");
292+
}
293+
278294
private static void BuildQualifierAndName(INamedTypeSymbol named, Context cx, EscapingTextWriter trapFile, ISymbol symbolBeingDefined)
279295
{
280296
if (named.ContainingType is not null)
@@ -289,8 +305,18 @@ private static void BuildQualifierAndName(INamedTypeSymbol named, Context cx, Es
289305
named.ContainingNamespace.BuildNamespace(cx, trapFile);
290306
}
291307

292-
var name = named.IsFileLocal ? named.MetadataName : named.Name;
293-
trapFile.Write(name);
308+
if (named.IsFileLocal)
309+
{
310+
trapFile.Write(named.MetadataName);
311+
}
312+
else if (named.IsExtension)
313+
{
314+
named.BuildExtensionTypeId(cx, trapFile);
315+
}
316+
else
317+
{
318+
trapFile.Write(named.Name);
319+
}
294320
}
295321

296322
private static void BuildTupleId(INamedTypeSymbol named, Context cx, EscapingTextWriter trapFile, ISymbol symbolBeingDefined)
@@ -391,6 +417,7 @@ public static void BuildDisplayName(this ITypeSymbol type, Context cx, TextWrite
391417
case TypeKind.Enum:
392418
case TypeKind.Delegate:
393419
case TypeKind.Error:
420+
case TypeKind.Extension:
394421
var named = (INamedTypeSymbol)type;
395422
named.BuildNamedTypeDisplayName(cx, trapFile, constructUnderlyingTupleType);
396423
return;
@@ -465,6 +492,20 @@ public static void BuildFunctionPointerSignature(IFunctionPointerTypeSymbol funp
465492
private static void BuildFunctionPointerTypeDisplayName(this IFunctionPointerTypeSymbol funptr, Context cx, TextWriter trapFile) =>
466493
BuildFunctionPointerSignature(funptr, trapFile, s => s.BuildDisplayName(cx, trapFile));
467494

495+
private static void BuildExtensionTypeDisplayName(this INamedTypeSymbol named, Context cx, TextWriter trapFile)
496+
{
497+
trapFile.Write("extension(");
498+
if (named.ExtensionParameter?.Type is ITypeSymbol type)
499+
{
500+
type.BuildDisplayName(cx, trapFile);
501+
}
502+
else
503+
{
504+
trapFile.Write("unknown");
505+
}
506+
trapFile.Write(")");
507+
}
508+
468509
private static void BuildNamedTypeDisplayName(this INamedTypeSymbol namedType, Context cx, TextWriter trapFile, bool constructUnderlyingTupleType)
469510
{
470511
if (!constructUnderlyingTupleType && namedType.IsTupleType)
@@ -484,6 +525,12 @@ private static void BuildNamedTypeDisplayName(this INamedTypeSymbol namedType, C
484525
return;
485526
}
486527

528+
if (namedType.IsExtension)
529+
{
530+
namedType.BuildExtensionTypeDisplayName(cx, trapFile);
531+
return;
532+
}
533+
487534
if (namedType.IsAnonymousType)
488535
{
489536
namedType.BuildAnonymousName(cx, trapFile);
@@ -596,6 +643,87 @@ public static bool IsSourceDeclaration(this IParameterSymbol parameter)
596643
return true;
597644
}
598645

646+
/// <summary>
647+
/// Return true if this method is a compiler-generated extension method.
648+
/// </summary>
649+
public static bool IsCompilerGeneratedExtensionMethod(this IMethodSymbol method) =>
650+
method.TryGetExtensionMethod(out _);
651+
652+
/// <summary>
653+
/// Returns true if this method is a compiler-generated extension method,
654+
/// and outputs the original extension method declaration.
655+
/// </summary>
656+
public static bool TryGetExtensionMethod(this IMethodSymbol method, out IMethodSymbol? declaration)
657+
{
658+
declaration = null;
659+
if (method.IsImplicitlyDeclared && method.ContainingSymbol is INamedTypeSymbol containingType)
660+
{
661+
// Extension types are declared within the same type as the generated
662+
// extension method implementation.
663+
var extensions = containingType.GetMembers()
664+
.OfType<INamedTypeSymbol>()
665+
.Where(t => t.IsExtension);
666+
// Find the (possibly unbound) original extension method that maps to this implementation (if any).
667+
var unboundDeclaration = extensions.SelectMany(e => e.GetMembers())
668+
.OfType<IMethodSymbol>()
669+
.FirstOrDefault(m => SymbolEqualityComparer.Default.Equals(m.AssociatedExtensionImplementation, method.ConstructedFrom));
670+
671+
var isFullyConstructed = method.IsBoundGenericMethod();
672+
if (isFullyConstructed && unboundDeclaration?.ContainingType is INamedTypeSymbol extensionType)
673+
{
674+
try
675+
{
676+
// Use the type arguments from the constructed extension method to construct the extension type.
677+
var arguments = method.TypeArguments.ToArray();
678+
var (extensionTypeArguments, extensionMethodArguments) = arguments.SplitAt(extensionType.TypeParameters.Length);
679+
680+
// Construct the extension type.
681+
var boundExtensionType = extensionType.IsUnboundGenericType()
682+
? extensionType.Construct(extensionTypeArguments.ToArray())
683+
: extensionType;
684+
685+
// Find the extension method declaration within the constructed extension type.
686+
var extensionDeclaration = boundExtensionType.GetMembers()
687+
.OfType<IMethodSymbol>()
688+
.First(c => SymbolEqualityComparer.Default.Equals(c.OriginalDefinition, unboundDeclaration));
689+
690+
// If the extension declaration is unbound apply the remaning type arguments and construct it.
691+
declaration = extensionDeclaration.IsUnboundGenericMethod()
692+
? extensionDeclaration.Construct(extensionMethodArguments.ToArray())
693+
: extensionDeclaration;
694+
}
695+
catch
696+
{
697+
// If anything goes wrong, fall back to the unbound declaration.
698+
declaration = unboundDeclaration;
699+
}
700+
}
701+
else
702+
{
703+
declaration = unboundDeclaration;
704+
}
705+
706+
}
707+
return declaration is not null;
708+
}
709+
710+
/// <summary>
711+
/// Returns true if this method is an unbound generic method.
712+
/// </summary>
713+
public static bool IsUnboundGenericMethod(this IMethodSymbol method) =>
714+
method.IsGenericMethod && SymbolEqualityComparer.Default.Equals(method.ConstructedFrom, method);
715+
716+
/// <summary>
717+
/// Returns true if this method is a bound generic method.
718+
/// </summary>
719+
public static bool IsBoundGenericMethod(this IMethodSymbol method) => method.IsGenericMethod && !method.IsUnboundGenericMethod();
720+
721+
/// <summary>
722+
/// Returns true if this type is an unbound generic type.
723+
/// </summary>
724+
public static bool IsUnboundGenericType(this INamedTypeSymbol type) =>
725+
type.IsGenericType && SymbolEqualityComparer.Default.Equals(type.ConstructedFrom, type);
726+
599727
/// <summary>
600728
/// Gets the base type of `symbol`. Unlike `symbol.BaseType`, this excludes effective base
601729
/// types of type parameters as well as `object` base types.

csharp/extractor/Semmle.Extraction.CSharp/Entities/Expressions/Invocation.cs

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ private Invocation(ExpressionNodeInfo info)
2424

2525
private bool IsExplicitDelegateInvokeCall() => Kind == ExprKind.DELEGATE_INVOCATION && Context.GetModel(Syntax.Expression).GetSymbolInfo(Syntax.Expression).Symbol is IMethodSymbol m && m.MethodKind == MethodKind.DelegateInvoke;
2626

27+
private bool IsOperatorCall() => Kind == ExprKind.OPERATOR_INVOCATION;
28+
29+
private bool IsValidMemberAccessKind()
30+
{
31+
return Kind == ExprKind.METHOD_INVOCATION ||
32+
IsEventDelegateCall() ||
33+
IsExplicitDelegateInvokeCall() ||
34+
IsOperatorCall();
35+
}
36+
2737
protected override void PopulateExpression(TextWriter trapFile)
2838
{
2939
if (IsNameof(Syntax))
@@ -37,7 +47,7 @@ protected override void PopulateExpression(TextWriter trapFile)
3747
var target = TargetSymbol;
3848
switch (Syntax.Expression)
3949
{
40-
case MemberAccessExpressionSyntax memberAccess when Kind == ExprKind.METHOD_INVOCATION || IsEventDelegateCall() || IsExplicitDelegateInvokeCall():
50+
case MemberAccessExpressionSyntax memberAccess when IsValidMemberAccessKind():
4151
memberName = memberAccess.Name.Identifier.Text;
4252
if (Syntax.Expression.Kind() == SyntaxKind.SimpleMemberAccessExpression)
4353
// Qualified method call; `x.M()`
@@ -113,14 +123,31 @@ private static bool IsDynamicCall(ExpressionNodeInfo info)
113123

114124
public SymbolInfo SymbolInfo => info.SymbolInfo;
115125

126+
private static bool IsOperatorLikeCall(ExpressionNodeInfo info)
127+
{
128+
return info.SymbolInfo.Symbol is IMethodSymbol method &&
129+
method.TryGetExtensionMethod(out var original) &&
130+
original!.MethodKind == MethodKind.UserDefinedOperator;
131+
}
132+
116133
public IMethodSymbol? TargetSymbol
117134
{
118135
get
119136
{
120137
var si = SymbolInfo;
121138

122-
if (si.Symbol is not null)
123-
return si.Symbol as IMethodSymbol;
139+
if (si.Symbol is ISymbol symbol)
140+
{
141+
var method = symbol as IMethodSymbol;
142+
// Case for compiler-generated extension methods.
143+
if (method is not null &&
144+
method.TryGetExtensionMethod(out var original))
145+
{
146+
return original;
147+
}
148+
149+
return method;
150+
}
124151

125152
if (si.CandidateReason == CandidateReason.OverloadResolutionFailure)
126153
{
@@ -196,15 +223,25 @@ private static bool IsLocalFunctionInvocation(ExpressionNodeInfo info)
196223

197224
private static ExprKind GetKind(ExpressionNodeInfo info)
198225
{
199-
return IsNameof((InvocationExpressionSyntax)info.Node)
200-
? ExprKind.NAMEOF
201-
: IsDelegateLikeCall(info)
202-
? IsDelegateInvokeCall(info)
203-
? ExprKind.DELEGATE_INVOCATION
204-
: ExprKind.FUNCTION_POINTER_INVOCATION
205-
: IsLocalFunctionInvocation(info)
206-
? ExprKind.LOCAL_FUNCTION_INVOCATION
207-
: ExprKind.METHOD_INVOCATION;
226+
if (IsNameof((InvocationExpressionSyntax)info.Node))
227+
{
228+
return ExprKind.NAMEOF;
229+
}
230+
if (IsDelegateLikeCall(info))
231+
{
232+
return IsDelegateInvokeCall(info)
233+
? ExprKind.DELEGATE_INVOCATION
234+
: ExprKind.FUNCTION_POINTER_INVOCATION;
235+
}
236+
if (IsLocalFunctionInvocation(info))
237+
{
238+
return ExprKind.LOCAL_FUNCTION_INVOCATION;
239+
}
240+
if (IsOperatorLikeCall(info))
241+
{
242+
return ExprKind.OPERATOR_INVOCATION;
243+
}
244+
return ExprKind.METHOD_INVOCATION;
208245
}
209246

210247
private static bool IsNameof(InvocationExpressionSyntax syntax)

csharp/extractor/Semmle.Extraction.CSharp/Entities/Method.cs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,37 @@ internal abstract class Method : CachedSymbol<IMethodSymbol>, IExpressionParentE
1414
protected Method(Context cx, IMethodSymbol init)
1515
: base(cx, init) { }
1616

17+
private SyntheticExtensionParameter? SyntheticParameter { get; set; }
18+
19+
private int SynthesizeExtensionParameter()
20+
{
21+
// Synthesize implicit parameter for extension methods declared using extension(...) syntax.
22+
if (Symbol.ContainingSymbol is INamedTypeSymbol type &&
23+
type.IsExtension && type.ExtensionParameter is IParameterSymbol parameter &&
24+
!string.IsNullOrEmpty(parameter.Name) && !Symbol.IsStatic)
25+
{
26+
var originalSyntheticParam = OriginalDefinition.SyntheticParameter;
27+
SyntheticParameter = SyntheticExtensionParameter.Create(Context, this, parameter, originalSyntheticParam);
28+
return 1;
29+
}
30+
31+
return 0;
32+
}
33+
1734
protected void PopulateParameters()
1835
{
1936
var originalMethod = OriginalDefinition;
37+
var positionOffset = SynthesizeExtensionParameter();
38+
2039
IEnumerable<IParameterSymbol> parameters = Symbol.Parameters;
2140
IEnumerable<IParameterSymbol> originalParameters = originalMethod.Symbol.Parameters;
2241

2342
foreach (var p in parameters.Zip(originalParameters, (paramSymbol, originalParam) => new { paramSymbol, originalParam }))
2443
{
2544
var original = SymbolEqualityComparer.Default.Equals(p.paramSymbol, p.originalParam)
2645
? null
27-
: Parameter.Create(Context, p.originalParam, originalMethod);
28-
Parameter.Create(Context, p.paramSymbol, this, original);
46+
: Parameter.Create(Context, p.originalParam, originalMethod, null, positionOffset);
47+
Parameter.Create(Context, p.paramSymbol, this, original, positionOffset);
2948
}
3049

3150
if (Symbol.IsVararg)
@@ -302,9 +321,9 @@ public static void AddExplicitInterfaceQualifierToId(Context cx, EscapingTextWri
302321
/// <summary>
303322
/// Whether this method has unbound type parameters.
304323
/// </summary>
305-
public bool IsUnboundGeneric => IsGeneric && SymbolEqualityComparer.Default.Equals(Symbol.ConstructedFrom, Symbol);
324+
public bool IsUnboundGeneric => Symbol.IsUnboundGenericMethod();
306325

307-
public bool IsBoundGeneric => IsGeneric && !IsUnboundGeneric;
326+
public bool IsBoundGeneric => Symbol.IsBoundGenericMethod();
308327

309328
protected IMethodSymbol ConstructedFromSymbol => Symbol.ConstructedFrom;
310329

csharp/extractor/Semmle.Extraction.CSharp/Entities/OrdinaryMethod.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ protected OrdinaryMethod(Context cx, IMethodSymbol init)
2323
? Symbol.ContainingType.GetSymbolLocation()
2424
: BodyDeclaringSymbol.GetSymbolLocation();
2525

26-
public override bool NeedsPopulation => base.NeedsPopulation || IsCompilerGeneratedDelegate();
26+
public override bool NeedsPopulation =>
27+
(base.NeedsPopulation || IsCompilerGeneratedDelegate()) &&
28+
// Exclude compiler-generated extension methods. A call to such a method
29+
// is replaced by a call to the defining extension method.
30+
!Symbol.IsCompilerGeneratedExtensionMethod();
2731

2832
public override void Populate(TextWriter trapFile)
2933
{

csharp/extractor/Semmle.Extraction.CSharp/Entities/Types/NamedType.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ private NamedType(Context cx, INamedTypeSymbol init, bool constructUnderlyingTup
2020
public static NamedType Create(Context cx, INamedTypeSymbol type) =>
2121
NamedTypeFactory.Instance.CreateEntityFromSymbol(cx, type);
2222

23+
public NamedType OriginalDefinition => Create(Context, Symbol.OriginalDefinition);
24+
2325
/// <summary>
2426
/// Creates a named type entity from a tuple type. Unlike <see cref="Create"/>, this
2527
/// will create an entity for the underlying `System.ValueTuple` struct.
@@ -90,6 +92,25 @@ public override void Populate(TextWriter trapFile)
9092
{
9193
trapFile.anonymous_types(this);
9294
}
95+
96+
if (Symbol.IsExtension && Symbol.ExtensionParameter is IParameterSymbol parameter)
97+
{
98+
// For some reason an extension type has a receiver parameter with an empty name
99+
// even when there is no parameter.
100+
if (!string.IsNullOrEmpty(parameter.Name))
101+
{
102+
var originalType = OriginalDefinition;
103+
// In case this is a constructed generic, we also need to create the unbound parameter.
104+
var originalParameter = SymbolEqualityComparer.Default.Equals(Symbol, originalType.Symbol.ExtensionParameter) || originalType.Symbol.ExtensionParameter is null
105+
? null
106+
: Parameter.Create(Context, originalType.Symbol.ExtensionParameter, originalType);
107+
Parameter.Create(Context, parameter, this, originalParameter);
108+
}
109+
110+
// Use the parameter type as the receiver type.
111+
var receiverType = Type.Create(Context, parameter.Type).TypeRef;
112+
trapFile.extension_receiver_type(this, receiverType);
113+
}
93114
}
94115

95116
private readonly Lazy<Type[]> typeArgumentsLazy;

csharp/extractor/Semmle.Extraction.CSharp/Entities/Types/Type.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ public Kinds.TypeKind GetTypeKind(Context cx, bool constructUnderlyingTupleType)
105105
case TypeKind.Pointer: return Kinds.TypeKind.POINTER;
106106
case TypeKind.FunctionPointer: return Kinds.TypeKind.FUNCTION_POINTER;
107107
case TypeKind.Error: return Kinds.TypeKind.UNKNOWN;
108+
case TypeKind.Extension: return Kinds.TypeKind.EXTENSION;
108109
default:
109110
cx.ModelError(Symbol, $"Unhandled type kind '{Symbol.TypeKind}'");
110111
return Kinds.TypeKind.UNKNOWN;
@@ -366,7 +367,7 @@ private class DelegateTypeParameter : Parameter
366367
private DelegateTypeParameter(Context cx, IParameterSymbol init, IEntity parent, Parameter? original)
367368
: base(cx, init, parent, original) { }
368369

369-
public static new DelegateTypeParameter Create(Context cx, IParameterSymbol param, IEntity parent, Parameter? original = null) =>
370+
public static DelegateTypeParameter Create(Context cx, IParameterSymbol param, IEntity parent, Parameter? original = null) =>
370371
// We need to use a different cache key than `param` to avoid mixing up
371372
// `DelegateTypeParameter`s and `Parameter`s
372373
DelegateTypeParameterFactory.Instance.CreateEntity(cx, (typeof(DelegateTypeParameter), new SymbolEqualityWrapper(param)), (param, parent, original));

0 commit comments

Comments
 (0)