Skip to content

Commit b41df50

Browse files
authored
Merge pull request #13 from coord-e/subst-body
Instantiate body instead of using ParamTypeMapper
2 parents 37b4fed + e008bff commit b41df50

File tree

7 files changed

+124
-99
lines changed

7 files changed

+124
-99
lines changed

src/analyze.rs

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,14 @@ impl<'tcx> ReplacePlacesVisitor<'tcx> {
104104
}
105105

106106
#[derive(Debug, Clone)]
107-
struct DeferredDefTy {
108-
cache: Rc<RefCell<HashMap<rty::TypeArgs, rty::RefinedType>>>,
107+
struct DeferredDefTy<'tcx> {
108+
cache: Rc<RefCell<HashMap<mir_ty::GenericArgsRef<'tcx>, rty::RefinedType>>>,
109109
}
110110

111111
#[derive(Debug, Clone)]
112-
enum DefTy {
112+
enum DefTy<'tcx> {
113113
Concrete(rty::RefinedType),
114-
Deferred(DeferredDefTy),
114+
Deferred(DeferredDefTy<'tcx>),
115115
}
116116

117117
#[derive(Clone)]
@@ -123,7 +123,7 @@ pub struct Analyzer<'tcx> {
123123
/// currently contains only local-def templates,
124124
/// but will be extended to contain externally known def's refinement types
125125
/// (at least for every defs referenced by local def bodies)
126-
defs: HashMap<DefId, DefTy>,
126+
defs: HashMap<DefId, DefTy<'tcx>>,
127127

128128
/// Resulting CHC system.
129129
system: Rc<RefCell<chc::System>>,
@@ -241,15 +241,17 @@ impl<'tcx> Analyzer<'tcx> {
241241
pub fn def_ty_with_args(
242242
&mut self,
243243
def_id: DefId,
244-
rty_args: rty::TypeArgs,
244+
generic_args: mir_ty::GenericArgsRef<'tcx>,
245245
) -> Option<rty::RefinedType> {
246246
let deferred_ty = match self.defs.get(&def_id)? {
247247
DefTy::Concrete(rty) => {
248+
let type_builder = TypeBuilder::new(self.tcx, def_id);
249+
248250
let mut def_ty = rty.clone();
249251
def_ty.instantiate_ty_params(
250-
rty_args
251-
.clone()
252-
.into_iter()
252+
generic_args
253+
.types()
254+
.map(|ty| type_builder.build(ty))
253255
.map(rty::RefinedType::unrefined)
254256
.collect(),
255257
);
@@ -259,21 +261,17 @@ impl<'tcx> Analyzer<'tcx> {
259261
};
260262

261263
let deferred_ty_cache = Rc::clone(&deferred_ty.cache); // to cut reference to allow &mut self
262-
if let Some(rty) = deferred_ty_cache.borrow().get(&rty_args) {
264+
if let Some(rty) = deferred_ty_cache.borrow().get(&generic_args) {
263265
return Some(rty.clone());
264266
}
265267

266-
let type_builder = TypeBuilder::new(self.tcx, def_id).with_param_mapper({
267-
let rty_args = rty_args.clone();
268-
move |ty: rty::ParamType| rty_args[ty.idx].clone()
269-
});
270268
let mut analyzer = self.local_def_analyzer(def_id.as_local()?);
271-
analyzer.type_builder(type_builder);
269+
analyzer.generic_args(generic_args);
272270

273271
let expected = analyzer.expected_ty();
274272
deferred_ty_cache
275273
.borrow_mut()
276-
.insert(rty_args, expected.clone());
274+
.insert(generic_args, expected.clone());
277275

278276
analyzer.run(&expected);
279277
Some(expected)
@@ -340,4 +338,30 @@ impl<'tcx> Analyzer<'tcx> {
340338
self.tcx.dcx().err(format!("verification error: {:?}", err));
341339
}
342340
}
341+
342+
/// Computes the signature of the local function.
343+
///
344+
/// This is a drop-in replacement of `self.tcx.fn_sig(local_def_id).instantiate_identity().skip_binder()`,
345+
/// but extracts parameter and return types directly from the given `body` to obtain a signature that
346+
/// reflects potential type instantiations happened after `optimized_mir`.
347+
pub fn local_fn_sig_with_body(
348+
&self,
349+
local_def_id: LocalDefId,
350+
body: &mir::Body<'tcx>,
351+
) -> mir_ty::FnSig<'tcx> {
352+
let ty = self.tcx.type_of(local_def_id).instantiate_identity();
353+
let sig = if let mir_ty::TyKind::Closure(_, substs) = ty.kind() {
354+
substs.as_closure().sig().skip_binder()
355+
} else {
356+
ty.fn_sig(self.tcx).skip_binder()
357+
};
358+
359+
self.tcx.mk_fn_sig(
360+
body.args_iter().map(|arg| body.local_decls[arg].ty),
361+
body.return_ty(),
362+
sig.c_variadic,
363+
sig.unsafety,
364+
sig.abi,
365+
)
366+
}
343367
}

src/analyze/basic_block.rs

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -363,15 +363,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
363363
_ty,
364364
) => {
365365
let func_ty = match operand.const_fn_def() {
366-
Some((def_id, args)) => {
367-
let rty_args: IndexVec<_, _> =
368-
args.types().map(|ty| self.type_builder.build(ty)).collect();
369-
self.ctx
370-
.def_ty_with_args(def_id, rty_args)
371-
.expect("unknown def")
372-
.ty
373-
.clone()
374-
}
366+
Some((def_id, args)) => self
367+
.ctx
368+
.def_ty_with_args(def_id, args)
369+
.expect("unknown def")
370+
.ty
371+
.clone(),
375372
_ => unimplemented!(),
376373
};
377374
PlaceType::with_ty_and_term(func_ty.vacuous(), chc::Term::null())
@@ -571,14 +568,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
571568
let ret = rty::RefinedType::new(rty::Type::unit(), ret_formula.into());
572569
rty::FunctionType::new([param1, param2].into_iter().collect(), ret).into()
573570
}
574-
Some((def_id, args)) => {
575-
let rty_args = args.types().map(|ty| self.type_builder.build(ty)).collect();
576-
self.ctx
577-
.def_ty_with_args(def_id, rty_args)
578-
.expect("unknown def")
579-
.ty
580-
.vacuous()
581-
}
571+
Some((def_id, args)) => self
572+
.ctx
573+
.def_ty_with_args(def_id, args)
574+
.expect("unknown def")
575+
.ty
576+
.vacuous(),
582577
_ => self.operand_type(func.clone()).ty,
583578
};
584579
let expected_args: IndexVec<_, _> = args
@@ -1088,11 +1083,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
10881083
self
10891084
}
10901085

