Skip to content

Commit 47ee3ed

Browse files
committed
Rust: Improve type inference for closures
1 parent 37c819b commit 47ee3ed

File tree

3 files changed

+575
-122
lines changed

3 files changed

+575
-122
lines changed

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 142 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,18 @@ private module Input2 implements InputSig2 {
205205
// `DynTypeBoundListMention` for further details.
206206
exists(DynTraitTypeRepr object |
207207
abs = object and
208-
condition = object.getTypeBoundList() and
208+
condition = object.getTypeBoundList()
209+
|
209210
constraint = object.getTrait()
211+
// or
212+
// TTrait(object.getTrait()) =
213+
// constraint
214+
// .(ImplTraitTypeRepr)
215+
// .getTypeBoundList()
216+
// .getABound()
217+
// .getTypeRepr()
218+
// .(TypeMention)
219+
// .resolveType()
210220
)
211221
)
212222
}
@@ -407,6 +417,14 @@ private predicate isPanicMacroCall(MacroExpr me) {
407417
me.getMacroCall().resolveMacro().(MacroRules).getName().getText() = "panic"
408418
}
409419

420+
// Due to "binding modes" the type of the pattern is not necessarily the
421+
// same as the type of the initializer. The pattern being an identifier
422+
// pattern is sufficient to ensure that this is not the case.
423+
private predicate identLetStmt(LetStmt let, IdentPat lhs, Expr rhs) {
424+
let.getPat() = lhs and
425+
let.getInitializer() = rhs
426+
}
427+
410428
/** Module for inferring certain type information. */
411429
module CertainTypeInference {
412430
pragma[nomagic]
@@ -484,11 +502,7 @@ module CertainTypeInference {
484502
// is not a certain type equality.
485503
exists(LetStmt let |
486504
not let.hasTypeRepr() and
487-
// Due to "binding modes" the type of the pattern is not necessarily the
488-
// same as the type of the initializer. The pattern being an identifier
489-
// pattern is sufficient to ensure that this is not the case.
490-
let.getPat().(IdentPat) = n1 and
491-
let.getInitializer() = n2
505+
identLetStmt(let, n1, n2)
492506
)
493507
or
494508
exists(LetExpr let |
@@ -512,6 +526,25 @@ module CertainTypeInference {
512526
)
513527
else prefix2.isEmpty()
514528
)
529+
or
530+
exists(CallExprImpl::DynamicCallExpr dce, TupleType tt, int i |
531+
n1 = dce.getArgList() and
532+
tt.getArity() = dce.getNumberOfSyntacticArguments() and
533+
n2 = dce.getSyntacticPositionalArgument(i) and
534+
prefix1 = TypePath::singleton(tt.getPositionalTypeParameter(i)) and
535+
prefix2.isEmpty()
536+
)
537+
or
538+
exists(ClosureExpr ce, int index |
539+
n1 = ce and
540+
n2 = ce.getParam(index).getPat() and
541+
prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and
542+
prefix2.isEmpty()
543+
)
544+
or
545+
n1 = any(ClosureExpr ce | not ce.hasRetType() and ce.getClosureBody() = n2) and
546+
prefix1 = closureReturnPath() and
547+
prefix2.isEmpty()
515548
}
516549

517550
pragma[nomagic]
@@ -781,17 +814,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
781814
prefix2.isEmpty() and
782815
s = getRangeType(n1)
783816
)
784-
or
785-
exists(ClosureExpr ce, int index |
786-
n1 = ce and
787-
n2 = ce.getParam(index).getPat() and
788-
prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and
789-
prefix2.isEmpty()
790-
)
791-
or
792-
n1.(ClosureExpr).getClosureBody() = n2 and
793-
prefix1 = closureReturnPath() and
794-
prefix2.isEmpty()
795817
}
796818

