Skip to content

Commit 652c8db

Browse files
committed
Rust: Restrict type propagation into receivers
1 parent b34777e commit 652c8db

File tree

2 files changed

+75
-49
lines changed

2 files changed

+75
-49
lines changed

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 75 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -778,13 +778,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
778778
prefix1 = TypePath::singleton(getArrayTypeParameter()) and
779779
prefix2.isEmpty()
780780
or
781-
exists(Struct s |
782-
n2 = [n1.(RangeExpr).getStart(), n1.(RangeExpr).getEnd()] and
783-
prefix1 = TypePath::singleton(TTypeParamTypeParameter(s.getGenericParamList().getATypeParam())) and
784-
prefix2.isEmpty() and
785-
s = getRangeType(n1)
786-
)
787-
or
788781
exists(ClosureExpr ce, int index |
789782
n1 = ce and
790783
n2 = ce.getParam(index).getPat() and
@@ -829,6 +822,12 @@ private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) {
829822
bodyReturns(parent, child) and
830823
strictcount(Expr e | bodyReturns(parent, e)) > 1 and
831824
prefix.isEmpty()
825+
or
826+
exists(Struct s |
827+
child = [parent.(RangeExpr).getStart(), parent.(RangeExpr).getEnd()] and
828+
prefix = TypePath::singleton(TTypeParamTypeParameter(s.getGenericParamList().getATypeParam())) and
829+
s = getRangeType(parent)
830+
)
832831
}
833832

