Skip to content

Commit 29a0c1e

Browse files
authored
Merge pull request #21823 from Veykril/push-zvmkkvrwwppl
Implement signature type inference
2 parents b42b63f + f32fe40 commit 29a0c1e

59 files changed

Lines changed: 1089 additions & 355 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

crates/hir-def/src/db.rs

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@ use la_arena::ArenaMap;
88
use triomphe::Arc;
99

1010
use crate::{
11-
AssocItemId, AttrDefId, BlockId, BlockLoc, ConstId, ConstLoc, DefWithBodyId, EnumId, EnumLoc,
12-
EnumVariantId, EnumVariantLoc, ExternBlockId, ExternBlockLoc, ExternCrateId, ExternCrateLoc,
13-
FunctionId, FunctionLoc, GenericDefId, ImplId, ImplLoc, LocalFieldId, Macro2Id, Macro2Loc,
14-
MacroExpander, MacroId, MacroRulesId, MacroRulesLoc, MacroRulesLocFlags, ProcMacroId,
15-
ProcMacroLoc, StaticId, StaticLoc, StructId, StructLoc, TraitId, TraitLoc, TypeAliasId,
16-
TypeAliasLoc, UnionId, UnionLoc, UseId, UseLoc, VariantId,
11+
AnonConstId, AnonConstLoc, AssocItemId, AttrDefId, BlockId, BlockLoc, ConstId, ConstLoc,
12+
DefWithBodyId, EnumId, EnumLoc, EnumVariantId, EnumVariantLoc, ExpressionStoreOwner,
13+
ExternBlockId, ExternBlockLoc, ExternCrateId, ExternCrateLoc, FunctionId, FunctionLoc,
14+
GenericDefId, ImplId, ImplLoc, LocalFieldId, Macro2Id, Macro2Loc, MacroExpander, MacroId,
15+
MacroRulesId, MacroRulesLoc, MacroRulesLocFlags, ProcMacroId, ProcMacroLoc, StaticId,
16+
StaticLoc, StructId, StructLoc, TraitId, TraitLoc, TypeAliasId, TypeAliasLoc, UnionId,
17+
UnionLoc, UseId, UseLoc, VariantId,
1718
attrs::AttrFlags,
1819
expr_store::{
1920
Body, BodySourceMap, ExpressionStore, ExpressionStoreSourceMap, scope::ExprScopes,
@@ -61,6 +62,9 @@ pub trait InternDatabase: RootQueryDb {
6162
#[salsa::interned]
6263
fn intern_static(&self, loc: StaticLoc) -> StaticId;
6364

65+
#[salsa::interned]
66+
fn intern_anon_const(&self, loc: AnonConstLoc) -> AnonConstId;
67+
6468
#[salsa::interned]
6569
fn intern_trait(&self, loc: TraitLoc) -> TraitId;
6670

@@ -212,8 +216,15 @@ pub trait DefDatabase: InternDatabase + ExpandDatabase + SourceDatabase {
212216
#[salsa::invoke(Body::body_query)]
213217
fn body(&self, def: DefWithBodyId) -> Arc<Body>;
214218

219+
#[salsa::invoke(ExprScopes::body_expr_scopes_query)]
220+
fn body_expr_scopes(&self, def: DefWithBodyId) -> Arc<ExprScopes>;
221+
222+
#[salsa::invoke(ExprScopes::sig_expr_scopes_query)]
223+
fn sig_expr_scopes(&self, def: GenericDefId) -> Arc<ExprScopes>;
224+
225+
#[salsa::transparent]
215226
#[salsa::invoke(ExprScopes::expr_scopes_query)]
216-
fn expr_scopes(&self, def: DefWithBodyId) -> Arc<ExprScopes>;
227+
fn expr_scopes(&self, def: ExpressionStoreOwner) -> Arc<ExprScopes>;
217228

218229
#[salsa::transparent]
219230
#[salsa::invoke(GenericParams::new)]

crates/hir-def/src/expr_store.rs

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,24 @@ pub type TypeSource = InFile<TypePtr>;
9494
pub type LifetimePtr = AstPtr<ast::Lifetime>;
9595
pub type LifetimeSource = InFile<LifetimePtr>;
9696

97+
/// Describes where a const expression originated from.
98+
///
99+
/// Used by signature/body inference to determine the expected type for each
100+
/// const expression root.
101+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
102+
pub enum ConstExprOrigin {
103+
/// Array length expression: `[T; <expr>]` — expected type is `usize`.
104+
ArrayLength,
105+
/// Const parameter default value: `const N: usize = <expr>`.
106+
ConstParam(crate::hir::generics::LocalTypeOrConstParamId),
107+
/// Const generic argument in a path: `SomeType::<{ <expr> }>` or `some_fn::<{ <expr> }>()`.
108+
/// Determining the expected type requires path resolution, so it is deferred.
109+
GenericArgsPath,
110+
}
111+
97112
// We split the store into types-only and expressions, because most stores (e.g. generics)
98113
// don't store any expressions and this saves memory. Same thing for the source map.
99-
#[derive(Debug, PartialEq, Eq)]
114+
#[derive(Debug, Clone, PartialEq, Eq)]
100115
struct ExpressionOnlyStore {
101116
exprs: Arena<Expr>,
102117
pats: Arena<Pat>,
@@ -113,9 +128,15 @@ struct ExpressionOnlyStore {
113128
/// Expressions (and destructuing patterns) that can be recorded here are single segment path, although not all single segments path refer
114129
/// to variables and have hygiene (some refer to items, we don't know at this stage).
115130
ident_hygiene: FxHashMap<ExprOrPatId, HygieneId>,
131+
132+
/// Maps const expression roots to their origin.
133+
///
134+
/// Populated during lowering. Used by signature inference to determine expected types,
135+
/// and by `signature_const_expr_roots()` to enumerate roots for scope computation.
136+
const_expr_origins: ThinVec<(ExprId, ConstExprOrigin)>,
116137
}
117138

118-
#[derive(Debug, PartialEq, Eq)]
139+
#[derive(Debug, Clone, PartialEq, Eq)]
119140
pub struct ExpressionStore {
120141
expr_only: Option<Box<ExpressionOnlyStore>>,
121142
pub types: Arena<TypeRef>,
@@ -226,6 +247,7 @@ pub struct ExpressionStoreBuilder {
226247
pub types: Arena<TypeRef>,
227248
block_scopes: Vec<BlockId>,
228249
ident_hygiene: FxHashMap<ExprOrPatId, HygieneId>,
250+
pub const_expr_origins: Option<ThinVec<(ExprId, ConstExprOrigin)>>,
229251

230252
// AST expressions can create patterns in destructuring assignments. Therefore, `ExprSource` can also map
231253
// to `PatId`, and `PatId` can also map to `ExprSource` (the other way around is unaffected).
@@ -297,6 +319,7 @@ impl ExpressionStoreBuilder {
297319
mut bindings,
298320
mut binding_owners,
299321
mut ident_hygiene,
322+
mut const_expr_origins,
300323
mut types,
301324
mut lifetimes,
302325

@@ -356,6 +379,9 @@ impl ExpressionStoreBuilder {
356379

357380
let store = {
358381
let expr_only = if has_exprs {
382+
if let Some(const_expr_origins) = &mut const_expr_origins {
383+
const_expr_origins.shrink_to_fit();
384+
}
359385
Some(Box::new(ExpressionOnlyStore {
360386
exprs,
361387
pats,
@@ -364,6 +390,7 @@ impl ExpressionStoreBuilder {
364390
binding_owners,
365391
block_scopes: block_scopes.into_boxed_slice(),
366392
ident_hygiene,
393+
const_expr_origins: const_expr_origins.unwrap_or_default(),
367394
}))
368395
} else {
369396
None
@@ -413,6 +440,29 @@ impl ExpressionStore {
413440
EMPTY.clone()
414441
}
415442

443+
/// Returns all const expression root `ExprId`s found in this store.
444+
///
445+
/// Used to compute expression scopes for signature stores.
446+
pub fn signature_const_expr_roots(&self) -> impl Iterator<Item = ExprId> {
447+
self.const_expr_origins().iter().map(|&(id, _)| id)
448+
}
449+
450+
/// Like [`Self::signature_const_expr_roots`], but also returns the origin
451+
/// of each const expression.
452+
///
453+
/// This is used by signature inference to determine the expected type for
454+
/// each root expression.
455+
pub fn signature_const_expr_roots_with_origins(
456+
&self,
457+
) -> impl Iterator<Item = (ExprId, ConstExprOrigin)> {
458+
self.const_expr_origins().iter().map(|&(id, origin)| (id, origin))
459+
}
460+
461+
/// Returns the map of const expression roots to their origins.
462+
pub fn const_expr_origins(&self) -> &[(ExprId, ConstExprOrigin)] {
463+
self.expr_only.as_ref().map_or(&[], |it| &it.const_expr_origins)
464+
}
465+
416466
/// Returns an iterator over all block expressions in this store that define inner items.
417467
pub fn blocks<'a>(
418468
&'a self,

crates/hir-def/src/expr_store/body.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ impl Body {
9999
DefWithBodyId::VariantId(v) => {
100100
let s = v.lookup(db);
101101
let src = s.source(db);
102-
src.map(|it| it.expr())
102+
src.map(|it| it.const_arg()?.expr())
103103
}
104104
}
105105
};

crates/hir-def/src/expr_store/lower.rs

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ use crate::{
3737
attrs::AttrFlags,
3838
db::DefDatabase,
3939
expr_store::{
40-
Body, BodySourceMap, ExprPtr, ExpressionStore, ExpressionStoreBuilder,
40+
Body, BodySourceMap, ConstExprOrigin, ExprPtr, ExpressionStore, ExpressionStoreBuilder,
4141
ExpressionStoreDiagnostics, ExpressionStoreSourceMap, HygieneId, LabelPtr, LifetimePtr,
4242
PatPtr, TypePtr,
4343
expander::Expander,
@@ -79,7 +79,7 @@ pub(super) fn lower_body(
7979
let mut self_param = None;
8080
let mut source_map_self_param = None;
8181
let mut params = vec![];
82-
let mut collector = ExprCollector::new(db, module, current_file_id);
82+
let mut collector = ExprCollector::body(db, module, current_file_id);
8383

8484
let skip_body = AttrFlags::query(
8585
db,
@@ -186,7 +186,7 @@ pub(crate) fn lower_type_ref(
186186
module: ModuleId,
187187
type_ref: InFile<Option<ast::Type>>,
188188
) -> (ExpressionStore, ExpressionStoreSourceMap, TypeRefId) {
189-
let mut expr_collector = ExprCollector::new(db, module, type_ref.file_id);
189+
let mut expr_collector = ExprCollector::signature(db, module, type_ref.file_id);
190190
let type_ref =
191191
expr_collector.lower_type_ref_opt(type_ref.value, &mut ExprCollector::impl_trait_allocator);
192192
let (store, source_map) = expr_collector.store.finish();
@@ -201,7 +201,7 @@ pub(crate) fn lower_generic_params(
201201
param_list: Option<ast::GenericParamList>,
202202
where_clause: Option<ast::WhereClause>,
203203
) -> (Arc<ExpressionStore>, Arc<GenericParams>, ExpressionStoreSourceMap) {
204-
let mut expr_collector = ExprCollector::new(db, module, file_id);
204+
let mut expr_collector = ExprCollector::signature(db, module, file_id);
205205
let mut collector = generics::GenericParamsCollector::new(def);
206206
collector.lower(&mut expr_collector, param_list, where_clause);
207207
let params = collector.finish();
@@ -215,7 +215,7 @@ pub(crate) fn lower_impl(
215215
impl_syntax: InFile<ast::Impl>,
216216
impl_id: ImplId,
217217
) -> (ExpressionStore, ExpressionStoreSourceMap, TypeRefId, Option<TraitRef>, Arc<GenericParams>) {
218-
let mut expr_collector = ExprCollector::new(db, module, impl_syntax.file_id);
218+
let mut expr_collector = ExprCollector::signature(db, module, impl_syntax.file_id);
219219
let self_ty =
220220
expr_collector.lower_type_ref_opt_disallow_impl_trait(impl_syntax.value.self_ty());
221221
let trait_ = impl_syntax.value.trait_().and_then(|it| match &it {
@@ -243,7 +243,7 @@ pub(crate) fn lower_trait(
243243
trait_syntax: InFile<ast::Trait>,
244244
trait_id: TraitId,
245245
) -> (ExpressionStore, ExpressionStoreSourceMap, Arc<GenericParams>) {
246-
let mut expr_collector = ExprCollector::new(db, module, trait_syntax.file_id);
246+
let mut expr_collector = ExprCollector::signature(db, module, trait_syntax.file_id);
247247
let mut collector = generics::GenericParamsCollector::with_self_param(
248248
&mut expr_collector,
249249
trait_id.into(),
@@ -271,7 +271,7 @@ pub(crate) fn lower_type_alias(
271271
Box<[TypeBound]>,
272272
Option<TypeRefId>,
273273
) {
274-
let mut expr_collector = ExprCollector::new(db, module, alias.file_id);
274+
let mut expr_collector = ExprCollector::signature(db, module, alias.file_id);
275275
let bounds = alias
276276
.value
277277
.type_bound_list()
@@ -313,7 +313,7 @@ pub(crate) fn lower_function(
313313
bool,
314314
bool,
315315
) {
316-
let mut expr_collector = ExprCollector::new(db, module, fn_.file_id);
316+
let mut expr_collector = ExprCollector::signature(db, module, fn_.file_id);
317317
let mut collector = generics::GenericParamsCollector::new(function_id.into());
318318
collector.lower(&mut expr_collector, fn_.value.generic_param_list(), fn_.value.where_clause());
319319
let mut params = vec![];
@@ -532,7 +532,20 @@ impl BindingList {
532532
}
533533

534534
impl<'db> ExprCollector<'db> {
535-
pub fn new(
535+
/// Creates a collector for a signature store, this will populate `const_expr_origins` to any
536+
/// top level const arg roots.
537+
pub fn signature(
538+
db: &dyn DefDatabase,
539+
module: ModuleId,
540+
current_file_id: HirFileId,
541+
) -> ExprCollector<'_> {
542+
let mut this = Self::body(db, module, current_file_id);
543+
this.store.const_expr_origins = Some(Default::default());
544+
this
545+
}
546+
547+
/// Creates a collector for a bidy store.
548+
pub fn body(
536549
db: &dyn DefDatabase,
537550
module: ModuleId,
538551
current_file_id: HirFileId,
@@ -577,7 +590,10 @@ impl<'db> ExprCollector<'db> {
577590
self.expander.span_map()
578591
}
579592

580-
pub fn lower_lifetime_ref(&mut self, lifetime: ast::Lifetime) -> LifetimeRefId {
593+
pub(in crate::expr_store) fn lower_lifetime_ref(
594+
&mut self,
595+
lifetime: ast::Lifetime,
596+
) -> LifetimeRefId {
581597
// FIXME: Keyword check?
582598
let lifetime_ref = match &*lifetime.text() {
583599
"" | "'" => LifetimeRef::Error,
@@ -588,15 +604,18 @@ impl<'db> ExprCollector<'db> {
588604
self.alloc_lifetime_ref(lifetime_ref, AstPtr::new(&lifetime))
589605
}
590606

591-
pub fn lower_lifetime_ref_opt(&mut self, lifetime: Option<ast::Lifetime>) -> LifetimeRefId {
607+
pub(in crate::expr_store) fn lower_lifetime_ref_opt(
608+
&mut self,
609+
lifetime: Option<ast::Lifetime>,
610+
) -> LifetimeRefId {
592611
match lifetime {
593612
Some(lifetime) => self.lower_lifetime_ref(lifetime),
594613
None => self.alloc_lifetime_ref_desugared(LifetimeRef::Placeholder),
595614
}
596615
}
597616

598617
/// Converts an `ast::TypeRef` to a `hir::TypeRef`.
599-
pub fn lower_type_ref(
618+
pub(in crate::expr_store) fn lower_type_ref(
600619
&mut self,
601620
node: ast::Type,
602621
impl_trait_lower_fn: ImplTraitLowerFn<'_>,
@@ -621,6 +640,9 @@ impl<'db> ExprCollector<'db> {
621640
}
622641
ast::Type::ArrayType(inner) => {
623642
let len = self.lower_const_arg_opt(inner.const_arg());
643+
if let Some(const_expr_origins) = &mut self.store.const_expr_origins {
644+
const_expr_origins.push((len.expr, ConstExprOrigin::ArrayLength));
645+
}
624646
TypeRef::Array(ArrayType {
625647
ty: self.lower_type_ref_opt(inner.ty(), impl_trait_lower_fn),
626648
len,
@@ -810,7 +832,7 @@ impl<'db> ExprCollector<'db> {
810832

811833
/// Collect `GenericArgs` from the parts of a fn-like path, i.e. `Fn(X, Y)
812834
/// -> Z` (which desugars to `Fn<(X, Y), Output=Z>`).
813-
pub fn lower_generic_args_from_fn_path(
835+
pub(in crate::expr_store) fn lower_generic_args_from_fn_path(
814836
&mut self,
815837
args: Option<ast::ParenthesizedArgList>,
816838
ret_type: Option<ast::RetType>,
@@ -905,6 +927,9 @@ impl<'db> ExprCollector<'db> {
905927
}
906928
ast::GenericArg::ConstArg(arg) => {
907929
let arg = self.lower_const_arg(arg);
930+
if let Some(const_expr_origins) = &mut self.store.const_expr_origins {
931+
const_expr_origins.push((arg.expr, ConstExprOrigin::GenericArgsPath));
932+
}
908933
args.push(GenericArg::Const(arg))
909934
}
910935
}
@@ -1045,17 +1070,30 @@ impl<'db> ExprCollector<'db> {
10451070
}
10461071

10471072
fn lower_const_arg_opt(&mut self, arg: Option<ast::ConstArg>) -> ConstRef {
1048-
ConstRef { expr: self.collect_expr_opt(arg.and_then(|it| it.expr())) }
1073+
let const_expr_origins = self.store.const_expr_origins.take();
1074+
let r = ConstRef { expr: self.collect_expr_opt(arg.and_then(|it| it.expr())) };
1075+
self.store.const_expr_origins = const_expr_origins;
1076+
r
10491077
}
10501078

1051-
fn lower_const_arg(&mut self, arg: ast::ConstArg) -> ConstRef {
1052-
ConstRef { expr: self.collect_expr_opt(arg.expr()) }
1079+
pub fn lower_const_arg(&mut self, arg: ast::ConstArg) -> ConstRef {
1080+
let const_expr_origins = self.store.const_expr_origins.take();
1081+
let r = ConstRef { expr: self.collect_expr_opt(arg.expr()) };
1082+
self.store.const_expr_origins = const_expr_origins;
1083+
r
10531084
}
10541085

10551086
fn collect_expr(&mut self, expr: ast::Expr) -> ExprId {
10561087
self.maybe_collect_expr(expr).unwrap_or_else(|| self.missing_expr())
10571088
}
10581089

1090+
pub(in crate::expr_store) fn collect_expr_opt(&mut self, expr: Option<ast::Expr>) -> ExprId {
1091+
match expr {
1092+
Some(expr) => self.collect_expr(expr),
1093+
None => self.missing_expr(),
1094+
}
1095+
}
1096+
10591097
/// Returns `None` if and only if the expression is `#[cfg]`d out.
10601098
fn maybe_collect_expr(&mut self, expr: ast::Expr) -> Option<ExprId> {
10611099
let syntax_ptr = AstPtr::new(&expr);
@@ -2065,13 +2103,6 @@ impl<'db> ExprCollector<'db> {
20652103
}
20662104
}
20672105

2068-
pub fn collect_expr_opt(&mut self, expr: Option<ast::Expr>) -> ExprId {
2069-
match expr {
2070-
Some(expr) => self.collect_expr(expr),
2071-
None => self.missing_expr(),
2072-
}
2073-
}
2074-
20752106
fn collect_macro_as_stmt(
20762107
&mut self,
20772108
statements: &mut Vec<Statement>,

crates/hir-def/src/expr_store/lower/generics.rs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,17 @@ impl GenericParamsCollector {
141141
const_param.ty(),
142142
&mut ExprCollector::impl_trait_error_allocator,
143143
);
144-
let param = ConstParamData {
145-
name,
146-
ty,
147-
default: const_param.default_val().map(|it| ec.lower_const_arg(it)),
148-
};
149-
let _idx = self.type_or_consts.alloc(param.into());
144+
let default = const_param.default_val().map(|it| ec.lower_const_arg(it));
145+
let param = ConstParamData { name, ty, default };
146+
let idx = self.type_or_consts.alloc(param.into());
147+
if let Some(default) = default
148+
&& let Some(const_expr_origins) = &mut ec.store.const_expr_origins
149+
{
150+
const_expr_origins.push((
151+
default.expr,
152+
crate::expr_store::ConstExprOrigin::ConstParam(idx),
153+
));
154+
}
150155
}
151156
ast::GenericParam::LifetimeParam(lifetime_param) => {
152157
let lifetime = ec.lower_lifetime_ref_opt(lifetime_param.lifetime());

crates/hir-def/src/expr_store/lower/path/tests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ fn lower_path(path: ast::Path) -> (TestDB, ExpressionStore, Option<Path>) {
2121
let (db, file_id) = TestDB::with_single_file("");
2222
let krate = db.fetch_test_crate();
2323
let mut ctx =
24-
ExprCollector::new(&db, crate_def_map(&db, krate).root_module_id(), file_id.into());
24+
ExprCollector::signature(&db, crate_def_map(&db, krate).root_module_id(), file_id.into());
2525
let lowered_path = ctx.lower_path(path, &mut ExprCollector::impl_trait_allocator);
2626
let (store, _) = ctx.store.finish();
2727
(db, store, lowered_path)

0 commit comments

Comments
 (0)