Skip to content

Commit 0694319

Browse files
authored
Merge pull request #21596 from hvitved/rust/data-flow-closure-type
Rust: Track closure types in data flow
2 parents 609621f + 23f0810 commit 0694319

File tree

4 files changed

+193
-19
lines changed

4 files changed

+193
-19
lines changed

rust/ql/lib/codeql/rust/dataflow/internal/DataFlowImpl.qll

Lines changed: 101 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,7 @@ module LocalFlow {
193193
}
194194

195195
pragma[nomagic]
196-
predicate localFlowStepCommon(Node nodeFrom, Node nodeTo) {
197-
nodeFrom.asExpr() = getALastEvalNode(nodeTo.asExpr())
198-
or
196+
predicate localMustFlowStep(Node nodeFrom, Node nodeTo) {
199197
// An edge from the right-hand side of a let statement to the left-hand side.
200198
exists(LetStmt s |
201199
nodeFrom.asExpr() = s.getInitializer() and
@@ -238,6 +236,15 @@ module LocalFlow {
238236
nodeTo.asPat() = match.getAnArm().getPat()
239237
)
240238
or
239+
nodeFrom.asExpr() = nodeTo.asExpr().(ParenExpr).getExpr()
240+
}
241+
242+
pragma[nomagic]
243+
predicate localFlowStepCommon(Node nodeFrom, Node nodeTo) {
244+
localMustFlowStep(nodeFrom, nodeTo)
245+
or
246+
nodeFrom.asExpr() = getALastEvalNode(nodeTo.asExpr())
247+
or
241248
nodeFrom.asPat().(OrPat).getAPat() = nodeTo.asPat()
242249
or
243250
nodeTo.(PostUpdateNode).getPreUpdateNode().asExpr() =
@@ -263,10 +270,84 @@ predicate lambdaCallExpr(CallExprImpl::DynamicCallExpr call, LambdaCallKind kind
263270
exists(kind)
264271
}
265272

273+
// NOTE: We do not yet track type information, except for closures where
274+
// we use the closure itself to represent the unique type.
275+
final class DataFlowType extends TDataFlowType {
276+
Expr asClosureExpr() { this = TClosureExprType(result) }
277+
278+
predicate isUnknownType() { this = TUnknownType() }
279+
280+
predicate isSourceContextParameterType() { this = TSourceContextParameterType() }
281+
282+
string toString() {
283+
exists(this.asClosureExpr()) and
284+
result = "... => .."
285+
or
286+
this.isUnknownType() and
287+
result = ""
288+
or
289+
this.isSourceContextParameterType() and
290+
result = "<source context parameter type>"
291+
}
292+
}
293+
294+
pragma[nomagic]
295+
private predicate compatibleTypesSourceContextParameterTypeLeft(DataFlowType t1, DataFlowType t2) {
296+
t1.isSourceContextParameterType() and not exists(t2.asClosureExpr())
297+
}
298+
299+
pragma[nomagic]
300+
private predicate compatibleTypesLeft(DataFlowType t1, DataFlowType t2) {
301+
t1.isUnknownType() and exists(t2)
302+
or
303+
t1.asClosureExpr() = t2.asClosureExpr()
304+
or
305+
compatibleTypesSourceContextParameterTypeLeft(t1, t2)
306+
}
307+
308+
predicate compatibleTypes(DataFlowType t1, DataFlowType t2) {
309+
compatibleTypesLeft(t1, t2)
310+
or
311+
compatibleTypesLeft(t2, t1)
312+
}
313+
314+
pragma[nomagic]
315+
predicate typeStrongerThan(DataFlowType t1, DataFlowType t2) {
316+
not t1.isUnknownType() and t2.isUnknownType()
317+
or
318+
compatibleTypesSourceContextParameterTypeLeft(t1, t2)
319+
}
320+
321+
DataFlowType getNodeType(NodePublic node) {
322+
result.asClosureExpr() = node.asExpr()
323+
or
324+
result.asClosureExpr() = node.(ClosureParameterNode).getCfgScope()
325+
or
326+
exists(VariableCapture::Flow::SynthesizedCaptureNode scn |
327+
scn = node.(CaptureNode).getSynthesizedCaptureNode() and
328+
if scn.isInstanceAccess()
329+
then result.asClosureExpr() = scn.getEnclosingCallable()
330+
else result.isUnknownType()
331+
)
332+
or
333+
not lambdaCreationExpr(node.asExpr()) and
334+
not node instanceof ClosureParameterNode and
335+
not node instanceof CaptureNode and
336+
result.isUnknownType()
337+
}
338+
266339
// Defines a set of aliases needed for the `RustDataFlow` module
267340
private module Aliases {
268341
class DataFlowCallableAlias = DataFlowCallable;
269342

343+
class DataFlowTypeAlias = DataFlowType;
344+
345+
predicate compatibleTypesAlias = compatibleTypes/2;
346+
347+
predicate typeStrongerThanAlias = typeStrongerThan/2;
348+
349+
predicate getNodeTypeAlias = getNodeType/1;
350+
270351
class ReturnKindAlias = ReturnKind;
271352

272353
class DataFlowCallAlias = DataFlowCall;
@@ -398,8 +479,6 @@ module RustDataFlowGen<RustDataFlowInputSig Input> implements InputSig<Location>
398479
result = node.(Node::Node).getEnclosingCallable()
399480
}
400481

401-
DataFlowType getNodeType(Node node) { any() }
402-
403482
predicate nodeIsHidden(Node node) {
404483
node instanceof SsaNode or
405484
node.(FlowSummaryNode).getSummaryNode().isHidden() or
@@ -486,15 +565,17 @@ module RustDataFlowGen<RustDataFlowInputSig Input> implements InputSig<Location>
486565
*/
487566
OutNode getAnOutNode(DataFlowCall call, ReturnKind kind) { call = result.getCall(kind) }
488567

489-
// NOTE: For now we use the type `Unit` and do not benefit from type
490-
// information in the data flow analysis.
491-
final class DataFlowType extends Unit {
492-
string toString() { result = "" }
493-
}
568+
class DataFlowType = DataFlowTypeAlias;
494569

495-
predicate compatibleTypes(DataFlowType t1, DataFlowType t2) { any() }
570+
predicate compatibleTypes = compatibleTypesAlias/2;
496571

497-
predicate typeStrongerThan(DataFlowType t1, DataFlowType t2) { none() }
572+
predicate typeStrongerThan = typeStrongerThanAlias/2;
573+
574+
DataFlowType getSourceContextParameterNodeType(Node p) {
575+
exists(p) and result.isSourceContextParameterType()
576+
}
577+
578+
predicate getNodeType = getNodeTypeAlias/1;
498579

499580
class Content = ContentAlias;
500581

@@ -897,6 +978,8 @@ module RustDataFlowGen<RustDataFlowInputSig Input> implements InputSig<Location>
897978
predicate localMustFlowStep(Node node1, Node node2) {
898979
SsaFlow::localMustFlowStep(node1, node2)
899980
or
981+
LocalFlow::localMustFlowStep(node1, node2)
982+
or
900983
FlowSummaryImpl::Private::Steps::summaryLocalMustFlowStep(node1
901984
.(FlowSummaryNode)
902985
.getSummaryNode(), node2.(FlowSummaryNode).getSummaryNode())
@@ -1110,6 +1193,12 @@ private module Cached {
11101193
TCfgScope(CfgScope scope) or
11111194
TSummarizedCallable(SummarizedCallable c)
11121195

1196+
cached
1197+
newtype TDataFlowType =
1198+
TClosureExprType(Expr e) { lambdaCreationExpr(e) } or
1199+
TUnknownType() or
1200+
TSourceContextParameterType()
1201+
11131202
/** This is the local flow predicate that is exposed. */
11141203
cached
11151204
predicate localFlowStepImpl(Node nodeFrom, Node nodeTo) {

rust/ql/lib/codeql/rust/dataflow/internal/ModelsAsData.qll

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ private module Debug {
291291
private import Content
292292
private import codeql.rust.dataflow.internal.DataFlowImpl
293293
private import codeql.rust.internal.typeinference.TypeMention
294-
private import codeql.rust.internal.typeinference.Type
294+
private import codeql.rust.internal.typeinference.Type as Type
295295

296296
private predicate relevantManualModel(SummarizedCallableImpl sc, string can) {
297297
exists(Provenance manual |
@@ -309,7 +309,7 @@ private module Debug {
309309
sc.propagatesFlow(input, _, _, _, _, _) and
310310
input.head() = SummaryComponent::argument(pos) and
311311
p = pos.getParameterIn(sc.getParamList()) and
312-
tm.getType() instanceof RefType and
312+
tm.getType() instanceof Type::RefType and
313313
not input.tail().head() = SummaryComponent::content(TSingletonContentSet(TReferenceContent()))
314314
|
315315
tm = p.getTypeRepr()
@@ -324,7 +324,7 @@ private module Debug {
324324
exists(TypeMention tm |
325325
relevantManualModel(sc, can) and
326326
sc.propagatesFlow(_, output, _, _, _, _) and
327-
tm.getType() instanceof RefType and
327+
tm.getType() instanceof Type::RefType and
328328
output.head() = SummaryComponent::return(_) and
329329
not output.tail().head() =
330330
SummaryComponent::content(TSingletonContentSet(TReferenceContent())) and

rust/ql/test/library-tests/dataflow/lambdas/inline-flow.expected

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ edges
1313
| main.rs:29:13:29:22 | f(...) | main.rs:29:9:29:9 | b | provenance | |
1414
| main.rs:29:21:29:21 | a | main.rs:27:20:27:23 | ... | provenance | |
1515
| main.rs:29:21:29:21 | a | main.rs:29:13:29:22 | f(...) | provenance | |
16-
| main.rs:37:16:37:25 | source(...) | main.rs:39:5:39:5 | [post] f [captured capt] | provenance | |
17-
| main.rs:39:5:39:5 | [post] f [captured capt] | main.rs:40:10:40:13 | capt | provenance | |
18-
| main.rs:39:5:39:5 | [post] f [captured capt] | main.rs:44:5:44:5 | g [captured capt] | provenance | |
16+
| main.rs:37:16:37:25 | source(...) | main.rs:39:5:39:5 | [post] f : ... => .. [captured capt] | provenance | |
17+
| main.rs:39:5:39:5 | [post] f : ... => .. [captured capt] | main.rs:40:10:40:13 | capt | provenance | |
18+
| main.rs:39:5:39:5 | [post] f : ... => .. [captured capt] | main.rs:44:5:44:5 | g [captured capt] | provenance | |
1919
| main.rs:44:5:44:5 | g [captured capt] | main.rs:42:14:42:17 | capt | provenance | |
2020
| main.rs:47:29:49:1 | { ... } | main.rs:57:10:57:12 | f(...) | provenance | |
2121
| main.rs:48:5:48:14 | source(...) | main.rs:47:29:49:1 | { ... } | provenance | |
@@ -30,6 +30,31 @@ edges
3030
| main.rs:77:13:77:22 | f(...) | main.rs:77:9:77:9 | b | provenance | |
3131
| main.rs:77:21:77:21 | a | main.rs:66:24:66:32 | ...: i64 | provenance | |
3232
| main.rs:77:21:77:21 | a | main.rs:77:13:77:22 | f(...) | provenance | |
33+
| main.rs:81:28:81:33 | ...: i64 | main.rs:82:7:82:7 | x | provenance | |
34+
| main.rs:81:28:81:33 | ...: i64 | main.rs:82:7:82:7 | x | provenance | |
35+
| main.rs:81:28:81:33 | ...: i64 | main.rs:82:7:82:7 | x | provenance | |
36+
| main.rs:82:7:82:7 | x | main.rs:87:12:87:12 | ... | provenance | |
37+
| main.rs:82:7:82:7 | x | main.rs:89:12:89:12 | ... | provenance | |
38+
| main.rs:82:7:82:7 | x | main.rs:99:17:99:17 | ... | provenance | |
39+
| main.rs:82:7:82:7 | x | main.rs:101:17:101:17 | ... | provenance | |
40+
| main.rs:86:9:86:9 | a | main.rs:87:24:87:24 | a | provenance | |
41+
| main.rs:86:13:86:22 | source(...) | main.rs:86:9:86:9 | a | provenance | |
42+
| main.rs:87:12:87:12 | ... | main.rs:87:20:87:20 | x | provenance | |
43+
| main.rs:87:24:87:24 | a | main.rs:81:28:81:33 | ...: i64 | provenance | |
44+
| main.rs:88:9:88:9 | b | main.rs:89:24:89:24 | b | provenance | |
45+
| main.rs:88:13:88:22 | source(...) | main.rs:88:9:88:9 | b | provenance | |
46+
| main.rs:89:12:89:12 | ... | main.rs:89:20:89:20 | x | provenance | |
47+
| main.rs:89:24:89:24 | b | main.rs:81:28:81:33 | ...: i64 | provenance | |
48+
| main.rs:93:33:93:38 | ...: i64 | main.rs:94:14:94:14 | x | provenance | |
49+
| main.rs:94:14:94:14 | x | main.rs:81:28:81:33 | ...: i64 | provenance | |
50+
| main.rs:98:9:98:9 | a | main.rs:99:29:99:29 | a | provenance | |
51+
| main.rs:98:13:98:22 | source(...) | main.rs:98:9:98:9 | a | provenance | |
52+
| main.rs:99:17:99:17 | ... | main.rs:99:25:99:25 | x | provenance | |
53+
| main.rs:99:29:99:29 | a | main.rs:93:33:93:38 | ...: i64 | provenance | |
54+
| main.rs:100:9:100:9 | b | main.rs:101:29:101:29 | b | provenance | |
55+
| main.rs:100:13:100:22 | source(...) | main.rs:100:9:100:9 | b | provenance | |
56+
| main.rs:101:17:101:17 | ... | main.rs:101:25:101:25 | x | provenance | |
57+
| main.rs:101:29:101:29 | b | main.rs:93:33:93:38 | ...: i64 | provenance | |
3358
nodes
3459
| main.rs:10:20:10:52 | if cond {...} else {...} | semmle.label | if cond {...} else {...} |
3560
| main.rs:10:30:10:39 | source(...) | semmle.label | source(...) |
@@ -48,7 +73,7 @@ nodes
4873
| main.rs:29:21:29:21 | a | semmle.label | a |
4974
| main.rs:30:10:30:10 | b | semmle.label | b |
5075
| main.rs:37:16:37:25 | source(...) | semmle.label | source(...) |
51-
| main.rs:39:5:39:5 | [post] f [captured capt] | semmle.label | [post] f [captured capt] |
76+
| main.rs:39:5:39:5 | [post] f : ... => .. [captured capt] | semmle.label | [post] f : ... => .. [captured capt] |
5277
| main.rs:40:10:40:13 | capt | semmle.label | capt |
5378
| main.rs:42:14:42:17 | capt | semmle.label | capt |
5479
| main.rs:44:5:44:5 | g [captured capt] | semmle.label | g [captured capt] |
@@ -68,6 +93,34 @@ nodes
6893
| main.rs:77:13:77:22 | f(...) | semmle.label | f(...) |
6994
| main.rs:77:21:77:21 | a | semmle.label | a |
7095
| main.rs:78:10:78:10 | b | semmle.label | b |
96+
| main.rs:81:28:81:33 | ...: i64 | semmle.label | ...: i64 |
97+
| main.rs:81:28:81:33 | ...: i64 | semmle.label | ...: i64 |
98+
| main.rs:81:28:81:33 | ...: i64 | semmle.label | ...: i64 |
99+
| main.rs:82:7:82:7 | x | semmle.label | x |
100+
| main.rs:82:7:82:7 | x | semmle.label | x |
101+
| main.rs:82:7:82:7 | x | semmle.label | x |
102+
| main.rs:86:9:86:9 | a | semmle.label | a |
103+
| main.rs:86:13:86:22 | source(...) | semmle.label | source(...) |
104+
| main.rs:87:12:87:12 | ... | semmle.label | ... |
105+
| main.rs:87:20:87:20 | x | semmle.label | x |
106+
| main.rs:87:24:87:24 | a | semmle.label | a |
107+
| main.rs:88:9:88:9 | b | semmle.label | b |
108+
| main.rs:88:13:88:22 | source(...) | semmle.label | source(...) |
109+
| main.rs:89:12:89:12 | ... | semmle.label | ... |
110+
| main.rs:89:20:89:20 | x | semmle.label | x |
111+
| main.rs:89:24:89:24 | b | semmle.label | b |
112+
| main.rs:93:33:93:38 | ...: i64 | semmle.label | ...: i64 |
113+
| main.rs:94:14:94:14 | x | semmle.label | x |
114+
| main.rs:98:9:98:9 | a | semmle.label | a |
115+
| main.rs:98:13:98:22 | source(...) | semmle.label | source(...) |
116+
| main.rs:99:17:99:17 | ... | semmle.label | ... |
117+
| main.rs:99:25:99:25 | x | semmle.label | x |
118+
| main.rs:99:29:99:29 | a | semmle.label | a |
119+
| main.rs:100:9:100:9 | b | semmle.label | b |
120+
| main.rs:100:13:100:22 | source(...) | semmle.label | source(...) |
121+
| main.rs:101:17:101:17 | ... | semmle.label | ... |
122+
| main.rs:101:25:101:25 | x | semmle.label | x |
123+
| main.rs:101:29:101:29 | b | semmle.label | b |
71124
subpaths
72125
| main.rs:29:21:29:21 | a | main.rs:27:20:27:23 | ... | main.rs:27:26:27:52 | if cond {...} else {...} | main.rs:29:13:29:22 | f(...) |
73126
| main.rs:77:21:77:21 | a | main.rs:66:24:66:32 | ...: i64 | main.rs:66:42:72:1 | { ... } | main.rs:77:13:77:22 | f(...) |
@@ -81,3 +134,9 @@ testFailures
81134
| main.rs:52:10:52:13 | data | main.rs:62:13:62:22 | source(...) | main.rs:52:10:52:13 | data | $@ | main.rs:62:13:62:22 | source(...) | source(...) |
82135
| main.rs:57:10:57:12 | f(...) | main.rs:48:5:48:14 | source(...) | main.rs:57:10:57:12 | f(...) | $@ | main.rs:48:5:48:14 | source(...) | source(...) |
83136
| main.rs:78:10:78:10 | b | main.rs:76:13:76:22 | source(...) | main.rs:78:10:78:10 | b | $@ | main.rs:76:13:76:22 | source(...) | source(...) |
137+
| main.rs:87:20:87:20 | x | main.rs:86:13:86:22 | source(...) | main.rs:87:20:87:20 | x | $@ | main.rs:86:13:86:22 | source(...) | source(...) |
138+
| main.rs:89:20:89:20 | x | main.rs:88:13:88:22 | source(...) | main.rs:89:20:89:20 | x | $@ | main.rs:88:13:88:22 | source(...) | source(...) |
139+
| main.rs:99:25:99:25 | x | main.rs:98:13:98:22 | source(...) | main.rs:99:25:99:25 | x | $@ | main.rs:98:13:98:22 | source(...) | source(...) |
140+
| main.rs:99:25:99:25 | x | main.rs:100:13:100:22 | source(...) | main.rs:99:25:99:25 | x | $@ | main.rs:100:13:100:22 | source(...) | source(...) |
141+
| main.rs:101:25:101:25 | x | main.rs:98:13:98:22 | source(...) | main.rs:101:25:101:25 | x | $@ | main.rs:98:13:98:22 | source(...) | source(...) |
142+
| main.rs:101:25:101:25 | x | main.rs:100:13:100:22 | source(...) | main.rs:101:25:101:25 | x | $@ | main.rs:100:13:100:22 | source(...) | source(...) |

rust/ql/test/library-tests/dataflow/lambdas/main.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,30 @@ fn function_flows_through() {
7878
sink(b); // $ hasValueFlow=56
7979
}
8080

81+
fn apply<F: Fn(i64)>(f: F, x: i64) {
82+
f(x);
83+
}
84+
85+
fn test_apply() {
86+
let a = source(77);
87+
apply(|x| sink(x), a); // $ hasValueFlow=77
88+
let b = source(78);
89+
apply(|x| sink(x), b); // $ hasValueFlow=78
90+
apply(|x| sink(x), 0);
91+
}
92+
93+
fn apply_wrap<F: Fn(i64)>(f: F, x: i64) {
94+
apply(f, x);
95+
}
96+
97+
fn test_apply_wrap() {
98+
let a = source(79);
99+
apply_wrap(|x| sink(x), a); // $ hasValueFlow=79 $ SPURIOUS: hasValueFlow=80
100+
let b = source(80);
101+
apply_wrap(|x| sink(x), b); // $ hasValueFlow=80 $ SPURIOUS: hasValueFlow=79
102+
apply_wrap(|x| sink(x), 0);
103+
}
104+
81105
fn main() {
82106
closure_flow_out();
83107
closure_flow_in();
@@ -86,4 +110,6 @@ fn main() {
86110
function_flow_in();
87111
function_flow_out();
88112
function_flows_through();
113+
test_apply();
114+
test_apply_wrap();
89115
}

0 commit comments

Comments
 (0)