Skip to content

Commit 3da56ad

Browse files
committed
Rust: Infer certain type information for struct expressions
1 parent edec76a commit 3da56ad

5 files changed

Lines changed: 84 additions & 39 deletions

File tree

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,18 @@ module CertainTypeInference {
372372
)
373373
}
374374

375+
pragma[nomagic]
376+
private Type inferCertainStructExprType(AstNode n, TypePath path) {
377+
exists(StructExprMatchingInput::Access a, StructExprMatchingInput::AccessPosition apos |
378+
// The fields of a struct expressions is a coercion site, hence we can
379+
// only determine certain type information for the struct expression
380+
// itself
381+
apos.isStructPos() and
382+
n = a.getNodeAt(apos) and
383+
result = StructExprMatching::getAccessDeclaredType(a, apos, path)
384+
)
385+
}
386+
375387
predicate certainTypeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) {
376388
prefix1.isEmpty() and
377389
prefix2.isEmpty() and
@@ -440,6 +452,8 @@ module CertainTypeInference {
440452
or
441453
result = inferLogicalOperationType(n, path)
442454
or
455+
result = inferCertainStructExprType(n, path)
456+
or
443457
result = inferRangeExprType(n) and
444458
path.isEmpty()
445459
or
@@ -743,7 +757,14 @@ private module StructExprMatchingInput implements MatchingInputSig {
743757
class AccessPosition = DeclarationPosition;
744758

745759
class Access extends StructExpr {
746-
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) { none() }
760+
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
761+
exists(TypeMention arg, int i, Path p |
762+
result = arg.resolveTypeAt(path) and
763+
p = this.getPath() and
764+
arg = p.getSegment().getGenericArgList().getTypeArg(pragma[only_bind_into](i)) and
765+
apos.asTypeParam() = resolvePath(p).getTypeParam(pragma[only_bind_into](i))
766+
)
767+
}
747768

748769
AstNode getNodeAt(AccessPosition apos) {
749770
result = this.getFieldExpr(apos.asFieldPos()).getExpr()
@@ -778,8 +799,14 @@ private module StructExprMatching = Matching<StructExprMatchingInput>;
778799
pragma[nomagic]
779800
private Type inferStructExprType(AstNode n, TypePath path) {
780801
exists(StructExprMatchingInput::Access a, StructExprMatchingInput::AccessPosition apos |
781-
n = a.getNodeAt(apos) and
782-
result = StructExprMatching::inferAccessType(a, apos, path)
802+
n = a.getNodeAt(apos)
803+
|
804+
// Exclude the struct expression itself, as to not overlap with what is
805+
// already produced by certain type information.
806+
not apos.isStructPos() and
807+
result = StructExprMatching::getAccessDeclaredType(a, apos, path)
808+
or
809+
result = StructExprMatching::getAccessInferredType(a, apos, path)
783810
)
784811
}
785812

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ mod field_access {
3434

3535
fn generic_field_access() {
3636
// Explicit type argument
37-
let x = GenericThing::<S> { a: S }; // $ type=x:A.S
37+
let x = GenericThing::<S> { a: S }; // $ certainType=x:A.S
3838
println!("{:?}", x.a); // $ fieldof=GenericThing
3939

4040
// Implicit type argument
@@ -2384,7 +2384,7 @@ mod loops {
23842384
let range_full = ..; // $ certainType=range_full:RangeFull
23852385
for i in &[1i64, 2i64, 3i64][range_full] {} // $ target=index MISSING: type=i:&T.i64
23862386

2387-
let range1 = // $ type=range1:Range type=range1:Idx.u16
2387+
let range1 = // $ certainType=range1:Range type=range1:Idx.u16
23882388
std::ops::Range {
23892389
start: 0u16,
23902390
end: 10u16,
@@ -2480,7 +2480,7 @@ mod explicit_type_args {
24802480
let x7 = S4(S2); // $ type=x7:T4.S2
24812481
let x8 = S4(0); // $ type=x8:T4.i32
24822482
let x9 = S4(S2::default()); // $ type=x9:T4.S2 target=default
2483-
let x10 = S5::<S2> // $ type=x10:T5.S2
2483+
let x10 = S5::<S2> // $ certainType=x10:T5.S2
24842484
{
24852485
field: Default::default(), // $ target=default
24862486
};
@@ -2594,7 +2594,7 @@ pub mod exec {
25942594
impl Connection for MySqlConnection {}
25952595

25962596
pub fn f() {
2597-
let c = MySqlConnection {}; // $ type=c:MySqlConnection
2597+
let c = MySqlConnection {}; // $ certainType=c:MySqlConnection
25982598

25992599
c.execute1(); // $ MISSING: target=execute1
26002600
MySqlConnection::execute1(&c); // $ MISSING: target=execute1

rust/ql/test/library-tests/type-inference/pattern_matching.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -360,16 +360,16 @@ pub fn record_patterns() {
360360
// RecordPat - Record (struct) patterns
361361
match point {
362362
Point { x: 0, y: 0 } => {
363-
let origin = point; // $ type=origin:Point
363+
let origin = point; // $ certainType=origin:Point
364364
println!("Origin point: {:?}", origin);
365365
}
366366
Point { x, y: 0 } => {
367367
let x_axis_x = x; // $ type=x_axis_x:i32
368-
let x_axis_point = point; // $ type=x_axis_point:Point
368+
let x_axis_point = point; // $ certainType=x_axis_point:Point
369369
println!("Point on x-axis: x={}, point={:?}", x_axis_x, x_axis_point);
370370
}
371371
Point { x: 10, .. } => {
372-
let ten_x_point = point; // $ type=ten_x_point:Point
372+
let ten_x_point = point; // $ certainType=ten_x_point:Point
373373
println!("Point with x=10: {:?}", ten_x_point);
374374
}
375375
Point { x, y } => {

rust/ql/test/library-tests/type-inference/type-inference.expected

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -385,8 +385,6 @@ inferType
385385
| dereference.rs:122:23:122:29 | &... | &T | file://:0:0:0:0 | & |
386386
| dereference.rs:122:23:122:29 | &... | &T.&T | dereference.rs:99:5:100:21 | Key |
387387
| dereference.rs:122:24:122:29 | Key {...} | | dereference.rs:99:5:100:21 | Key |
388-
| dereference.rs:122:24:122:29 | Key {...} | | file://:0:0:0:0 | & |
389-
| dereference.rs:122:24:122:29 | Key {...} | &T | dereference.rs:99:5:100:21 | Key |
390388
| dereference.rs:123:16:123:28 | Some(...) | | {EXTERNAL LOCATION} | Option |
391389
| dereference.rs:123:16:123:28 | Some(...) | T | file://:0:0:0:0 | & |
392390
| dereference.rs:123:16:123:28 | Some(...) | T.&T | dereference.rs:99:5:100:21 | Key |
@@ -501,10 +499,8 @@ inferType
501499
| dyn_type.rs:61:5:61:36 | ...::new(...) | T | dyn_type.rs:33:1:36:1 | GenStruct |
502500
| dyn_type.rs:61:5:61:36 | ...::new(...) | T.A | dyn_type.rs:60:18:60:43 | A |
503501
| dyn_type.rs:61:5:61:36 | ...::new(...) | T.dyn(A) | dyn_type.rs:60:18:60:43 | A |
504-
| dyn_type.rs:61:14:61:35 | GenStruct {...} | | dyn_type.rs:10:1:13:1 | dyn GenericGet |
505502
| dyn_type.rs:61:14:61:35 | GenStruct {...} | | dyn_type.rs:33:1:36:1 | GenStruct |
506503
| dyn_type.rs:61:14:61:35 | GenStruct {...} | A | dyn_type.rs:60:18:60:43 | A |
507-
| dyn_type.rs:61:14:61:35 | GenStruct {...} | dyn(A) | dyn_type.rs:60:18:60:43 | A |
508504
| dyn_type.rs:61:33:61:33 | a | | dyn_type.rs:60:18:60:43 | A |
509505
| dyn_type.rs:64:25:64:27 | obj | | file://:0:0:0:0 | & |
510506
| dyn_type.rs:64:25:64:27 | obj | &T | dyn_type.rs:5:1:8:1 | dyn MyTrait1 |
@@ -623,16 +619,13 @@ inferType
623619
| dyn_type.rs:102:26:102:48 | &... | | file://:0:0:0:0 | & |
624620
| dyn_type.rs:102:26:102:48 | &... | &T | dyn_type.rs:5:1:8:1 | dyn MyTrait1 |
625621
| dyn_type.rs:102:26:102:48 | &... | &T | dyn_type.rs:21:1:24:1 | MyStruct |
626-
| dyn_type.rs:102:27:102:48 | MyStruct {...} | | dyn_type.rs:5:1:8:1 | dyn MyTrait1 |
627622
| dyn_type.rs:102:27:102:48 | MyStruct {...} | | dyn_type.rs:21:1:24:1 | MyStruct |
628623
| dyn_type.rs:102:45:102:46 | 42 | | {EXTERNAL LOCATION} | i32 |
629624
| dyn_type.rs:103:28:105:5 | &... | | file://:0:0:0:0 | & |
630625
| dyn_type.rs:103:28:105:5 | &... | &T | dyn_type.rs:10:1:13:1 | dyn GenericGet |
631626
| dyn_type.rs:103:28:105:5 | &... | &T | dyn_type.rs:33:1:36:1 | GenStruct |
632627
| dyn_type.rs:103:28:105:5 | &... | &T.dyn(A) | {EXTERNAL LOCATION} | String |
633-
| dyn_type.rs:103:29:105:5 | GenStruct {...} | | dyn_type.rs:10:1:13:1 | dyn GenericGet |
634628
| dyn_type.rs:103:29:105:5 | GenStruct {...} | | dyn_type.rs:33:1:36:1 | GenStruct |
635-
| dyn_type.rs:103:29:105:5 | GenStruct {...} | dyn(A) | {EXTERNAL LOCATION} | String |
636629
| dyn_type.rs:104:16:104:17 | "" | | file://:0:0:0:0 | & |
637630
| dyn_type.rs:104:16:104:17 | "" | &T | {EXTERNAL LOCATION} | str |
638631
| dyn_type.rs:107:21:107:45 | &... | | file://:0:0:0:0 | & |
@@ -641,11 +634,8 @@ inferType
641634
| dyn_type.rs:107:21:107:45 | &... | &T.A | {EXTERNAL LOCATION} | i32 |
642635
| dyn_type.rs:107:21:107:45 | &... | &T.dyn(AP) | {EXTERNAL LOCATION} | bool |
643636
| dyn_type.rs:107:21:107:45 | &... | &T.dyn(GP) | {EXTERNAL LOCATION} | i64 |
644-
| dyn_type.rs:107:22:107:45 | GenStruct {...} | | dyn_type.rs:15:1:19:1 | dyn AssocTrait |
645637
| dyn_type.rs:107:22:107:45 | GenStruct {...} | | dyn_type.rs:33:1:36:1 | GenStruct |
646638
| dyn_type.rs:107:22:107:45 | GenStruct {...} | A | {EXTERNAL LOCATION} | i32 |
647-
| dyn_type.rs:107:22:107:45 | GenStruct {...} | dyn(AP) | {EXTERNAL LOCATION} | bool |
648-
| dyn_type.rs:107:22:107:45 | GenStruct {...} | dyn(GP) | {EXTERNAL LOCATION} | i64 |
649639
| dyn_type.rs:107:41:107:43 | 100 | | {EXTERNAL LOCATION} | i32 |
650640
| loop/main.rs:7:12:7:15 | SelfParam | | loop/main.rs:6:1:8:1 | Self [trait T1] |
651641
| loop/main.rs:11:12:11:15 | SelfParam | | loop/main.rs:10:1:14:1 | Self [trait T2] |

shared/typeinference/codeql/typeinference/internal/TypeInference.qll

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1467,10 +1467,6 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
14671467
private predicate typeMatch(
14681468
Access a, AccessEnvironment e, Declaration target, TypePath path, Type t, TypeParameter tp
14691469
) {
1470-
// A type given at the access corresponds directly to the type parameter
1471-
// at the target.
1472-
explicitTypeMatch(a, e, target, path, t, tp)
1473-
or
14741470
// We can infer the type of `tp` from one of the access positions
14751471
directTypeMatch(a, e, target, path, t, tp)
14761472
or
@@ -1481,6 +1477,44 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
14811477
typeConstraintBaseTypeMatch(a, e, target, path, t, tp)
14821478
}
14831479

1480+
/**
1481+
* Gets the type of `a` and `apos` at `path` using declared type
1482+
* information at the access or declaration, if any.
1483+
*/
1484+
pragma[nomagic]
1485+
Type getAccessDeclaredType(Access a, AccessEnvironment e, AccessPosition apos, TypePath path) {
1486+
exists(DeclarationPosition dpos, Declaration target |
1487+
accessDeclarationPositionMatch(apos, dpos)
1488+
|
1489+
result = target.getDeclaredType(dpos, path) and
1490+
target = a.getTarget(e) and
1491+
not result instanceof TypeParameter
1492+
or
1493+
exists(TypePath prefix, TypeParameter tp, TypePath suffix |
1494+
tp = target.getDeclaredType(dpos, prefix) and
1495+
path = prefix.append(suffix) and
1496+
explicitTypeMatch(a, e, target, suffix, result, tp)
1497+
)
1498+
)
1499+
}
1500+
1501+
/**
1502+
* Gets the inferred type of `a` and `apos` at `path` using inferred type
1503+
* information which is not explicitly declared, if any.
1504+
*/
1505+
pragma[nomagic]
1506+
Type getAccessInferredType(Access a, AccessEnvironment e, AccessPosition apos, TypePath path) {
1507+
exists(
1508+
DeclarationPosition dpos, Declaration target, TypePath prefix, TypeParameter tp,
1509+
TypePath suffix
1510+
|
1511+
accessDeclarationPositionMatch(apos, dpos) and
1512+
tp = target.getDeclaredType(dpos, prefix) and
1513+
path = prefix.append(suffix) and
1514+
typeMatch(a, e, target, suffix, result, tp)
1515+
)
1516+
}
1517+
14841518
/**
14851519
* Gets the inferred type of `a` at `path` for position `apos` and environment `e`.
14861520
*
@@ -1521,21 +1555,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
15211555
*/
15221556
pragma[nomagic]
15231557
Type inferAccessType(Access a, AccessEnvironment e, AccessPosition apos, TypePath path) {
1524-
exists(DeclarationPosition dpos | accessDeclarationPositionMatch(apos, dpos) |
1525-
// A suffix of `path` leads to a type parameter in the target
1526-
exists(Declaration target, TypePath prefix, TypeParameter tp, TypePath suffix |
1527-
tp = target.getDeclaredType(dpos, prefix) and
1528-
path = prefix.append(suffix) and
1529-
typeMatch(a, e, target, suffix, result, tp)
1530-
)
1531-
or
1532-
// `path` corresponds directly to a concrete type in the declaration
1533-
exists(Declaration target |
1534-
result = target.getDeclaredType(dpos, path) and
1535-
target = a.getTarget(e) and
1536-
not result instanceof TypeParameter
1537-
)
1538-
)
1558+
result = [getAccessDeclaredType(a, e, apos, path), getAccessInferredType(a, e, apos, path)]
15391559
}
15401560
}
15411561

@@ -1677,6 +1697,14 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
16771697

16781698
import M
16791699

1700+
Type getAccessDeclaredType(Input::Access a, Input::AccessPosition apos, TypePath path) {
1701+
result = M::getAccessDeclaredType(a, _, apos, path)
1702+
}
1703+
1704+
Type getAccessInferredType(Input::Access a, Input::AccessPosition apos, TypePath path) {
1705+
result = M::getAccessInferredType(a, _, apos, path)
1706+
}
1707+
16801708
Type inferAccessType(Input::Access a, Input::AccessPosition apos, TypePath path) {
16811709
result = M::inferAccessType(a, _, apos, path)
16821710
}

0 commit comments

Comments
 (0)