Skip to content

Commit 587e83f

Browse files
committed
wip
1 parent 309756c commit 587e83f

3 files changed

Lines changed: 128 additions & 59 deletions

File tree

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 123 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ pragma[nomagic]
467467
private Type inferStructExprType(AstNode n, TypePath path) {
468468
exists(StructExprMatchingInput::Access a, StructExprMatchingInput::AccessPosition apos |
469469
n = a.getNodeAt(apos) and
470-
result = StructExprMatching::inferAccessType(a, apos, path)
470+
result = StructExprMatching::inferAccessType(a, apos, path, _)
471471
)
472472
}
473473

@@ -516,7 +516,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
516516
}
517517
}
518518

519-
abstract class Declaration extends AstNode {
519+
abstract additional class Decl extends AstNode {
520520
abstract TypeParameter getTypeParameter(TypeParameterPosition ppos);
521521

522522
pragma[nomagic]
@@ -532,7 +532,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
532532
}
533533
}
534534

535-
private class TupleStructDecl extends Declaration, Struct {
535+
private class TupleStructDecl extends Decl, Struct {
536536
TupleStructDecl() { this.isTuple() }
537537

538538
override TypeParameter getTypeParameter(TypeParameterPosition ppos) {
@@ -555,7 +555,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
555555
}
556556
}
557557

558-
private class TupleVariantDecl extends Declaration, Variant {
558+
private class TupleVariantDecl extends Decl, Variant {
559559
TupleVariantDecl() { this.isTuple() }
560560

561561
override TypeParameter getTypeParameter(TypeParameterPosition ppos) {
@@ -580,7 +580,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
580580
}
581581
}
582582

583-
private class FunctionDecl extends Declaration, Function {
583+
private class FunctionDecl extends Decl, Function {
584584
override TypeParameter getTypeParameter(TypeParameterPosition ppos) {
585585
typeParamMatchPosition(this.getGenericParamList().getATypeParam(), result, ppos)
586586
or
@@ -632,6 +632,33 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
632632
}
633633
}
634634

635+
additional newtype TDeclaration = additional MkDeclaration(Decl decl, Boolean isDeref)
636+
637+
class Declaration extends MkDeclaration {
638+
private Decl decl;
639+
private boolean isDeref;
640+
641+
Declaration() { this = MkDeclaration(decl, isDeref) }
642+
643+
Decl getDecl() { result = decl }
644+
645+
TypeParameter getTypeParameter(TypeParameterPosition ppos) {
646+
result = this.getDecl().getTypeParameter(ppos)
647+
}
648+
649+
Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
650+
result = this.getDecl().getDeclaredType(dpos, path)
651+
}
652+
653+
Type getParameterType(DeclarationPosition dpos, TypePath path) {
654+
result = this.getDecl().getParameterType(dpos, path)
655+
}
656+
657+
string toString() { result = this.getDecl().toString() }
658+
659+
Location getLocation() { result = this.getDecl().getLocation() }
660+
}
661+
635662
private predicate argPos(CallExprBase call, Expr e, int pos, boolean isMethodCall) {
636663
exists(ArgList al |
637664
e = al.getArg(pos) and
@@ -706,9 +733,12 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
706733
}
707734

708735
override Declaration getTarget() {
709-
result = CallExprImpl::getResolvedFunction(this)
736+
result = MkDeclaration(CallExprImpl::getResolvedFunction(this), false)
710737
or
711-
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
738+
exists(Function f, boolean isDeref |
739+
f = inferMethodCallTarget(this, isDeref) and // mutual recursion; resolving method calls requires resolving types and vice versa
740+
result = MkDeclaration(f, isDeref)
741+
)
712742
}
713743
}
714744

@@ -733,7 +763,10 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
733763
}
734764

735765
override Declaration getTarget() {
736-
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
766+
exists(Function f, boolean isDeref |
767+
f = inferMethodCallTarget(this, isDeref) and // mutual recursion; resolving method calls requires resolving types and vice versa
768+
result = MkDeclaration(f, isDeref)
769+
)
737770
}
738771
}
739772

