Skip to content

Commit 77d9d0a

Browse files
committed
Rust: Implement support for associated types accessed on type parameters
1 parent 077fc9f commit 77d9d0a

File tree

7 files changed

+233
-22
lines changed

7 files changed

+233
-22
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/**
2+
* Provides classes and helper predicates for associated types.
3+
*/
4+
5+
private import rust
6+
private import codeql.rust.internal.PathResolution
7+
private import TypeMention
8+
private import Type
9+
private import TypeInference
10+
11+
/** An associated type, that is, a type alias in a trait block. */
12+
final class AssocType extends TypeAlias {
13+
Trait trait;
14+
15+
AssocType() { this = trait.getAssocItemList().getAnAssocItem() }
16+
17+
Trait getTrait() { result = trait }
18+
19+
string getText() { result = this.getName().getText() }
20+
}
21+
22+
/** Gets an associated type of `trait` or of a supertrait of `trait`. */
23+
AssocType getTraitAssocType(Trait trait) {
24+
result = trait.getSupertrait*().getAssocItemList().getAnAssocItem()
25+
}
26+
27+
/** Holds if `path` is of the form `<type as trait>::name` */
28+
predicate asTraitPath(Path path, TypeRepr typeRepr, Path traitPath, string name) {
29+
exists(PathSegment segment |
30+
segment = path.getQualifier().getSegment() and
31+
typeRepr = segment.getTypeRepr() and
32+
traitPath = segment.getTraitTypeRepr().getPath() and
33+
name = path.getText()
34+
)
35+
}
36+
37+
/**
38+
* Holds if `assoc` is accessed on `tp` in `path`.
39+
*
40+
* That is this is the case when `path` is of the form `<tp as
41+
* Trait>::AssocType` or `tp::AssocType`; and `AssocType` resolves to `assoc`.
42+
*/
43+
predicate tpAssociatedType(TypeParam tp, AssocType assoc, Path path) {
44+
resolvePath(path.getQualifier()) = tp and
45+
resolvePath(path) = assoc
46+
or
47+
exists(TypeRepr typeRepr, Path traitPath, string name |
48+
asTraitPath(path, typeRepr, traitPath, name) and
49+
tp = resolvePath(typeRepr.(PathTypeRepr).getPath()) and
50+
assoc = resolvePath(traitPath).(TraitItemNode).getAssocItem(name)
51+
)
52+
}
53+
54+
/**
55+
* Holds if `bound` is a type bound for `tp` that gives rise to `assoc` being
56+
* present for `tp`.
57+
*/
58+
predicate tpBoundAssociatedType(
59+
TypeParam tp, TypeBound bound, Path path, TraitItemNode trait, AssocType assoc
60+
) {
61+
bound = tp.getATypeBound() and
62+
path = bound.getTypeRepr().(PathTypeRepr).getPath() and
63+
trait = resolvePath(path) and
64+
assoc = getTraitAssocType(trait)
65+
}

rust/ql/lib/codeql/rust/internal/typeinference/Type.qll

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,7 @@ private import codeql.rust.elements.internal.generated.Raw
88
private import codeql.rust.elements.internal.generated.Synth
99
private import codeql.rust.frameworks.stdlib.Stdlib
1010
private import codeql.rust.frameworks.stdlib.Builtins as Builtins
11-
12-
/** Gets a type alias of `trait` or of a supertrait of `trait`. */
13-
private TypeAlias getTraitTypeAlias(Trait trait) {
14-
result = trait.getSupertrait*().getAssocItemList().getAnAssocItem()
15-
}
11+
private import AssociatedTypes
1612

1713
/**
1814
* Holds if a dyn trait type for the trait `trait` should have a type parameter
@@ -31,7 +27,7 @@ private TypeAlias getTraitTypeAlias(Trait trait) {
3127
*/
3228
private predicate dynTraitTypeParameter(Trait trait, AstNode n) {
3329
trait = any(DynTraitTypeRepr dt).getTrait() and
34-
n = [trait.getGenericParamList().getATypeParam().(AstNode), getTraitTypeAlias(trait)]
30+
n = [trait.getGenericParamList().getATypeParam().(AstNode), getTraitAssocType(trait)]
3531
}
3632

