diff --git a/src/analyze.rs b/src/analyze.rs index 3550294..a1ebd34 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -17,7 +17,7 @@ use rustc_span::def_id::{DefId, LocalDefId}; use crate::chc; use crate::pretty::PrettyDisplayExt as _; -use crate::refine::{self, BasicBlockType}; +use crate::refine::{self, BasicBlockType, TypeBuilder}; use crate::rty; mod annot; @@ -103,6 +103,17 @@ impl<'tcx> ReplacePlacesVisitor<'tcx> { } } +#[derive(Debug, Clone)] +struct DeferredDefTy { + cache: Rc>>, +} + +#[derive(Debug, Clone)] +enum DefTy { + Concrete(rty::RefinedType), + Deferred(DeferredDefTy), +} + #[derive(Clone)] pub struct Analyzer<'tcx> { tcx: TyCtxt<'tcx>, @@ -112,7 +123,7 @@ pub struct Analyzer<'tcx> { /// currently contains only local-def templates, /// but will be extended to contain externally known def's refinement types /// (at least for every defs referenced by local def bodies) - defs: HashMap, + defs: HashMap, /// Resulting CHC system. system: Rc>, @@ -207,11 +218,65 @@ impl<'tcx> Analyzer<'tcx> { pub fn register_def(&mut self, def_id: DefId, rty: rty::RefinedType) { tracing::info!(def_id = ?def_id, rty = %rty.display(), "register_def"); - self.defs.insert(def_id, rty); + self.defs.insert(def_id, DefTy::Concrete(rty)); } - pub fn def_ty(&self, def_id: DefId) -> Option<&rty::RefinedType> { - self.defs.get(&def_id) + pub fn register_deferred_def(&mut self, def_id: DefId) { + tracing::info!(def_id = ?def_id, "register_deferred_def"); + self.defs.insert( + def_id, + DefTy::Deferred(DeferredDefTy { + cache: Rc::new(RefCell::new(HashMap::new())), + }), + ); + } + + pub fn concrete_def_ty(&self, def_id: DefId) -> Option<&rty::RefinedType> { + self.defs.get(&def_id).and_then(|def_ty| match def_ty { + DefTy::Concrete(rty) => Some(rty), + DefTy::Deferred(_) => None, + }) + } + + pub fn def_ty_with_args( + &mut self, + def_id: DefId, + rty_args: rty::TypeArgs, + ) -> Option { + let deferred_ty = match self.defs.get(&def_id)? { + DefTy::Concrete(rty) => { + let mut def_ty = rty.clone(); + def_ty.instantiate_ty_params( + rty_args + .clone() + .into_iter() + .map(rty::RefinedType::unrefined) + .collect(), + ); + return Some(def_ty); + } + DefTy::Deferred(deferred) => deferred, + }; + + let deferred_ty_cache = Rc::clone(&deferred_ty.cache); // to cut reference to allow &mut self + if let Some(rty) = deferred_ty_cache.borrow().get(&rty_args) { + return Some(rty.clone()); + } + + let type_builder = TypeBuilder::new(self.tcx, def_id).with_param_mapper({ + let rty_args = rty_args.clone(); + move |ty: rty::ParamType| rty_args[ty.idx].clone() + }); + let mut analyzer = self.local_def_analyzer(def_id.as_local()?); + analyzer.type_builder(type_builder); + + let expected = analyzer.expected_ty(); + deferred_ty_cache + .borrow_mut() + .insert(rty_args, expected.clone()); + + analyzer.run(&expected); + Some(expected) } pub fn register_basic_block_ty( diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index 7b346ce..9f6ee49 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -33,6 +33,7 @@ pub struct Analyzer<'tcx, 'ctx> { basic_block: BasicBlock, body: Cow<'tcx, Body<'tcx>>, + type_builder: TypeBuilder<'tcx>, env: Env, local_decls: IndexVec>, // TODO: remove this @@ -56,10 +57,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.ctx.basic_block_ty(self.local_def_id, bb) } - fn type_builder(&self) -> TypeBuilder<'tcx> { - TypeBuilder::new(self.tcx, self.local_def_id.to_def_id()) - } - fn bind_local(&mut self, local: Local, rty: rty::RefinedType) { let rty = if self.is_mut_local(local) { // elaboration: @@ -226,7 +223,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let rty_args: IndexVec<_, _> = args .types() .map(|ty| { - self.type_builder() + self.type_builder .for_template(&mut self.ctx) .with_scope(&self.env) .build_refined(ty) @@ -267,10 +264,13 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { ) => { let func_ty = match operand.const_fn_def() { Some((def_id, args)) => { - if !args.is_empty() { - tracing::warn!(?args, ?def_id, "generic args ignored"); - } - self.ctx.def_ty(def_id).expect("unknown def").ty.clone() + let rty_args: IndexVec<_, _> = + args.types().map(|ty| self.type_builder.build(ty)).collect(); + self.ctx + .def_ty_with_args(def_id, rty_args) + .expect("unknown def") + .ty + .clone() } _ => unimplemented!(), }; @@ -440,7 +440,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { // TODO: move this to well-known defs? Some((def_id, args)) if self.is_box_new(def_id) => { let inner_ty = self - .type_builder() + .type_builder .for_template(&mut self.ctx) .build(args.type_at(0)) .vacuous(); @@ -454,7 +454,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { rty::FunctionType::new([param].into_iter().collect(), ret).into() } Some((def_id, args)) if self.is_mem_swap(def_id) => { - let inner_ty = self.type_builder().build(args.type_at(0)).vacuous(); + let inner_ty = self.type_builder.build(args.type_at(0)).vacuous(); let param1 = rty::RefinedType::unrefined(rty::PointerType::mut_to(inner_ty.clone()).into()); let param2 = @@ -472,14 +472,11 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { rty::FunctionType::new([param1, param2].into_iter().collect(), ret).into() } Some((def_id, args)) => { - if !args.is_empty() { - tracing::warn!(?args, ?def_id, "generic args ignored"); - } + let rty_args = args.types().map(|ty| self.type_builder.build(ty)).collect(); self.ctx - .def_ty(def_id) + .def_ty_with_args(def_id, rty_args) .expect("unknown def") .ty - .clone() .vacuous() } _ => self.operand_type(func.clone()).ty, @@ -541,7 +538,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } fn add_prophecy_var(&mut self, statement_index: usize, ty: mir_ty::Ty<'tcx>) { - let ty = self.type_builder().build(ty); + let ty = self.type_builder.build(ty); let temp_var = self.env.push_temp_var(ty.vacuous()); self.prophecy_vars.insert(statement_index, temp_var); tracing::debug!(stmt_idx = %statement_index, temp_var = ?temp_var, "add_prophecy_var"); @@ -562,7 +559,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { referent: mir::Place<'tcx>, prophecy_ty: mir_ty::Ty<'tcx>, ) -> rty::RefinedType { - let prophecy_ty = self.type_builder().build(prophecy_ty); + let prophecy_ty = self.type_builder.build(prophecy_ty); let prophecy = self.env.push_temp_var(prophecy_ty.vacuous()); let place = self.elaborate_place_for_borrow(&referent); self.env.borrow_place(place, prophecy).into() @@ -675,7 +672,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let decl = self.local_decls[destination].clone(); let rty = self - .type_builder() + .type_builder .for_template(&mut self.ctx) .with_scope(&self.env) .build_refined(decl.ty); @@ -749,7 +746,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { #[tracing::instrument(skip(self))] fn ret_template(&mut self) -> rty::RefinedType { let ret_ty = self.body.local_decls[mir::RETURN_PLACE].ty; - self.type_builder() + self.type_builder .for_template(&mut self.ctx) .with_scope(&self.env) .build_refined(ret_ty) @@ -955,6 +952,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let env = ctx.new_env(); let local_decls = body.local_decls.clone(); let prophecy_vars = Default::default(); + let type_builder = TypeBuilder::new(tcx, local_def_id.to_def_id()); Self { ctx, tcx, @@ -962,6 +960,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { drop_points, basic_block, body, + type_builder, env, local_decls, prophecy_vars, @@ -989,6 +988,11 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self } + pub fn type_builder(&mut self, type_builder: TypeBuilder<'tcx>) -> &mut Self { + self.type_builder = type_builder; + self + } + pub fn run(&mut self, expected: &BasicBlockType) { let span = tracing::info_span!("bb", bb = ?self.basic_block); let _guard = span.enter(); diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index 54a02d1..e15ca49 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -5,11 +5,9 @@ use std::collections::HashSet; use rustc_hir::def::DefKind; use rustc_index::IndexVec; use rustc_middle::ty::{self as mir_ty, TyCtxt}; -use rustc_span::def_id::DefId; -use rustc_span::symbol::Ident; +use rustc_span::def_id::{DefId, LocalDefId}; use crate::analyze; -use crate::annot::{self, AnnotFormula, AnnotParser, ResolverExt as _}; use crate::chc; use crate::refine::{self, TypeBuilder}; use crate::rty::{self, ClauseBuilderExt as _}; @@ -34,172 +32,32 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { fn refine_local_defs(&mut self) { for local_def_id in self.tcx.mir_keys(()) { if self.tcx.def_kind(*local_def_id).is_fn_like() { - self.refine_fn_def(local_def_id.to_def_id()); + self.refine_fn_def(*local_def_id); } } } - fn extract_require_annot( - &self, - resolver: T, - def_id: DefId, - ) -> Option> - where - T: annot::Resolver, - { - let mut require_annot = None; - for attrs in self - .tcx - .get_attrs_by_path(def_id, &analyze::annot::requires_path()) - { - if require_annot.is_some() { - unimplemented!(); - } - let ts = analyze::annot::extract_annot_tokens(attrs.clone()); - let require = AnnotParser::new(&resolver).parse_formula(ts).unwrap(); - require_annot = Some(require); - } - require_annot - } - - fn extract_ensure_annot(&self, resolver: T, def_id: DefId) -> Option> - where - T: annot::Resolver, - { - let mut ensure_annot = None; - for attrs in self - .tcx - .get_attrs_by_path(def_id, &analyze::annot::ensures_path()) - { - if ensure_annot.is_some() { - unimplemented!(); - } - let ts = analyze::annot::extract_annot_tokens(attrs.clone()); - let ensure = AnnotParser::new(&resolver).parse_formula(ts).unwrap(); - ensure_annot = Some(ensure); - } - ensure_annot - } + #[tracing::instrument(skip(self), fields(def_id = %self.tcx.def_path_str(local_def_id)))] + fn refine_fn_def(&mut self, local_def_id: LocalDefId) { + let mut analyzer = self.ctx.local_def_analyzer(local_def_id); - fn extract_param_annots( - &self, - resolver: T, - def_id: DefId, - ) -> Vec<(Ident, rty::RefinedType)> - where - T: annot::Resolver, - { - let mut param_annots = Vec::new(); - for attrs in self - .tcx - .get_attrs_by_path(def_id, &analyze::annot::param_path()) - { - let ts = analyze::annot::extract_annot_tokens(attrs.clone()); - let (ident, ts) = analyze::annot::split_param(&ts); - let param = AnnotParser::new(&resolver).parse_rty(ts).unwrap(); - param_annots.push((ident, param)); + if analyzer.is_annotated_as_trusted() { + assert!(analyzer.is_fully_annotated()); + self.trusted.insert(local_def_id.to_def_id()); } - param_annots - } - fn extract_ret_annot( - &self, - resolver: T, - def_id: DefId, - ) -> Option> - where - T: annot::Resolver, - { - let mut ret_annot = None; - for attrs in self + let sig = self .tcx - .get_attrs_by_path(def_id, &analyze::annot::ret_path()) - { - if ret_annot.is_some() { - unimplemented!(); - } - let ts = analyze::annot::extract_annot_tokens(attrs.clone()); - let ret = AnnotParser::new(&resolver).parse_rty(ts).unwrap(); - ret_annot = Some(ret); + .fn_sig(local_def_id) + .instantiate_identity() + .skip_binder(); + use mir_ty::TypeVisitableExt as _; + if sig.has_param() && !analyzer.is_fully_annotated() { + self.ctx.register_deferred_def(local_def_id.to_def_id()); + } else { + let expected = analyzer.expected_ty(); + self.ctx.register_def(local_def_id.to_def_id(), expected); } - ret_annot - } - - #[tracing::instrument(skip(self), fields(def_id = %self.tcx.def_path_str(def_id)))] - fn refine_fn_def(&mut self, def_id: DefId) { - let sig = self.tcx.fn_sig(def_id); - let sig = sig.instantiate_identity().skip_binder(); // TODO: is it OK? - - let mut param_resolver = analyze::annot::ParamResolver::default(); - for (input_ident, input_ty) in self.tcx.fn_arg_names(def_id).iter().zip(sig.inputs()) { - let input_ty = TypeBuilder::new(self.tcx, def_id).build(*input_ty); - param_resolver.push_param(input_ident.name, input_ty.to_sort()); - } - - let mut require_annot = self.extract_require_annot(¶m_resolver, def_id); - let mut ensure_annot = { - let output_ty = TypeBuilder::new(self.tcx, def_id).build(sig.output()); - let resolver = annot::StackedResolver::default() - .resolver(analyze::annot::ResultResolver::new(output_ty.to_sort())) - .resolver((¶m_resolver).map(rty::RefinedTypeVar::Free)); - self.extract_ensure_annot(resolver, def_id) - }; - let param_annots = self.extract_param_annots(¶m_resolver, def_id); - let ret_annot = self.extract_ret_annot(¶m_resolver, def_id); - - if self - .tcx - .get_attrs_by_path(def_id, &analyze::annot::callable_path()) - .next() - .is_some() - { - if require_annot.is_some() || ensure_annot.is_some() { - unimplemented!(); - } - - require_annot = Some(AnnotFormula::top()); - ensure_annot = Some(AnnotFormula::top()); - } - - assert!(require_annot.is_none() || param_annots.is_empty()); - assert!(ensure_annot.is_none() || ret_annot.is_none()); - - if self - .tcx - .get_attrs_by_path(def_id, &analyze::annot::trusted_path()) - .next() - .is_some() - { - assert!(require_annot.is_some() || !param_annots.is_empty()); - assert!(ensure_annot.is_some() || ret_annot.is_some()); - self.trusted.insert(def_id); - } - - let mut builder = - TypeBuilder::new(self.tcx, def_id).for_function_template(&mut self.ctx, sig); - if let Some(AnnotFormula::Formula(require)) = require_annot { - let formula = require.map_var(|idx| { - if idx.index() == sig.inputs().len() - 1 { - rty::RefinedTypeVar::Value - } else { - rty::RefinedTypeVar::Free(idx) - } - }); - builder.param_refinement(formula.into()); - } - if let Some(AnnotFormula::Formula(ensure)) = ensure_annot { - builder.ret_refinement(ensure.into()); - } - for (ident, annot_rty) in param_annots { - use annot::Resolver as _; - let (idx, _) = param_resolver.resolve(ident).expect("unknown param"); - builder.param_rty(idx, annot_rty); - } - if let Some(ret_rty) = ret_annot { - builder.ret_rty(ret_rty); - } - let rty = rty::RefinedType::unrefined(builder.build().into()); - self.ctx.register_def(def_id, rty); } fn analyze_local_defs(&mut self) { @@ -211,8 +69,28 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { tracing::info!(?local_def_id, "trusted"); continue; } - let expected = self.ctx.def_ty(local_def_id.to_def_id()).unwrap().clone(); - self.ctx.local_def_analyzer(*local_def_id).run(&expected); + let Some(expected) = self.ctx.concrete_def_ty(local_def_id.to_def_id()) else { + // when the local_def_id is deferred it would be skipped + continue; + }; + + // check polymorphic function def by replacing type params with some opaque type + // (and this is no-op if the function is mono) + let type_builder = TypeBuilder::new(self.tcx, local_def_id.to_def_id()) + .with_param_mapper(|_| rty::Type::int()); + let mut expected = expected.clone(); + let subst = rty::TypeParamSubst::new( + expected + .free_ty_params() + .into_iter() + .map(|ty_param| (ty_param, rty::RefinedType::unrefined(rty::Type::int()))) + .collect(), + ); + expected.subst_ty_params(&subst); + self.ctx + .local_def_analyzer(*local_def_id) + .type_builder(type_builder) + .run(&expected); } } @@ -222,7 +100,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { // TODO: replace code here with relate_* in Env + Refine context (created with empty env) let entry_ty = self .ctx - .def_ty(def_id) + .concrete_def_ty(def_id) .unwrap() .ty .as_function() diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 7e7c737..58e4b8b 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -7,8 +7,10 @@ use rustc_index::IndexVec; use rustc_middle::mir::{self, BasicBlock, Body, Local}; use rustc_middle::ty::{self as mir_ty, TyCtxt, TypeAndMut}; use rustc_span::def_id::LocalDefId; +use rustc_span::symbol::Ident; use crate::analyze; +use crate::annot::{self, AnnotFormula, AnnotParser, ResolverExt as _}; use crate::chc; use crate::pretty::PrettyDisplayExt as _; use crate::refine::{BasicBlockType, TypeBuilder}; @@ -26,9 +28,217 @@ pub struct Analyzer<'tcx, 'ctx> { body: Body<'tcx>, drop_points: HashMap, + type_builder: TypeBuilder<'tcx>, } impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { + fn extract_require_annot(&self, resolver: T) -> Option> + where + T: annot::Resolver, + { + let mut require_annot = None; + for attrs in self.tcx.get_attrs_by_path( + self.local_def_id.to_def_id(), + &analyze::annot::requires_path(), + ) { + if require_annot.is_some() { + unimplemented!(); + } + let ts = analyze::annot::extract_annot_tokens(attrs.clone()); + let require = AnnotParser::new(&resolver).parse_formula(ts).unwrap(); + require_annot = Some(require); + } + require_annot + } + + fn extract_ensure_annot(&self, resolver: T) -> Option> + where + T: annot::Resolver, + { + let mut ensure_annot = None; + for attrs in self.tcx.get_attrs_by_path( + self.local_def_id.to_def_id(), + &analyze::annot::ensures_path(), + ) { + if ensure_annot.is_some() { + unimplemented!(); + } + let ts = analyze::annot::extract_annot_tokens(attrs.clone()); + let ensure = AnnotParser::new(&resolver).parse_formula(ts).unwrap(); + ensure_annot = Some(ensure); + } + ensure_annot + } + + fn extract_param_annots(&self, resolver: T) -> Vec<(Ident, rty::RefinedType)> + where + T: annot::Resolver, + { + let mut param_annots = Vec::new(); + for attrs in self + .tcx + .get_attrs_by_path(self.local_def_id.to_def_id(), &analyze::annot::param_path()) + { + let ts = analyze::annot::extract_annot_tokens(attrs.clone()); + let (ident, ts) = analyze::annot::split_param(&ts); + let param = AnnotParser::new(&resolver).parse_rty(ts).unwrap(); + param_annots.push((ident, param)); + } + param_annots + } + + fn extract_ret_annot(&self, resolver: T) -> Option> + where + T: annot::Resolver, + { + let mut ret_annot = None; + for attrs in self + .tcx + .get_attrs_by_path(self.local_def_id.to_def_id(), &analyze::annot::ret_path()) + { + if ret_annot.is_some() { + unimplemented!(); + } + let ts = analyze::annot::extract_annot_tokens(attrs.clone()); + let ret = AnnotParser::new(&resolver).parse_rty(ts).unwrap(); + ret_annot = Some(ret); + } + ret_annot + } + + pub fn is_annotated_as_trusted(&self) -> bool { + self.tcx + .get_attrs_by_path( + self.local_def_id.to_def_id(), + &analyze::annot::trusted_path(), + ) + .next() + .is_some() + } + + pub fn is_annotated_as_callable(&self) -> bool { + self.tcx + .get_attrs_by_path( + self.local_def_id.to_def_id(), + &analyze::annot::callable_path(), + ) + .next() + .is_some() + } + + // TODO: unify this logic with extraction functions above + pub fn is_fully_annotated(&self) -> bool { + let has_require = self + .tcx + .get_attrs_by_path( + self.local_def_id.to_def_id(), + &analyze::annot::requires_path(), + ) + .next() + .is_some(); + let has_ensure = self + .tcx + .get_attrs_by_path( + self.local_def_id.to_def_id(), + &analyze::annot::ensures_path(), + ) + .next() + .is_some(); + let annotated_params: Vec<_> = self + .tcx + .get_attrs_by_path(self.local_def_id.to_def_id(), &analyze::annot::param_path()) + .map(|attrs| { + let ts = analyze::annot::extract_annot_tokens(attrs.clone()); + let (ident, _) = analyze::annot::split_param(&ts); + ident + }) + .collect(); + let has_ret = self + .tcx + .get_attrs_by_path(self.local_def_id.to_def_id(), &analyze::annot::ret_path()) + .next() + .is_some(); + + let arg_names = self.tcx.fn_arg_names(self.local_def_id.to_def_id()); + let all_params_annotated = arg_names + .iter() + .all(|ident| annotated_params.contains(ident)); + self.is_annotated_as_callable() + || (has_require && has_ensure) + || (all_params_annotated && has_ret) + } + + pub fn expected_ty(&mut self) -> rty::RefinedType { + let sig = self.tcx.fn_sig(self.local_def_id); + let sig = sig.instantiate_identity().skip_binder(); + + let mut param_resolver = analyze::annot::ParamResolver::default(); + for (input_ident, input_ty) in self + .tcx + .fn_arg_names(self.local_def_id.to_def_id()) + .iter() + .zip(sig.inputs()) + { + let input_ty = self.type_builder.build(*input_ty); + param_resolver.push_param(input_ident.name, input_ty.to_sort()); + } + + let mut require_annot = self.extract_require_annot(¶m_resolver); + let mut ensure_annot = { + let output_ty = self.type_builder.build(sig.output()); + let resolver = annot::StackedResolver::default() + .resolver(analyze::annot::ResultResolver::new(output_ty.to_sort())) + .resolver((¶m_resolver).map(rty::RefinedTypeVar::Free)); + self.extract_ensure_annot(resolver) + }; + let param_annots = self.extract_param_annots(¶m_resolver); + let ret_annot = self.extract_ret_annot(¶m_resolver); + + if self.is_annotated_as_callable() { + if require_annot.is_some() || ensure_annot.is_some() { + unimplemented!(); + } + if !param_annots.is_empty() || ret_annot.is_some() { + unimplemented!(); + } + + require_annot = Some(AnnotFormula::top()); + ensure_annot = Some(AnnotFormula::top()); + } + + assert!(require_annot.is_none() || param_annots.is_empty()); + assert!(ensure_annot.is_none() || ret_annot.is_none()); + + let mut builder = self.type_builder.for_function_template(&mut self.ctx, sig); + if let Some(AnnotFormula::Formula(require)) = require_annot { + let formula = require.map_var(|idx| { + if idx.index() == sig.inputs().len() - 1 { + rty::RefinedTypeVar::Value + } else { + rty::RefinedTypeVar::Free(idx) + } + }); + builder.param_refinement(formula.into()); + } + if let Some(AnnotFormula::Formula(ensure)) = ensure_annot { + builder.ret_refinement(ensure.into()); + } + for (ident, annot_rty) in param_annots { + use annot::Resolver as _; + let (idx, _) = param_resolver.resolve(ident).expect("unknown param"); + builder.param_rty(idx, annot_rty); + } + if let Some(ret_rty) = ret_annot { + builder.ret_rty(ret_rty); + } + + // Note that we do not expect predicate variables to be generated here + // when type params are still present in the type. Callers should ensure either + // - type params are fully instantiated, or + // - the function is fully annotated + rty::RefinedType::unrefined(builder.build().into()) + } + fn is_mut_param(&self, param_idx: rty::FunctionParamIdx) -> bool { let param_local = analyze::local_of_function_param(param_idx); self.body.local_decls[param_local].mutability.is_mut() @@ -306,7 +516,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } // function return type is basic block return type let ret_ty = self.body.local_decls[mir::RETURN_PLACE].ty; - let rty = TypeBuilder::new(self.tcx, self.local_def_id.to_def_id()) + let rty = self + .type_builder .for_template(&mut self.ctx) .build_basic_block(live_locals, ret_ty); self.ctx.register_basic_block_ty(self.local_def_id, bb, rty); @@ -321,6 +532,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .basic_block_analyzer(self.local_def_id, bb) .body(self.body.clone()) .drop_points(drop_points) + .type_builder(self.type_builder.clone()) .run(&rty); } } @@ -426,15 +638,22 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let tcx = ctx.tcx; let body = tcx.optimized_mir(local_def_id.to_def_id()).clone(); let drop_points = Default::default(); + let type_builder = TypeBuilder::new(tcx, local_def_id.to_def_id()); Self { ctx, tcx, local_def_id, body, drop_points, + type_builder, } } + pub fn type_builder(&mut self, type_builder: TypeBuilder<'tcx>) -> &mut Self { + self.type_builder = type_builder; + self + } + pub fn run(&mut self, expected: &rty::RefinedType) { let span = tracing::info_span!("def", def = %self.tcx.def_path_str(self.local_def_id)); let _guard = span.enter(); diff --git a/src/chc.rs b/src/chc.rs index 8a3309f..6f5db32 100644 --- a/src/chc.rs +++ b/src/chc.rs @@ -389,7 +389,7 @@ impl Function { } /// A logical term. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Term { Null, Var(V), @@ -984,7 +984,7 @@ impl Pred { } /// An atom is a predicate applied to a list of terms. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Atom { pub pred: Pred, pub args: Vec>, @@ -1077,7 +1077,7 @@ impl Atom { /// While it allows arbitrary [`Atom`] in its `Atom` variant, we only expect atoms with known /// predicates (i.e., predicates other than `Pred::Var`) to appear in formulas. It is our TODO to /// enforce this restriction statically. Also see the definition of [`Body`]. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Formula { Atom(Atom), Not(Box>), @@ -1296,7 +1296,7 @@ impl Formula { } /// The body part of a clause, consisting of atoms and a formula. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Body { pub atoms: Vec>, /// NOTE: This doesn't contain predicate variables. Also see [`Formula`]. diff --git a/src/refine/template.rs b/src/refine/template.rs index bec3f8a..c859419 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -59,6 +59,19 @@ where } } +trait ParamTypeMapper { + fn map_param_ty(&self, ty: rty::ParamType) -> rty::Type; +} + +impl ParamTypeMapper for F +where + F: Fn(rty::ParamType) -> rty::Type, +{ + fn map_param_ty(&self, ty: rty::ParamType) -> rty::Type { + self(ty) + } +} + /// Translates [`mir_ty::Ty`] to [`rty::Type`]. /// /// This struct implements a translation from Rust MIR types to Thrust types. @@ -70,9 +83,13 @@ where pub struct TypeBuilder<'tcx> { tcx: mir_ty::TyCtxt<'tcx>, /// Maps index in [`mir_ty::ParamTy`] to [`rty::TypeParamIdx`]. - /// These indices may differ because we skip lifetime parameters. + /// These indices may differ because we skip lifetime parameters and they always need to be + /// mapped when we translate a [`mir_ty::ParamTy`] to [`rty::ParamType`]. /// See [`rty::TypeParamIdx`] for more details. param_idx_mapping: HashMap, + /// Optionally also want to further map rty::ParamType to other rty::Type before generating + /// templates. This is no-op by default. + param_type_mapper: std::rc::Rc, } impl<'tcx> TypeBuilder<'tcx> { @@ -86,21 +103,31 @@ impl<'tcx> TypeBuilder<'tcx> { mir_ty::GenericParamDefKind::Type { .. } => { param_idx_mapping.insert(i as u32, param_idx_mapping.len().into()); } - mir_ty::GenericParamDefKind::Const { .. } => unimplemented!(), + mir_ty::GenericParamDefKind::Const { .. } => {} } } Self { tcx, param_idx_mapping, + param_type_mapper: std::rc::Rc::new(|ty: rty::ParamType| ty.into()), } } - fn translate_param_type(&self, ty: &mir_ty::ParamTy) -> rty::ParamType { + pub fn with_param_mapper(mut self, mapper: F) -> Self + where + F: Fn(rty::ParamType) -> rty::Type + 'static, + { + self.param_type_mapper = std::rc::Rc::new(mapper); + self + } + + fn translate_param_type(&self, ty: &mir_ty::ParamTy) -> rty::Type { let index = *self .param_idx_mapping .get(&ty.index) .expect("unknown type param idx"); - rty::ParamType::new(index) + let param_ty = rty::ParamType::new(index); + self.param_type_mapper.map_param_ty(param_ty) } // TODO: consolidate two impls @@ -125,7 +152,7 @@ impl<'tcx> TypeBuilder<'tcx> { rty::TupleType::new(elems).into() } mir_ty::TyKind::Never => rty::Type::never(), - mir_ty::TyKind::Param(ty) => self.translate_param_type(ty).into(), + mir_ty::TyKind::Param(ty) => self.translate_param_type(ty), mir_ty::TyKind::FnPtr(sig) => { // TODO: justification for skip_binder let sig = sig.skip_binder(); @@ -251,7 +278,7 @@ where rty::TupleType::new(elems).into() } mir_ty::TyKind::Never => rty::Type::never(), - mir_ty::TyKind::Param(ty) => self.inner.translate_param_type(ty).into(), + mir_ty::TyKind::Param(ty) => self.inner.translate_param_type(ty).vacuous(), mir_ty::TyKind::FnPtr(sig) => { // TODO: justification for skip_binder let sig = sig.skip_binder(); @@ -373,6 +400,17 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { self.ret_rty = Some(rty); self } + + pub fn would_contain_template(&self) -> bool { + if self.param_tys.is_empty() { + return self.ret_rty.is_none(); + } + + let last_param_idx = rty::FunctionParamIdx::from(self.param_tys.len() - 1); + let param_annotated = + self.param_refinement.is_some() || self.param_rtys.contains_key(&last_param_idx); + self.ret_rty.is_none() || !param_annotated + } } impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> diff --git a/src/rty.rs b/src/rty.rs index c706897..ce6ef5e 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -55,7 +55,7 @@ mod subtyping; pub use subtyping::{relate_sub_closed_type, ClauseScope, Subtyping}; mod params; -pub use params::{TypeArgs, TypeParamIdx, TypeParamSubst}; +pub use params::{RefinedTypeArgs, TypeArgs, TypeParamIdx, TypeParamSubst}; rustc_index::newtype_index! { /// An index representing function parameter. @@ -88,7 +88,7 @@ where /// In Thrust, function types are closed. Because of that, function types, thus its parameters and /// return type only refer to the parameters of the function itself using [`FunctionParamIdx`] and /// do not accept other type of variables from the environment. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct FunctionType { pub params: IndexVec>, pub ret: Box>, @@ -156,7 +156,7 @@ impl FunctionType { } /// The kind of a reference, which is either mutable or immutable. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum RefKind { Mut, Immut, @@ -181,7 +181,7 @@ where } /// The kind of a pointer, which is either a reference or an owned pointer. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum PointerKind { Ref(RefKind), Own, @@ -221,7 +221,7 @@ impl PointerKind { } /// A pointer type. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct PointerType { pub kind: PointerKind, pub elem: Box>, @@ -334,7 +334,7 @@ impl PointerType { /// Note that the current implementation uses tuples to represent structs. See /// implementation in `crate::refine::template` module for details. /// It is our TODO to improve the struct representation. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct TupleType { pub elems: Vec>, } @@ -458,7 +458,7 @@ impl EnumDatatypeDef { /// An enum type. /// /// An enum type includes its type arguments and the argument types can refer to outer variables `T`. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct EnumType { pub symbol: chc::DatatypeSymbol, pub args: IndexVec>, @@ -487,7 +487,7 @@ where } impl EnumType { - pub fn new(symbol: chc::DatatypeSymbol, args: TypeArgs) -> Self { + pub fn new(symbol: chc::DatatypeSymbol, args: RefinedTypeArgs) -> Self { EnumType { symbol, args } } @@ -560,7 +560,7 @@ impl EnumType { } /// A type parameter. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ParamType { pub idx: TypeParamIdx, } @@ -589,7 +589,7 @@ impl ParamType { } /// An underlying type of a refinement type. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Type { Int, Bool, @@ -995,7 +995,7 @@ impl ShiftExistential for RefinedTypeVar { /// A formula, potentially equipped with an existential quantifier. /// /// Note: This is not to be confused with [`crate::chc::Formula`] in the [`crate::chc`] module, which is a different notion. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Formula { pub existentials: IndexVec, pub body: chc::Body, @@ -1236,7 +1236,7 @@ impl Instantiator { } /// A refinement type. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct RefinedType { pub ty: Type, pub refinement: Refinement, @@ -1372,7 +1372,7 @@ impl RefinedType { } } - pub fn instantiate_ty_params(&mut self, params: TypeArgs) + pub fn instantiate_ty_params(&mut self, params: RefinedTypeArgs) where FV: chc::Var, { diff --git a/src/rty/params.rs b/src/rty/params.rs index 6414f6a..b76b6d1 100644 --- a/src/rty/params.rs +++ b/src/rty/params.rs @@ -7,7 +7,7 @@ use rustc_index::IndexVec; use crate::chc; -use super::{Closed, RefinedType}; +use super::{Closed, RefinedType, Type}; rustc_index::newtype_index! { /// An index representing a type parameter. @@ -53,7 +53,8 @@ impl TypeParamIdx { } } -pub type TypeArgs = IndexVec>; +pub type RefinedTypeArgs = IndexVec>; +pub type TypeArgs = IndexVec>; /// A substitution for type parameters that maps type parameters to refinement types. #[derive(Debug, Clone)] @@ -71,6 +72,16 @@ impl Default for TypeParamSubst { impl From> for TypeParamSubst { fn from(params: TypeArgs) -> Self { + let subst = params + .into_iter_enumerated() + .map(|(idx, ty)| (idx, RefinedType::unrefined(ty))) + .collect(); + Self { subst } + } +} + +impl From> for TypeParamSubst { + fn from(params: RefinedTypeArgs) -> Self { let subst = params.into_iter_enumerated().collect(); Self { subst } } @@ -112,12 +123,12 @@ impl TypeParamSubst { } } - pub fn into_args(mut self, expected_len: usize, mut default: F) -> TypeArgs + pub fn into_args(mut self, expected_len: usize, mut default: F) -> RefinedTypeArgs where T: chc::Var, F: FnMut(TypeParamIdx) -> RefinedType, { - let mut args = TypeArgs::new(); + let mut args = RefinedTypeArgs::new(); for idx in 0..expected_len { let ty = self .subst diff --git a/tests/ui/fail/adt_poly_fn_poly.rs b/tests/ui/fail/adt_poly_fn_poly.rs new file mode 100644 index 0000000..899f875 --- /dev/null +++ b/tests/ui/fail/adt_poly_fn_poly.rs @@ -0,0 +1,39 @@ +//@error-in-other-file: Unsat + +pub enum X { + A(T), + B(T), +} + +#[thrust::trusted] +#[thrust::requires(true)] +#[thrust::ensures(true)] +fn rand() -> X { unimplemented!() } + +fn is_a(x: &X) -> bool { + match x { + X::A(_) => true, + X::B(_) => false, + } +} + +fn inv(x: X) -> X { + match x { + X::A(i) => X::B(i), + X::B(i) => X::A(i), + } +} + +fn rand_a() -> X { + let x = rand(); + if !is_a(&x) { loop {} } + x +} + +#[thrust::callable] +fn check() { + let x = rand_a::(); + assert!(is_a(&inv(x))); +} + +fn main() {} diff --git a/tests/ui/fail/fn_poly.rs b/tests/ui/fail/fn_poly.rs new file mode 100644 index 0000000..15351dd --- /dev/null +++ b/tests/ui/fail/fn_poly.rs @@ -0,0 +1,9 @@ +//@error-in-other-file: Unsat + +fn left(x: (T, U)) -> T { + x.0 +} + +fn main() { + assert!(left((42, 0)) == 0); +} diff --git a/tests/ui/fail/fn_poly_annot.rs b/tests/ui/fail/fn_poly_annot.rs new file mode 100644 index 0000000..2458a98 --- /dev/null +++ b/tests/ui/fail/fn_poly_annot.rs @@ -0,0 +1,11 @@ +//@error-in-other-file: Unsat + +#[thrust::requires(true)] +#[thrust::ensures(result != x.0)] +fn left(x: (T, U)) -> T { + x.0 +} + +fn main() { + assert!(left((42, 0)) == 42); +} diff --git a/tests/ui/fail/fn_poly_annot_complex.rs b/tests/ui/fail/fn_poly_annot_complex.rs new file mode 100644 index 0000000..e37f5b8 --- /dev/null +++ b/tests/ui/fail/fn_poly_annot_complex.rs @@ -0,0 +1,12 @@ +//@error-in-other-file: Unsat + +#[thrust::requires((x.0 > 0) && (x.1 > 0))] +#[thrust::ensures((result.0 == x.1) && (result.1 == x.0))] +fn swap_positive(x: (i32, i32, T, U)) -> (i32, i32, U, T) { + (x.1, x.0, x.3, x.2) +} + +fn main() { + let result = swap_positive((-5, 10, true, 42)); + assert!(result.0 == 10); +} diff --git a/tests/ui/fail/fn_poly_annot_multi_inst.rs b/tests/ui/fail/fn_poly_annot_multi_inst.rs new file mode 100644 index 0000000..cf8331f --- /dev/null +++ b/tests/ui/fail/fn_poly_annot_multi_inst.rs @@ -0,0 +1,15 @@ +//@error-in-other-file: Unsat + +#[thrust::requires(true)] +#[thrust::ensures(result == x)] +fn id(x: T) -> T { + x +} + +fn main() { + let a = id(42); + assert!(a == 42); + + let b = id(true); + assert!(b == false); +} diff --git a/tests/ui/fail/fn_poly_annot_nested.rs b/tests/ui/fail/fn_poly_annot_nested.rs new file mode 100644 index 0000000..e1e5015 --- /dev/null +++ b/tests/ui/fail/fn_poly_annot_nested.rs @@ -0,0 +1,17 @@ +//@error-in-other-file: Unsat + +#[thrust::requires(true)] +#[thrust::ensures(result == x)] +fn id(x: T) -> T { + x +} + +#[thrust::requires(true)] +#[thrust::ensures(result != x)] +fn apply_twice(x: T) -> T { + id(id(x)) +} + +fn main() { + assert!(apply_twice(42) == 42); +} diff --git a/tests/ui/fail/fn_poly_annot_recursive.rs b/tests/ui/fail/fn_poly_annot_recursive.rs new file mode 100644 index 0000000..aa40960 --- /dev/null +++ b/tests/ui/fail/fn_poly_annot_recursive.rs @@ -0,0 +1,17 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +#[thrust::requires(n >= 0)] +#[thrust::ensures(result == value)] +fn repeat(n: i32, value: T) -> T { + if n == 0 { + value + } else { + repeat(n - 1, value) + } +} + +fn main() { + let result = repeat(-1, 42); + assert!(result == 42); +} diff --git a/tests/ui/fail/fn_poly_annot_ref.rs b/tests/ui/fail/fn_poly_annot_ref.rs new file mode 100644 index 0000000..10ba6b5 --- /dev/null +++ b/tests/ui/fail/fn_poly_annot_ref.rs @@ -0,0 +1,13 @@ +//@error-in-other-file: Unsat + +#[thrust::requires(true)] +#[thrust::ensures(result != x)] +fn id_ref(x: &T) -> &T { + x +} + +fn main() { + let val = 42; + let r = id_ref(&val); + assert!(*r == 42); +} diff --git a/tests/ui/fail/fn_poly_annot_stronger.rs b/tests/ui/fail/fn_poly_annot_stronger.rs new file mode 100644 index 0000000..90738a8 --- /dev/null +++ b/tests/ui/fail/fn_poly_annot_stronger.rs @@ -0,0 +1,13 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +#[thrust::requires(x > 0)] +#[thrust::ensures((result == x) && (result > 0))] +fn pass_positive(x: i32, _dummy: T) -> i32 { + x +} + +fn main() { + let result = pass_positive(-5, true); + assert!(result == -5); +} diff --git a/tests/ui/fail/fn_poly_double_nested.rs b/tests/ui/fail/fn_poly_double_nested.rs new file mode 100644 index 0000000..15040c5 --- /dev/null +++ b/tests/ui/fail/fn_poly_double_nested.rs @@ -0,0 +1,21 @@ +//@error-in-other-file: Unsat + +fn id(x: T) -> T { + x +} + +fn apply_twice(x: T) -> T { + id(id(x)) +} + +fn apply_thrice(x: T) -> T { + apply_twice(id(x)) +} + +fn apply_four(x: T) -> T { + apply_twice(apply_twice(x)) +} + +fn main() { + assert!(apply_four(42) == 43); +} diff --git a/tests/ui/fail/fn_poly_multiple_calls.rs b/tests/ui/fail/fn_poly_multiple_calls.rs new file mode 100644 index 0000000..f510eab --- /dev/null +++ b/tests/ui/fail/fn_poly_multiple_calls.rs @@ -0,0 +1,13 @@ +//@error-in-other-file: Unsat + +fn first(pair: (T, U)) -> T { + pair.0 +} + +fn main() { + let x = first((42, true)); + let y = first((true, 100)); + + assert!(x == 42); + assert!(y == false); +} diff --git a/tests/ui/fail/fn_poly_mut_ref.rs b/tests/ui/fail/fn_poly_mut_ref.rs new file mode 100644 index 0000000..1ab6ff9 --- /dev/null +++ b/tests/ui/fail/fn_poly_mut_ref.rs @@ -0,0 +1,16 @@ +//@error-in-other-file: Unsat + +fn update(x: &mut T, new_val: T) { + *x = new_val; +} + +fn chain_update(x: &mut T, temp: T, final_val: T) { + update(x, temp); + update(x, final_val); +} + +fn main() { + let mut val = 42; + chain_update(&mut val, 100, 200); + assert!(val == 42); +} diff --git a/tests/ui/fail/fn_poly_nested_calls.rs b/tests/ui/fail/fn_poly_nested_calls.rs new file mode 100644 index 0000000..609f136 --- /dev/null +++ b/tests/ui/fail/fn_poly_nested_calls.rs @@ -0,0 +1,17 @@ +//@error-in-other-file: Unsat + +fn id(x: T) -> T { + x +} + +fn apply_twice(x: T) -> T { + id(id(x)) +} + +fn apply_thrice(x: T) -> T { + id(apply_twice(x)) +} + +fn main() { + assert!(apply_thrice(42) == 43); +} diff --git a/tests/ui/fail/fn_poly_param_order.rs b/tests/ui/fail/fn_poly_param_order.rs new file mode 100644 index 0000000..93af596 --- /dev/null +++ b/tests/ui/fail/fn_poly_param_order.rs @@ -0,0 +1,18 @@ +//@error-in-other-file: Unsat + +fn select(a: T, b: U, c: V, which: i32) -> T { + if which == 0 { + a + } else { + a + } +} + +fn rotate(triple: (A, B, C)) -> (B, C, A) { + (triple.1, triple.2, triple.0) +} + +fn main() { + let x = rotate((1, true, 42)); + assert!(x.0 == false); +} diff --git a/tests/ui/fail/fn_poly_recursive.rs b/tests/ui/fail/fn_poly_recursive.rs new file mode 100644 index 0000000..8b2aaba --- /dev/null +++ b/tests/ui/fail/fn_poly_recursive.rs @@ -0,0 +1,22 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +fn repeat(n: i32, value: T) -> T { + if n <= 1 { + value + } else { + repeat(n - 1, repeat(1, value)) + } +} + +fn identity_loop(depth: i32, x: T) -> T { + if depth == 0 { + x + } else { + identity_loop(depth - 1, identity_loop(0, x)) + } +} + +fn main() { + assert!(repeat(5, 42) == 43); +} diff --git a/tests/ui/fail/fn_poly_ref.rs b/tests/ui/fail/fn_poly_ref.rs new file mode 100644 index 0000000..671f8b0 --- /dev/null +++ b/tests/ui/fail/fn_poly_ref.rs @@ -0,0 +1,15 @@ +//@error-in-other-file: Unsat + +fn identity_ref(x: &T) -> &T { + x +} + +fn chain_ref(x: &T) -> &T { + identity_ref(identity_ref(x)) +} + +fn main() { + let val = 42; + let r = chain_ref(&val); + assert!(*r == 43); +} diff --git a/tests/ui/fail/fn_poly_unused_param.rs b/tests/ui/fail/fn_poly_unused_param.rs new file mode 100644 index 0000000..bde86f1 --- /dev/null +++ b/tests/ui/fail/fn_poly_unused_param.rs @@ -0,0 +1,15 @@ +//@error-in-other-file: Unsat + +fn project_first(triple: (T, U, V)) -> T { + triple.0 +} + +fn chain(x: A, _phantom_b: B, _phantom_c: C) -> A { + x +} + +fn main() { + let x = project_first((42, true, 100)); + let y = chain(x, (1, 2), false); + assert!(y == 43); +} diff --git a/tests/ui/pass/adt_poly_fn_poly.rs b/tests/ui/pass/adt_poly_fn_poly.rs new file mode 100644 index 0000000..d3b91f4 --- /dev/null +++ b/tests/ui/pass/adt_poly_fn_poly.rs @@ -0,0 +1,39 @@ +//@check-pass + +pub enum X { + A(T), + B(T), +} + +#[thrust::trusted] +#[thrust::requires(true)] +#[thrust::ensures(true)] +fn rand() -> X { unimplemented!() } + +fn is_a(x: &X) -> bool { + match x { + X::A(_) => true, + X::B(_) => false, + } +} + +fn inv(x: X) -> X { + match x { + X::A(i) => X::B(i), + X::B(i) => X::A(i), + } +} + +fn rand_a() -> X { + let x = rand(); + if !is_a(&x) { loop {} } + x +} + +#[thrust::callable] +fn check() { + let x = rand_a::(); + assert!(!is_a(&inv(x))); +} + +fn main() {} diff --git a/tests/ui/pass/fn_poly.rs b/tests/ui/pass/fn_poly.rs new file mode 100644 index 0000000..4a8e678 --- /dev/null +++ b/tests/ui/pass/fn_poly.rs @@ -0,0 +1,9 @@ +//@check-pass + +fn left(x: (T, U)) -> T { + x.0 +} + +fn main() { + assert!(left((42, 0)) == 42); +} diff --git a/tests/ui/pass/fn_poly_annot.rs b/tests/ui/pass/fn_poly_annot.rs new file mode 100644 index 0000000..3176816 --- /dev/null +++ b/tests/ui/pass/fn_poly_annot.rs @@ -0,0 +1,11 @@ +//@check-pass + +#[thrust::requires(true)] +#[thrust::ensures(result == x.0)] +fn left(x: (T, U)) -> T { + x.0 +} + +fn main() { + assert!(left((42, 0)) == 42); +} diff --git a/tests/ui/pass/fn_poly_annot_complex.rs b/tests/ui/pass/fn_poly_annot_complex.rs new file mode 100644 index 0000000..79d6251 --- /dev/null +++ b/tests/ui/pass/fn_poly_annot_complex.rs @@ -0,0 +1,14 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +#[thrust::requires((n > 0) && (m > 0))] +#[thrust::ensures((result.0 == m) && (result.1 == n))] +fn swap_pair(n: i32, m: i32, _phantom: T) -> (i32, i32) { + (m, n) +} + +fn main() { + let result = swap_pair(5, 10, true); + assert!(result.0 == 10); + assert!(result.1 == 5); +} diff --git a/tests/ui/pass/fn_poly_annot_multi_inst.rs b/tests/ui/pass/fn_poly_annot_multi_inst.rs new file mode 100644 index 0000000..372cd66 --- /dev/null +++ b/tests/ui/pass/fn_poly_annot_multi_inst.rs @@ -0,0 +1,18 @@ +//@check-pass + +#[thrust::requires(true)] +#[thrust::ensures(result == x)] +fn id(x: T) -> T { + x +} + +fn main() { + let a = id(42); + assert!(a == 42); + + let b = id(true); + assert!(b == true); + + let c = id((1, 2)); + assert!(c.0 == 1); +} diff --git a/tests/ui/pass/fn_poly_annot_nested.rs b/tests/ui/pass/fn_poly_annot_nested.rs new file mode 100644 index 0000000..927a393 --- /dev/null +++ b/tests/ui/pass/fn_poly_annot_nested.rs @@ -0,0 +1,17 @@ +//@check-pass + +#[thrust::requires(true)] +#[thrust::ensures(result == x)] +fn id(x: T) -> T { + x +} + +#[thrust::requires(true)] +#[thrust::ensures(result == x)] +fn apply_twice(x: T) -> T { + id(id(x)) +} + +fn main() { + assert!(apply_twice(42) == 42); +} diff --git a/tests/ui/pass/fn_poly_annot_recursive.rs b/tests/ui/pass/fn_poly_annot_recursive.rs new file mode 100644 index 0000000..f7dc255 --- /dev/null +++ b/tests/ui/pass/fn_poly_annot_recursive.rs @@ -0,0 +1,17 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +#[thrust::requires(n >= 0)] +#[thrust::ensures(result == value)] +fn repeat(n: i32, value: T) -> T { + if n == 0 { + value + } else { + repeat(n - 1, value) + } +} + +fn main() { + let result = repeat(5, 42); + assert!(result == 42); +} diff --git a/tests/ui/pass/fn_poly_annot_ref.rs b/tests/ui/pass/fn_poly_annot_ref.rs new file mode 100644 index 0000000..adb6a14 --- /dev/null +++ b/tests/ui/pass/fn_poly_annot_ref.rs @@ -0,0 +1,13 @@ +//@check-pass + +#[thrust::requires(true)] +#[thrust::ensures(result == x)] +fn id_ref(x: &T) -> &T { + x +} + +fn main() { + let val = 42; + let r = id_ref(&val); + assert!(*r == 42); +} diff --git a/tests/ui/pass/fn_poly_annot_stronger.rs b/tests/ui/pass/fn_poly_annot_stronger.rs new file mode 100644 index 0000000..c3ea370 --- /dev/null +++ b/tests/ui/pass/fn_poly_annot_stronger.rs @@ -0,0 +1,13 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +#[thrust::requires(x > 0)] +#[thrust::ensures((result == x) && (result > 0))] +fn pass_positive(x: i32, _dummy: T) -> i32 { + x +} + +fn main() { + let result = pass_positive(42, true); + assert!(result == 42); +} diff --git a/tests/ui/pass/fn_poly_double_nested.rs b/tests/ui/pass/fn_poly_double_nested.rs new file mode 100644 index 0000000..5c9b054 --- /dev/null +++ b/tests/ui/pass/fn_poly_double_nested.rs @@ -0,0 +1,22 @@ +//@check-pass + +fn id(x: T) -> T { + x +} + +fn apply_twice(x: T) -> T { + id(id(x)) +} + +fn apply_thrice(x: T) -> T { + apply_twice(id(x)) +} + +fn apply_four(x: T) -> T { + apply_twice(apply_twice(x)) +} + +fn main() { + assert!(apply_four(42) == 42); + assert!(apply_thrice(true) == true); +} diff --git a/tests/ui/pass/fn_poly_multiple_calls.rs b/tests/ui/pass/fn_poly_multiple_calls.rs new file mode 100644 index 0000000..aa83ba1 --- /dev/null +++ b/tests/ui/pass/fn_poly_multiple_calls.rs @@ -0,0 +1,15 @@ +//@check-pass + +fn first(pair: (T, U)) -> T { + pair.0 +} + +fn main() { + let x = first((42, true)); + let y = first((true, 100)); + let z = first(((1, 2), 3)); + + assert!(x == 42); + assert!(y == true); + assert!(z.0 == 1); +} diff --git a/tests/ui/pass/fn_poly_mut_ref.rs b/tests/ui/pass/fn_poly_mut_ref.rs new file mode 100644 index 0000000..4298a64 --- /dev/null +++ b/tests/ui/pass/fn_poly_mut_ref.rs @@ -0,0 +1,16 @@ +//@check-pass + +fn update(x: &mut T, new_val: T) { + *x = new_val; +} + +fn chain_update(x: &mut T, temp: T, final_val: T) { + update(x, temp); + update(x, final_val); +} + +fn main() { + let mut val = 42; + chain_update(&mut val, 100, 200); + assert!(val == 200); +} diff --git a/tests/ui/pass/fn_poly_nested_calls.rs b/tests/ui/pass/fn_poly_nested_calls.rs new file mode 100644 index 0000000..b2bef0b --- /dev/null +++ b/tests/ui/pass/fn_poly_nested_calls.rs @@ -0,0 +1,18 @@ +//@check-pass + +fn id(x: T) -> T { + x +} + +fn apply_twice(x: T) -> T { + id(id(x)) +} + +fn apply_thrice(x: T) -> T { + id(apply_twice(x)) +} + +fn main() { + assert!(apply_thrice(42) == 42); + assert!(apply_twice(true) == true); +} diff --git a/tests/ui/pass/fn_poly_param_order.rs b/tests/ui/pass/fn_poly_param_order.rs new file mode 100644 index 0000000..41191de --- /dev/null +++ b/tests/ui/pass/fn_poly_param_order.rs @@ -0,0 +1,20 @@ +//@check-pass + +fn select(a: T, b: U, c: V, which: i32) -> T { + if which == 0 { + a + } else { + a + } +} + +fn rotate(triple: (A, B, C)) -> (B, C, A) { + (triple.1, triple.2, triple.0) +} + +fn main() { + let x = rotate((1, true, 42)); + assert!(x.0 == true); + assert!(x.1 == 42); + assert!(x.2 == 1); +} diff --git a/tests/ui/pass/fn_poly_recursive.rs b/tests/ui/pass/fn_poly_recursive.rs new file mode 100644 index 0000000..cfc7e2e --- /dev/null +++ b/tests/ui/pass/fn_poly_recursive.rs @@ -0,0 +1,23 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +fn repeat(n: i32, value: T) -> T { + if n <= 1 { + value + } else { + repeat(n - 1, repeat(1, value)) + } +} + +fn identity_loop(depth: i32, x: T) -> T { + if depth == 0 { + x + } else { + identity_loop(depth - 1, identity_loop(0, x)) + } +} + +fn main() { + assert!(repeat(5, 42) == 42); + assert!(identity_loop(3, true) == true); +} diff --git a/tests/ui/pass/fn_poly_ref.rs b/tests/ui/pass/fn_poly_ref.rs new file mode 100644 index 0000000..fae27ae --- /dev/null +++ b/tests/ui/pass/fn_poly_ref.rs @@ -0,0 +1,15 @@ +//@check-pass + +fn identity_ref(x: &T) -> &T { + x +} + +fn chain_ref(x: &T) -> &T { + identity_ref(identity_ref(x)) +} + +fn main() { + let val = 42; + let r = chain_ref(&val); + assert!(*r == 42); +} diff --git a/tests/ui/pass/fn_poly_unused_param.rs b/tests/ui/pass/fn_poly_unused_param.rs new file mode 100644 index 0000000..e2ec90e --- /dev/null +++ b/tests/ui/pass/fn_poly_unused_param.rs @@ -0,0 +1,15 @@ +//@check-pass + +fn project_first(triple: (T, U, V)) -> T { + triple.0 +} + +fn chain(x: A, _phantom_b: B, _phantom_c: C) -> A { + x +} + +fn main() { + let x = project_first((42, true, 100)); + let y = chain(x, (1, 2), false); + assert!(y == 42); +}