Skip to content

Commit 522e4d6

Browse files
authored
Merge pull request #21273 from paldepind/rust/tp-assoc
Rust: Implement support for associated types accessed on type parameters
2 parents 9ed2261 + 6c67475 commit 522e4d6

File tree

7 files changed

+843
-520
lines changed

7 files changed

+843
-520
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/**
2+
* Provides classes and helper predicates for associated types.
3+
*/
4+
5+
private import rust
6+
private import codeql.rust.internal.PathResolution
7+
private import TypeMention
8+
private import Type
9+
private import TypeInference
10+
11+
/** An associated type, that is, a type alias in a trait block. */
12+
final class AssocType extends TypeAlias {
13+
Trait trait;
14+
15+
AssocType() { this = trait.getAssocItemList().getAnAssocItem() }
16+
17+
Trait getTrait() { result = trait }
18+
19+
string getText() { result = this.getName().getText() }
20+
}
21+
22+
/** Gets an associated type of `trait` or of a supertrait of `trait`. */
23+
AssocType getTraitAssocType(Trait trait) { result.getTrait() = trait.getSupertrait*() }
24+
25+
/** Holds if `path` is of the form `<type as trait>::name` */
26+
pragma[nomagic]
27+
predicate pathTypeAsTraitAssoc(Path path, TypeRepr typeRepr, Path traitPath, string name) {
28+
exists(PathSegment segment |
29+
segment = path.getQualifier().getSegment() and
30+
typeRepr = segment.getTypeRepr() and
31+
traitPath = segment.getTraitTypeRepr().getPath() and
32+
name = path.getText()
33+
)
34+
}
35+
36+
/**
37+
* Holds if `assoc` is accessed on `tp` in `path`.
38+
*
39+
* That is, this is the case when `path` is of the form `<tp as
40+
* Trait>::AssocType` or `tp::AssocType`; and `AssocType` resolves to `assoc`.
41+
*/
42+
predicate tpAssociatedType(TypeParam tp, AssocType assoc, Path path) {
43+
resolvePath(path.getQualifier()) = tp and
44+
resolvePath(path) = assoc
45+
or
46+
exists(PathTypeRepr typeRepr, Path traitPath, string name |
47+
pathTypeAsTraitAssoc(path, typeRepr, traitPath, name) and
48+
tp = resolvePath(typeRepr.getPath()) and
49+
assoc = resolvePath(traitPath).(TraitItemNode).getAssocItem(name)
50+
)
51+
}
52+
53+
/**
54+
* Holds if `bound` is a type bound for `tp` that gives rise to `assoc` being
55+
* present for `tp`.
56+
*/
57+
predicate tpBoundAssociatedType(
58+
TypeParam tp, TypeBound bound, Path path, TraitItemNode trait, AssocType assoc
59+
) {
60+
bound = tp.getATypeBound() and
61+
path = bound.getTypeRepr().(PathTypeRepr).getPath() and
62+
trait = resolvePath(path) and
63+
assoc = getTraitAssocType(trait)
64+
}

rust/ql/lib/codeql/rust/internal/typeinference/Type.qll

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,7 @@ private import codeql.rust.elements.internal.generated.Raw
88
private import codeql.rust.elements.internal.generated.Synth
99
private import codeql.rust.frameworks.stdlib.Stdlib
1010
private import codeql.rust.frameworks.stdlib.Builtins as Builtins
11-
12-
/** Gets a type alias of `trait` or of a supertrait of `trait`. */
13-
private TypeAlias getTraitTypeAlias(Trait trait) {
14-
result = trait.getSupertrait*().getAssocItemList().getAnAssocItem()
15-
}
11+
private import AssociatedType
1612

1713
/**
1814
* Holds if a dyn trait type for the trait `trait` should have a type parameter
@@ -31,7 +27,7 @@ private TypeAlias getTraitTypeAlias(Trait trait) {
3127
*/
3228
private predicate dynTraitTypeParameter(Trait trait, AstNode n) {
3329
trait = any(DynTraitTypeRepr dt).getTrait() and
34-
n = [trait.getGenericParamList().getATypeParam().(AstNode), getTraitTypeAlias(trait)]
30+
n = [trait.getGenericParamList().getATypeParam().(AstNode), getTraitAssocType(trait)]
3531
}
3632

