Skip to content

Commit 076faa3

Browse files
committed
add pyTorch :) code execution sinks, add proper tests
1 parent 3d7db0e commit 076faa3

File tree

5 files changed

+106
-0
lines changed

5 files changed

+106
-0
lines changed

python/ql/lib/semmle/python/Frameworks.qll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ private import semmle.python.frameworks.SqlAlchemy
6868
private import semmle.python.frameworks.Starlette
6969
private import semmle.python.frameworks.Stdlib
7070
private import semmle.python.frameworks.Toml
71+
private import semmle.python.frameworks.Torch
7172
private import semmle.python.frameworks.Tornado
7273
private import semmle.python.frameworks.Twisted
7374
private import semmle.python.frameworks.Ujson
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/**
2+
* Provides classes modeling security-relevant aspects of the `torch` PyPI package.
3+
* See https://pypi.org/project/torch/.
4+
*/
5+
6+
private import python
7+
private import semmle.python.Concepts
8+
private import semmle.python.ApiGraphs
9+
10+
/**
11+
* Provides models for the `torch` PyPI package.
12+
* See https://pypi.org/project/torch/.
13+
*/
14+
private module Torch {
15+
/**
16+
* A call to `torch.load`
17+
* See https://pytorch.org/docs/stable/generated/torch.load.html#torch.load
18+
*/
19+
private class TorchLoadCall extends Decoding::Range, API::CallNode {
20+
TorchLoadCall() { this = API::moduleImport("torch").getMember("load").getACall() }
21+
22+
override predicate mayExecuteInput() {
23+
not exists(this.getParameter(2, "pickle_module").asSink()) or
24+
exists(this.getParameter(2, "pickle_module").asSink().asExpr().(None))
25+
}
26+
27+
override DataFlow::Node getAnInput() { result = this.getParameter(0, "f").asSink() }
28+
29+
override DataFlow::Node getOutput() { result = this }
30+
31+
override string getFormat() { result = "pickle" }
32+
}
33+
34+
API::Node test() {
35+
result = API::moduleImport("torch").getMember("package").getMember("PackageImporter")
36+
}
37+
38+
/**
39+
* A call to `torch.package.PackageImporter`
40+
* See https://pytorch.org/docs/stable/package.html#torch.package.PackageImporter
41+
*/
42+
private class TorchPackageImporter extends Decoding::Range, API::CallNode {
43+
TorchPackageImporter() {
44+
this = API::moduleImport("torch").getMember("package").getMember("PackageImporter").getACall() and
45+
exists(this.getAMethodCall("load_pickle"))
46+
}
47+
48+
override predicate mayExecuteInput() { any() }
49+
50+
override DataFlow::Node getAnInput() {
51+
result = this.getParameter(0, "file_or_buffer").asSink()
52+
}
53+
54+
override DataFlow::Node getOutput() { result = this.getAMethodCall("load_pickle") }
55+
56+
override string getFormat() { result = "pickle" }
57+
}
58+
59+
/**
60+
* A call to `torch.jit.load`
61+
* See https://pytorch.org/docs/stable/generated/torch.jit.load.html#torch.jit.load
62+
*/
63+
private class TorchJitLoad extends Decoding::Range, API::CallNode {
64+
TorchJitLoad() {
65+
this = API::moduleImport("torch").getMember("jit").getMember("load").getACall()
66+
}
67+
68+
override predicate mayExecuteInput() { any() }
69+
70+
override DataFlow::Node getAnInput() { result = this.getParameter(0, "f").asSink() }
71+
72+
override DataFlow::Node getOutput() { result = this }
73+
74+
override string getFormat() { result = "pickle" }
75+
}
76+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
testFailures
2+
failures
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
import python
2+
import experimental.meta.ConceptsTest
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from io import BytesIO
2+
3+
import torch
4+
5+
6+
def someSafeMethod():
7+
pass
8+
9+
10+
PicklePayload = BytesIO(b"payload")
11+
torch.load(PicklePayload) # $ decodeInput=PicklePayload decodeOutput=torch.load(..) decodeFormat=pickle decodeMayExecuteInput
12+
torch.load(PicklePayload, pickle_module=None) # $ decodeInput=PicklePayload decodeOutput=torch.load(..) decodeFormat=pickle decodeMayExecuteInput
13+
torch.load(PicklePayload, pickle_module=someSafeMethod()) # $ decodeInput=PicklePayload decodeOutput=torch.load(..) decodeFormat=pickle
14+
15+
from torch.package import PackageImporter
16+
17+
importer = PackageImporter(PicklePayload) # $ decodeInput=PicklePayload PackageImporter(..) decodeFormat=pickle decodeMayExecuteInput
18+
my_tensor = importer.load_pickle("my_resources", "tensor.pkl") # $ decodeOutput=importer.load_pickle(..)
19+
20+
importer = PackageImporter(PicklePayload)
21+
22+
23+
from torch import jit
24+
25+
jit.load(PicklePayload) # $ decodeInput=PicklePayload decodeOutput=jit.load(..) decodeFormat=pickle decodeMayExecuteInput

0 commit comments

Comments
 (0)