834833
/**
@@ -1031,10 +1030,10 @@ private module StructExprMatchingInput implements MatchingInputSig {
10311030
private module StructExprMatching = Matching<StructExprMatchingInput>;
10321031

10331032
pragma[nomagic]
1034-
private Type inferStructExprType0(AstNode n, boolean isReturn, TypePath path) {
1033+
private Type inferStructExprType0(AstNode n, FunctionPosition pos, TypePath path) {
10351034
exists(StructExprMatchingInput::Access a, StructExprMatchingInput::AccessPosition apos |
10361035
n = a.getNodeAt(apos) and
1037-
if apos.isStructPos() then isReturn = true else isReturn = false
1036+
if apos.isStructPos() then pos.isReturn() else pos.asPosition() = 0 // the acutal position doesn't matter, as long as it is positional
10381037
|
10391038
result = StructExprMatching::inferAccessType(a, apos, path)
10401039
or
@@ -1113,6 +1112,25 @@ private Trait getCallExprTraitQualifier(CallExpr ce) {
11131112
* Provides functionality related to context-based typing of calls.
11141113
*/
11151114
private module ContextTyping {
1115+
/**
1116+
* Holds if `f` mentions type parameter `tp` at some non-return position,
1117+
* possibly via a constraint on another mentioned type parameter.
1118+
*/
1119+
pragma[nomagic]
1120+
private predicate assocFunctionMentionsTypeParameterAtNonRetPos(
1121+
ImplOrTraitItemNode i, Function f, TypeParameter tp
1122+
) {
1123+
exists(FunctionPosition nonRetPos |
1124+
not nonRetPos.isReturn() and
1125+
tp = getAssocFunctionTypeAt(f, i, nonRetPos, _)
1126+
)
1127+
or
1128+
exists(TypeParameter mid |
1129+
assocFunctionMentionsTypeParameterAtNonRetPos(i, f, mid) and
1130+
tp = getATypeParameterConstraint(mid, _)
1131+
)
1132+
}
1133+
11161134
/**
11171135
* Holds if the return type of the function `f` inside `i` at `path` is type
11181136
* parameter `tp`, and `tp` does not appear in the type of any parameter of
@@ -1129,12 +1147,7 @@ private module ContextTyping {
11291147
) {
11301148
pos.isReturn() and
11311149
tp = getAssocFunctionTypeAt(f, i, pos, path) and
1132-
not exists(FunctionPosition nonResPos | not nonResPos.isReturn() |
1133-
tp = getAssocFunctionTypeAt(f, i, nonResPos, _)
1134-
or
1135-
// `Self` types in traits implicitly mention all type parameters of the trait
1136-
getAssocFunctionTypeAt(f, i, nonResPos, _) = TSelfTypeParameter(i)
1137-
)
1150+
not assocFunctionMentionsTypeParameterAtNonRetPos(i, f, tp)
11381151
}
11391152

11401153
/**
@@ -1184,7 +1197,7 @@ private module ContextTyping {
11841197
pragma[nomagic]
11851198
private predicate hasUnknownType(AstNode n) { hasUnknownTypeAt(n, _) }
11861199

1187-
signature Type inferCallTypeSig(AstNode n, boolean isReturn, TypePath path);
1200+
signature Type inferCallTypeSig(AstNode n, FunctionPosition pos, TypePath path);
11881201

11891202
/**
11901203
* Given a predicate `inferCallType` for inferring the type of a call at a given
@@ -1194,19 +1207,34 @@ private module ContextTyping {
11941207
*/
11951208
module CheckContextTyping<inferCallTypeSig/3 inferCallType> {
11961209
pragma[nomagic]
1197-
private Type inferCallTypeFromContextCand(AstNode n, TypePath prefix, TypePath path) {
1198-
result = inferCallType(n, false, path) and
1210+
private Type inferCallNonReturnType(AstNode n, FunctionPosition pos, TypePath path) {
1211+
result = inferCallType(n, pos, path) and
1212+
not pos.isReturn()
1213+
}
1214+
1215+
pragma[nomagic]
1216+
private Type inferCallNonReturnType(
1217+
AstNode n, FunctionPosition pos, TypePath prefix, TypePath path
1218+
) {
1219+
result = inferCallNonReturnType(n, pos, path) and
11991220
hasUnknownType(n) and
12001221
prefix = path.getAPrefix()
12011222
}
12021223

12031224
pragma[nomagic]
12041225
Type check(AstNode n, TypePath path) {
1205-
result = inferCallType(n, true, path)
1226+
result = inferCallType(n, any(FunctionPosition pos | pos.isReturn()), path)
12061227
or
1207-
exists(TypePath prefix |
1208-
result = inferCallTypeFromContextCand(n, prefix, path) and
1228+
exists(FunctionPosition pos, TypePath prefix |
1229+
result = inferCallNonReturnType(n, pos, prefix, path) and
12091230
hasUnknownTypeAt(n, prefix)
1231+
|
1232+
pos.isPosition()
1233+
or
1234+
// Never propagate type information directly into the receiver, since its type
1235+
// must already have been known in order to resolve the call
1236+
pos.isSelf() and
1237+
not prefix.isEmpty()
12101238
)
12111239
}
12121240
}
@@ -2607,12 +2635,9 @@ private Type inferMethodCallType0(
26072635
}
26082636

26092637
pragma[nomagic]
2610-
private Type inferMethodCallTypeNonSelf(AstNode n, boolean isReturn, TypePath path) {
2611-
exists(MethodCallMatchingInput::AccessPosition apos |
2612-
result = inferMethodCallType0(_, apos, n, _, path) and
2613-
not apos.isSelf() and
2614-
if apos.isReturn() then isReturn = true else isReturn = false
2615-
)
2638+
private Type inferMethodCallTypeNonSelf(AstNode n, FunctionPosition pos, TypePath path) {
2639+
result = inferMethodCallType0(_, pos, n, _, path) and
2640+
not pos.isSelf()
26162641
}
26172642

26182643
/**
@@ -2664,11 +2689,11 @@ private Type inferMethodCallTypeSelf(AstNode n, DerefChain derefChain, TypePath
26642689
)
26652690
}
26662691

2667-
private Type inferMethodCallTypePreCheck(AstNode n, boolean isReturn, TypePath path) {
2668-
result = inferMethodCallTypeNonSelf(n, isReturn, path)
2692+
private Type inferMethodCallTypePreCheck(AstNode n, FunctionPosition pos, TypePath path) {
2693+
result = inferMethodCallTypeNonSelf(n, pos, path)
26692694
or
26702695
result = inferMethodCallTypeSelf(n, DerefChain::nil(), path) and
2671-
isReturn = false
2696+
pos.isSelf()
26722697
}
26732698

26742699
/**
@@ -3301,14 +3326,11 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
33013326
private module NonMethodCallMatching = Matching<NonMethodCallMatchingInput>;
33023327

33033328
pragma[nomagic]
3304-
private Type inferNonMethodCallType0(AstNode n, boolean isReturn, TypePath path) {
3305-
exists(NonMethodCallMatchingInput::Access a, NonMethodCallMatchingInput::AccessPosition apos |
3306-
n = a.getNodeAt(apos) and
3307-
if apos.isReturn() then isReturn = true else isReturn = false
3308-
|
3309-
result = NonMethodCallMatching::inferAccessType(a, apos, path)
3329+
private Type inferNonMethodCallType0(AstNode n, FunctionPosition pos, TypePath path) {
3330+
exists(NonMethodCallMatchingInput::Access a | n = a.getNodeAt(pos) |
3331+
result = NonMethodCallMatching::inferAccessType(a, pos, path)
33103332
or
3311-
a.hasUnknownTypeAt(apos, path) and
3333+
a.hasUnknownTypeAt(pos, path) and
33123334
result = TUnknownType()
33133335
)
33143336
}
@@ -3379,11 +3401,10 @@ private module OperationMatchingInput implements MatchingInputSig {
33793401
private module OperationMatching = Matching<OperationMatchingInput>;
33803402

33813403
pragma[nomagic]
3382-
private Type inferOperationType0(AstNode n, boolean isReturn, TypePath path) {
3383-
exists(OperationMatchingInput::Access a, OperationMatchingInput::AccessPosition apos |
3384-
n = a.getNodeAt(apos) and
3385-
result = OperationMatching::inferAccessType(a, apos, path) and
3386-
if apos.isReturn() then isReturn = true else isReturn = false
3404+
private Type inferOperationType0(AstNode n, FunctionPosition pos, TypePath path) {
3405+
exists(OperationMatchingInput::Access a |
3406+
n = a.getNodeAt(pos) and
3407+
result = OperationMatching::inferAccessType(a, pos, path)
33873408
)
33883409
}
33893410

@@ -3716,11 +3737,13 @@ private module AwaitSatisfiesConstraintInput implements SatisfiesConstraintInput
37163737
}
37173738
}
37183739

3740+
private module AwaitSatisfiesConstraint =
3741+
SatisfiesConstraint<AwaitTarget, AwaitSatisfiesConstraintInput>;
3742+
37193743
pragma[nomagic]
37203744
private Type inferAwaitExprType(AstNode n, TypePath path) {
37213745
exists(TypePath exprPath |
3722-
SatisfiesConstraint<AwaitTarget, AwaitSatisfiesConstraintInput>::satisfiesConstraintType(n.(AwaitExpr)
3723-
.getExpr(), _, exprPath, result) and
3746+
AwaitSatisfiesConstraint::satisfiesConstraintType(n.(AwaitExpr).getExpr(), _, exprPath, result) and
37243747
exprPath.isCons(getFutureOutputTypeParameter(), path)
37253748
)
37263749
}
@@ -3922,13 +3945,15 @@ private AssociatedTypeTypeParameter getIntoIteratorItemTypeParameter() {
39223945
result = getAssociatedTypeTypeParameter(any(IntoIteratorTrait t).getItemType())
39233946
}
39243947

3948+
private module ForIterableSatisfiesConstraint =
3949+
SatisfiesConstraint<ForIterableExpr, ForIterableSatisfiesConstraintInput>;
3950+
39253951
pragma[nomagic]
39263952
private Type inferForLoopExprType(AstNode n, TypePath path) {
39273953
// type of iterable -> type of pattern (loop variable)
39283954
exists(ForExpr fe, TypePath exprPath, AssociatedTypeTypeParameter tp |
39293955
n = fe.getPat() and
3930-
SatisfiesConstraint<ForIterableExpr, ForIterableSatisfiesConstraintInput>::satisfiesConstraintType(fe.getIterable(),
3931-
_, exprPath, result) and
3956+
ForIterableSatisfiesConstraint::satisfiesConstraintType(fe.getIterable(), _, exprPath, result) and
39323957
exprPath.isCons(tp, path)
39333958
|
39343959
tp = getIntoIteratorItemTypeParameter()
@@ -3963,10 +3988,12 @@ private module InvokedClosureSatisfiesConstraintInput implements
39633988
}
39643989
}
39653990

3991+
private module InvokedClosureSatisfiesConstraint =
3992+
SatisfiesConstraint<InvokedClosureExpr, InvokedClosureSatisfiesConstraintInput>;
3993+
39663994
/** Gets the type of `ce` when viewed as an implementation of `FnOnce`. */
39673995
private Type invokedClosureFnTypeAt(InvokedClosureExpr ce, TypePath path) {
3968-
SatisfiesConstraint<InvokedClosureExpr, InvokedClosureSatisfiesConstraintInput>::satisfiesConstraintType(ce,
3969-
_, path, result)
3996+
InvokedClosureSatisfiesConstraint::satisfiesConstraintType(ce, _, path, result)
39703997
}
39713998

39723999
/**

rust/ql/test/library-tests/type-inference/type-inference.expected

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9488,7 +9488,6 @@ inferType
94889488
| main.rs:1412:17:1412:20 | self | TRef.TSlice | main.rs:1410:14:1410:23 | T |
94899489
| main.rs:1412:17:1412:27 | self.get(...) | | {EXTERNAL LOCATION} | Option |
94909490
| main.rs:1412:17:1412:27 | self.get(...) | T | {EXTERNAL LOCATION} | & |
9491-
| main.rs:1412:17:1412:27 | self.get(...) | T.TRef | main.rs:1410:14:1410:23 | T |
94929491
| main.rs:1412:17:1412:36 | ... .unwrap() | | {EXTERNAL LOCATION} | & |
94939492
| main.rs:1412:17:1412:36 | ... .unwrap() | TRef | main.rs:1410:14:1410:23 | T |
94949493
| main.rs:1412:26:1412:26 | 0 | | {EXTERNAL LOCATION} | i32 |

0 commit comments

Comments
 (0)