3733
cached
@@ -43,8 +39,11 @@ newtype TType =
4339
TNeverType() or
4440
TUnknownType() or
4541
TTypeParamTypeParameter(TypeParam t) or
46-
TAssociatedTypeTypeParameter(Trait trait, TypeAlias typeAlias) {
47-
getTraitTypeAlias(trait) = typeAlias
42+
TAssociatedTypeTypeParameter(Trait trait, AssocType typeAlias) {
43+
getTraitAssocType(trait) = typeAlias
44+
} or
45+
TTypeParamAssociatedTypeTypeParameter(TypeParam tp, AssocType assoc) {
46+
tpAssociatedType(tp, assoc, _)
4847
} or
4948
TDynTraitTypeParameter(Trait trait, AstNode n) { dynTraitTypeParameter(trait, n) } or
5049
TImplTraitTypeParameter(ImplTraitTypeRepr implTrait, TypeParam tp) {
@@ -464,6 +463,52 @@ class AssociatedTypeTypeParameter extends TypeParameter, TAssociatedTypeTypePara
464463
override Location getLocation() { result = typeAlias.getLocation() }
465464
}
466465

466+
/**
467+
* A type parameter corresponding to an associated type accessed on a type
468+
* parameter, for example `T::AssociatedType` where `T` is a type parameter.
469+
*
470+
* These type parameters are created when a function signature accesses an
471+
* associated type on a type parameter. For example, in
472+
* ```rust
473+
* fn foo<T: SomeTrait>(arg: T::Assoc) { }
474+
* ```
475+
* we create a `TypeParamAssociatedTypeTypeParameter` for `Assoc` on `T` and the
476+
* mention `T::Assoc` resolves to this type parameter. If denoting the type
477+
* parameter by `T_Assoc` then the above function is treated as if it was
478+
* ```rust
479+
* fn foo<T: SomeTrait<Assoc = T_Assoc>, T_Assoc>(arg: T_Assoc) { }
480+
* ```
481+
*/
482+
class TypeParamAssociatedTypeTypeParameter extends TypeParameter,
483+
TTypeParamAssociatedTypeTypeParameter
484+
{
485+
private TypeParam typeParam;
486+
private AssocType assoc;
487+
488+
TypeParamAssociatedTypeTypeParameter() {
489+
this = TTypeParamAssociatedTypeTypeParameter(typeParam, assoc)
490+
}
491+
492+
/** Gets the type parameter that this associated type is accessed on. */
493+
TypeParam getTypeParam() { result = typeParam }
494+
495+
/** Gets the associated type alias. */
496+
AssocType getTypeAlias() { result = assoc }
497+
498+
/** Gets a path that accesses this type parameter. */
499+
Path getAPath() { tpAssociatedType(typeParam, assoc, result) }
500+
501+
override ItemNode getDeclaringItem() { result.getTypeParam(_) = typeParam }
502+
503+
override string toString() {
504+
result =
505+
typeParam.toString() + "::" + assoc.getName().getText() + "[" +
506+
assoc.getTrait().getName().getText() + "]"
507+
}
508+
509+
override Location getLocation() { result = typeParam.getLocation() }
510+
}
511+
467512
/** Gets the associated type type-parameter corresponding directly to `typeAlias`. */
468513
AssociatedTypeTypeParameter getAssociatedTypeTypeParameter(TypeAlias typeAlias) {
469514
result.isDirect() and result.getTypeAlias() = typeAlias

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ private module Input implements InputSig1<Location>, InputSig2<PreTypeMention> {
108108
id2 = idOfTypeParameterAstNode(tp0.(AssociatedTypeTypeParameter).getTypeAlias())
109109
or
110110
kind = 4 and
111+
id1 = idOfTypeParameterAstNode(tp0.(TypeParamAssociatedTypeTypeParameter).getTypeParam()) and
112+
id2 = idOfTypeParameterAstNode(tp0.(TypeParamAssociatedTypeTypeParameter).getTypeAlias())
113+
or
114+
kind = 5 and
111115
id1 = 0 and
112116
exists(AstNode node | id2 = idOfTypeParameterAstNode(node) |
113117
node = tp0.(TypeParamTypeParameter).getTypeParam() or
@@ -270,13 +274,21 @@ private class FunctionDeclaration extends Function {
270274
this = i.asSome().getAnAssocItem()
271275
}
272276

277+
TypeParam getTypeParam(ImplOrTraitItemNodeOption i) {
278+
i = parent and
279+
result = [this.getGenericParamList().getATypeParam(), i.asSome().getTypeParam(_)]
280+
}
281+
273282
TypeParameter getTypeParameter(ImplOrTraitItemNodeOption i, TypeParameterPosition ppos) {
283+
typeParamMatchPosition(this.getTypeParam(i), result, ppos)
284+
or
285+
// For every `TypeParam` of this function, any associated types accessed on
286+
// the type parameter are also type parameters.
287+
ppos.isImplicit() and
288+
result.(TypeParamAssociatedTypeTypeParameter).getTypeParam() = this.getTypeParam(i)
289+
or
274290
i = parent and
275291
(
276-
typeParamMatchPosition(this.getGenericParamList().getATypeParam(), result, ppos)
277-
or
278-
typeParamMatchPosition(i.asSome().getTypeParam(_), result, ppos)
279-
or
280292
ppos.isImplicit() and result = TSelfTypeParameter(i.asSome())
281293
or
282294
ppos.isImplicit() and result.(AssociatedTypeTypeParameter).getTrait() = i.asSome()

rust/ql/lib/codeql/rust/internal/typeinference/TypeMention.qll

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ private import codeql.rust.frameworks.stdlib.Stdlib
66
private import Type
77
private import TypeAbstraction
88
private import TypeInference
9+
private import AssociatedType
910

1011
bindingset[trait, name]
1112
pragma[inline_late]
@@ -319,6 +320,22 @@ private module MkTypeMention<getAdditionalPathTypeAtSig/2 getAdditionalPathTypeA
319320
tp = TAssociatedTypeTypeParameter(resolved, alias) and
320321
path.isEmpty()
321322
)
323+
or
324+
// If this path is a type parameter bound, then any associated types
325+
// accessed on the type parameter, which originate from this bound, should
326+
// be instantiated into the bound, as explained in the comment for
327+
// `TypeParamAssociatedTypeTypeParameter`.
328+
// ```rust
329+
// fn foo<T: SomeTrait<Assoc = T_Assoc>, T_Assoc>(arg: T_Assoc) { }
330+
// ^^^^^^^^^ ^^^^^ ^^^^^^^
331+
// this path result
332+
// ```
333+
exists(TypeParam typeParam, Trait trait, AssocType assoc |
334+
tpBoundAssociatedType(typeParam, _, this, trait, assoc) and
335+
tp = TAssociatedTypeTypeParameter(resolved, assoc) and
336+
result = TTypeParamAssociatedTypeTypeParameter(typeParam, assoc) and
337+
path.isEmpty()
338+
)
322339
}
323340

324341
bindingset[name]
@@ -372,6 +389,8 @@ private module MkTypeMention<getAdditionalPathTypeAtSig/2 getAdditionalPathTypeA
372389
or
373390
// Handles paths of the form `Self::AssocType` within a trait block
374391
result = TAssociatedTypeTypeParameter(resolvePath(this.getQualifier()), resolved)
392+
or
393+
result.(TypeParamAssociatedTypeTypeParameter).getAPath() = this
375394
}
376395

377396
override Type resolvePathTypeAt(TypePath typePath) {
@@ -690,11 +709,10 @@ private predicate pathConcreteTypeAssocType(
690709
|
691710
// path of the form `<Type as Trait>::AssocType`
692711
// ^^^ tm ^^^^^^^^^ name
693-
exists(string name |
694-
name = path.getText() and
695-
trait = resolvePath(qualifier.getSegment().getTraitTypeRepr().getPath()) and
696-
getTraitAssocType(trait, name) = alias and
697-
tm = qualifier.getSegment().getTypeRepr()
712+
exists(string name, Path traitPath |
713+
pathTypeAsTraitAssoc(path, tm, traitPath, name) and
714+
trait = resolvePath(traitPath) and
715+
getTraitAssocType(trait, name) = alias
698716
)
699717
or
700718
// path of the form `Self::AssocType` within an `impl` block

rust/ql/test/library-tests/type-inference/associated_types.rs

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ impl<A> Wrapper<A> {
77
}
88
}
99

10-
#[derive(Debug, Default)]
10+
#[derive(Debug, Default, Clone, Copy)]
1111
struct S;
1212

1313
#[derive(Debug, Default)]
@@ -260,13 +260,65 @@ mod type_param_access_associated_type {
260260
)
261261
}
262262

263+
// Associated type accessed on a type parameter of an impl block
264+
impl<TI> Wrapper<TI>
265+
where
266+
TI: GetSet,
267+
{
268+
fn extract(&self) -> TI::Output {
269+
self.0.get() // $ fieldof=Wrapper target=GetSet::get
270+
}
271+
}
272+
273+
// Associated type accessed on another associated type
274+
275+
fn tp_nested_assoc_type<T: GetSet>(thing: T) -> <<T as GetSet>::Output as GetSet>::Output
276+
where
277+
<T as GetSet>::Output: GetSet,
278+
{
279+
thing.get().get() // $ target=GetSet::get target=GetSet::get
280+
}
281+
282+
pub trait GetSetWrap {
283+
type Assoc: GetSet;
284+
285+
// GetSetWrap::get_wrap
286+
fn get_wrap(&self) -> Self::Assoc;
287+
}
288+
289+
impl GetSetWrap for S {
290+
type Assoc = S;
291+
292+
// S::get_wrap
293+
fn get_wrap(&self) -> Self::Assoc {
294+
S
295+
}
296+
}
297+
298+
// Nested associated type accessed on a type parameter of an impl block
299+
impl<TI> Wrapper<TI>
300+
where
301+
TI: GetSetWrap,
302+
{
303+
fn extract2(&self) -> <<TI as GetSetWrap>::Assoc as GetSet>::Output {
304+
self.0.get_wrap().get() // $ fieldof=Wrapper target=GetSetWrap::get_wrap $ MISSING: target=GetSet::get
305+
}
306+
}
307+
263308
pub fn test() {
264-
let _o1 = tp_with_as(S); // $ target=tp_with_as MISSING: type=_o1:S3
265-
let _o2 = tp_without_as(S); // $ target=tp_without_as MISSING: type=_o2:S3
309+
let _o1 = tp_with_as(S); // $ target=tp_with_as type=_o1:S3
310+
let _o2 = tp_without_as(S); // $ target=tp_without_as type=_o2:S3
266311
let (
267312
_o3, // $ MISSING: type=_o3:S3
268-
_o4, // $ MISSING: type=_o4:bool
313+
_o4, // $ type=_o4:bool
269314
) = tp_assoc_from_supertrait(S); // $ target=tp_assoc_from_supertrait
315+
316+
let _o5 = tp_nested_assoc_type(Wrapper(S)); // $ target=tp_nested_assoc_type MISSING: type=_o5:S3
317+
318+
let w = Wrapper(S);
319+
let _extracted = w.extract(); // $ target=extract type=_extracted:S3
320+
321+
let _extracted2 = w.extract2(); // $ target=extract2 MISSING: type=_extracted2:S3
270322
}
271323
}
272324

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1748,7 +1748,7 @@ mod overloadable_operators {
17481748
let i64_mul = 17i64 * 18i64; // $ type=i64_mul:i64 target=mul
17491749
let i64_div = 19i64 / 20i64; // $ type=i64_div:i64 target=div
17501750
let i64_rem = 21i64 % 22i64; // $ type=i64_rem:i64 target=rem
1751-
let i64_param_add = param_add(1i64, 2i64); // $ target=param_add $ MISSING: type=i64_param_add:i64
1751+
let i64_param_add = param_add(1i64, 2i64); // $ target=param_add $ type=i64_param_add:i64
17521752

17531753
// Arithmetic assignment operators
17541754
let mut i64_add_assign = 23i64;
@@ -2053,7 +2053,7 @@ mod indexers {
20532053
let xs: [S; 1] = [S];
20542054
let x = xs[0].foo(); // $ target=foo type=x:S target=index
20552055

2056-
let y = param_index(vec, 0); // $ target=param_index $ MISSING: type=y:S
2056+
let y = param_index(vec, 0); // $ target=param_index $ type=y:S
20572057

20582058
analyze_slice(&xs); // $ target=analyze_slice
20592059
}

0 commit comments

Comments
 (0)