From 12411774676804c23cb825a7a166ae74986ba48a Mon Sep 17 00:00:00 2001 From: coord_e Date: Tue, 23 Sep 2025 17:15:36 +0900 Subject: [PATCH 1/5] Support enums with lifetime params --- src/analyze/crate_.rs | 58 ++++++++++++++++++++++++++--------- src/rty/params.rs | 4 +++ tests/ui/fail/adt_poly_ref.rs | 14 +++++++++ tests/ui/pass/adt_poly_ref.rs | 14 +++++++++ 4 files changed, 76 insertions(+), 14 deletions(-) create mode 100644 tests/ui/fail/adt_poly_ref.rs create mode 100644 tests/ui/pass/adt_poly_ref.rs diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index 9a1fa67..fa7143b 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -251,6 +251,29 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { continue; }; let adt = self.tcx.adt_def(local_def_id); + + // The index of TyKind::ParamTy is based on the every generic parameters in + // the definition, including lifetimes. Given the following definition: + // + // struct X<'a, T> { f: &'a T } + // + // The type of field `f` is &T1 (not T0). However, in Thrust, we ignore lifetime + // parameters and the index of rty::ParamType is based on type parameters only. + // We're building a mapping from the original index to the new index here. + let generics = self.tcx.generics_of(local_def_id); + let mut type_param_mapping: std::collections::HashMap = + Default::default(); + for i in 0..generics.count() { + let generic_param = generics.param_at(i, self.tcx); + match generic_param.kind { + mir_ty::GenericParamDefKind::Lifetime => {} + mir_ty::GenericParamDefKind::Type { .. } => { + type_param_mapping.insert(i, type_param_mapping.len()); + } + mir_ty::GenericParamDefKind::Const { .. } => unimplemented!(), + } + } + let name = refine::datatype_symbol(self.tcx, local_def_id.to_def_id()); let variants: IndexVec<_, _> = adt .variants() @@ -264,7 +287,26 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .iter() .map(|field| { let field_ty = self.tcx.type_of(field.did).instantiate_identity(); - self.ctx.unrefined_ty(field_ty) + + // see the comment above about this mapping + let subst = rty::TypeParamSubst::new( + type_param_mapping + .iter() + .map(|(old, new)| { + let old = rty::TypeParamIdx::from(*old); + let new = + rty::ParamType::new(rty::TypeParamIdx::from(*new)); + (old, rty::RefinedType::unrefined(new.into())) + }) + .collect(), + ); + + // the subst doesn't contain refinements, so it's OK to take ty only + // after substitution + let mut field_rty = + rty::RefinedType::unrefined(self.ctx.unrefined_ty(field_ty)); + field_rty.subst_ty_params(&subst); + field_rty.ty }) .collect(); rty::EnumVariantDef { @@ -275,19 +317,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { }) .collect(); - let ty_params = adt - .all_fields() - .map(|f| self.tcx.type_of(f.did).instantiate_identity()) - .flat_map(|ty| { - if let mir_ty::TyKind::Param(p) = ty.kind() { - Some(p.index as usize) - } else { - None - } - }) - .max() - .map(|max| max + 1) - .unwrap_or(0); + let ty_params = type_param_mapping.len(); tracing::debug!(?local_def_id, ?name, ?ty_params, "ty_params count"); let def = rty::EnumDatatypeDef { diff --git a/src/rty/params.rs b/src/rty/params.rs index 17ebc2b..fa0daa3 100644 --- a/src/rty/params.rs +++ b/src/rty/params.rs @@ -71,6 +71,10 @@ impl std::ops::Index for TypeParamSubst { } impl TypeParamSubst { + pub fn new(subst: BTreeMap>) -> Self { + Self { subst } + } + pub fn singleton(idx: TypeParamIdx, ty: RefinedType) -> Self { let mut subst = BTreeMap::default(); subst.insert(idx, ty); diff --git a/tests/ui/fail/adt_poly_ref.rs b/tests/ui/fail/adt_poly_ref.rs new file mode 100644 index 0000000..8c42d4b --- /dev/null +++ b/tests/ui/fail/adt_poly_ref.rs @@ -0,0 +1,14 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +enum X<'a, T> { + A(&'a T), +} + +fn main() { + let i = 42; + let x = X::A(&i); + match x { + X::A(i) => assert!(*i == 41), + } +} diff --git a/tests/ui/pass/adt_poly_ref.rs b/tests/ui/pass/adt_poly_ref.rs new file mode 100644 index 0000000..f0e5e30 --- /dev/null +++ b/tests/ui/pass/adt_poly_ref.rs @@ -0,0 +1,14 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +enum X<'a, T> { + A(&'a T), +} + +fn main() { + let i = 42; + let x = X::A(&i); + match x { + X::A(i) => assert!(*i == 42), + } +} From 233298cd7c39fed6d57aef9b6160a3de367dd64b Mon Sep 17 00:00:00 2001 From: coord_e Date: Tue, 23 Sep 2025 17:17:51 +0900 Subject: [PATCH 2/5] Rename TypeParams to TypeArgs --- src/analyze/basic_block.rs | 8 ++++---- src/refine/env.rs | 2 +- src/rty.rs | 6 +++--- src/rty/params.rs | 14 +++++++------- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index cbb6589..e189695 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -219,7 +219,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .into_iter() .map(|ty| rty::RefinedType::unrefined(ty.vacuous())); - let params: IndexVec<_, _> = args + let rty_args: IndexVec<_, _> = args .types() .map(|ty| { self.ctx @@ -230,15 +230,15 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { for (field_pty, mut variant_rty) in field_tys.clone().into_iter().zip(variant_rtys) { - variant_rty.instantiate_ty_params(params.clone()); + variant_rty.instantiate_ty_params(rty_args.clone()); let cs = self .env .relate_sub_refined_type(&field_pty.into(), &variant_rty.boxed()); self.ctx.extend_clauses(cs); } - let sort_args: Vec<_> = params.iter().map(|rty| rty.ty.to_sort()).collect(); - let ty = rty::EnumType::new(ty_sym.clone(), params).into(); + let sort_args: Vec<_> = rty_args.iter().map(|rty| rty.ty.to_sort()).collect(); + let ty = rty::EnumType::new(ty_sym.clone(), rty_args).into(); let mut builder = PlaceTypeBuilder::default(); let mut field_terms = Vec::new(); diff --git a/src/refine/env.rs b/src/refine/env.rs index e647765..5569485 100644 --- a/src/refine/env.rs +++ b/src/refine/env.rs @@ -930,7 +930,7 @@ impl Env { .field_tys() .map(|ty| rty::RefinedType::unrefined(ty.clone().vacuous()).boxed()); let got_tys = field_tys.iter().map(|ty| ty.clone().into()); - rty::unify_tys_params(expected_tys, got_tys).into_params(def.ty_params, |_| { + rty::unify_tys_params(expected_tys, got_tys).into_args(def.ty_params, |_| { panic!("var_type: should unify all params") }) }; diff --git a/src/rty.rs b/src/rty.rs index 72fede8..c706897 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -55,7 +55,7 @@ mod subtyping; pub use subtyping::{relate_sub_closed_type, ClauseScope, Subtyping}; mod params; -pub use params::{TypeParamIdx, TypeParamSubst, TypeParams}; +pub use params::{TypeArgs, TypeParamIdx, TypeParamSubst}; rustc_index::newtype_index! { /// An index representing function parameter. @@ -487,7 +487,7 @@ where } impl EnumType { - pub fn new(symbol: chc::DatatypeSymbol, args: TypeParams) -> Self { + pub fn new(symbol: chc::DatatypeSymbol, args: TypeArgs) -> Self { EnumType { symbol, args } } @@ -1372,7 +1372,7 @@ impl RefinedType { } } - pub fn instantiate_ty_params(&mut self, params: TypeParams) + pub fn instantiate_ty_params(&mut self, params: TypeArgs) where FV: chc::Var, { diff --git a/src/rty/params.rs b/src/rty/params.rs index fa0daa3..b57ff55 100644 --- a/src/rty/params.rs +++ b/src/rty/params.rs @@ -39,7 +39,7 @@ impl TypeParamIdx { } } -pub type TypeParams = IndexVec>; +pub type TypeArgs = IndexVec>; /// A substitution for type parameters that maps type parameters to refinement types. #[derive(Debug, Clone)] @@ -55,8 +55,8 @@ impl Default for TypeParamSubst { } } -impl From> for TypeParamSubst { - fn from(params: TypeParams) -> Self { +impl From> for TypeParamSubst { + fn from(params: TypeArgs) -> Self { let subst = params.into_iter_enumerated().collect(); Self { subst } } @@ -98,20 +98,20 @@ impl TypeParamSubst { } } - pub fn into_params(mut self, expected_len: usize, mut default: F) -> TypeParams + pub fn into_args(mut self, expected_len: usize, mut default: F) -> TypeArgs where T: chc::Var, F: FnMut(TypeParamIdx) -> RefinedType, { - let mut params = TypeParams::new(); + let mut args = TypeArgs::new(); for idx in 0..expected_len { let ty = self .subst .remove(&idx.into()) .unwrap_or_else(|| default(idx.into())); - params.push(ty); + args.push(ty); } - params + args } pub fn strip_refinement(self) -> TypeParamSubst { From 3f431a8bdd3b86c4e06b4a32efbed21f3d932773 Mon Sep 17 00:00:00 2001 From: coord_e Date: Fri, 24 Oct 2025 16:05:24 +0900 Subject: [PATCH 3/5] Refactor construction of type templates --- src/analyze.rs | 12 +- src/analyze/basic_block.rs | 38 +-- src/analyze/crate_.rs | 13 +- src/analyze/local_def.rs | 6 +- src/refine.rs | 2 +- src/refine/env.rs | 3 +- src/refine/template.rs | 509 +++++++++++++++++-------------------- 7 files changed, 276 insertions(+), 307 deletions(-) diff --git a/src/analyze.rs b/src/analyze.rs index 898df5b..3550294 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -123,22 +123,12 @@ pub struct Analyzer<'tcx> { enum_defs: Rc>>, } -impl<'tcx> crate::refine::TemplateTypeGenerator<'tcx> for Analyzer<'tcx> { - fn tcx(&self) -> TyCtxt<'tcx> { - self.tcx - } - +impl<'tcx> crate::refine::TemplateRegistry for Analyzer<'tcx> { fn register_template(&mut self, tmpl: rty::Template) -> rty::RefinedType { tmpl.into_refined_type(|pred_sig| self.generate_pred_var(pred_sig)) } } -impl<'tcx> crate::refine::UnrefinedTypeGenerator<'tcx> for Analyzer<'tcx> { - fn tcx(&self) -> TyCtxt<'tcx> { - self.tcx - } -} - impl<'tcx> Analyzer<'tcx> { pub fn generate_pred_var(&mut self, sig: chc::PredSig) -> chc::PredVarId { self.system diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index e189695..2258f33 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -14,7 +14,7 @@ use crate::chc; use crate::pretty::PrettyDisplayExt as _; use crate::refine::{ self, Assumption, BasicBlockType, Env, PlaceType, PlaceTypeBuilder, PlaceTypeVar, TempVarIdx, - TemplateTypeGenerator, UnrefinedTypeGenerator, Var, + TypeBuilder, Var, }; use crate::rty::{ self, ClauseBuilderExt as _, ClauseScope as _, ShiftExistential as _, Subtyping as _, @@ -222,9 +222,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let rty_args: IndexVec<_, _> = args .types() .map(|ty| { - self.ctx - .build_template_ty_with_scope(&self.env) - .refined_ty(ty) + TypeBuilder::new(self.tcx) + .for_template(&mut self.ctx) + .with_scope(&self.env) + .build_refined(ty) }) .collect(); for (field_pty, mut variant_rty) in @@ -237,7 +238,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.ctx.extend_clauses(cs); } - let sort_args: Vec<_> = rty_args.iter().map(|rty| rty.ty.to_sort()).collect(); + let sort_args: Vec<_> = + rty_args.iter().map(|rty| rty.ty.to_sort()).collect(); let ty = rty::EnumType::new(ty_sym.clone(), rty_args).into(); let mut builder = PlaceTypeBuilder::default(); @@ -433,7 +435,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let func_ty = match func.const_fn_def() { // TODO: move this to well-known defs? Some((def_id, args)) if self.is_box_new(def_id) => { - let inner_ty = self.ctx.build_template_ty().ty(args.type_at(0)).vacuous(); + let inner_ty = TypeBuilder::new(self.tcx) + .for_template(&mut self.ctx) + .build(args.type_at(0)) + .vacuous(); let param = rty::RefinedType::unrefined(inner_ty.clone()); let ret_term = chc::Term::box_(chc::Term::var(rty::FunctionParamIdx::from(0_usize))); @@ -444,7 +449,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { rty::FunctionType::new([param].into_iter().collect(), ret).into() } Some((def_id, args)) if self.is_mem_swap(def_id) => { - let inner_ty = self.ctx.unrefined_ty(args.type_at(0)).vacuous(); + let inner_ty = TypeBuilder::new(self.tcx).build(args.type_at(0)).vacuous(); let param1 = rty::RefinedType::unrefined(rty::PointerType::mut_to(inner_ty.clone()).into()); let param2 = @@ -531,7 +536,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } fn add_prophecy_var(&mut self, statement_index: usize, ty: mir_ty::Ty<'tcx>) { - let ty = self.ctx.unrefined_ty(ty); + let ty = TypeBuilder::new(self.tcx).build(ty); let temp_var = self.env.push_temp_var(ty.vacuous()); self.prophecy_vars.insert(statement_index, temp_var); tracing::debug!(stmt_idx = %statement_index, temp_var = ?temp_var, "add_prophecy_var"); @@ -552,7 +557,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { referent: mir::Place<'tcx>, prophecy_ty: mir_ty::Ty<'tcx>, ) -> rty::RefinedType { - let prophecy_ty = self.ctx.unrefined_ty(prophecy_ty); + let prophecy_ty = TypeBuilder::new(self.tcx).build(prophecy_ty); let prophecy = self.env.push_temp_var(prophecy_ty.vacuous()); let place = self.elaborate_place_for_borrow(&referent); self.env.borrow_place(place, prophecy).into() @@ -664,10 +669,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } let decl = self.local_decls[destination].clone(); - let rty = self - .ctx - .build_template_ty_with_scope(&self.env) - .refined_ty(decl.ty); + let rty = TypeBuilder::new(self.tcx) + .for_template(&mut self.ctx) + .with_scope(&self.env) + .build_refined(decl.ty); self.type_call(func.clone(), args.clone().into_iter().map(|a| a.node), &rty); self.bind_local(destination, rty); } @@ -738,9 +743,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { #[tracing::instrument(skip(self))] fn ret_template(&mut self) -> rty::RefinedType { let ret_ty = self.body.local_decls[mir::RETURN_PLACE].ty; - self.ctx - .build_template_ty_with_scope(&self.env) - .refined_ty(ret_ty) + TypeBuilder::new(self.tcx) + .for_template(&mut self.ctx) + .with_scope(&self.env) + .build_refined(ret_ty) } // TODO: remove this diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index fa7143b..727b496 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -11,7 +11,7 @@ use rustc_span::symbol::Ident; use crate::analyze; use crate::annot::{self, AnnotFormula, AnnotParser, ResolverExt as _}; use crate::chc; -use crate::refine::{self, TemplateTypeGenerator, UnrefinedTypeGenerator}; +use crate::refine::{self, TypeBuilder}; use crate::rty::{self, ClauseBuilderExt as _}; /// An implementation of local crate analysis. @@ -132,13 +132,13 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let mut param_resolver = analyze::annot::ParamResolver::default(); for (input_ident, input_ty) in self.tcx.fn_arg_names(def_id).iter().zip(sig.inputs()) { - let input_ty = self.ctx.unrefined_ty(*input_ty); + let input_ty = TypeBuilder::new(self.tcx).build(*input_ty); param_resolver.push_param(input_ident.name, input_ty.to_sort()); } let mut require_annot = self.extract_require_annot(¶m_resolver, def_id); let mut ensure_annot = { - let output_ty = self.ctx.unrefined_ty(sig.output()); + let output_ty = TypeBuilder::new(self.tcx).build(sig.output()); let resolver = annot::StackedResolver::default() .resolver(analyze::annot::ResultResolver::new(output_ty.to_sort())) .resolver((¶m_resolver).map(rty::RefinedTypeVar::Free)); @@ -175,7 +175,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.trusted.insert(def_id); } - let mut builder = self.ctx.build_function_template_ty(sig); + let mut builder = TypeBuilder::new(self.tcx).for_function_template(&mut self.ctx, sig); if let Some(AnnotFormula::Formula(require)) = require_annot { let formula = require.map_var(|idx| { if idx.index() == sig.inputs().len() - 1 { @@ -303,8 +303,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { // the subst doesn't contain refinements, so it's OK to take ty only // after substitution - let mut field_rty = - rty::RefinedType::unrefined(self.ctx.unrefined_ty(field_ty)); + let mut field_rty = rty::RefinedType::unrefined( + TypeBuilder::new(self.tcx).build(field_ty), + ); field_rty.subst_ty_params(&subst); field_rty.ty }) diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index ef5870e..c1ab72c 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -11,7 +11,7 @@ use rustc_span::def_id::LocalDefId; use crate::analyze; use crate::chc; use crate::pretty::PrettyDisplayExt as _; -use crate::refine::{BasicBlockType, TemplateTypeGenerator}; +use crate::refine::{BasicBlockType, TypeBuilder}; use crate::rty; /// An implementation of the typing of local definitions. @@ -306,7 +306,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } // function return type is basic block return type let ret_ty = self.body.local_decls[mir::RETURN_PLACE].ty; - let rty = self.ctx.basic_block_template_ty(live_locals, ret_ty); + let rty = TypeBuilder::new(self.tcx) + .for_template(&mut self.ctx) + .build_basic_block(live_locals, ret_ty); self.ctx.register_basic_block_ty(self.local_def_id, bb, rty); } } diff --git a/src/refine.rs b/src/refine.rs index 0371da1..4736b39 100644 --- a/src/refine.rs +++ b/src/refine.rs @@ -8,7 +8,7 @@ //! module and remove this one. mod template; -pub use template::{TemplateScope, TemplateTypeGenerator, UnrefinedTypeGenerator}; +pub use template::{TemplateRegistry, TemplateScope, TypeBuilder}; mod basic_block; pub use basic_block::BasicBlockType; diff --git a/src/refine/env.rs b/src/refine/env.rs index 5569485..a1edbc1 100644 --- a/src/refine/env.rs +++ b/src/refine/env.rs @@ -558,7 +558,8 @@ impl rty::ClauseScope for Env { } } -impl refine::TemplateScope for Env { +impl refine::TemplateScope for Env { + type Var = Var; fn build_template(&self) -> rty::TemplateBuilder { let mut builder = rty::TemplateBuilder::default(); for (v, sort) in self.dependencies() { diff --git a/src/refine/template.rs b/src/refine/template.rs index b6ae7b8..a2380e0 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -9,54 +9,146 @@ use crate::chc; use crate::refine; use crate::rty; -pub trait TemplateScope { - fn build_template(&self) -> rty::TemplateBuilder; +pub trait TemplateRegistry { + fn register_template(&mut self, tmpl: rty::Template) -> rty::RefinedType; } -impl TemplateScope for &U +impl TemplateRegistry for &mut T where - U: TemplateScope, + T: TemplateRegistry + ?Sized, { - fn build_template(&self) -> rty::TemplateBuilder { - U::build_template(self) + fn register_template(&mut self, tmpl: rty::Template) -> rty::RefinedType { + T::register_template(self, tmpl) + } +} + +#[derive(Clone, Default)] +pub struct EmptyTemplateScope; + +impl TemplateScope for EmptyTemplateScope { + type Var = rty::Closed; + fn build_template(&self) -> rty::TemplateBuilder { + rty::TemplateBuilder::default() } } -impl TemplateScope for rty::TemplateBuilder +pub trait TemplateScope { + type Var: chc::Var; + fn build_template(&self) -> rty::TemplateBuilder; +} + +impl TemplateScope for &T +where + T: TemplateScope, +{ + type Var = T::Var; + fn build_template(&self) -> rty::TemplateBuilder { + T::build_template(self) + } +} + +impl TemplateScope for rty::TemplateBuilder where T: chc::Var, { + type Var = T; fn build_template(&self) -> rty::TemplateBuilder { self.clone() } } -pub trait TemplateTypeGenerator<'tcx> { - fn tcx(&self) -> mir_ty::TyCtxt<'tcx>; - fn register_template(&mut self, tmpl: rty::Template) -> rty::RefinedType; +#[derive(Clone)] +pub struct TypeBuilder<'tcx> { + tcx: mir_ty::TyCtxt<'tcx>, +} - fn build_template_ty_with_scope(&mut self, scope: T) -> TemplateTypeBuilder { - TemplateTypeBuilder { - gen: self, - scope, - _marker: std::marker::PhantomData, +impl<'tcx> TypeBuilder<'tcx> { + pub fn new(tcx: mir_ty::TyCtxt<'tcx>) -> Self { + Self { tcx } + } + + // TODO: consolidate two impls + pub fn build(&self, ty: mir_ty::Ty<'tcx>) -> rty::Type { + match ty.kind() { + mir_ty::TyKind::Bool => rty::Type::bool(), + mir_ty::TyKind::Uint(_) | mir_ty::TyKind::Int(_) => rty::Type::int(), + mir_ty::TyKind::Str => rty::Type::string(), + mir_ty::TyKind::Ref(_, elem_ty, mutbl) => { + let elem_ty = self.build(*elem_ty); + match mutbl { + mir_ty::Mutability::Mut => rty::PointerType::mut_to(elem_ty).into(), + mir_ty::Mutability::Not => rty::PointerType::immut_to(elem_ty).into(), + } + } + mir_ty::TyKind::Tuple(ts) => { + // elaboration: all fields are boxed + let elems = ts + .iter() + .map(|ty| rty::PointerType::own(self.build(ty)).into()) + .collect(); + rty::TupleType::new(elems).into() + } + mir_ty::TyKind::Never => rty::Type::never(), + mir_ty::TyKind::Param(ty) => rty::ParamType::new(ty.index.into()).into(), + mir_ty::TyKind::FnPtr(sig) => { + // TODO: justification for skip_binder + let sig = sig.skip_binder(); + let params = sig + .inputs() + .iter() + .map(|ty| rty::RefinedType::unrefined(self.build(*ty)).vacuous()) + .collect(); + let ret = rty::RefinedType::unrefined(self.build(sig.output())); + rty::FunctionType::new(params, ret.vacuous()).into() + } + mir_ty::TyKind::Adt(def, params) if def.is_box() => { + rty::PointerType::own(self.build(params.type_at(0))).into() + } + mir_ty::TyKind::Adt(def, params) => { + if def.is_enum() { + let sym = refine::datatype_symbol(self.tcx, def.did()); + let args: IndexVec<_, _> = params + .types() + .map(|ty| rty::RefinedType::unrefined(self.build(ty))) + .collect(); + rty::EnumType::new(sym, args).into() + } else if def.is_struct() { + let elem_tys = def + .all_fields() + .map(|field| { + let ty = field.ty(self.tcx, params); + // elaboration: all fields are boxed + rty::PointerType::own(self.build(ty)).into() + }) + .collect(); + rty::TupleType::new(elem_tys).into() + } else { + unimplemented!("unsupported ADT: {:?}", ty); + } + } + kind => unimplemented!("unrefined_ty: {:?}", kind), } } - fn build_template_ty(&mut self) -> TemplateTypeBuilder, V> { + pub fn for_template<'a, R>( + &self, + registry: &'a mut R, + ) -> TemplateTypeBuilder<'tcx, 'a, R, EmptyTemplateScope> { TemplateTypeBuilder { - gen: self, + tcx: self.tcx, + registry, scope: Default::default(), - _marker: std::marker::PhantomData, } } - fn build_function_template_ty( - &mut self, + pub fn for_function_template<'a, R>( + &self, + registry: &'a mut R, sig: mir_ty::FnSig<'tcx>, - ) -> FunctionTemplateTypeBuilder<'_, 'tcx, Self> { + ) -> FunctionTemplateTypeBuilder<'tcx, 'a, R> { FunctionTemplateTypeBuilder { - gen: self, + tcx: self.tcx, + registry, param_tys: sig .inputs() .iter() @@ -71,12 +163,101 @@ pub trait TemplateTypeGenerator<'tcx> { ret_rty: None, } } +} + +pub struct TemplateTypeBuilder<'tcx, 'a, R, S> { + tcx: mir_ty::TyCtxt<'tcx>, + registry: &'a mut R, + scope: S, +} - fn build_basic_block_template_ty( +impl<'tcx, 'a, R, S> TemplateTypeBuilder<'tcx, 'a, R, S> { + pub fn with_scope(self, scope: T) -> TemplateTypeBuilder<'tcx, 'a, R, T> { + TemplateTypeBuilder { + tcx: self.tcx, + registry: self.registry, + scope, + } + } +} + +impl<'tcx, 'a, R, S> TemplateTypeBuilder<'tcx, 'a, R, S> +where + R: TemplateRegistry, + S: TemplateScope, +{ + pub fn build(&mut self, ty: mir_ty::Ty<'tcx>) -> rty::Type { + match ty.kind() { + mir_ty::TyKind::Bool => rty::Type::bool(), + mir_ty::TyKind::Uint(_) | mir_ty::TyKind::Int(_) => rty::Type::int(), + mir_ty::TyKind::Str => rty::Type::string(), + mir_ty::TyKind::Ref(_, elem_ty, mutbl) => { + let elem_ty = self.build(*elem_ty); + match mutbl { + mir_ty::Mutability::Mut => rty::PointerType::mut_to(elem_ty).into(), + mir_ty::Mutability::Not => rty::PointerType::immut_to(elem_ty).into(), + } + } + mir_ty::TyKind::Tuple(ts) => { + // elaboration: all fields are boxed + let elems = ts + .iter() + .map(|ty| rty::PointerType::own(self.build(ty)).into()) + .collect(); + rty::TupleType::new(elems).into() + } + mir_ty::TyKind::Never => rty::Type::never(), + mir_ty::TyKind::Param(ty) => rty::ParamType::new(ty.index.into()).into(), + mir_ty::TyKind::FnPtr(sig) => { + // TODO: justification for skip_binder + let sig = sig.skip_binder(); + let ty = TypeBuilder::new(self.tcx) + .for_function_template(self.registry, sig) + .build(); + rty::Type::function(ty) + } + mir_ty::TyKind::Adt(def, params) if def.is_box() => { + rty::PointerType::own(self.build(params.type_at(0))).into() + } + mir_ty::TyKind::Adt(def, params) => { + if def.is_enum() { + let sym = refine::datatype_symbol(self.tcx, def.did()); + let args: IndexVec<_, _> = + params.types().map(|ty| self.build_refined(ty)).collect(); + rty::EnumType::new(sym, args).into() + } else if def.is_struct() { + let elem_tys = def + .all_fields() + .map(|field| { + let ty = field.ty(self.tcx, params); + // elaboration: all fields are boxed + rty::PointerType::own(self.build(ty)).into() + }) + .collect(); + rty::TupleType::new(elem_tys).into() + } else { + unimplemented!("unsupported ADT: {:?}", ty); + } + } + kind => unimplemented!("ty: {:?}", kind), + } + } + + pub fn build_refined(&mut self, ty: mir_ty::Ty<'tcx>) -> rty::RefinedType { + // TODO: consider building ty with scope + let ty = TypeBuilder::new(self.tcx) + .for_template(self.registry) + .build(ty) + .vacuous(); + let tmpl = self.scope.build_template().build(ty); + self.registry.register_template(tmpl) + } + + pub fn build_basic_block( &mut self, live_locals: I, ret_ty: mir_ty::Ty<'tcx>, - ) -> BasicBlockTemplateTypeBuilder<'_, 'tcx, Self> + ) -> BasicBlockType where I: IntoIterator)>, { @@ -87,51 +268,23 @@ pub trait TemplateTypeGenerator<'tcx> { locals.push((local, ty.mutbl)); tys.push(ty); } - let inner = FunctionTemplateTypeBuilder { - gen: self, + let ty = FunctionTemplateTypeBuilder { + tcx: self.tcx, + registry: self.registry, param_tys: tys, ret_ty, param_rtys: Default::default(), param_refinement: None, ret_rty: None, - }; - BasicBlockTemplateTypeBuilder { inner, locals } - } - - fn basic_block_template_ty( - &mut self, - live_locals: I, - ret_ty: mir_ty::Ty<'tcx>, - ) -> BasicBlockType - where - I: IntoIterator)>, - { - self.build_basic_block_template_ty(live_locals, ret_ty) - .build() - } - - fn function_template_ty(&mut self, sig: mir_ty::FnSig<'tcx>) -> rty::FunctionType { - self.build_function_template_ty(sig).build() - } -} - -impl<'tcx, T> TemplateTypeGenerator<'tcx> for &mut T -where - T: TemplateTypeGenerator<'tcx> + ?Sized, -{ - fn tcx(&self) -> mir_ty::TyCtxt<'tcx> { - T::tcx(self) - } - - fn register_template(&mut self, tmpl: rty::Template) -> rty::RefinedType { - T::register_template(self, tmpl) + } + .build(); + BasicBlockType { ty, locals } } } -#[derive(Debug)] -pub struct FunctionTemplateTypeBuilder<'a, 'tcx, T: ?Sized> { - // can't use T: TemplateTypeGenerator<'tcx> directly because of recursive instantiation - gen: &'a mut T, +pub struct FunctionTemplateTypeBuilder<'tcx, 'a, R> { + tcx: mir_ty::TyCtxt<'tcx>, + registry: &'a mut R, param_tys: Vec>, ret_ty: mir_ty::Ty<'tcx>, param_refinement: Option>, @@ -139,10 +292,7 @@ pub struct FunctionTemplateTypeBuilder<'a, 'tcx, T: ?Sized> { ret_rty: Option>, } -impl<'a, 'tcx, T> FunctionTemplateTypeBuilder<'a, 'tcx, T> -where - T: TemplateTypeGenerator<'tcx> + ?Sized, -{ +impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { pub fn param_refinement( &mut self, refinement: rty::Refinement, @@ -174,7 +324,7 @@ where &mut self, refinement: rty::Refinement, ) -> &mut Self { - let ty = UnrefinedTypeGeneratorWrapper(&mut self.gen).unrefined_ty(self.ret_ty); + let ty = TypeBuilder::new(self.tcx).build(self.ret_ty); self.ret_rty = Some(rty::RefinedType::new(ty.vacuous(), refinement)); self } @@ -183,7 +333,12 @@ where self.ret_rty = Some(rty); self } +} +impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> +where + R: TemplateRegistry, +{ pub fn build(&mut self) -> rty::FunctionType { let mut builder = rty::TemplateBuilder::default(); let mut param_rtys = IndexVec::::new(); @@ -195,19 +350,20 @@ where .unwrap_or_else(|| { if idx == self.param_tys.len() - 1 { if let Some(param_refinement) = &self.param_refinement { - let ty = UnrefinedTypeGeneratorWrapper(&mut self.gen) - .unrefined_ty(param_ty.ty); + let ty = TypeBuilder::new(self.tcx).build(param_ty.ty); rty::RefinedType::new(ty.vacuous(), param_refinement.clone()) } else { - self.gen - .build_template_ty_with_scope(&builder) - .refined_ty(param_ty.ty) + TypeBuilder::new(self.tcx) + .for_template(self.registry) + .with_scope(&builder) + .build_refined(param_ty.ty) } } else { rty::RefinedType::unrefined( - self.gen - .build_template_ty_with_scope(&builder) - .ty(param_ty.ty), + TypeBuilder::new(self.tcx) + .for_template(self.registry) + .with_scope(&builder) + .build(param_ty.ty), ) } }); @@ -227,208 +383,21 @@ where let param_rty = if let Some(param_refinement) = &self.param_refinement { rty::RefinedType::new(rty::Type::unit(), param_refinement.clone()) } else { - let unit_ty = mir_ty::Ty::new_unit(self.gen.tcx()); - self.gen - .build_template_ty_with_scope(&builder) - .refined_ty(unit_ty) + let unit_ty = mir_ty::Ty::new_unit(self.tcx); + TypeBuilder::new(self.tcx) + .for_template(self.registry) + .with_scope(&builder) + .build_refined(unit_ty) }; param_rtys.push(param_rty); } let ret_rty = self.ret_rty.clone().unwrap_or_else(|| { - self.gen - .build_template_ty_with_scope(&builder) - .refined_ty(self.ret_ty) + TypeBuilder::new(self.tcx) + .for_template(self.registry) + .with_scope(&builder) + .build_refined(self.ret_ty) }); rty::FunctionType::new(param_rtys, ret_rty) } } - -#[derive(Debug)] -pub struct BasicBlockTemplateTypeBuilder<'a, 'tcx, T: ?Sized> { - inner: FunctionTemplateTypeBuilder<'a, 'tcx, T>, - locals: IndexVec, -} - -impl<'a, 'tcx, T> BasicBlockTemplateTypeBuilder<'a, 'tcx, T> -where - T: TemplateTypeGenerator<'tcx> + ?Sized, -{ - #[allow(dead_code)] - pub fn param_refinement( - &mut self, - refinement: rty::Refinement, - ) -> &mut Self { - self.inner.param_refinement(refinement); - self - } - - #[allow(dead_code)] - pub fn ret_rty(&mut self, rty: rty::RefinedType) -> &mut Self { - self.inner.ret_rty(rty); - self - } - - pub fn build(&mut self) -> BasicBlockType { - let ty = self.inner.build(); - BasicBlockType { - ty, - locals: self.locals.clone(), - } - } -} - -#[derive(Debug)] -pub struct TemplateTypeBuilder<'a, T: ?Sized, U, V> { - // can't use T: TemplateTypeGenerator<'tcx> directly because of recursive instantiation - gen: &'a mut T, - scope: U, - _marker: std::marker::PhantomData V>, -} - -impl<'a, 'tcx, T, U, V> TemplateTypeBuilder<'a, T, U, V> -where - T: TemplateTypeGenerator<'tcx> + ?Sized, - U: TemplateScope, - V: chc::Var, -{ - pub fn ty(&mut self, ty: mir_ty::Ty<'tcx>) -> rty::Type { - match ty.kind() { - mir_ty::TyKind::Bool => rty::Type::bool(), - mir_ty::TyKind::Uint(_) | mir_ty::TyKind::Int(_) => rty::Type::int(), - mir_ty::TyKind::Str => rty::Type::string(), - mir_ty::TyKind::Ref(_, elem_ty, mutbl) => { - let elem_ty = self.ty(*elem_ty); - match mutbl { - mir_ty::Mutability::Mut => rty::PointerType::mut_to(elem_ty).into(), - mir_ty::Mutability::Not => rty::PointerType::immut_to(elem_ty).into(), - } - } - mir_ty::TyKind::Tuple(ts) => { - // elaboration: all fields are boxed - let elems = ts - .iter() - .map(|ty| rty::PointerType::own(self.ty(ty)).into()) - .collect(); - rty::TupleType::new(elems).into() - } - mir_ty::TyKind::Never => rty::Type::never(), - mir_ty::TyKind::Param(ty) => rty::ParamType::new(ty.index.into()).into(), - mir_ty::TyKind::FnPtr(sig) => { - // TODO: justification for skip_binder - let sig = sig.skip_binder(); - let ty = self.gen.function_template_ty(sig); - rty::Type::function(ty) - } - mir_ty::TyKind::Adt(def, params) if def.is_box() => { - rty::PointerType::own(self.ty(params.type_at(0))).into() - } - mir_ty::TyKind::Adt(def, params) => { - if def.is_enum() { - let sym = refine::datatype_symbol(self.gen.tcx(), def.did()); - let args: IndexVec<_, _> = - params.types().map(|ty| self.refined_ty(ty)).collect(); - rty::EnumType::new(sym, args).into() - } else if def.is_struct() { - let elem_tys = def - .all_fields() - .map(|field| { - let ty = field.ty(self.gen.tcx(), params); - // elaboration: all fields are boxed - rty::PointerType::own(self.ty(ty)).into() - }) - .collect(); - rty::TupleType::new(elem_tys).into() - } else { - unimplemented!("unsupported ADT: {:?}", ty); - } - } - kind => unimplemented!("ty: {:?}", kind), - } - } - - pub fn refined_ty(&mut self, ty: mir_ty::Ty<'tcx>) -> rty::RefinedType { - // TODO: consider building ty with scope - let ty = self.gen.build_template_ty().ty(ty); - let tmpl = self.scope.build_template().build(ty); - self.gen.register_template(tmpl) - } -} - -pub trait UnrefinedTypeGenerator<'tcx> { - fn tcx(&self) -> mir_ty::TyCtxt<'tcx>; - - // TODO: consolidate two defs - fn unrefined_ty(&mut self, ty: mir_ty::Ty<'tcx>) -> rty::Type { - match ty.kind() { - mir_ty::TyKind::Bool => rty::Type::bool(), - mir_ty::TyKind::Uint(_) | mir_ty::TyKind::Int(_) => rty::Type::int(), - mir_ty::TyKind::Str => rty::Type::string(), - mir_ty::TyKind::Ref(_, elem_ty, mutbl) => { - let elem_ty = self.unrefined_ty(*elem_ty); - match mutbl { - mir_ty::Mutability::Mut => rty::PointerType::mut_to(elem_ty).into(), - mir_ty::Mutability::Not => rty::PointerType::immut_to(elem_ty).into(), - } - } - mir_ty::TyKind::Tuple(ts) => { - // elaboration: all fields are boxed - let elems = ts - .iter() - .map(|ty| rty::PointerType::own(self.unrefined_ty(ty)).into()) - .collect(); - rty::TupleType::new(elems).into() - } - mir_ty::TyKind::Never => rty::Type::never(), - mir_ty::TyKind::Param(ty) => rty::ParamType::new(ty.index.into()).into(), - mir_ty::TyKind::FnPtr(sig) => { - // TODO: justification for skip_binder - let sig = sig.skip_binder(); - let params = sig - .inputs() - .iter() - .map(|ty| rty::RefinedType::unrefined(self.unrefined_ty(*ty)).vacuous()) - .collect(); - let ret = rty::RefinedType::unrefined(self.unrefined_ty(sig.output())); - rty::FunctionType::new(params, ret.vacuous()).into() - } - mir_ty::TyKind::Adt(def, params) if def.is_box() => { - rty::PointerType::own(self.unrefined_ty(params.type_at(0))).into() - } - mir_ty::TyKind::Adt(def, params) => { - if def.is_enum() { - let sym = refine::datatype_symbol(self.tcx(), def.did()); - let args: IndexVec<_, _> = params - .types() - .map(|ty| rty::RefinedType::unrefined(self.unrefined_ty(ty))) - .collect(); - rty::EnumType::new(sym, args).into() - } else if def.is_struct() { - let elem_tys = def - .all_fields() - .map(|field| { - let ty = field.ty(self.tcx(), params); - // elaboration: all fields are boxed - rty::PointerType::own(self.unrefined_ty(ty)).into() - }) - .collect(); - rty::TupleType::new(elem_tys).into() - } else { - unimplemented!("unsupported ADT: {:?}", ty); - } - } - kind => unimplemented!("unrefined_ty: {:?}", kind), - } - } -} - -struct UnrefinedTypeGeneratorWrapper(T); - -impl<'tcx, T> UnrefinedTypeGenerator<'tcx> for UnrefinedTypeGeneratorWrapper -where - T: TemplateTypeGenerator<'tcx>, -{ - fn tcx(&self) -> mir_ty::TyCtxt<'tcx> { - self.0.tcx() - } -} From fc1b85c2108ae699d06e31cf5dbb99ed903ed7eb Mon Sep 17 00:00:00 2001 From: coord_e Date: Fri, 24 Oct 2025 17:21:14 +0900 Subject: [PATCH 4/5] Handle parameter shifting in TypeBuilder --- src/analyze/basic_block.rs | 20 ++++++---- src/analyze/crate_.rs | 61 +++++++---------------------- src/analyze/local_def.rs | 2 +- src/refine/template.rs | 80 +++++++++++++++++++++++++------------- 4 files changed, 82 insertions(+), 81 deletions(-) diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index 2258f33..7b346ce 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -56,6 +56,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.ctx.basic_block_ty(self.local_def_id, bb) } + fn type_builder(&self) -> TypeBuilder<'tcx> { + TypeBuilder::new(self.tcx, self.local_def_id.to_def_id()) + } + fn bind_local(&mut self, local: Local, rty: rty::RefinedType) { let rty = if self.is_mut_local(local) { // elaboration: @@ -222,7 +226,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let rty_args: IndexVec<_, _> = args .types() .map(|ty| { - TypeBuilder::new(self.tcx) + self.type_builder() .for_template(&mut self.ctx) .with_scope(&self.env) .build_refined(ty) @@ -435,7 +439,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let func_ty = match func.const_fn_def() { // TODO: move this to well-known defs? Some((def_id, args)) if self.is_box_new(def_id) => { - let inner_ty = TypeBuilder::new(self.tcx) + let inner_ty = self + .type_builder() .for_template(&mut self.ctx) .build(args.type_at(0)) .vacuous(); @@ -449,7 +454,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { rty::FunctionType::new([param].into_iter().collect(), ret).into() } Some((def_id, args)) if self.is_mem_swap(def_id) => { - let inner_ty = TypeBuilder::new(self.tcx).build(args.type_at(0)).vacuous(); + let inner_ty = self.type_builder().build(args.type_at(0)).vacuous(); let param1 = rty::RefinedType::unrefined(rty::PointerType::mut_to(inner_ty.clone()).into()); let param2 = @@ -536,7 +541,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } fn add_prophecy_var(&mut self, statement_index: usize, ty: mir_ty::Ty<'tcx>) { - let ty = TypeBuilder::new(self.tcx).build(ty); + let ty = self.type_builder().build(ty); let temp_var = self.env.push_temp_var(ty.vacuous()); self.prophecy_vars.insert(statement_index, temp_var); tracing::debug!(stmt_idx = %statement_index, temp_var = ?temp_var, "add_prophecy_var"); @@ -557,7 +562,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { referent: mir::Place<'tcx>, prophecy_ty: mir_ty::Ty<'tcx>, ) -> rty::RefinedType { - let prophecy_ty = TypeBuilder::new(self.tcx).build(prophecy_ty); + let prophecy_ty = self.type_builder().build(prophecy_ty); let prophecy = self.env.push_temp_var(prophecy_ty.vacuous()); let place = self.elaborate_place_for_borrow(&referent); self.env.borrow_place(place, prophecy).into() @@ -669,7 +674,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } let decl = self.local_decls[destination].clone(); - let rty = TypeBuilder::new(self.tcx) + let rty = self + .type_builder() .for_template(&mut self.ctx) .with_scope(&self.env) .build_refined(decl.ty); @@ -743,7 +749,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { #[tracing::instrument(skip(self))] fn ret_template(&mut self) -> rty::RefinedType { let ret_ty = self.body.local_decls[mir::RETURN_PLACE].ty; - TypeBuilder::new(self.tcx) + self.type_builder() .for_template(&mut self.ctx) .with_scope(&self.env) .build_refined(ret_ty) diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index 727b496..54a02d1 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -132,13 +132,13 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let mut param_resolver = analyze::annot::ParamResolver::default(); for (input_ident, input_ty) in self.tcx.fn_arg_names(def_id).iter().zip(sig.inputs()) { - let input_ty = TypeBuilder::new(self.tcx).build(*input_ty); + let input_ty = TypeBuilder::new(self.tcx, def_id).build(*input_ty); param_resolver.push_param(input_ident.name, input_ty.to_sort()); } let mut require_annot = self.extract_require_annot(¶m_resolver, def_id); let mut ensure_annot = { - let output_ty = TypeBuilder::new(self.tcx).build(sig.output()); + let output_ty = TypeBuilder::new(self.tcx, def_id).build(sig.output()); let resolver = annot::StackedResolver::default() .resolver(analyze::annot::ResultResolver::new(output_ty.to_sort())) .resolver((¶m_resolver).map(rty::RefinedTypeVar::Free)); @@ -175,7 +175,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.trusted.insert(def_id); } - let mut builder = TypeBuilder::new(self.tcx).for_function_template(&mut self.ctx, sig); + let mut builder = + TypeBuilder::new(self.tcx, def_id).for_function_template(&mut self.ctx, sig); if let Some(AnnotFormula::Formula(require)) = require_annot { let formula = require.map_var(|idx| { if idx.index() == sig.inputs().len() - 1 { @@ -252,28 +253,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { }; let adt = self.tcx.adt_def(local_def_id); - // The index of TyKind::ParamTy is based on the every generic parameters in - // the definition, including lifetimes. Given the following definition: - // - // struct X<'a, T> { f: &'a T } - // - // The type of field `f` is &T1 (not T0). However, in Thrust, we ignore lifetime - // parameters and the index of rty::ParamType is based on type parameters only. - // We're building a mapping from the original index to the new index here. - let generics = self.tcx.generics_of(local_def_id); - let mut type_param_mapping: std::collections::HashMap = - Default::default(); - for i in 0..generics.count() { - let generic_param = generics.param_at(i, self.tcx); - match generic_param.kind { - mir_ty::GenericParamDefKind::Lifetime => {} - mir_ty::GenericParamDefKind::Type { .. } => { - type_param_mapping.insert(i, type_param_mapping.len()); - } - mir_ty::GenericParamDefKind::Const { .. } => unimplemented!(), - } - } - let name = refine::datatype_symbol(self.tcx, local_def_id.to_def_id()); let variants: IndexVec<_, _> = adt .variants() @@ -287,27 +266,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .iter() .map(|field| { let field_ty = self.tcx.type_of(field.did).instantiate_identity(); - - // see the comment above about this mapping - let subst = rty::TypeParamSubst::new( - type_param_mapping - .iter() - .map(|(old, new)| { - let old = rty::TypeParamIdx::from(*old); - let new = - rty::ParamType::new(rty::TypeParamIdx::from(*new)); - (old, rty::RefinedType::unrefined(new.into())) - }) - .collect(), - ); - - // the subst doesn't contain refinements, so it's OK to take ty only - // after substitution - let mut field_rty = rty::RefinedType::unrefined( - TypeBuilder::new(self.tcx).build(field_ty), - ); - field_rty.subst_ty_params(&subst); - field_rty.ty + TypeBuilder::new(self.tcx, local_def_id.to_def_id()).build(field_ty) }) .collect(); rty::EnumVariantDef { @@ -318,7 +277,15 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { }) .collect(); - let ty_params = type_param_mapping.len(); + let generics = self.tcx.generics_of(local_def_id); + let ty_params = (0..generics.count()) + .filter(|idx| { + matches!( + generics.param_at(*idx, self.tcx).kind, + mir_ty::GenericParamDefKind::Type { .. } + ) + }) + .count(); tracing::debug!(?local_def_id, ?name, ?ty_params, "ty_params count"); let def = rty::EnumDatatypeDef { diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index c1ab72c..7e7c737 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -306,7 +306,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } // function return type is basic block return type let ret_ty = self.body.local_decls[mir::RETURN_PLACE].ty; - let rty = TypeBuilder::new(self.tcx) + let rty = TypeBuilder::new(self.tcx, self.local_def_id.to_def_id()) .for_template(&mut self.ctx) .build_basic_block(live_locals, ret_ty); self.ctx.register_basic_block_ty(self.local_def_id, bb, rty); diff --git a/src/refine/template.rs b/src/refine/template.rs index a2380e0..1614b38 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use rustc_index::IndexVec; use rustc_middle::mir::{Local, Mutability}; use rustc_middle::ty as mir_ty; +use rustc_span::def_id::DefId; use super::basic_block::BasicBlockType; use crate::chc; @@ -60,11 +61,43 @@ where #[derive(Clone)] pub struct TypeBuilder<'tcx> { tcx: mir_ty::TyCtxt<'tcx>, + type_param_mapping: HashMap, } impl<'tcx> TypeBuilder<'tcx> { - pub fn new(tcx: mir_ty::TyCtxt<'tcx>) -> Self { - Self { tcx } + pub fn new(tcx: mir_ty::TyCtxt<'tcx>, def_id: DefId) -> Self { + // The index of TyKind::ParamTy is based on the every generic parameters in + // the definition, including lifetimes. Given the following definition: + // + // struct X<'a, T> { f: &'a T } + // + // The type of field `f` is &T1 (not T0). However, in Thrust, we ignore lifetime + // parameters and the index of rty::ParamType is based on type parameters only. + // We're building a mapping from the original index to the new index here. + let generics = tcx.generics_of(def_id); + let mut type_param_mapping: HashMap = Default::default(); + for i in 0..generics.count() { + let generic_param = generics.param_at(i, tcx); + match generic_param.kind { + mir_ty::GenericParamDefKind::Lifetime => {} + mir_ty::GenericParamDefKind::Type { .. } => { + type_param_mapping.insert(i as u32, type_param_mapping.len().into()); + } + mir_ty::GenericParamDefKind::Const { .. } => unimplemented!(), + } + } + Self { + tcx, + type_param_mapping, + } + } + + fn translate_param_type(&self, ty: &mir_ty::ParamTy) -> rty::ParamType { + let index = *self + .type_param_mapping + .get(&ty.index) + .expect("unknown type param idx"); + rty::ParamType::new(index) } // TODO: consolidate two impls @@ -89,7 +122,7 @@ impl<'tcx> TypeBuilder<'tcx> { rty::TupleType::new(elems).into() } mir_ty::TyKind::Never => rty::Type::never(), - mir_ty::TyKind::Param(ty) => rty::ParamType::new(ty.index.into()).into(), + mir_ty::TyKind::Param(ty) => self.translate_param_type(ty).into(), mir_ty::TyKind::FnPtr(sig) => { // TODO: justification for skip_binder let sig = sig.skip_binder(); @@ -135,7 +168,7 @@ impl<'tcx> TypeBuilder<'tcx> { registry: &'a mut R, ) -> TemplateTypeBuilder<'tcx, 'a, R, EmptyTemplateScope> { TemplateTypeBuilder { - tcx: self.tcx, + inner: self.clone(), registry, scope: Default::default(), } @@ -147,7 +180,7 @@ impl<'tcx> TypeBuilder<'tcx> { sig: mir_ty::FnSig<'tcx>, ) -> FunctionTemplateTypeBuilder<'tcx, 'a, R> { FunctionTemplateTypeBuilder { - tcx: self.tcx, + inner: self.clone(), registry, param_tys: sig .inputs() @@ -166,7 +199,7 @@ impl<'tcx> TypeBuilder<'tcx> { } pub struct TemplateTypeBuilder<'tcx, 'a, R, S> { - tcx: mir_ty::TyCtxt<'tcx>, + inner: TypeBuilder<'tcx>, registry: &'a mut R, scope: S, } @@ -174,7 +207,7 @@ pub struct TemplateTypeBuilder<'tcx, 'a, R, S> { impl<'tcx, 'a, R, S> TemplateTypeBuilder<'tcx, 'a, R, S> { pub fn with_scope(self, scope: T) -> TemplateTypeBuilder<'tcx, 'a, R, T> { TemplateTypeBuilder { - tcx: self.tcx, + inner: self.inner, registry: self.registry, scope, } @@ -207,13 +240,11 @@ where rty::TupleType::new(elems).into() } mir_ty::TyKind::Never => rty::Type::never(), - mir_ty::TyKind::Param(ty) => rty::ParamType::new(ty.index.into()).into(), + mir_ty::TyKind::Param(ty) => self.inner.translate_param_type(ty).into(), mir_ty::TyKind::FnPtr(sig) => { // TODO: justification for skip_binder let sig = sig.skip_binder(); - let ty = TypeBuilder::new(self.tcx) - .for_function_template(self.registry, sig) - .build(); + let ty = self.inner.for_function_template(self.registry, sig).build(); rty::Type::function(ty) } mir_ty::TyKind::Adt(def, params) if def.is_box() => { @@ -221,7 +252,7 @@ where } mir_ty::TyKind::Adt(def, params) => { if def.is_enum() { - let sym = refine::datatype_symbol(self.tcx, def.did()); + let sym = refine::datatype_symbol(self.inner.tcx, def.did()); let args: IndexVec<_, _> = params.types().map(|ty| self.build_refined(ty)).collect(); rty::EnumType::new(sym, args).into() @@ -229,7 +260,7 @@ where let elem_tys = def .all_fields() .map(|field| { - let ty = field.ty(self.tcx, params); + let ty = field.ty(self.inner.tcx, params); // elaboration: all fields are boxed rty::PointerType::own(self.build(ty)).into() }) @@ -245,10 +276,7 @@ where pub fn build_refined(&mut self, ty: mir_ty::Ty<'tcx>) -> rty::RefinedType { // TODO: consider building ty with scope - let ty = TypeBuilder::new(self.tcx) - .for_template(self.registry) - .build(ty) - .vacuous(); + let ty = self.inner.for_template(self.registry).build(ty).vacuous(); let tmpl = self.scope.build_template().build(ty); self.registry.register_template(tmpl) } @@ -269,7 +297,7 @@ where tys.push(ty); } let ty = FunctionTemplateTypeBuilder { - tcx: self.tcx, + inner: self.inner.clone(), registry: self.registry, param_tys: tys, ret_ty, @@ -283,7 +311,7 @@ where } pub struct FunctionTemplateTypeBuilder<'tcx, 'a, R> { - tcx: mir_ty::TyCtxt<'tcx>, + inner: TypeBuilder<'tcx>, registry: &'a mut R, param_tys: Vec>, ret_ty: mir_ty::Ty<'tcx>, @@ -324,7 +352,7 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { &mut self, refinement: rty::Refinement, ) -> &mut Self { - let ty = TypeBuilder::new(self.tcx).build(self.ret_ty); + let ty = self.inner.build(self.ret_ty); self.ret_rty = Some(rty::RefinedType::new(ty.vacuous(), refinement)); self } @@ -350,17 +378,17 @@ where .unwrap_or_else(|| { if idx == self.param_tys.len() - 1 { if let Some(param_refinement) = &self.param_refinement { - let ty = TypeBuilder::new(self.tcx).build(param_ty.ty); + let ty = self.inner.build(param_ty.ty); rty::RefinedType::new(ty.vacuous(), param_refinement.clone()) } else { - TypeBuilder::new(self.tcx) + self.inner .for_template(self.registry) .with_scope(&builder) .build_refined(param_ty.ty) } } else { rty::RefinedType::unrefined( - TypeBuilder::new(self.tcx) + self.inner .for_template(self.registry) .with_scope(&builder) .build(param_ty.ty), @@ -383,8 +411,8 @@ where let param_rty = if let Some(param_refinement) = &self.param_refinement { rty::RefinedType::new(rty::Type::unit(), param_refinement.clone()) } else { - let unit_ty = mir_ty::Ty::new_unit(self.tcx); - TypeBuilder::new(self.tcx) + let unit_ty = mir_ty::Ty::new_unit(self.inner.tcx); + self.inner .for_template(self.registry) .with_scope(&builder) .build_refined(unit_ty) @@ -393,7 +421,7 @@ where } let ret_rty = self.ret_rty.clone().unwrap_or_else(|| { - TypeBuilder::new(self.tcx) + self.inner .for_template(self.registry) .with_scope(&builder) .build_refined(self.ret_ty) From 18bbbeebf2ef11f74c30808b505cfb68ad9577fb Mon Sep 17 00:00:00 2001 From: coord_e Date: Sat, 25 Oct 2025 11:24:19 +0900 Subject: [PATCH 5/5] Enhance docs --- src/refine/template.rs | 38 +++++++++++++++++++++++++------------- src/rty/params.rs | 14 ++++++++++++++ 2 files changed, 39 insertions(+), 13 deletions(-) diff --git a/src/refine/template.rs b/src/refine/template.rs index 1614b38..bec3f8a 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -23,6 +23,7 @@ where } } +/// [`TemplateScope`] with no variables in scope. #[derive(Clone, Default)] pub struct EmptyTemplateScope; @@ -58,43 +59,45 @@ where } } +/// Translates [`mir_ty::Ty`] to [`rty::Type`]. +/// +/// This struct implements a translation from Rust MIR types to Thrust types. +/// Thrust types may contain refinement predicates which do not exist in MIR types, +/// and [`TypeBuilder`] solely builds types with null refinement (true) in +/// [`TypeBuilder::build`]. This also provides [`TypeBuilder::for_template`] to build +/// refinement types by filling unknown predicates with templates with predicate variables. #[derive(Clone)] pub struct TypeBuilder<'tcx> { tcx: mir_ty::TyCtxt<'tcx>, - type_param_mapping: HashMap, + /// Maps index in [`mir_ty::ParamTy`] to [`rty::TypeParamIdx`]. + /// These indices may differ because we skip lifetime parameters. + /// See [`rty::TypeParamIdx`] for more details. + param_idx_mapping: HashMap, } impl<'tcx> TypeBuilder<'tcx> { pub fn new(tcx: mir_ty::TyCtxt<'tcx>, def_id: DefId) -> Self { - // The index of TyKind::ParamTy is based on the every generic parameters in - // the definition, including lifetimes. Given the following definition: - // - // struct X<'a, T> { f: &'a T } - // - // The type of field `f` is &T1 (not T0). However, in Thrust, we ignore lifetime - // parameters and the index of rty::ParamType is based on type parameters only. - // We're building a mapping from the original index to the new index here. let generics = tcx.generics_of(def_id); - let mut type_param_mapping: HashMap = Default::default(); + let mut param_idx_mapping: HashMap = Default::default(); for i in 0..generics.count() { let generic_param = generics.param_at(i, tcx); match generic_param.kind { mir_ty::GenericParamDefKind::Lifetime => {} mir_ty::GenericParamDefKind::Type { .. } => { - type_param_mapping.insert(i as u32, type_param_mapping.len().into()); + param_idx_mapping.insert(i as u32, param_idx_mapping.len().into()); } mir_ty::GenericParamDefKind::Const { .. } => unimplemented!(), } } Self { tcx, - type_param_mapping, + param_idx_mapping, } } fn translate_param_type(&self, ty: &mir_ty::ParamTy) -> rty::ParamType { let index = *self - .type_param_mapping + .param_idx_mapping .get(&ty.index) .expect("unknown type param idx"); rty::ParamType::new(index) @@ -198,8 +201,16 @@ impl<'tcx> TypeBuilder<'tcx> { } } +/// Translates [`mir_ty::Ty`] to [`rty::Type`] using templates for refinements. +/// +/// [`rty::Template`] is a refinement type in the form of `{ T | P(x1, ..., xn) }` where `P` is a +/// predicate variable. When constructing a template, we need to know which variables can affect the +/// predicate of the template (dependencies, `x1, ..., xn`), and they are provided by the +/// [`TemplateScope`]. No variables are in scope by default and you can provide a scope using +/// [`TemplateTypeBuilder::with_scope`]. pub struct TemplateTypeBuilder<'tcx, 'a, R, S> { inner: TypeBuilder<'tcx>, + // XXX: this can't be simply `R` because monomorphization instantiates types recursively registry: &'a mut R, scope: S, } @@ -310,6 +321,7 @@ where } } +/// A builder for function template types. pub struct FunctionTemplateTypeBuilder<'tcx, 'a, R> { inner: TypeBuilder<'tcx>, registry: &'a mut R, diff --git a/src/rty/params.rs b/src/rty/params.rs index b57ff55..6414f6a 100644 --- a/src/rty/params.rs +++ b/src/rty/params.rs @@ -11,6 +11,20 @@ use super::{Closed, RefinedType}; rustc_index::newtype_index! { /// An index representing a type parameter. + /// + /// ## Note on indexing of type parameters + /// + /// The index of [`rustc_middle::ty::ParamTy`] is based on all generic parameters in + /// the definition, including lifetimes. Given the following definition: + /// + /// ```rust + /// struct X<'a, T> { f: &'a T } + /// ``` + /// + /// The type of field `f` is `&T1` (not `&T0`) in MIR. However, in Thrust, we ignore lifetime + /// parameters and the index of [`rty::ParamType`](super::ParamType) is based on type parameters only, giving `f` + /// the type `&T0`. [`TypeBuilder`](crate::refine::TypeBuilder) takes care of this difference when translating MIR + /// types to Thrust types. #[orderable] #[debug_format = "T{}"] pub struct TypeParamIdx { }