797819
/**
@@ -828,6 +850,19 @@ private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) {
828850
prefix.isEmpty()
829851
}
830852

853+
private Type inferUnknownTypeFromAnnotation(AstNode n, TypePath path) {
854+
inferType(n, path) = TUnknownType() and
855+
// Normally, these are coercion sites, but in case a type is unknown we
856+
// allow for type information to flow from the type annotation.
857+
exists(TypeMention tm | result = tm.resolveTypeAt(path) |
858+
tm = any(LetStmt let | identLetStmt(let, _, n)).getTypeRepr()
859+
or
860+
tm = any(ClosureExpr ce | n = ce.getBody()).getRetType().getTypeRepr()
861+
or
862+
tm = getReturnTypeMention(any(Function f | n = f.getBody()))
863+
)
864+
}
865+
831866
/**
832867
* Holds if the type tree of `n1` at `prefix1` should be equal to the type tree
833868
* of `n2` at `prefix2`, but type information should only propagate from `n1` to
@@ -1509,6 +1544,8 @@ private module MethodResolution {
15091544
* or
15101545
* 4. `MethodCallOperation`: an operation expression, `x + y`, which is syntactic sugar
15111546
* for `Add::add(x, y)`.
1547+
* 5. `ClosureMethodCall`: a call to a closure, `c(x)`, which is syntactic sugar for
1548+
* `c.call_once(x)`, `c.call_mut(x)`, or `c.call(x)`.
15121549
*
15131550
* Note that only in case 1 and 2 is auto-dereferencing and borrowing allowed.
15141551
*
@@ -1520,7 +1557,7 @@ private module MethodResolution {
15201557
abstract class MethodCall extends Expr {
15211558
abstract predicate hasNameAndArity(string name, int arity);
15221559

1523-
abstract Expr getArg(ArgumentPosition pos);
1560+
abstract AstNode getArg(ArgumentPosition pos);
15241561

15251562
abstract predicate supportsAutoDerefAndBorrow();
15261563

@@ -1888,6 +1925,16 @@ private module MethodResolution {
18881925
)
18891926
}
18901927

1928+
private Method testresolveCallTarget(
1929+
ImplOrTraitItemNode i, DerefChain derefChain, BorrowKind borrow
1930+
) {
1931+
this = Debug::getRelevantLocatable() and
1932+
exists(MethodCallCand mcc |
1933+
mcc = MkMethodCallCand(this, derefChain, borrow) and
1934+
result = mcc.resolveCallTarget(i)
1935+
)
1936+
}
1937+
18911938
/**
18921939
* Holds if the argument `arg` of this call has been implicitly dereferenced
18931940
* and borrowed according to `derefChain` and `borrow`, in order to be able to
@@ -2050,6 +2097,26 @@ private module MethodResolution {
20502097
override Trait getTrait() { super.isOverloaded(result, _, _) }
20512098
}
20522099

2100+
private class ClosureMethodCall extends MethodCall instanceof CallExprImpl::DynamicCallExpr {
2101+
pragma[nomagic]
2102+
override predicate hasNameAndArity(string name, int arity) {
2103+
name = "call_once" and // todo: handle call_mut and call
2104+
arity = 1 // args are passed in a tuple
2105+
}
2106+
2107+
override AstNode getArg(ArgumentPosition pos) {
2108+
pos.isSelf() and
2109+
result = super.getFunction()
2110+
or
2111+
pos.asPosition() = 0 and
2112+
result = super.getArgList()
2113+
}
2114+
2115+
override predicate supportsAutoDerefAndBorrow() { any() }
2116+
2117+
override Trait getTrait() { result instanceof AnyFnTrait }
2118+
}
2119+
20532120
pragma[nomagic]
20542121
private Method getMethodSuccessor(ImplOrTraitItemNode i, string name, int arity) {
20552122
result = i.getASuccessor(name) and
@@ -2471,7 +2538,8 @@ private module MethodCallMatchingInput implements MatchingWithEnvironmentInputSi
24712538
class Access extends MethodCallFinal, ContextTyping::ContextTypedCallCand {
24722539
Access() {
24732540
// handled in the `OperationMatchingInput` module
2474-
not this instanceof Operation
2541+
not this instanceof Operation //and
2542+
// this = Debug::getRelevantLocatable()
24752543
}
24762544

24772545
pragma[nomagic]
@@ -2523,6 +2591,16 @@ private module MethodCallMatchingInput implements MatchingWithEnvironmentInputSi
25232591
result = this.getInferredNonSelfType(apos, path)
25242592
}
25252593

2594+
private Type testgetInferredType(string derefChainBorrow, AccessPosition apos, TypePath path) {
2595+
this = Debug::getRelevantLocatable() and
2596+
(
2597+
result = this.getInferredSelfType(apos, derefChainBorrow, path)
2598+
or
2599+
result = this.getInferredNonSelfType(apos, path) and
2600+
derefChainBorrow = ""
2601+
)
2602+
}
2603+
25262604
Method getTarget(ImplOrTraitItemNode i, string derefChainBorrow) {
25272605
exists(DerefChain derefChain, BorrowKind borrow |
25282606
derefChainBorrow = encodeDerefChainBorrow(derefChain, borrow) and
@@ -2596,6 +2674,7 @@ pragma[nomagic]
25962674
private Type inferMethodCallTypeSelf(
25972675
AstNode n, DerefChain derefChain, BorrowKind borrow, TypePath path
25982676
) {
2677+
// n = Debug::getRelevantLocatable() and
25992678
exists(MethodCallMatchingInput::AccessPosition apos, string derefChainBorrow |
26002679
result = inferMethodCallType0(_, apos, n, derefChainBorrow, path) and
26012680
apos.isSelf() and
@@ -2639,6 +2718,11 @@ private Type inferMethodCallTypePreCheck(AstNode n, boolean isReturn, TypePath p
26392718
isReturn = false
26402719
}
26412720

2721+
private Type testinferMethodCallTypePreCheck(AstNode n, boolean isReturn, TypePath path) {
2722+
result = inferMethodCallTypePreCheck(n, isReturn, path) and
2723+
n = Debug::getRelevantLocatable()
2724+
}
2725+
26422726
/**
26432727
* Gets the type of `n` at `path`, where `n` is either a method call or an
26442728
* argument/receiver of a method call.
@@ -3137,6 +3221,7 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
31373221
}
31383222

31393223
class Access extends NonMethodResolution::NonMethodCall, ContextTyping::ContextTypedCallCand {
3224+
// Access() { this = Debug::getRelevantLocatable() }
31403225
pragma[nomagic]
31413226
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
31423227
result = getCallExprTypeArgument(this, apos, path)
@@ -3210,6 +3295,12 @@ private Type inferNonMethodCallType0(AstNode n, boolean isReturn, TypePath path)
32103295
)
32113296
}
32123297

3298+
pragma[nomagic]
3299+
private Type inferNonMethodCallType1(AstNode n, boolean isReturn, TypePath path) {
3300+
result = inferNonMethodCallType0(n, isReturn, path) and
3301+
n = Debug::getRelevantLocatable()
3302+
}
3303+
32133304
private predicate inferNonMethodCallType =
32143305
ContextTyping::CheckContextTyping<inferNonMethodCallType0/3>::check/2;
32153306

@@ -3892,73 +3983,39 @@ private TypePath closureParameterPath(int arity, int index) {
38923983
TypePath::singleton(getTupleTypeParameter(arity, index)))
38933984
}
38943985

3895-
/** Gets the path to the return type of the `FnOnce` trait. */
3896-
private TypePath fnReturnPath() {
3897-
result = TypePath::singleton(getAssociatedTypeTypeParameter(any(FnOnceTrait t).getOutputType()))
3898-
}
3899-
3900-
/**
3901-
* Gets the path to the parameter type of the `FnOnce` trait with arity `arity`
3902-
* and index `index`.
3903-
*/
3904-
pragma[nomagic]
3905-
private TypePath fnParameterPath(int arity, int index) {
3906-
result =
3907-
TypePath::cons(TTypeParamTypeParameter(any(FnOnceTrait t).getTypeParam()),
3908-
TypePath::singleton(getTupleTypeParameter(arity, index)))
3909-
}
3910-
39113986
pragma[nomagic]
3912-
private Type inferDynamicCallExprType(Expr n, TypePath path) {
3913-
exists(InvokedClosureExpr ce |
3914-
// Propagate the function's return type to the call expression
3915-
exists(TypePath path0 | result = invokedClosureFnTypeAt(ce, path0) |
3916-
n = ce.getCall() and
3917-
path = path0.stripPrefix(fnReturnPath())
3987+
private Type inferClosureExprType(AstNode n, TypePath path) {
3988+
exists(ClosureExpr ce |
3989+
n = ce and
3990+
(
3991+
path.isEmpty() and
3992+
result = closureRootType()
3993+
or
3994+
path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam())) and
3995+
result.(TupleType).getArity() = ce.getNumberOfParams()
39183996
or
3919-
// Propagate the function's parameter type to the arguments
3920-
exists(int index |
3921-
n = ce.getCall().getSyntacticPositionalArgument(index) and
3922-
path =
3923-
path0.stripPrefix(fnParameterPath(ce.getCall().getArgList().getNumberOfArgs(), index))
3997+
exists(TypePath path0 |
3998+
result = ce.getRetType().getTypeRepr().(TypeMention).resolveTypeAt(path0) and
3999+
path = closureReturnPath().append(path0)
39244000
)
39254001
)
39264002
or
3927-
// _If_ the invoked expression has the type of a closure, then we propagate
3928-
// the surrounding types into the closure.
3929-
exists(int arity, TypePath path0 | ce.getTypeAt(TypePath::nil()) = closureRootType() |
3930-
// Propagate the type of arguments to the parameter types of closure
3931-
exists(int index, ArgList args |
3932-
n = ce and
3933-
args = ce.getCall().getArgList() and
3934-
arity = args.getNumberOfArgs() and
3935-
result = inferType(args.getArg(index), path0) and
3936-
path = closureParameterPath(arity, index).append(path0)
3937-
)
3938-
or
3939-
// Propagate the type of the call expression to the return type of the closure
3940-
n = ce and
3941-
arity = ce.getCall().getArgList().getNumberOfArgs() and
3942-
result = inferType(ce.getCall(), path0) and
3943-
path = closureReturnPath().append(path0)
4003+
exists(Param p |
4004+
p = ce.getAParam() and
4005+
not p.hasTypeRepr() and
4006+
n = p.getPat() and
4007+
result = TUnknownType() and
4008+
path.isEmpty()
39444009
)
39454010
)
39464011
}
39474012

39484013
pragma[nomagic]
3949-
private Type inferClosureExprType(AstNode n, TypePath path) {
3950-
exists(ClosureExpr ce |
3951-
n = ce and
3952-
path.isEmpty() and
3953-
result = closureRootType()
3954-
or
3955-
n = ce and
3956-
path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam())) and
3957-
result.(TupleType).getArity() = ce.getNumberOfParams()
3958-
or
3959-
// Propagate return type annotation to body
3960-
n = ce.getClosureBody() and
3961-
result = ce.getRetType().getTypeRepr().(TypeMention).resolveTypeAt(path)
4014+
private TupleType inferArgList(ArgList args, TypePath path) {
4015+
exists(CallExprImpl::DynamicCallExpr dce |
4016+
args = dce.getArgList() and
4017+
result.getArity() = dce.getNumberOfSyntacticArguments() and
4018+
path.isEmpty()
39624019
)
39634020
}
39644021

@@ -4005,7 +4062,9 @@ private module Cached {
40054062
or
40064063
i instanceof ImplItemNode and dispatch = false
40074064
|
4008-
result = call.(MethodResolution::MethodCall).resolveCallTarget(i, _, _) or
4065+
result = call.(MethodResolution::MethodCall).resolveCallTarget(i, _, _) and
4066+
not call instanceof CallExprImpl::DynamicCallExpr
4067+
or
40094068
result = call.(NonMethodResolution::NonMethodCall).resolveCallTargetViaTypeInference(i)
40104069
)
40114070
}
@@ -4115,13 +4174,15 @@ private module Cached {
41154174
or
41164175
result = inferForLoopExprType(n, path)
41174176
or
4118-
result = inferDynamicCallExprType(n, path)
4119-
or
41204177
result = inferClosureExprType(n, path)
41214178
or
4179+
result = inferArgList(n, path)
4180+
or
41224181
result = inferStructPatType(n, path)
41234182
or
41244183
result = inferTupleStructPatType(n, path)
4184+
or
4185+
result = inferUnknownTypeFromAnnotation(n, path)
41254186
)
41264187
}
41274188
}
@@ -4138,8 +4199,8 @@ private module Debug {
41384199
Locatable getRelevantLocatable() {
41394200
exists(string filepath, int startline, int startcolumn, int endline, int endcolumn |
41404201
result.getLocation().hasLocationInfo(filepath, startline, startcolumn, endline, endcolumn) and
4141-
filepath.matches("%/sqlx.rs") and
4142-
startline = [56 .. 60]
4202+
filepath.matches("%/closure.rs") and
4203+
startline = [10]
41434204
)
41444205
}
41454206

0 commit comments

Comments
 (0)