1091-
pub fn type_builder(&mut self, type_builder: TypeBuilder<'tcx>) -> &mut Self {
1092-
self.type_builder = type_builder;
1093-
self
1094-
}
1095-
10961086
pub fn run(&mut self, expected: &BasicBlockType) {
10971087
let span = tracing::info_span!("bb", bb = ?self.basic_block);
10981088
let _guard = span.enter();

src/analyze/crate_.rs

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
7676

7777
// check polymorphic function def by replacing type params with some opaque type
7878
// (and this is no-op if the function is mono)
79-
let type_builder = TypeBuilder::new(self.tcx, local_def_id.to_def_id())
80-
.with_param_mapper(|_| rty::Type::int());
8179
let mut expected = expected.clone();
8280
let subst = rty::TypeParamSubst::new(
8381
expected
@@ -87,13 +85,62 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
8785
.collect(),
8886
);
8987
expected.subst_ty_params(&subst);
88+
let generic_args = self.placeholder_generic_args(*local_def_id);
9089
self.ctx
9190
.local_def_analyzer(*local_def_id)
92-
.type_builder(type_builder)
91+
.generic_args(generic_args)
9392
.run(&expected);
9493
}
9594
}
9695

96+
fn placeholder_generic_args(&self, local_def_id: LocalDefId) -> mir_ty::GenericArgsRef<'tcx> {
97+
let mut constrained_params = HashSet::new();
98+
let predicates = self.tcx.predicates_of(local_def_id);
99+
let sized_trait = self.tcx.lang_items().sized_trait().unwrap();
100+
for (clause, _) in predicates.predicates {
101+
let mir_ty::ClauseKind::Trait(pred) = clause.kind().skip_binder() else {
102+
continue;
103+
};
104+
if pred.def_id() == sized_trait {
105+
continue;
106+
};
107+
for arg in pred.trait_ref.args.iter().flat_map(|ty| ty.walk()) {
108+
let Some(ty) = arg.as_type() else {
109+
continue;
110+
};
111+
let mir_ty::TyKind::Param(param_ty) = ty.kind() else {
112+
continue;
113+
};
114+
constrained_params.insert(param_ty.index);
115+
}
116+
}
117+
118+
let mut args: Vec<mir_ty::GenericArg<'tcx>> = Vec::new();
119+
120+
let generics = self.tcx.generics_of(local_def_id);
121+
for idx in 0..generics.count() {
122+
let param = generics.param_at(idx, self.tcx);
123+
let arg = match param.kind {
124+
mir_ty::GenericParamDefKind::Type { .. } => {
125+
if constrained_params.contains(&param.index) {
126+
panic!(
127+
"unable to check generic function with constrained type parameter: {}",
128+
self.tcx.def_path_str(local_def_id)
129+
);
130+
}
131+
self.tcx.types.i32.into()
132+
}
133+
mir_ty::GenericParamDefKind::Const { .. } => {
134+
unimplemented!()
135+
}
136+
mir_ty::GenericParamDefKind::Lifetime { .. } => self.tcx.lifetimes.re_erased.into(),
137+
};
138+
args.push(arg);
139+
}
140+
141+
self.tcx.mk_args(&args)
142+
}
143+
97144
fn assert_callable_entry(&mut self) {
98145
if let Some((def_id, _)) = self.tcx.entry_fn(()) {
99146
// we want to assert entry function is safe to execute without any assumption

src/analyze/local_def.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
169169
}
170170

171171
pub fn expected_ty(&mut self) -> rty::RefinedType {
172-
let sig = self.tcx.fn_sig(self.local_def_id);
173-
let sig = sig.instantiate_identity().skip_binder();
172+
let sig = self
173+
.ctx
174+
.local_fn_sig_with_body(self.local_def_id, &self.body);
174175

175176
let mut param_resolver = analyze::annot::ParamResolver::default();
176177
for (input_ident, input_ty) in self
@@ -532,7 +533,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
532533
.basic_block_analyzer(self.local_def_id, bb)
533534
.body(self.body.clone())
534535
.drop_points(drop_points)
535-
.type_builder(self.type_builder.clone())
536536
.run(&rty);
537537
}
538538
}
@@ -649,8 +649,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
649649
}
650650
}
651651

