diff --git a/src/analyze.rs b/src/analyze.rs index 898df5b..3550294 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -123,22 +123,12 @@ pub struct Analyzer<'tcx> { enum_defs: Rc>>, } -impl<'tcx> crate::refine::TemplateTypeGenerator<'tcx> for Analyzer<'tcx> { - fn tcx(&self) -> TyCtxt<'tcx> { - self.tcx - } - +impl<'tcx> crate::refine::TemplateRegistry for Analyzer<'tcx> { fn register_template(&mut self, tmpl: rty::Template) -> rty::RefinedType { tmpl.into_refined_type(|pred_sig| self.generate_pred_var(pred_sig)) } } -impl<'tcx> crate::refine::UnrefinedTypeGenerator<'tcx> for Analyzer<'tcx> { - fn tcx(&self) -> TyCtxt<'tcx> { - self.tcx - } -} - impl<'tcx> Analyzer<'tcx> { pub fn generate_pred_var(&mut self, sig: chc::PredSig) -> chc::PredVarId { self.system diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index cbb6589..7b346ce 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -14,7 +14,7 @@ use crate::chc; use crate::pretty::PrettyDisplayExt as _; use crate::refine::{ self, Assumption, BasicBlockType, Env, PlaceType, PlaceTypeBuilder, PlaceTypeVar, TempVarIdx, - TemplateTypeGenerator, UnrefinedTypeGenerator, Var, + TypeBuilder, Var, }; use crate::rty::{ self, ClauseBuilderExt as _, ClauseScope as _, ShiftExistential as _, Subtyping as _, @@ -56,6 +56,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.ctx.basic_block_ty(self.local_def_id, bb) } + fn type_builder(&self) -> TypeBuilder<'tcx> { + TypeBuilder::new(self.tcx, self.local_def_id.to_def_id()) + } + fn bind_local(&mut self, local: Local, rty: rty::RefinedType) { let rty = if self.is_mut_local(local) { // elaboration: @@ -219,26 +223,28 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .into_iter() .map(|ty| rty::RefinedType::unrefined(ty.vacuous())); - let params: IndexVec<_, _> = args + let rty_args: IndexVec<_, _> = args .types() .map(|ty| { - self.ctx - .build_template_ty_with_scope(&self.env) - .refined_ty(ty) + self.type_builder() + .for_template(&mut self.ctx) + .with_scope(&self.env) + .build_refined(ty) }) .collect(); for (field_pty, mut variant_rty) in field_tys.clone().into_iter().zip(variant_rtys) { - variant_rty.instantiate_ty_params(params.clone()); + variant_rty.instantiate_ty_params(rty_args.clone()); let cs = self .env .relate_sub_refined_type(&field_pty.into(), &variant_rty.boxed()); self.ctx.extend_clauses(cs); } - let sort_args: Vec<_> = params.iter().map(|rty| rty.ty.to_sort()).collect(); - let ty = rty::EnumType::new(ty_sym.clone(), params).into(); + let sort_args: Vec<_> = + rty_args.iter().map(|rty| rty.ty.to_sort()).collect(); + let ty = rty::EnumType::new(ty_sym.clone(), rty_args).into(); let mut builder = PlaceTypeBuilder::default(); let mut field_terms = Vec::new(); @@ -433,7 +439,11 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let func_ty = match func.const_fn_def() { // TODO: move this to well-known defs? Some((def_id, args)) if self.is_box_new(def_id) => { - let inner_ty = self.ctx.build_template_ty().ty(args.type_at(0)).vacuous(); + let inner_ty = self + .type_builder() + .for_template(&mut self.ctx) + .build(args.type_at(0)) + .vacuous(); let param = rty::RefinedType::unrefined(inner_ty.clone()); let ret_term = chc::Term::box_(chc::Term::var(rty::FunctionParamIdx::from(0_usize))); @@ -444,7 +454,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { rty::FunctionType::new([param].into_iter().collect(), ret).into() } Some((def_id, args)) if self.is_mem_swap(def_id) => { - let inner_ty = self.ctx.unrefined_ty(args.type_at(0)).vacuous(); + let inner_ty = self.type_builder().build(args.type_at(0)).vacuous(); let param1 = rty::RefinedType::unrefined(rty::PointerType::mut_to(inner_ty.clone()).into()); let param2 = @@ -531,7 +541,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } fn add_prophecy_var(&mut self, statement_index: usize, ty: mir_ty::Ty<'tcx>) { - let ty = self.ctx.unrefined_ty(ty); + let ty = self.type_builder().build(ty); let temp_var = self.env.push_temp_var(ty.vacuous()); self.prophecy_vars.insert(statement_index, temp_var); tracing::debug!(stmt_idx = %statement_index, temp_var = ?temp_var, "add_prophecy_var"); @@ -552,7 +562,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { referent: mir::Place<'tcx>, prophecy_ty: mir_ty::Ty<'tcx>, ) -> rty::RefinedType { - let prophecy_ty = self.ctx.unrefined_ty(prophecy_ty); + let prophecy_ty = self.type_builder().build(prophecy_ty); let prophecy = self.env.push_temp_var(prophecy_ty.vacuous()); let place = self.elaborate_place_for_borrow(&referent); self.env.borrow_place(place, prophecy).into() @@ -665,9 +675,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let decl = self.local_decls[destination].clone(); let rty = self - .ctx - .build_template_ty_with_scope(&self.env) - .refined_ty(decl.ty); + .type_builder() + .for_template(&mut self.ctx) + .with_scope(&self.env) + .build_refined(decl.ty); self.type_call(func.clone(), args.clone().into_iter().map(|a| a.node), &rty); self.bind_local(destination, rty); } @@ -738,9 +749,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { #[tracing::instrument(skip(self))] fn ret_template(&mut self) -> rty::RefinedType { let ret_ty = self.body.local_decls[mir::RETURN_PLACE].ty; - self.ctx - .build_template_ty_with_scope(&self.env) - .refined_ty(ret_ty) + self.type_builder() + .for_template(&mut self.ctx) + .with_scope(&self.env) + .build_refined(ret_ty) } // TODO: remove this diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index 9a1fa67..54a02d1 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -11,7 +11,7 @@ use rustc_span::symbol::Ident; use crate::analyze; use crate::annot::{self, AnnotFormula, AnnotParser, ResolverExt as _}; use crate::chc; -use crate::refine::{self, TemplateTypeGenerator, UnrefinedTypeGenerator}; +use crate::refine::{self, TypeBuilder}; use crate::rty::{self, ClauseBuilderExt as _}; /// An implementation of local crate analysis. @@ -132,13 +132,13 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let mut param_resolver = analyze::annot::ParamResolver::default(); for (input_ident, input_ty) in self.tcx.fn_arg_names(def_id).iter().zip(sig.inputs()) { - let input_ty = self.ctx.unrefined_ty(*input_ty); + let input_ty = TypeBuilder::new(self.tcx, def_id).build(*input_ty); param_resolver.push_param(input_ident.name, input_ty.to_sort()); } let mut require_annot = self.extract_require_annot(¶m_resolver, def_id); let mut ensure_annot = { - let output_ty = self.ctx.unrefined_ty(sig.output()); + let output_ty = TypeBuilder::new(self.tcx, def_id).build(sig.output()); let resolver = annot::StackedResolver::default() .resolver(analyze::annot::ResultResolver::new(output_ty.to_sort())) .resolver((¶m_resolver).map(rty::RefinedTypeVar::Free)); @@ -175,7 +175,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.trusted.insert(def_id); } - let mut builder = self.ctx.build_function_template_ty(sig); + let mut builder = + TypeBuilder::new(self.tcx, def_id).for_function_template(&mut self.ctx, sig); if let Some(AnnotFormula::Formula(require)) = require_annot { let formula = require.map_var(|idx| { if idx.index() == sig.inputs().len() - 1 { @@ -251,6 +252,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { continue; }; let adt = self.tcx.adt_def(local_def_id); + let name = refine::datatype_symbol(self.tcx, local_def_id.to_def_id()); let variants: IndexVec<_, _> = adt .variants() @@ -264,7 +266,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .iter() .map(|field| { let field_ty = self.tcx.type_of(field.did).instantiate_identity(); - self.ctx.unrefined_ty(field_ty) + TypeBuilder::new(self.tcx, local_def_id.to_def_id()).build(field_ty) }) .collect(); rty::EnumVariantDef { @@ -275,19 +277,15 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { }) .collect(); - let ty_params = adt - .all_fields() - .map(|f| self.tcx.type_of(f.did).instantiate_identity()) - .flat_map(|ty| { - if let mir_ty::TyKind::Param(p) = ty.kind() { - Some(p.index as usize) - } else { - None - } + let generics = self.tcx.generics_of(local_def_id); + let ty_params = (0..generics.count()) + .filter(|idx| { + matches!( + generics.param_at(*idx, self.tcx).kind, + mir_ty::GenericParamDefKind::Type { .. } + ) }) - .max() - .map(|max| max + 1) - .unwrap_or(0); + .count(); tracing::debug!(?local_def_id, ?name, ?ty_params, "ty_params count"); let def = rty::EnumDatatypeDef { diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index ef5870e..7e7c737 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -11,7 +11,7 @@ use rustc_span::def_id::LocalDefId; use crate::analyze; use crate::chc; use crate::pretty::PrettyDisplayExt as _; -use crate::refine::{BasicBlockType, TemplateTypeGenerator}; +use crate::refine::{BasicBlockType, TypeBuilder}; use crate::rty; /// An implementation of the typing of local definitions. @@ -306,7 +306,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } // function return type is basic block return type let ret_ty = self.body.local_decls[mir::RETURN_PLACE].ty; - let rty = self.ctx.basic_block_template_ty(live_locals, ret_ty); + let rty = TypeBuilder::new(self.tcx, self.local_def_id.to_def_id()) + .for_template(&mut self.ctx) + .build_basic_block(live_locals, ret_ty); self.ctx.register_basic_block_ty(self.local_def_id, bb, rty); } } diff --git a/src/refine.rs b/src/refine.rs index 0371da1..4736b39 100644 --- a/src/refine.rs +++ b/src/refine.rs @@ -8,7 +8,7 @@ //! module and remove this one. mod template; -pub use template::{TemplateScope, TemplateTypeGenerator, UnrefinedTypeGenerator}; +pub use template::{TemplateRegistry, TemplateScope, TypeBuilder}; mod basic_block; pub use basic_block::BasicBlockType; diff --git a/src/refine/env.rs b/src/refine/env.rs index e647765..a1edbc1 100644 --- a/src/refine/env.rs +++ b/src/refine/env.rs @@ -558,7 +558,8 @@ impl rty::ClauseScope for Env { } } -impl refine::TemplateScope for Env { +impl refine::TemplateScope for Env { + type Var = Var; fn build_template(&self) -> rty::TemplateBuilder { let mut builder = rty::TemplateBuilder::default(); for (v, sort) in self.dependencies() { @@ -930,7 +931,7 @@ impl Env { .field_tys() .map(|ty| rty::RefinedType::unrefined(ty.clone().vacuous()).boxed()); let got_tys = field_tys.iter().map(|ty| ty.clone().into()); - rty::unify_tys_params(expected_tys, got_tys).into_params(def.ty_params, |_| { + rty::unify_tys_params(expected_tys, got_tys).into_args(def.ty_params, |_| { panic!("var_type: should unify all params") }) }; diff --git a/src/refine/template.rs b/src/refine/template.rs index b6ae7b8..bec3f8a 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -3,60 +3,188 @@ use std::collections::HashMap; use rustc_index::IndexVec; use rustc_middle::mir::{Local, Mutability}; use rustc_middle::ty as mir_ty; +use rustc_span::def_id::DefId; use super::basic_block::BasicBlockType; use crate::chc; use crate::refine; use crate::rty; -pub trait TemplateScope { - fn build_template(&self) -> rty::TemplateBuilder; +pub trait TemplateRegistry { + fn register_template(&mut self, tmpl: rty::Template) -> rty::RefinedType; } -impl TemplateScope for &U +impl TemplateRegistry for &mut T where - U: TemplateScope, + T: TemplateRegistry + ?Sized, { - fn build_template(&self) -> rty::TemplateBuilder { - U::build_template(self) + fn register_template(&mut self, tmpl: rty::Template) -> rty::RefinedType { + T::register_template(self, tmpl) + } +} + +/// [`TemplateScope`] with no variables in scope. +#[derive(Clone, Default)] +pub struct EmptyTemplateScope; + +impl TemplateScope for EmptyTemplateScope { + type Var = rty::Closed; + fn build_template(&self) -> rty::TemplateBuilder { + rty::TemplateBuilder::default() } } -impl TemplateScope for rty::TemplateBuilder +pub trait TemplateScope { + type Var: chc::Var; + fn build_template(&self) -> rty::TemplateBuilder; +} + +impl TemplateScope for &T +where + T: TemplateScope, +{ + type Var = T::Var; + fn build_template(&self) -> rty::TemplateBuilder { + T::build_template(self) + } +} + +impl TemplateScope for rty::TemplateBuilder where T: chc::Var, { + type Var = T; fn build_template(&self) -> rty::TemplateBuilder { self.clone() } } -pub trait TemplateTypeGenerator<'tcx> { - fn tcx(&self) -> mir_ty::TyCtxt<'tcx>; - fn register_template(&mut self, tmpl: rty::Template) -> rty::RefinedType; +/// Translates [`mir_ty::Ty`] to [`rty::Type`]. +/// +/// This struct implements a translation from Rust MIR types to Thrust types. +/// Thrust types may contain refinement predicates which do not exist in MIR types, +/// and [`TypeBuilder`] solely builds types with null refinement (true) in +/// [`TypeBuilder::build`]. This also provides [`TypeBuilder::for_template`] to build +/// refinement types by filling unknown predicates with templates with predicate variables. +#[derive(Clone)] +pub struct TypeBuilder<'tcx> { + tcx: mir_ty::TyCtxt<'tcx>, + /// Maps index in [`mir_ty::ParamTy`] to [`rty::TypeParamIdx`]. + /// These indices may differ because we skip lifetime parameters. + /// See [`rty::TypeParamIdx`] for more details. + param_idx_mapping: HashMap, +} - fn build_template_ty_with_scope(&mut self, scope: T) -> TemplateTypeBuilder { - TemplateTypeBuilder { - gen: self, - scope, - _marker: std::marker::PhantomData, +impl<'tcx> TypeBuilder<'tcx> { + pub fn new(tcx: mir_ty::TyCtxt<'tcx>, def_id: DefId) -> Self { + let generics = tcx.generics_of(def_id); + let mut param_idx_mapping: HashMap = Default::default(); + for i in 0..generics.count() { + let generic_param = generics.param_at(i, tcx); + match generic_param.kind { + mir_ty::GenericParamDefKind::Lifetime => {} + mir_ty::GenericParamDefKind::Type { .. } => { + param_idx_mapping.insert(i as u32, param_idx_mapping.len().into()); + } + mir_ty::GenericParamDefKind::Const { .. } => unimplemented!(), + } + } + Self { + tcx, + param_idx_mapping, + } + } + + fn translate_param_type(&self, ty: &mir_ty::ParamTy) -> rty::ParamType { + let index = *self + .param_idx_mapping + .get(&ty.index) + .expect("unknown type param idx"); + rty::ParamType::new(index) + } + + // TODO: consolidate two impls + pub fn build(&self, ty: mir_ty::Ty<'tcx>) -> rty::Type { + match ty.kind() { + mir_ty::TyKind::Bool => rty::Type::bool(), + mir_ty::TyKind::Uint(_) | mir_ty::TyKind::Int(_) => rty::Type::int(), + mir_ty::TyKind::Str => rty::Type::string(), + mir_ty::TyKind::Ref(_, elem_ty, mutbl) => { + let elem_ty = self.build(*elem_ty); + match mutbl { + mir_ty::Mutability::Mut => rty::PointerType::mut_to(elem_ty).into(), + mir_ty::Mutability::Not => rty::PointerType::immut_to(elem_ty).into(), + } + } + mir_ty::TyKind::Tuple(ts) => { + // elaboration: all fields are boxed + let elems = ts + .iter() + .map(|ty| rty::PointerType::own(self.build(ty)).into()) + .collect(); + rty::TupleType::new(elems).into() + } + mir_ty::TyKind::Never => rty::Type::never(), + mir_ty::TyKind::Param(ty) => self.translate_param_type(ty).into(), + mir_ty::TyKind::FnPtr(sig) => { + // TODO: justification for skip_binder + let sig = sig.skip_binder(); + let params = sig + .inputs() + .iter() + .map(|ty| rty::RefinedType::unrefined(self.build(*ty)).vacuous()) + .collect(); + let ret = rty::RefinedType::unrefined(self.build(sig.output())); + rty::FunctionType::new(params, ret.vacuous()).into() + } + mir_ty::TyKind::Adt(def, params) if def.is_box() => { + rty::PointerType::own(self.build(params.type_at(0))).into() + } + mir_ty::TyKind::Adt(def, params) => { + if def.is_enum() { + let sym = refine::datatype_symbol(self.tcx, def.did()); + let args: IndexVec<_, _> = params + .types() + .map(|ty| rty::RefinedType::unrefined(self.build(ty))) + .collect(); + rty::EnumType::new(sym, args).into() + } else if def.is_struct() { + let elem_tys = def + .all_fields() + .map(|field| { + let ty = field.ty(self.tcx, params); + // elaboration: all fields are boxed + rty::PointerType::own(self.build(ty)).into() + }) + .collect(); + rty::TupleType::new(elem_tys).into() + } else { + unimplemented!("unsupported ADT: {:?}", ty); + } + } + kind => unimplemented!("unrefined_ty: {:?}", kind), } } - fn build_template_ty(&mut self) -> TemplateTypeBuilder, V> { + pub fn for_template<'a, R>( + &self, + registry: &'a mut R, + ) -> TemplateTypeBuilder<'tcx, 'a, R, EmptyTemplateScope> { TemplateTypeBuilder { - gen: self, + inner: self.clone(), + registry, scope: Default::default(), - _marker: std::marker::PhantomData, } } - fn build_function_template_ty( - &mut self, + pub fn for_function_template<'a, R>( + &self, + registry: &'a mut R, sig: mir_ty::FnSig<'tcx>, - ) -> FunctionTemplateTypeBuilder<'_, 'tcx, Self> { + ) -> FunctionTemplateTypeBuilder<'tcx, 'a, R> { FunctionTemplateTypeBuilder { - gen: self, + inner: self.clone(), + registry, param_tys: sig .inputs() .iter() @@ -71,12 +199,104 @@ pub trait TemplateTypeGenerator<'tcx> { ret_rty: None, } } +} + +/// Translates [`mir_ty::Ty`] to [`rty::Type`] using templates for refinements. +/// +/// [`rty::Template`] is a refinement type in the form of `{ T | P(x1, ..., xn) }` where `P` is a +/// predicate variable. When constructing a template, we need to know which variables can affect the +/// predicate of the template (dependencies, `x1, ..., xn`), and they are provided by the +/// [`TemplateScope`]. No variables are in scope by default and you can provide a scope using +/// [`TemplateTypeBuilder::with_scope`]. +pub struct TemplateTypeBuilder<'tcx, 'a, R, S> { + inner: TypeBuilder<'tcx>, + // XXX: this can't be simply `R` because monomorphization instantiates types recursively + registry: &'a mut R, + scope: S, +} + +impl<'tcx, 'a, R, S> TemplateTypeBuilder<'tcx, 'a, R, S> { + pub fn with_scope(self, scope: T) -> TemplateTypeBuilder<'tcx, 'a, R, T> { + TemplateTypeBuilder { + inner: self.inner, + registry: self.registry, + scope, + } + } +} + +impl<'tcx, 'a, R, S> TemplateTypeBuilder<'tcx, 'a, R, S> +where + R: TemplateRegistry, + S: TemplateScope, +{ + pub fn build(&mut self, ty: mir_ty::Ty<'tcx>) -> rty::Type { + match ty.kind() { + mir_ty::TyKind::Bool => rty::Type::bool(), + mir_ty::TyKind::Uint(_) | mir_ty::TyKind::Int(_) => rty::Type::int(), + mir_ty::TyKind::Str => rty::Type::string(), + mir_ty::TyKind::Ref(_, elem_ty, mutbl) => { + let elem_ty = self.build(*elem_ty); + match mutbl { + mir_ty::Mutability::Mut => rty::PointerType::mut_to(elem_ty).into(), + mir_ty::Mutability::Not => rty::PointerType::immut_to(elem_ty).into(), + } + } + mir_ty::TyKind::Tuple(ts) => { + // elaboration: all fields are boxed + let elems = ts + .iter() + .map(|ty| rty::PointerType::own(self.build(ty)).into()) + .collect(); + rty::TupleType::new(elems).into() + } + mir_ty::TyKind::Never => rty::Type::never(), + mir_ty::TyKind::Param(ty) => self.inner.translate_param_type(ty).into(), + mir_ty::TyKind::FnPtr(sig) => { + // TODO: justification for skip_binder + let sig = sig.skip_binder(); + let ty = self.inner.for_function_template(self.registry, sig).build(); + rty::Type::function(ty) + } + mir_ty::TyKind::Adt(def, params) if def.is_box() => { + rty::PointerType::own(self.build(params.type_at(0))).into() + } + mir_ty::TyKind::Adt(def, params) => { + if def.is_enum() { + let sym = refine::datatype_symbol(self.inner.tcx, def.did()); + let args: IndexVec<_, _> = + params.types().map(|ty| self.build_refined(ty)).collect(); + rty::EnumType::new(sym, args).into() + } else if def.is_struct() { + let elem_tys = def + .all_fields() + .map(|field| { + let ty = field.ty(self.inner.tcx, params); + // elaboration: all fields are boxed + rty::PointerType::own(self.build(ty)).into() + }) + .collect(); + rty::TupleType::new(elem_tys).into() + } else { + unimplemented!("unsupported ADT: {:?}", ty); + } + } + kind => unimplemented!("ty: {:?}", kind), + } + } - fn build_basic_block_template_ty( + pub fn build_refined(&mut self, ty: mir_ty::Ty<'tcx>) -> rty::RefinedType { + // TODO: consider building ty with scope + let ty = self.inner.for_template(self.registry).build(ty).vacuous(); + let tmpl = self.scope.build_template().build(ty); + self.registry.register_template(tmpl) + } + + pub fn build_basic_block( &mut self, live_locals: I, ret_ty: mir_ty::Ty<'tcx>, - ) -> BasicBlockTemplateTypeBuilder<'_, 'tcx, Self> + ) -> BasicBlockType where I: IntoIterator)>, { @@ -87,51 +307,24 @@ pub trait TemplateTypeGenerator<'tcx> { locals.push((local, ty.mutbl)); tys.push(ty); } - let inner = FunctionTemplateTypeBuilder { - gen: self, + let ty = FunctionTemplateTypeBuilder { + inner: self.inner.clone(), + registry: self.registry, param_tys: tys, ret_ty, param_rtys: Default::default(), param_refinement: None, ret_rty: None, - }; - BasicBlockTemplateTypeBuilder { inner, locals } - } - - fn basic_block_template_ty( - &mut self, - live_locals: I, - ret_ty: mir_ty::Ty<'tcx>, - ) -> BasicBlockType - where - I: IntoIterator)>, - { - self.build_basic_block_template_ty(live_locals, ret_ty) - .build() - } - - fn function_template_ty(&mut self, sig: mir_ty::FnSig<'tcx>) -> rty::FunctionType { - self.build_function_template_ty(sig).build() - } -} - -impl<'tcx, T> TemplateTypeGenerator<'tcx> for &mut T -where - T: TemplateTypeGenerator<'tcx> + ?Sized, -{ - fn tcx(&self) -> mir_ty::TyCtxt<'tcx> { - T::tcx(self) - } - - fn register_template(&mut self, tmpl: rty::Template) -> rty::RefinedType { - T::register_template(self, tmpl) + } + .build(); + BasicBlockType { ty, locals } } } -#[derive(Debug)] -pub struct FunctionTemplateTypeBuilder<'a, 'tcx, T: ?Sized> { - // can't use T: TemplateTypeGenerator<'tcx> directly because of recursive instantiation - gen: &'a mut T, +/// A builder for function template types. +pub struct FunctionTemplateTypeBuilder<'tcx, 'a, R> { + inner: TypeBuilder<'tcx>, + registry: &'a mut R, param_tys: Vec>, ret_ty: mir_ty::Ty<'tcx>, param_refinement: Option>, @@ -139,10 +332,7 @@ pub struct FunctionTemplateTypeBuilder<'a, 'tcx, T: ?Sized> { ret_rty: Option>, } -impl<'a, 'tcx, T> FunctionTemplateTypeBuilder<'a, 'tcx, T> -where - T: TemplateTypeGenerator<'tcx> + ?Sized, -{ +impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { pub fn param_refinement( &mut self, refinement: rty::Refinement, @@ -174,7 +364,7 @@ where &mut self, refinement: rty::Refinement, ) -> &mut Self { - let ty = UnrefinedTypeGeneratorWrapper(&mut self.gen).unrefined_ty(self.ret_ty); + let ty = self.inner.build(self.ret_ty); self.ret_rty = Some(rty::RefinedType::new(ty.vacuous(), refinement)); self } @@ -183,7 +373,12 @@ where self.ret_rty = Some(rty); self } +} +impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> +where + R: TemplateRegistry, +{ pub fn build(&mut self) -> rty::FunctionType { let mut builder = rty::TemplateBuilder::default(); let mut param_rtys = IndexVec::::new(); @@ -195,19 +390,20 @@ where .unwrap_or_else(|| { if idx == self.param_tys.len() - 1 { if let Some(param_refinement) = &self.param_refinement { - let ty = UnrefinedTypeGeneratorWrapper(&mut self.gen) - .unrefined_ty(param_ty.ty); + let ty = self.inner.build(param_ty.ty); rty::RefinedType::new(ty.vacuous(), param_refinement.clone()) } else { - self.gen - .build_template_ty_with_scope(&builder) - .refined_ty(param_ty.ty) + self.inner + .for_template(self.registry) + .with_scope(&builder) + .build_refined(param_ty.ty) } } else { rty::RefinedType::unrefined( - self.gen - .build_template_ty_with_scope(&builder) - .ty(param_ty.ty), + self.inner + .for_template(self.registry) + .with_scope(&builder) + .build(param_ty.ty), ) } }); @@ -227,208 +423,21 @@ where let param_rty = if let Some(param_refinement) = &self.param_refinement { rty::RefinedType::new(rty::Type::unit(), param_refinement.clone()) } else { - let unit_ty = mir_ty::Ty::new_unit(self.gen.tcx()); - self.gen - .build_template_ty_with_scope(&builder) - .refined_ty(unit_ty) + let unit_ty = mir_ty::Ty::new_unit(self.inner.tcx); + self.inner + .for_template(self.registry) + .with_scope(&builder) + .build_refined(unit_ty) }; param_rtys.push(param_rty); } let ret_rty = self.ret_rty.clone().unwrap_or_else(|| { - self.gen - .build_template_ty_with_scope(&builder) - .refined_ty(self.ret_ty) + self.inner + .for_template(self.registry) + .with_scope(&builder) + .build_refined(self.ret_ty) }); rty::FunctionType::new(param_rtys, ret_rty) } } - -#[derive(Debug)] -pub struct BasicBlockTemplateTypeBuilder<'a, 'tcx, T: ?Sized> { - inner: FunctionTemplateTypeBuilder<'a, 'tcx, T>, - locals: IndexVec, -} - -impl<'a, 'tcx, T> BasicBlockTemplateTypeBuilder<'a, 'tcx, T> -where - T: TemplateTypeGenerator<'tcx> + ?Sized, -{ - #[allow(dead_code)] - pub fn param_refinement( - &mut self, - refinement: rty::Refinement, - ) -> &mut Self { - self.inner.param_refinement(refinement); - self - } - - #[allow(dead_code)] - pub fn ret_rty(&mut self, rty: rty::RefinedType) -> &mut Self { - self.inner.ret_rty(rty); - self - } - - pub fn build(&mut self) -> BasicBlockType { - let ty = self.inner.build(); - BasicBlockType { - ty, - locals: self.locals.clone(), - } - } -} - -#[derive(Debug)] -pub struct TemplateTypeBuilder<'a, T: ?Sized, U, V> { - // can't use T: TemplateTypeGenerator<'tcx> directly because of recursive instantiation - gen: &'a mut T, - scope: U, - _marker: std::marker::PhantomData V>, -} - -impl<'a, 'tcx, T, U, V> TemplateTypeBuilder<'a, T, U, V> -where - T: TemplateTypeGenerator<'tcx> + ?Sized, - U: TemplateScope, - V: chc::Var, -{ - pub fn ty(&mut self, ty: mir_ty::Ty<'tcx>) -> rty::Type { - match ty.kind() { - mir_ty::TyKind::Bool => rty::Type::bool(), - mir_ty::TyKind::Uint(_) | mir_ty::TyKind::Int(_) => rty::Type::int(), - mir_ty::TyKind::Str => rty::Type::string(), - mir_ty::TyKind::Ref(_, elem_ty, mutbl) => { - let elem_ty = self.ty(*elem_ty); - match mutbl { - mir_ty::Mutability::Mut => rty::PointerType::mut_to(elem_ty).into(), - mir_ty::Mutability::Not => rty::PointerType::immut_to(elem_ty).into(), - } - } - mir_ty::TyKind::Tuple(ts) => { - // elaboration: all fields are boxed - let elems = ts - .iter() - .map(|ty| rty::PointerType::own(self.ty(ty)).into()) - .collect(); - rty::TupleType::new(elems).into() - } - mir_ty::TyKind::Never => rty::Type::never(), - mir_ty::TyKind::Param(ty) => rty::ParamType::new(ty.index.into()).into(), - mir_ty::TyKind::FnPtr(sig) => { - // TODO: justification for skip_binder - let sig = sig.skip_binder(); - let ty = self.gen.function_template_ty(sig); - rty::Type::function(ty) - } - mir_ty::TyKind::Adt(def, params) if def.is_box() => { - rty::PointerType::own(self.ty(params.type_at(0))).into() - } - mir_ty::TyKind::Adt(def, params) => { - if def.is_enum() { - let sym = refine::datatype_symbol(self.gen.tcx(), def.did()); - let args: IndexVec<_, _> = - params.types().map(|ty| self.refined_ty(ty)).collect(); - rty::EnumType::new(sym, args).into() - } else if def.is_struct() { - let elem_tys = def - .all_fields() - .map(|field| { - let ty = field.ty(self.gen.tcx(), params); - // elaboration: all fields are boxed - rty::PointerType::own(self.ty(ty)).into() - }) - .collect(); - rty::TupleType::new(elem_tys).into() - } else { - unimplemented!("unsupported ADT: {:?}", ty); - } - } - kind => unimplemented!("ty: {:?}", kind), - } - } - - pub fn refined_ty(&mut self, ty: mir_ty::Ty<'tcx>) -> rty::RefinedType { - // TODO: consider building ty with scope - let ty = self.gen.build_template_ty().ty(ty); - let tmpl = self.scope.build_template().build(ty); - self.gen.register_template(tmpl) - } -} - -pub trait UnrefinedTypeGenerator<'tcx> { - fn tcx(&self) -> mir_ty::TyCtxt<'tcx>; - - // TODO: consolidate two defs - fn unrefined_ty(&mut self, ty: mir_ty::Ty<'tcx>) -> rty::Type { - match ty.kind() { - mir_ty::TyKind::Bool => rty::Type::bool(), - mir_ty::TyKind::Uint(_) | mir_ty::TyKind::Int(_) => rty::Type::int(), - mir_ty::TyKind::Str => rty::Type::string(), - mir_ty::TyKind::Ref(_, elem_ty, mutbl) => { - let elem_ty = self.unrefined_ty(*elem_ty); - match mutbl { - mir_ty::Mutability::Mut => rty::PointerType::mut_to(elem_ty).into(), - mir_ty::Mutability::Not => rty::PointerType::immut_to(elem_ty).into(), - } - } - mir_ty::TyKind::Tuple(ts) => { - // elaboration: all fields are boxed - let elems = ts - .iter() - .map(|ty| rty::PointerType::own(self.unrefined_ty(ty)).into()) - .collect(); - rty::TupleType::new(elems).into() - } - mir_ty::TyKind::Never => rty::Type::never(), - mir_ty::TyKind::Param(ty) => rty::ParamType::new(ty.index.into()).into(), - mir_ty::TyKind::FnPtr(sig) => { - // TODO: justification for skip_binder - let sig = sig.skip_binder(); - let params = sig - .inputs() - .iter() - .map(|ty| rty::RefinedType::unrefined(self.unrefined_ty(*ty)).vacuous()) - .collect(); - let ret = rty::RefinedType::unrefined(self.unrefined_ty(sig.output())); - rty::FunctionType::new(params, ret.vacuous()).into() - } - mir_ty::TyKind::Adt(def, params) if def.is_box() => { - rty::PointerType::own(self.unrefined_ty(params.type_at(0))).into() - } - mir_ty::TyKind::Adt(def, params) => { - if def.is_enum() { - let sym = refine::datatype_symbol(self.tcx(), def.did()); - let args: IndexVec<_, _> = params - .types() - .map(|ty| rty::RefinedType::unrefined(self.unrefined_ty(ty))) - .collect(); - rty::EnumType::new(sym, args).into() - } else if def.is_struct() { - let elem_tys = def - .all_fields() - .map(|field| { - let ty = field.ty(self.tcx(), params); - // elaboration: all fields are boxed - rty::PointerType::own(self.unrefined_ty(ty)).into() - }) - .collect(); - rty::TupleType::new(elem_tys).into() - } else { - unimplemented!("unsupported ADT: {:?}", ty); - } - } - kind => unimplemented!("unrefined_ty: {:?}", kind), - } - } -} - -struct UnrefinedTypeGeneratorWrapper(T); - -impl<'tcx, T> UnrefinedTypeGenerator<'tcx> for UnrefinedTypeGeneratorWrapper -where - T: TemplateTypeGenerator<'tcx>, -{ - fn tcx(&self) -> mir_ty::TyCtxt<'tcx> { - self.0.tcx() - } -} diff --git a/src/rty.rs b/src/rty.rs index 72fede8..c706897 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -55,7 +55,7 @@ mod subtyping; pub use subtyping::{relate_sub_closed_type, ClauseScope, Subtyping}; mod params; -pub use params::{TypeParamIdx, TypeParamSubst, TypeParams}; +pub use params::{TypeArgs, TypeParamIdx, TypeParamSubst}; rustc_index::newtype_index! { /// An index representing function parameter. @@ -487,7 +487,7 @@ where } impl EnumType { - pub fn new(symbol: chc::DatatypeSymbol, args: TypeParams) -> Self { + pub fn new(symbol: chc::DatatypeSymbol, args: TypeArgs) -> Self { EnumType { symbol, args } } @@ -1372,7 +1372,7 @@ impl RefinedType { } } - pub fn instantiate_ty_params(&mut self, params: TypeParams) + pub fn instantiate_ty_params(&mut self, params: TypeArgs) where FV: chc::Var, { diff --git a/src/rty/params.rs b/src/rty/params.rs index 17ebc2b..6414f6a 100644 --- a/src/rty/params.rs +++ b/src/rty/params.rs @@ -11,6 +11,20 @@ use super::{Closed, RefinedType}; rustc_index::newtype_index! { /// An index representing a type parameter. + /// + /// ## Note on indexing of type parameters + /// + /// The index of [`rustc_middle::ty::ParamTy`] is based on all generic parameters in + /// the definition, including lifetimes. Given the following definition: + /// + /// ```rust + /// struct X<'a, T> { f: &'a T } + /// ``` + /// + /// The type of field `f` is `&T1` (not `&T0`) in MIR. However, in Thrust, we ignore lifetime + /// parameters and the index of [`rty::ParamType`](super::ParamType) is based on type parameters only, giving `f` + /// the type `&T0`. [`TypeBuilder`](crate::refine::TypeBuilder) takes care of this difference when translating MIR + /// types to Thrust types. #[orderable] #[debug_format = "T{}"] pub struct TypeParamIdx { } @@ -39,7 +53,7 @@ impl TypeParamIdx { } } -pub type TypeParams = IndexVec>; +pub type TypeArgs = IndexVec>; /// A substitution for type parameters that maps type parameters to refinement types. #[derive(Debug, Clone)] @@ -55,8 +69,8 @@ impl Default for TypeParamSubst { } } -impl From> for TypeParamSubst { - fn from(params: TypeParams) -> Self { +impl From> for TypeParamSubst { + fn from(params: TypeArgs) -> Self { let subst = params.into_iter_enumerated().collect(); Self { subst } } @@ -71,6 +85,10 @@ impl std::ops::Index for TypeParamSubst { } impl TypeParamSubst { + pub fn new(subst: BTreeMap>) -> Self { + Self { subst } + } + pub fn singleton(idx: TypeParamIdx, ty: RefinedType) -> Self { let mut subst = BTreeMap::default(); subst.insert(idx, ty); @@ -94,20 +112,20 @@ impl TypeParamSubst { } } - pub fn into_params(mut self, expected_len: usize, mut default: F) -> TypeParams + pub fn into_args(mut self, expected_len: usize, mut default: F) -> TypeArgs where T: chc::Var, F: FnMut(TypeParamIdx) -> RefinedType, { - let mut params = TypeParams::new(); + let mut args = TypeArgs::new(); for idx in 0..expected_len { let ty = self .subst .remove(&idx.into()) .unwrap_or_else(|| default(idx.into())); - params.push(ty); + args.push(ty); } - params + args } pub fn strip_refinement(self) -> TypeParamSubst { diff --git a/tests/ui/fail/adt_poly_ref.rs b/tests/ui/fail/adt_poly_ref.rs new file mode 100644 index 0000000..8c42d4b --- /dev/null +++ b/tests/ui/fail/adt_poly_ref.rs @@ -0,0 +1,14 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +enum X<'a, T> { + A(&'a T), +} + +fn main() { + let i = 42; + let x = X::A(&i); + match x { + X::A(i) => assert!(*i == 41), + } +} diff --git a/tests/ui/pass/adt_poly_ref.rs b/tests/ui/pass/adt_poly_ref.rs new file mode 100644 index 0000000..f0e5e30 --- /dev/null +++ b/tests/ui/pass/adt_poly_ref.rs @@ -0,0 +1,14 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +enum X<'a, T> { + A(&'a T), +} + +fn main() { + let i = 42; + let x = X::A(&i); + match x { + X::A(i) => assert!(*i == 42), + } +}