Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 142 additions & 81 deletions rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,18 @@
// `DynTypeBoundListMention` for further details.
exists(DynTraitTypeRepr object |
abs = object and
condition = object.getTypeBoundList() and
condition = object.getTypeBoundList()
|
constraint = object.getTrait()
// or
// TTrait(object.getTrait()) =
// constraint
// .(ImplTraitTypeRepr)
// .getTypeBoundList()
// .getABound()
// .getTypeRepr()
// .(TypeMention)
// .resolveType()
)
)
}
Expand Down Expand Up @@ -407,6 +417,14 @@
me.getMacroCall().resolveMacro().(MacroRules).getName().getText() = "panic"
}

// Due to "binding modes" the type of the pattern is not necessarily the
// same as the type of the initializer. The pattern being an identifier
// pattern is sufficient to ensure that this is not the case.
private predicate identLetStmt(LetStmt let, IdentPat lhs, Expr rhs) {
let.getPat() = lhs and
let.getInitializer() = rhs
}

/** Module for inferring certain type information. */
module CertainTypeInference {
pragma[nomagic]
Expand Down Expand Up @@ -484,11 +502,7 @@
// is not a certain type equality.
exists(LetStmt let |
not let.hasTypeRepr() and
// Due to "binding modes" the type of the pattern is not necessarily the
// same as the type of the initializer. The pattern being an identifier
// pattern is sufficient to ensure that this is not the case.
let.getPat().(IdentPat) = n1 and
let.getInitializer() = n2
identLetStmt(let, n1, n2)
)
or
exists(LetExpr let |
Expand All @@ -512,6 +526,25 @@
)
else prefix2.isEmpty()
)
or
exists(CallExprImpl::DynamicCallExpr dce, TupleType tt, int i |
n1 = dce.getArgList() and
tt.getArity() = dce.getNumberOfSyntacticArguments() and
n2 = dce.getSyntacticPositionalArgument(i) and
prefix1 = TypePath::singleton(tt.getPositionalTypeParameter(i)) and
prefix2.isEmpty()
)
or
exists(ClosureExpr ce, int index |
n1 = ce and
n2 = ce.getParam(index).getPat() and
prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and
prefix2.isEmpty()
)
or
n1 = any(ClosureExpr ce | not ce.hasRetType() and ce.getClosureBody() = n2) and
prefix1 = closureReturnPath() and
prefix2.isEmpty()
}

pragma[nomagic]
Expand Down Expand Up @@ -781,17 +814,6 @@
prefix2.isEmpty() and
s = getRangeType(n1)
)
or
exists(ClosureExpr ce, int index |
n1 = ce and
n2 = ce.getParam(index).getPat() and
prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and
prefix2.isEmpty()
)
or
n1.(ClosureExpr).getClosureBody() = n2 and
prefix1 = closureReturnPath() and
prefix2.isEmpty()
}