652-
pub fn type_builder(&mut self, type_builder: TypeBuilder<'tcx>) -> &mut Self {
653-
self.type_builder = type_builder;
652+
pub fn generic_args(&mut self, generic_args: mir_ty::GenericArgsRef<'tcx>) -> &mut Self {
653+
self.body =
654+
mir_ty::EarlyBinder::bind(self.body.clone()).instantiate(self.tcx, generic_args);
654655
self
655656
}
656657

src/chc.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ impl Function {
389389
}
390390

391391
/// A logical term.
392-
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
392+
#[derive(Debug, Clone)]
393393
pub enum Term<V = TermVarIdx> {
394394
Null,
395395
Var(V),
@@ -991,7 +991,7 @@ impl Pred {
991991
}
992992

993993
/// An atom is a predicate applied to a list of terms.
994-
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
994+
#[derive(Debug, Clone)]
995995
pub struct Atom<V = TermVarIdx> {
996996
pub pred: Pred,
997997
pub args: Vec<Term<V>>,
@@ -1084,7 +1084,7 @@ impl<V> Atom<V> {
10841084
/// While it allows arbitrary [`Atom`] in its `Atom` variant, we only expect atoms with known
10851085
/// predicates (i.e., predicates other than `Pred::Var`) to appear in formulas. It is our TODO to
10861086
/// enforce this restriction statically. Also see the definition of [`Body`].
1087-
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1087+
#[derive(Debug, Clone)]
10881088
pub enum Formula<V = TermVarIdx> {
10891089
Atom(Atom<V>),
10901090
Not(Box<Formula<V>>),
@@ -1338,7 +1338,7 @@ impl<V> Formula<V> {
13381338
}
13391339

13401340
/// The body part of a clause, consisting of atoms and a formula.
1341-
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1341+
#[derive(Debug, Clone)]
13421342
pub struct Body<V = TermVarIdx> {
13431343
pub atoms: Vec<Atom<V>>,
13441344
/// NOTE: This doesn't contain predicate variables. Also see [`Formula`].

src/refine/template.rs

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -59,19 +59,6 @@ where
5959
}
6060
}
6161

62-
trait ParamTypeMapper {
63-
fn map_param_ty(&self, ty: rty::ParamType) -> rty::Type<rty::Closed>;
64-
}
65-
66-
impl<F> ParamTypeMapper for F
67-
where
68-
F: Fn(rty::ParamType) -> rty::Type<rty::Closed>,
69-
{
70-
fn map_param_ty(&self, ty: rty::ParamType) -> rty::Type<rty::Closed> {
71-
self(ty)
72-
}
73-
}
74-
7562
/// Translates [`mir_ty::Ty`] to [`rty::Type`].
7663
///
7764
/// This struct implements a translation from Rust MIR types to Thrust types.
@@ -87,9 +74,6 @@ pub struct TypeBuilder<'tcx> {
8774
/// mapped when we translate a [`mir_ty::ParamTy`] to [`rty::ParamType`].
8875
/// See [`rty::TypeParamIdx`] for more details.
8976
param_idx_mapping: HashMap<u32, rty::TypeParamIdx>,
90-
/// Optionally also want to further map rty::ParamType to other rty::Type before generating
91-
/// templates. This is no-op by default.
92-
param_type_mapper: std::rc::Rc<dyn ParamTypeMapper>,
9377
}
9478

9579
impl<'tcx> TypeBuilder<'tcx> {
@@ -109,25 +93,15 @@ impl<'tcx> TypeBuilder<'tcx> {
10993
Self {
11094
tcx,
11195
param_idx_mapping,
112-
param_type_mapper: std::rc::Rc::new(|ty: rty::ParamType| ty.into()),
11396
}
11497
}
11598

116-
pub fn with_param_mapper<F>(mut self, mapper: F) -> Self
117-
where
118-
F: Fn(rty::ParamType) -> rty::Type<rty::Closed> + 'static,
119-
{
120-
self.param_type_mapper = std::rc::Rc::new(mapper);
121-
self
122-
}
123-
12499
fn translate_param_type(&self, ty: &mir_ty::ParamTy) -> rty::Type<rty::Closed> {
125100
let index = *self
126101
.param_idx_mapping
127102
.get(&ty.index)
128103
.expect("unknown type param idx");
129-
let param_ty = rty::ParamType::new(index);
130-
self.param_type_mapper.map_param_ty(param_ty)
104+
rty::ParamType::new(index).into()
131105
}
132106

133107
// TODO: consolidate two impls
@@ -400,17 +374,6 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> {
400374
self.ret_rty = Some(rty);
401375
self
402376
}
403-
404-
pub fn would_contain_template(&self) -> bool {
405-
if self.param_tys.is_empty() {
406-
return self.ret_rty.is_none();
407-
}
408-
409-
let last_param_idx = rty::FunctionParamIdx::from(self.param_tys.len() - 1);
410-
let param_annotated =
411-
self.param_refinement.is_some() || self.param_rtys.contains_key(&last_param_idx);
412-
self.ret_rty.is_none() || !param_annotated
413-
}
414377
}
415378

416379
impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R>

0 commit comments

Comments
 (0)