@@ -816,13 +849,16 @@ pragma[nomagic]
816849
private Type inferCallExprBaseType(AstNode n, TypePath path) {
817850
exists(
818851
CallExprBaseMatchingInput::Access a, CallExprBaseMatchingInput::AccessPosition apos,
819-
TypePath path0
852+
TypePath path0, CallExprBaseMatchingInput::Decl decl, boolean isDeref
820853
|
821854
n = a.getNodeAt(apos) and
822-
result = CallExprBaseMatching::inferAccessType(a, apos, path0)
855+
result =
856+
CallExprBaseMatching::inferAccessType(a, apos, path0,
857+
CallExprBaseMatchingInput::MkDeclaration(decl, isDeref))
823858
|
824859
if apos.isSelf()
825860
then
861+
isDeref = false and
826862
exists(Type receiverType | receiverType = inferType(n) |
827863
if receiverType = TRefType()
828864
then
@@ -960,7 +996,7 @@ private Type inferFieldExprType(AstNode n, TypePath path) {
960996
FieldExprMatchingInput::Access a, FieldExprMatchingInput::AccessPosition apos, TypePath path0
961997
|
962998
n = a.getNodeAt(apos) and
963-
result = FieldExprMatching::inferAccessType(a, apos, path0)
999+
result = FieldExprMatching::inferAccessType(a, apos, path0, _)
9641000
|
9651001
if apos.isSelf()
9661002
then
@@ -1115,7 +1151,7 @@ pragma[nomagic]
11151151
private Type inferAwaitExprType(AstNode n, TypePath path) {
11161152
exists(AwaitExprMatchingInput::Access a, AwaitExprMatchingInput::AccessPosition apos |
11171153
n = a.getNodeAt(apos) and
1118-
result = AwaitExprMatching::inferAccessType(a, apos, path)
1154+
result = AwaitExprMatching::inferAccessType(a, apos, path, _)
11191155
)
11201156
or
11211157
// This case is needed for `async` functions and blocks, where we assign
@@ -1141,7 +1177,7 @@ private module MethodCall {
11411177
Trait getTrait() { none() }
11421178

11431179
/** Gets the type of the receiver of the method call at `path`. */
1144-
abstract Type getTypeAt(TypePath path);
1180+
abstract Type getTypeAt(TypePath path, boolean isDeref);
11451181
}
11461182

11471183
final class MethodCall = MethodCallImpl;
@@ -1152,18 +1188,21 @@ private module MethodCall {
11521188
override int getArity() { result = super.getArgList().getNumberOfArgs() }
11531189

11541190
pragma[nomagic]
1155-
override Type getTypeAt(TypePath path) {
1156-
exists(TypePath path0, Type t0 |
1191+
override Type getTypeAt(TypePath path, boolean isDeref) {
1192+
exists(TypePath path0, Type t0, boolean isDeref0 |
11571193
t0 = inferType(super.getReceiver(), path0) and
11581194
(
1159-
path0.isCons(TRefTypeParameter(), path)
1195+
path0.isCons(TRefTypeParameter(), path) and
1196+
isDeref0 = true
11601197
or
11611198
not path0.isCons(TRefTypeParameter(), _) and
11621199
not (path0.isEmpty() and result = TRefType()) and
1163-
path = path0
1200+
path = path0 and
1201+
isDeref0 = false
11641202
)
11651203
|
1166-
result = t0
1204+
result = t0 and
1205+
isDeref = isDeref0
11671206
or
11681207
// We do not yet model the `Deref` trait, so we hard-code the fact that
11691208
// `String` dereferences to `str` here. This allows us e.g. to resolve
@@ -1173,7 +1212,8 @@ private module MethodCall {
11731212
// See also https://doc.rust-lang.org/reference/expressions/method-call-expr.html#r-expr.method.autoref-deref
11741213
path.isEmpty() and
11751214
t0.(StructType).asItemNode().(Struct).getCanonicalPath() = "alloc::string::String" and
1176-
result.(StructType).asItemNode() instanceof Builtins::Str
1215+
result.(StructType).asItemNode() instanceof Builtins::Str and
1216+
isDeref = true
11771217
)
11781218
}
11791219
}
@@ -1202,7 +1242,10 @@ private module MethodCall {
12021242
override Trait getTrait() { result = trait }
12031243

12041244
pragma[nomagic]
1205-
override Type getTypeAt(TypePath path) { result = inferType(receiver, path) }
1245+
override Type getTypeAt(TypePath path, boolean isDeref) {
1246+
result = inferType(receiver, path) and
1247+
isDeref = false
1248+
}
12061249
}
12071250

12081251
private class OperationMethodCall extends MethodCallImpl instanceof Operation {
@@ -1218,10 +1261,13 @@ private module MethodCall {
12181261
override Trait getTrait() { result = trait }
12191262

12201263
pragma[nomagic]
1221-
override Type getTypeAt(TypePath path) {
1222-
result = inferType(this.(BinaryExpr).getLhs(), path)
1223-
or
1224-
result = inferType(this.(PrefixExpr).getExpr(), path)
1264+
override Type getTypeAt(TypePath path, boolean isDeref) {
1265+
(
1266+
result = inferType(this.(BinaryExpr).getLhs(), path)
1267+
or
1268+
result = inferType(this.(PrefixExpr).getExpr(), path)
1269+
) and
1270+
isDeref = false
12251271
}
12261272
}
12271273
}
@@ -1251,23 +1297,51 @@ private predicate methodCandidateTrait(Type type, Trait trait, string name, int
12511297
methodCandidate(type, name, arity, impl)
12521298
}
12531299

1254-
private module IsInstantiationOfInput implements IsInstantiationOfInputSig<MethodCall> {
1300+
private import codeql.util.Boolean
1301+
1302+
private newtype TMethodCallDeref = MkMethodCallDeref(MethodCall mc, Boolean isDeref)
1303+
1304+
private class MethodCallDeref extends MkMethodCallDeref {
1305+
MethodCall mc;
1306+
boolean isDeref;
1307+
1308+
MethodCallDeref() { this = MkMethodCallDeref(mc, isDeref) }
1309+
1310+
MethodCall getMethodCall() { result = mc }
1311+
1312+
boolean isImplicitDeref() { result = isDeref }
1313+
1314+
Type getTypeAt(TypePath path) { result = mc.getTypeAt(path, isDeref) }
1315+
1316+
string toString() {
1317+
if isDeref = true then result = "[implicit deref] " + mc.toString() else result = mc.toString()
1318+
}
1319+
1320+
Location getLocation() { result = mc.getLocation() }
1321+
}
1322+
1323+
private module IsInstantiationOfInput implements IsInstantiationOfInputSig<MethodCallDeref> {
12551324
pragma[nomagic]
1256-
private predicate isMethodCall(MethodCall mc, Type rootType, string name, int arity) {
1257-
rootType = mc.getTypeAt(TypePath::nil()) and
1258-
name = mc.getMethodName() and
1259-
arity = mc.getArity()
1325+
private predicate isMethodCall(MethodCallDeref mc, Type rootType, string name, int arity) {
1326+
exists(MethodCall mc0, boolean isDeref |
1327+
mc = MkMethodCallDeref(mc0, isDeref) and
1328+
rootType = mc0.getTypeAt(TypePath::nil(), isDeref) and
1329+
name = mc0.getMethodName() and
1330+
arity = mc0.getArity()
1331+
)
12601332
}
12611333

12621334
pragma[nomagic]
1263-
predicate potentialInstantiationOf(MethodCall mc, TypeAbstraction impl, TypeMention constraint) {
1335+
predicate potentialInstantiationOf(
1336+
MethodCallDeref mc, TypeAbstraction impl, TypeMention constraint
1337+
) {
12641338
exists(Type rootType, string name, int arity |
12651339
isMethodCall(mc, rootType, name, arity) and
12661340
constraint = impl.(ImplTypeAbstraction).getSelfTy()
12671341
|
1268-
methodCandidateTrait(rootType, mc.getTrait(), name, arity, impl)
1342+
methodCandidateTrait(rootType, mc.getMethodCall().getTrait(), name, arity, impl)
12691343
or
1270-
not exists(mc.getTrait()) and
1344+
not exists(mc.getMethodCall().getTrait()) and
12711345
methodCandidate(rootType, name, arity, impl)
12721346
)
12731347
}
@@ -1294,10 +1368,10 @@ private Function getTypeParameterMethod(TypeParameter tp, string name) {
12941368
}
12951369

12961370
/** Gets a method from an `impl` block that matches the method call `mc`. */
1297-
private Function getMethodFromImpl(MethodCall mc) {
1371+
private Function getMethodFromImpl(MethodCallDeref mc) {
12981372
exists(Impl impl |
1299-
IsInstantiationOf<MethodCall, IsInstantiationOfInput>::isInstantiationOf(mc, impl, _) and
1300-
result = getMethodSuccessor(impl, mc.getMethodName())
1373+
IsInstantiationOf<MethodCallDeref, IsInstantiationOfInput>::isInstantiationOf(mc, impl, _) and
1374+
result = getMethodSuccessor(impl, mc.getMethodCall().getMethodName())
13011375
)
13021376
}
13031377

@@ -1311,16 +1385,16 @@ private Function getTraitMethod(ImplTraitReturnType trait, string name) {
13111385
* Gets a method that the method call `mc` resolves to based on type inference,
13121386
* if any.
13131387
*/
1314-
private Function inferMethodCallTarget(MethodCall mc) {
1388+
private Function inferMethodCallTarget(MethodCall mc, boolean isDeref) {
13151389
// The method comes from an `impl` block targeting the type of the receiver.
1316-
result = getMethodFromImpl(mc)
1390+
result = getMethodFromImpl(MkMethodCallDeref(mc, isDeref))
13171391
or
13181392
// The type of the receiver is a type parameter and the method comes from a
13191393
// trait bound on the type parameter.
1320-
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
1394+
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil(), isDeref), mc.getMethodName())
13211395
or
13221396
// The type of the receiver is an `impl Trait` type.
1323-
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
1397+
result = getTraitMethod(mc.getTypeAt(TypePath::nil(), isDeref), mc.getMethodName())
13241398
}
13251399

13261400
cached
@@ -1334,7 +1408,7 @@ private module Cached {
13341408
apos.isSelf() and
13351409
receiver = a.getNodeAt(apos) and
13361410
inferType(receiver) = TRefType() and
1337-
CallExprBaseMatching::inferAccessType(a, apos, TypePath::nil()) != TRefType()
1411+
CallExprBaseMatching::inferAccessType(a, apos, TypePath::nil(), _) != TRefType()
13381412
)
13391413
}
13401414

@@ -1344,7 +1418,7 @@ private module Cached {
13441418
exists(CallExprBaseMatchingInput::Access a, CallExprBaseMatchingInput::AccessPosition apos |
13451419
apos.isSelf() and
13461420
receiver = a.getNodeAt(apos) and
1347-
CallExprBaseMatching::inferAccessType(a, apos, TypePath::nil()) = TRefType() and
1421+
CallExprBaseMatching::inferAccessType(a, apos, TypePath::nil(), _) = TRefType() and
13481422
inferType(receiver) != TRefType()
13491423
)
13501424
}
@@ -1357,25 +1431,25 @@ private module Cached {
13571431
f = any(Impl impl | impl.hasTrait()).(ImplItemNode).getAnAssocItem()
13581432
}
13591433

1360-
private Function resolveMethodCallTargetFrom(MethodCall mc, boolean fromSource) {
1361-
result = inferMethodCallTarget(mc) and
1434+
private Function resolveMethodCallTargetFrom(MethodCall mc, boolean fromSource, boolean isDeref) {
1435+
result = inferMethodCallTarget(mc, isDeref) and
13621436
(if result.fromSource() then fromSource = true else fromSource = false) and
13631437
(
13641438
// prioritize inherent implementation methods first
13651439
isInherentImplFunction(result)
13661440
or
1367-
not isInherentImplFunction(inferMethodCallTarget(mc)) and
1441+
not isInherentImplFunction(inferMethodCallTarget(mc, _)) and
13681442
(
13691443
// then trait implementation methods
13701444
isTraitImplFunction(result)
13711445
or
1372-
not isTraitImplFunction(inferMethodCallTarget(mc)) and
1446+
not isTraitImplFunction(inferMethodCallTarget(mc, _)) and
13731447
(
13741448
// then trait methods with default implementations
13751449
result.hasBody()
13761450
or
13771451
// and finally trait methods without default implementations
1378-
not inferMethodCallTarget(mc).hasBody()
1452+
not inferMethodCallTarget(mc, _).hasBody()
13791453
)
13801454
)
13811455
)
@@ -1386,10 +1460,10 @@ private module Cached {
13861460
Function resolveMethodCallTarget(MethodCall mc) {
13871461
// Functions in source code also gets extracted as library code, due to
13881462
// this duplication we prioritize functions from source code.
1389-
result = resolveMethodCallTargetFrom(mc, true)
1463+
result = resolveMethodCallTargetFrom(mc, true, _)
13901464
or
1391-
not exists(resolveMethodCallTargetFrom(mc, true)) and
1392-
result = resolveMethodCallTargetFrom(mc, false)
1465+
not exists(resolveMethodCallTargetFrom(mc, true, _)) and
1466+
result = resolveMethodCallTargetFrom(mc, false, _)
13931467
}
13941468

13951469
pragma[inline]

rust/ql/test/library-tests/type-inference/type-inference.expected

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,14 +1390,11 @@ inferType
13901390
| main.rs:1109:26:1109:27 | x7 | T | file://:0:0:0:0 | & |
13911391
| main.rs:1109:26:1109:27 | x7 | T.&T | main.rs:1060:5:1061:14 | S2 |
13921392
| main.rs:1111:13:1111:14 | x9 | | {EXTERNAL LOCATION} | String |
1393-
| main.rs:1111:13:1111:14 | x9 | | {EXTERNAL LOCATION} | str |
13941393
| main.rs:1111:27:1111:33 | "Hello" | | {EXTERNAL LOCATION} | str |
13951394
| main.rs:1111:27:1111:45 | "Hello".to_string() | | {EXTERNAL LOCATION} | String |
1396-
| main.rs:1111:27:1111:45 | "Hello".to_string() | | {EXTERNAL LOCATION} | str |
13971395
| main.rs:1114:13:1114:13 | u | | {EXTERNAL LOCATION} | Result |
13981396
| main.rs:1114:13:1114:13 | u | T | {EXTERNAL LOCATION} | u32 |
13991397
| main.rs:1114:17:1114:18 | x9 | | {EXTERNAL LOCATION} | String |
1400-
| main.rs:1114:17:1114:18 | x9 | | {EXTERNAL LOCATION} | str |
14011398
| main.rs:1114:17:1114:33 | x9.parse() | | {EXTERNAL LOCATION} | Result |
14021399
| main.rs:1114:17:1114:33 | x9.parse() | T | {EXTERNAL LOCATION} | u32 |
14031400
| main.rs:1121:16:1121:20 | SelfParam | | file://:0:0:0:0 | & |

0 commit comments

Comments
 (0)