3733
cached
@@ -43,8 +39,11 @@ newtype TType =
4339
TNeverType() or
4440
TUnknownType() or
4541
TTypeParamTypeParameter(TypeParam t) or
46-
TAssociatedTypeTypeParameter(Trait trait, TypeAlias typeAlias) {
47-
getTraitTypeAlias(trait) = typeAlias
42+
TAssociatedTypeTypeParameter(Trait trait, AssocType typeAlias) {
43+
getTraitAssocType(trait) = typeAlias
44+
} or
45+
TTypeParamAssociatedTypeTypeParameter(TypeParam tp, AssocType assoc) {
46+
tpAssociatedType(tp, assoc, _)
4847
} or
4948
TDynTraitTypeParameter(Trait trait, AstNode n) { dynTraitTypeParameter(trait, n) } or
5049
TImplTraitTypeParameter(ImplTraitTypeRepr implTrait, TypeParam tp) {
@@ -464,6 +463,52 @@ class AssociatedTypeTypeParameter extends TypeParameter, TAssociatedTypeTypePara
464463
override Location getLocation() { result = typeAlias.getLocation() }
465464
}
466465

466+
/**
467+
* A type parameter corresponding to an associated type accessed on a type
468+
* parameter, for example `T::AssociatedType` where `T` is a type parameter.
469+
*
470+
* These type parameters are created when a function signature accesses an
471+
* associated type on a type parameter. For example, in
472+
* ```rust
473+
* fn foo<T: SomeTrait>(arg: T::Assoc) { }
474+
* ```
475+
* we create a `TypeParamAssociatedTypeTypeParameter` for `Assoc` on `T` and the
476+
* mention `T::Assoc` resolves to this type parameter. If denoting the type
477+
* parameter by `T_Assoc` then the above function is treated as if it was
478+
* ```rust
479+
* fn foo<T: SomeTrait<Assoc = T_Assoc>, T_Assoc>(arg: T_Assoc) { }
480+
* ```
481+
*/
482+
class TypeParamAssociatedTypeTypeParameter extends TypeParameter,
483+
TTypeParamAssociatedTypeTypeParameter
484+
{
485+
private TypeParam typeParam;
486+
private AssocType assoc;
487+
488+
TypeParamAssociatedTypeTypeParameter() {
489+
this = TTypeParamAssociatedTypeTypeParameter(typeParam, assoc)
490+
}
491+
492+
/** Gets the type parameter that this associated type is accessed on. */
493+
TypeParam getTypeParam() { result = typeParam }
494+
495+
/** Gets the associated type alias. */
496+
AssocType getTypeAlias() { result = assoc }
497+
498+
/** Gets a path that accesses this type parameter. */
499+
Path getPath() { tpAssociatedType(typeParam, assoc, result) }
500+
501+
override ItemNode getDeclaringItem() { result.getTypeParam(_) = typeParam }
502+
503+
override string toString() {
504+
result =
505+
typeParam.toString() + "::" + assoc.getName().getText() + "[" +
506+
assoc.getTrait().getName().getText() + "]"
507+
}
508+
509+
override Location getLocation() { result = typeParam.getLocation() }
510+
}
511+
467512
/** Gets the associated type type-parameter corresponding directly to `typeAlias`. */
468513
AssociatedTypeTypeParameter getAssociatedTypeTypeParameter(TypeAlias typeAlias) {
469514
result.isDirect() and result.getTypeAlias() = typeAlias

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ private module Input1 implements InputSig1<Location> {
108108
id2 = idOfTypeParameterAstNode(tp0.(AssociatedTypeTypeParameter).getTypeAlias())
109109
or
110110
kind = 4 and
111+
id1 = idOfTypeParameterAstNode(tp0.(TypeParamAssociatedTypeTypeParameter).getTypeParam()) and
112+
id2 = idOfTypeParameterAstNode(tp0.(TypeParamAssociatedTypeTypeParameter).getTypeAlias())
113+
or
114+
kind = 5 and
111115
id1 = 0 and
112116
exists(AstNode node | id2 = idOfTypeParameterAstNode(node) |
113117
node = tp0.(TypeParamTypeParameter).getTypeParam() or
@@ -273,9 +277,16 @@ private class FunctionDeclaration extends Function {
273277
TypeParameter getTypeParameter(ImplOrTraitItemNodeOption i, TypeParameterPosition ppos) {
274278
i = parent and
275279
(
276-
typeParamMatchPosition(this.getGenericParamList().getATypeParam(), result, ppos)
277-
or
278-
typeParamMatchPosition(i.asSome().getTypeParam(_), result, ppos)
280+
exists(TypeParam tp |
281+
tp = [this.getGenericParamList().getATypeParam(), i.asSome().getTypeParam(_)]
282+
|
283+
typeParamMatchPosition(tp, result, ppos)
284+
or
285+
// If `tp` is a type parameter for this function, then any associated
286+
// types accessed on `tp` are also type parameters.
287+
ppos.isImplicit() and
288+
result.(TypeParamAssociatedTypeTypeParameter).getTypeParam() = tp
289+
)
279290
or
280291
ppos.isImplicit() and result = TSelfTypeParameter(i.asSome())
281292
or

rust/ql/lib/codeql/rust/internal/typeinference/TypeMention.qll

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ private import codeql.rust.frameworks.stdlib.Stdlib
66
private import Type
77
private import TypeAbstraction
88
private import TypeInference
9+
private import AssociatedTypes
910

1011
bindingset[trait, name]
1112
pragma[inline_late]
@@ -290,6 +291,22 @@ private module MkTypeMention<getAdditionalPathTypeAtSig/2 getAdditionalPathTypeA
290291
tp = TAssociatedTypeTypeParameter(resolved, alias) and
291292
path.isEmpty()
292293
)
294+
or
295+
// If this path is a type parameter bound, then any associated types
296+
// accessed on the type parameter, that originate from this bound, should
297+
// be instantiated into the bound, as explained in the comment for
298+
// `TypeParamAssociatedTypeTypeParameter`.
299+
// ```rust
300+
// fn foo<T: SomeTrait<Assoc = T_Assoc>, T_Assoc>(arg: T_Assoc) { }
301+
// ^^^^^^^^^ ^^^^^ ^^^^^^^
302+
// this path result
303+
// ```
304+
exists(TypeParam typeParam, Trait trait, AssocType assoc |
305+
tpBoundAssociatedType(typeParam, _, this, trait, assoc) and
306+
tp = TAssociatedTypeTypeParameter(resolved, assoc) and
307+
result = TTypeParamAssociatedTypeTypeParameter(typeParam, assoc) and
308+
path.isEmpty()
309+
)
293310
}
294311

295312
bindingset[name]
@@ -343,6 +360,8 @@ private module MkTypeMention<getAdditionalPathTypeAtSig/2 getAdditionalPathTypeA
343360
or
344361
// Handles paths of the form `Self::AssocType` within a trait block
345362
result = TAssociatedTypeTypeParameter(resolvePath(this.getQualifier()), resolved)
363+
or
364+
result.(TypeParamAssociatedTypeTypeParameter).getPath() = this
346365
}
347366

348367
override Type resolvePathTypeAt(TypePath typePath) {
@@ -661,11 +680,10 @@ private predicate pathConcreteTypeAssocType(
661680
|
662681
// path of the form `<Type as Trait>::AssocType`
663682
// ^^^ tm ^^^^^^^^^ name
664-
exists(string name |
665-
name = path.getText() and
666-
trait = resolvePath(qualifier.getSegment().getTraitTypeRepr().getPath()) and
667-
getTraitAssocType(trait, name) = alias and
668-
tm = qualifier.getSegment().getTypeRepr()
683+
exists(string name, Path traitPath |
684+
asTraitPath(path, tm, traitPath, name) and
685+
trait = resolvePath(traitPath) and
686+
getTraitAssocType(trait, name) = alias
669687
)
670688
or
671689
// path of the form `Self::AssocType` within an `impl` block

rust/ql/test/library-tests/type-inference/associated_types.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,15 +271,15 @@ mod type_param_access_associated_type {
271271
}
272272

273273
pub fn test() {
274-
let _o1 = tp_with_as(S); // $ target=tp_with_as MISSING: type=_o1:S3
275-
let _o2 = tp_without_as(S); // $ target=tp_without_as MISSING: type=_o2:S3
274+
let _o1 = tp_with_as(S); // $ target=tp_with_as type=_o1:S3
275+
let _o2 = tp_without_as(S); // $ target=tp_without_as type=_o2:S3
276276
let (
277277
_o3, // $ MISSING: type=_o3:S3
278-
_o4, // $ MISSING: type=_o4:bool
278+
_o4, // $ type=_o4:bool
279279
) = tp_assoc_from_supertrait(S); // $ target=tp_assoc_from_supertrait
280280

281281
let w = Wrapper(S);
282-
let _extracted = w.extract(); // $ target=extract MISSING: type=_extracted:S3
282+
let _extracted = w.extract(); // $ target=extract type=_extracted:S3
283283
}
284284
}
285285

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1927,7 +1927,7 @@ mod overloadable_operators {
19271927
let i64_mul = 17i64 * 18i64; // $ type=i64_mul:i64 target=mul
19281928
let i64_div = 19i64 / 20i64; // $ type=i64_div:i64 target=div
19291929
let i64_rem = 21i64 % 22i64; // $ type=i64_rem:i64 target=rem
1930-
let i64_param_add = param_add(1i64, 2i64); // $ target=param_add $ MISSING: type=i64_param_add:i64
1930+
let i64_param_add = param_add(1i64, 2i64); // $ target=param_add $ type=i64_param_add:i64
19311931

19321932
// Arithmetic assignment operators
19331933
let mut i64_add_assign = 23i64;
@@ -2232,7 +2232,7 @@ mod indexers {
22322232
let xs: [S; 1] = [S];
22332233
let x = xs[0].foo(); // $ target=foo type=x:S target=index
22342234

2235-
let y = param_index(vec, 0); // $ target=param_index $ MISSING: type=y:S
2235+
let y = param_index(vec, 0); // $ target=param_index $ type=y:S
22362236

22372237
analyze_slice(&xs); // $ target=analyze_slice
22382238
}

0 commit comments

Comments
 (0)