/**
Expand Down Expand Up @@ -828,6 +850,19 @@
prefix.isEmpty()
}

private Type inferUnknownTypeFromAnnotation(AstNode n, TypePath path) {
inferType(n, path) = TUnknownType() and
// Normally, these are coercion sites, but in case a type is unknown we
// allow for type information to flow from the type annotation.
exists(TypeMention tm | result = tm.resolveTypeAt(path) |
tm = any(LetStmt let | identLetStmt(let, _, n)).getTypeRepr()
or
tm = any(ClosureExpr ce | n = ce.getBody()).getRetType().getTypeRepr()
or
tm = getReturnTypeMention(any(Function f | n = f.getBody()))
)
}

/**
* Holds if the type tree of `n1` at `prefix1` should be equal to the type tree
* of `n2` at `prefix2`, but type information should only propagate from `n1` to
Expand Down Expand Up @@ -1509,6 +1544,8 @@
* or
* 4. `MethodCallOperation`: an operation expression, `x + y`, which is syntactic sugar
* for `Add::add(x, y)`.
* 5. `ClosureMethodCall`: a call to a closure, `c(x)`, which is syntactic sugar for
* `c.call_once(x)`, `c.call_mut(x)`, or `c.call(x)`.
*
* Note that only in case 1 and 2 is auto-dereferencing and borrowing allowed.
*
Expand All @@ -1520,7 +1557,7 @@
abstract class MethodCall extends Expr {
abstract predicate hasNameAndArity(string name, int arity);

abstract Expr getArg(ArgumentPosition pos);
abstract AstNode getArg(ArgumentPosition pos);

abstract predicate supportsAutoDerefAndBorrow();

Expand Down Expand Up @@ -1888,6 +1925,16 @@
)
}

private Method testresolveCallTarget(

Check warning

Code scanning / CodeQL

Dead code Warning

This code is never used, and it's not publicly exported.
ImplOrTraitItemNode i, DerefChain derefChain, BorrowKind borrow
) {
this = Debug::getRelevantLocatable() and
exists(MethodCallCand mcc |
mcc = MkMethodCallCand(this, derefChain, borrow) and
result = mcc.resolveCallTarget(i)
)
}

/**
* Holds if the argument `arg` of this call has been implicitly dereferenced
* and borrowed according to `derefChain` and `borrow`, in order to be able to
Expand Down Expand Up @@ -2050,6 +2097,26 @@
override Trait getTrait() { super.isOverloaded(result, _, _) }
}

private class ClosureMethodCall extends MethodCall instanceof CallExprImpl::DynamicCallExpr {
pragma[nomagic]
override predicate hasNameAndArity(string name, int arity) {
name = "call_once" and // todo: handle call_mut and call
arity = 1 // args are passed in a tuple
}

override AstNode getArg(ArgumentPosition pos) {
pos.isSelf() and
result = super.getFunction()
or
pos.asPosition() = 0 and
result = super.getArgList()
}

override predicate supportsAutoDerefAndBorrow() { any() }

override Trait getTrait() { result instanceof AnyFnTrait }
}

pragma[nomagic]
private Method getMethodSuccessor(ImplOrTraitItemNode i, string name, int arity) {
result = i.getASuccessor(name) and
Expand Down Expand Up @@ -2471,7 +2538,8 @@
class Access extends MethodCallFinal, ContextTyping::ContextTypedCallCand {
Access() {
// handled in the `OperationMatchingInput` module
not this instanceof Operation
not this instanceof Operation //and
// this = Debug::getRelevantLocatable()
}

pragma[nomagic]
Expand Down Expand Up @@ -2523,6 +2591,16 @@
result = this.getInferredNonSelfType(apos, path)
}

private Type testgetInferredType(string derefChainBorrow, AccessPosition apos, TypePath path) {

Check warning

Code scanning / CodeQL

Dead code Warning

This code is never used, and it's not publicly exported.
this = Debug::getRelevantLocatable() and
(
result = this.getInferredSelfType(apos, derefChainBorrow, path)
or
result = this.getInferredNonSelfType(apos, path) and
derefChainBorrow = ""
)
}

Method getTarget(ImplOrTraitItemNode i, string derefChainBorrow) {
exists(DerefChain derefChain, BorrowKind borrow |
derefChainBorrow = encodeDerefChainBorrow(derefChain, borrow) and
Expand Down Expand Up @@ -2596,6 +2674,7 @@
private Type inferMethodCallTypeSelf(
AstNode n, DerefChain derefChain, BorrowKind borrow, TypePath path
) {
// n = Debug::getRelevantLocatable() and
exists(MethodCallMatchingInput::AccessPosition apos, string derefChainBorrow |
result = inferMethodCallType0(_, apos, n, derefChainBorrow, path) and
apos.isSelf() and
Expand Down Expand Up @@ -2639,6 +2718,11 @@
isReturn = false
}

private Type testinferMethodCallTypePreCheck(AstNode n, boolean isReturn, TypePath path) {

Check warning

Code scanning / CodeQL

Dead code Warning

This code is never used, and it's not publicly exported.
result = inferMethodCallTypePreCheck(n, isReturn, path) and
n = Debug::getRelevantLocatable()
}

/**
* Gets the type of `n` at `path`, where `n` is either a method call or an
* argument/receiver of a method call.
Expand Down Expand Up @@ -3137,6 +3221,7 @@
}

class Access extends NonMethodResolution::NonMethodCall, ContextTyping::ContextTypedCallCand {
// Access() { this = Debug::getRelevantLocatable() }
pragma[nomagic]
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
result = getCallExprTypeArgument(this, apos, path)
Expand Down Expand Up @@ -3210,6 +3295,12 @@
)
}

pragma[nomagic]
private Type inferNonMethodCallType1(AstNode n, boolean isReturn, TypePath path) {

Check warning

Code scanning / CodeQL

Dead code Warning

This code is never used, and it's not publicly exported.
result = inferNonMethodCallType0(n, isReturn, path) and
n = Debug::getRelevantLocatable()
}

private predicate inferNonMethodCallType =
ContextTyping::CheckContextTyping<inferNonMethodCallType0/3>::check/2;

Expand Down Expand Up @@ -3892,73 +3983,39 @@
TypePath::singleton(getTupleTypeParameter(arity, index)))
}

/** Gets the path to the return type of the `FnOnce` trait. */
private TypePath fnReturnPath() {
result = TypePath::singleton(getAssociatedTypeTypeParameter(any(FnOnceTrait t).getOutputType()))
}

/**
* Gets the path to the parameter type of the `FnOnce` trait with arity `arity`
* and index `index`.
*/
pragma[nomagic]
private TypePath fnParameterPath(int arity, int index) {
result =
TypePath::cons(TTypeParamTypeParameter(any(FnOnceTrait t).getTypeParam()),
TypePath::singleton(getTupleTypeParameter(arity, index)))
}

pragma[nomagic]
private Type inferDynamicCallExprType(Expr n, TypePath path) {
exists(InvokedClosureExpr ce |
// Propagate the function's return type to the call expression
exists(TypePath path0 | result = invokedClosureFnTypeAt(ce, path0) |
n = ce.getCall() and
path = path0.stripPrefix(fnReturnPath())
private Type inferClosureExprType(AstNode n, TypePath path) {
exists(ClosureExpr ce |
n = ce and
(
path.isEmpty() and
result = closureRootType()
or
path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam())) and
result.(TupleType).getArity() = ce.getNumberOfParams()
or
// Propagate the function's parameter type to the arguments
exists(int index |
n = ce.getCall().getSyntacticPositionalArgument(index) and
path =
path0.stripPrefix(fnParameterPath(ce.getCall().getArgList().getNumberOfArgs(), index))
exists(TypePath path0 |
result = ce.getRetType().getTypeRepr().(TypeMention).resolveTypeAt(path0) and
path = closureReturnPath().append(path0)
)
)
or
// _If_ the invoked expression has the type of a closure, then we propagate
// the surrounding types into the closure.
exists(int arity, TypePath path0 | ce.getTypeAt(TypePath::nil()) = closureRootType() |
// Propagate the type of arguments to the parameter types of closure
exists(int index, ArgList args |
n = ce and
args = ce.getCall().getArgList() and
arity = args.getNumberOfArgs() and
result = inferType(args.getArg(index), path0) and
path = closureParameterPath(arity, index).append(path0)
)
or
// Propagate the type of the call expression to the return type of the closure
n = ce and
arity = ce.getCall().getArgList().getNumberOfArgs() and
result = inferType(ce.getCall(), path0) and
path = closureReturnPath().append(path0)
exists(Param p |
p = ce.getAParam() and
not p.hasTypeRepr() and
n = p.getPat() and
result = TUnknownType() and
path.isEmpty()
)
)
}

pragma[nomagic]
private Type inferClosureExprType(AstNode n, TypePath path) {
exists(ClosureExpr ce |
n = ce and
path.isEmpty() and
result = closureRootType()
or
n = ce and
path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam())) and
result.(TupleType).getArity() = ce.getNumberOfParams()
or
// Propagate return type annotation to body
n = ce.getClosureBody() and
result = ce.getRetType().getTypeRepr().(TypeMention).resolveTypeAt(path)
private TupleType inferArgList(ArgList args, TypePath path) {
exists(CallExprImpl::DynamicCallExpr dce |
args = dce.getArgList() and
result.getArity() = dce.getNumberOfSyntacticArguments() and
path.isEmpty()
)
}

Expand Down Expand Up @@ -4005,7 +4062,9 @@
or
i instanceof ImplItemNode and dispatch = false
|
result = call.(MethodResolution::MethodCall).resolveCallTarget(i, _, _) or
result = call.(MethodResolution::MethodCall).resolveCallTarget(i, _, _) and
not call instanceof CallExprImpl::DynamicCallExpr
or
result = call.(NonMethodResolution::NonMethodCall).resolveCallTargetViaTypeInference(i)
)
}
Expand Down Expand Up @@ -4115,13 +4174,15 @@
or
result = inferForLoopExprType(n, path)
or
result = inferDynamicCallExprType(n, path)
or
result = inferClosureExprType(n, path)
or
result = inferArgList(n, path)
or
result = inferStructPatType(n, path)
or
result = inferTupleStructPatType(n, path)
or
result = inferUnknownTypeFromAnnotation(n, path)
)
}
}
Expand All @@ -4138,8 +4199,8 @@
Locatable getRelevantLocatable() {
exists(string filepath, int startline, int startcolumn, int endline, int endcolumn |
result.getLocation().hasLocationInfo(filepath, startline, startcolumn, endline, endcolumn) and
filepath.matches("%/sqlx.rs") and
startline = [56 .. 60]
filepath.matches("%/closure.rs") and
startline = [10]

Check warning

Code scanning / CodeQL

Singleton set literal Warning

Singleton set literal can be replaced by its member.
)
}

Expand Down
Loading
Loading