diff --git a/src/analyze.rs b/src/analyze.rs index 574a44b..f8078d5 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -16,6 +16,8 @@ use rustc_middle::mir::{self, BasicBlock, Local}; use rustc_middle::ty::{self as mir_ty, TyCtxt}; use rustc_span::def_id::{DefId, LocalDefId}; +use crate::analyze; +use crate::annot::{AnnotFormula, AnnotParser, Resolver}; use crate::chc; use crate::pretty::PrettyDisplayExt as _; use crate::refine::{self, BasicBlockType, TypeBuilder}; @@ -435,4 +437,54 @@ impl<'tcx> Analyzer<'tcx> { let body = self.tcx.optimized_mir(local_def_id); self.local_fn_sig_with_body(local_def_id, body) } + + fn extract_require_annot( + &self, + def_id: DefId, + resolver: T, + self_type_name: Option, + ) -> Option> + where + T: Resolver, + { + let mut require_annot = None; + let parser = AnnotParser::new(&resolver, self_type_name); + for attrs in self + .tcx + .get_attrs_by_path(def_id, &analyze::annot::requires_path()) + { + if require_annot.is_some() { + unimplemented!(); + } + let ts = analyze::annot::extract_annot_tokens(attrs.clone()); + let require = parser.parse_formula(ts).unwrap(); + require_annot = Some(require); + } + require_annot + } + + fn extract_ensure_annot( + &self, + def_id: DefId, + resolver: T, + self_type_name: Option, + ) -> Option> + where + T: Resolver, + { + let mut ensure_annot = None; + let parser = AnnotParser::new(&resolver, self_type_name); + for attrs in self + .tcx + .get_attrs_by_path(def_id, &analyze::annot::ensures_path()) + { + if ensure_annot.is_some() { + unimplemented!(); + } + let ts = analyze::annot::extract_annot_tokens(attrs.clone()); + let ensure = parser.parse_formula(ts).unwrap(); + ensure_annot = Some(ensure); + } + ensure_annot + } } diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 67033eb..97a20fe 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -53,66 +53,38 @@ pub struct Analyzer<'tcx, 'ctx> { } impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { - fn extract_require_annot(&self, resolver: T) -> Option> - where - T: annot::Resolver, - { - let mut require_annot = None; - for attrs in self.tcx.get_attrs_by_path( - self.local_def_id.to_def_id(), - &analyze::annot::requires_path(), - ) { - if require_annot.is_some() { - unimplemented!(); - } - let ts = analyze::annot::extract_annot_tokens(attrs.clone()); - let require = AnnotParser::new(&resolver).parse_formula(ts).unwrap(); - require_annot = Some(require); - } - require_annot - } - - fn extract_ensure_annot(&self, resolver: T) -> Option> - where - T: annot::Resolver, - { - let mut ensure_annot = None; - for attrs in self.tcx.get_attrs_by_path( - self.local_def_id.to_def_id(), - &analyze::annot::ensures_path(), - ) { - if ensure_annot.is_some() { - unimplemented!(); - } - let ts = analyze::annot::extract_annot_tokens(attrs.clone()); - let ensure = AnnotParser::new(&resolver).parse_formula(ts).unwrap(); - ensure_annot = Some(ensure); - } - ensure_annot - } - - fn extract_param_annots(&self, resolver: T) -> Vec<(Ident, rty::RefinedType)> + fn extract_param_annots( + &self, + resolver: T, + self_type_name: Option, + ) -> Vec<(Ident, rty::RefinedType)> where T: annot::Resolver, { let mut param_annots = Vec::new(); + let parser = AnnotParser::new(&resolver, self_type_name); for attrs in self .tcx .get_attrs_by_path(self.local_def_id.to_def_id(), &analyze::annot::param_path()) { let ts = analyze::annot::extract_annot_tokens(attrs.clone()); let (ident, ts) = analyze::annot::split_param(&ts); - let param = AnnotParser::new(&resolver).parse_rty(ts).unwrap(); + let param = parser.parse_rty(ts).unwrap(); param_annots.push((ident, param)); } param_annots } - fn extract_ret_annot(&self, resolver: T) -> Option> + fn extract_ret_annot( + &self, + resolver: T, + self_type_name: Option, + ) -> Option> where T: annot::Resolver, { let mut ret_annot = None; + let parser = AnnotParser::new(&resolver, self_type_name); for attrs in self .tcx .get_attrs_by_path(self.local_def_id.to_def_id(), &analyze::annot::ret_path()) @@ -121,14 +93,34 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { unimplemented!(); } let ts = analyze::annot::extract_annot_tokens(attrs.clone()); - let ret = AnnotParser::new(&resolver).parse_rty(ts).unwrap(); + let ret = parser.parse_rty(ts).unwrap(); ret_annot = Some(ret); } ret_annot } + fn impl_type(&self) -> Option> { + use rustc_hir::def::DefKind; + + let parent_def_id = self.tcx.parent(self.local_def_id.to_def_id()); + + if !matches!(self.tcx.def_kind(parent_def_id), DefKind::Impl { .. }) { + return None; + } + + let self_ty = self.tcx.type_of(parent_def_id).instantiate_identity(); + + Some(self_ty) + } + pub fn analyze_predicate_definition(&self, local_def_id: LocalDefId) { - let pred_name = self.tcx.item_name(local_def_id.to_def_id()).to_string(); + // predicate's name + let impl_type = self.impl_type(); + let pred_item_name = self.tcx.item_name(local_def_id.to_def_id()).to_string(); + let pred_name = match impl_type { + Some(t) => t.to_string() + "_" + &pred_item_name, + None => pred_item_name, + }; // function's body use rustc_hir::{Block, Expr, ExprKind}; @@ -252,6 +244,17 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { || (all_params_annotated && has_ret) } + pub fn trait_item_id(&self) -> Option { + let impl_item_assoc = self + .tcx + .opt_associated_item(self.local_def_id.to_def_id())?; + let trait_item_id = impl_item_assoc + .trait_item_def_id + .and_then(|id| id.as_local())?; + + Some(trait_item_id) + } + pub fn expected_ty(&mut self) -> rty::RefinedType { let sig = self .ctx @@ -268,16 +271,47 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { param_resolver.push_param(input_ident.name, input_ty.to_sort()); } - let mut require_annot = self.extract_require_annot(¶m_resolver); - let mut ensure_annot = { - let output_ty = self.type_builder.build(sig.output()); - let resolver = annot::StackedResolver::default() - .resolver(analyze::annot::ResultResolver::new(output_ty.to_sort())) - .resolver((¶m_resolver).map(rty::RefinedTypeVar::Free)); - self.extract_ensure_annot(resolver) - }; - let param_annots = self.extract_param_annots(¶m_resolver); - let ret_annot = self.extract_ret_annot(¶m_resolver); + let output_ty = self.type_builder.build(sig.output()); + let result_param_resolver = annot::StackedResolver::default() + .resolver(analyze::annot::ResultResolver::new(output_ty.to_sort())) + .resolver((¶m_resolver).map(rty::RefinedTypeVar::Free)); + + let self_type_name = self.impl_type().map(|ty| ty.to_string()); + + let mut require_annot = self.ctx.extract_require_annot( + self.local_def_id.to_def_id(), + ¶m_resolver, + self_type_name.clone(), + ); + + let mut ensure_annot = self.ctx.extract_ensure_annot( + self.local_def_id.to_def_id(), + &result_param_resolver, + self_type_name.clone(), + ); + + if let Some(trait_item_id) = self.trait_item_id() { + tracing::info!("trait item fonud: {:?}", trait_item_id); + let trait_require_annot = self.ctx.extract_require_annot( + trait_item_id.into(), + ¶m_resolver, + self_type_name.clone(), + ); + let trait_ensure_annot = self.ctx.extract_ensure_annot( + trait_item_id.into(), + &result_param_resolver, + self_type_name.clone(), + ); + + assert!(require_annot.is_none() || trait_require_annot.is_none()); + require_annot = require_annot.or(trait_require_annot); + + assert!(ensure_annot.is_none() || trait_ensure_annot.is_none()); + ensure_annot = ensure_annot.or(trait_ensure_annot); + } + + let param_annots = self.extract_param_annots(¶m_resolver, self_type_name.clone()); + let ret_annot = self.extract_ret_annot(¶m_resolver, self_type_name); if self.is_annotated_as_callable() { if require_annot.is_some() || ensure_annot.is_some() { diff --git a/src/annot.rs b/src/annot.rs index 30bf7d3..769e0e5 100644 --- a/src/annot.rs +++ b/src/annot.rs @@ -250,6 +250,7 @@ impl FormulaOrTerm { /// A parser for refinement type annotations and formula annotations. struct Parser<'a, T> { resolver: T, + self_type_name: Option, cursor: RefTokenTreeCursor<'a>, formula_existentials: HashMap, } @@ -453,6 +454,7 @@ where TokenTree::Delimited(_, _, Delimiter::Parenthesis, s) => { let mut parser = Parser { resolver: self.boxed_resolver(), + self_type_name: self.self_type_name.clone(), cursor: s.trees(), formula_existentials: self.formula_existentials.clone(), }; @@ -493,6 +495,7 @@ where let mut parser = Parser { resolver: self.boxed_resolver(), + self_type_name: self.self_type_name.clone(), cursor: args.trees(), formula_existentials: self.formula_existentials.clone(), }; @@ -518,11 +521,40 @@ where }; let mut parser = Parser { resolver: self.boxed_resolver(), + self_type_name: self.self_type_name.clone(), cursor: s.trees(), formula_existentials: self.formula_existentials.clone(), }; let args = parser.parse_arg_terms()?; parser.end_of_input()?; + + // Identify struct-bound predicates call such as `Self::pred()` + match path.segments.first() { + Some(AnnotPathSegment { + ident: Ident { name: symbol, .. }, + generic_args, + }) if symbol.as_str() == "Self" && generic_args.is_empty() => { + if path.segments.len() != 2 { + unimplemented!("long path beginning with `Self::`"); + } + + let func_name = path.segments.get(1).unwrap().ident.name.as_str(); + let pred_name = if let Some(self_type_name) = &self.self_type_name { + self_type_name.clone() + "_" + func_name + } else { + func_name.to_string() + }; + + let pred_symbol = chc::UserDefinedPred::new(pred_name); + let pred = chc::Pred::UserDefined(pred_symbol); + + let atom = chc::Atom::new(pred, args); + let formula = chc::Formula::Atom(atom); + return Ok(FormulaOrTerm::Formula(formula)); + } + _ => {} + } + let (term, sort) = path.to_datatype_ctor(args); FormulaOrTerm::Term(term, sort) } @@ -908,6 +940,7 @@ where TokenTree::Delimited(_, _, Delimiter::Parenthesis, ts) => { let mut parser = Parser { resolver: self.boxed_resolver(), + self_type_name: self.self_type_name.clone(), cursor: ts.trees(), formula_existentials: self.formula_existentials.clone(), }; @@ -1014,6 +1047,7 @@ where TokenTree::Delimited(_, _, Delimiter::Parenthesis, ts) => { let mut parser = Parser { resolver: self.boxed_resolver(), + self_type_name: self.self_type_name.clone(), cursor: ts.trees(), formula_existentials: self.formula_existentials.clone(), }; @@ -1050,6 +1084,7 @@ where let mut parser = Parser { resolver: self.boxed_resolver(), + self_type_name: self.self_type_name.clone(), cursor: ts.trees(), formula_existentials: self.formula_existentials.clone(), }; @@ -1074,6 +1109,7 @@ where let mut parser = Parser { resolver: RefinementResolver::new(self.boxed_resolver()), + self_type_name: self.self_type_name.clone(), cursor: parser.cursor, formula_existentials: self.formula_existentials.clone(), }; @@ -1199,11 +1235,15 @@ impl<'a, T> StackedResolver<'a, T> { #[derive(Debug, Clone)] pub struct AnnotParser { resolver: T, + self_type_name: Option, } impl AnnotParser { - pub fn new(resolver: T) -> Self { - Self { resolver } + pub fn new(resolver: T, self_type_name: Option) -> Self { + Self { + resolver, + self_type_name, + } } } @@ -1214,6 +1254,7 @@ where pub fn parse_rty(&self, ts: TokenStream) -> Result> { let mut parser = Parser { resolver: &self.resolver, + self_type_name: self.self_type_name.clone(), cursor: ts.trees(), formula_existentials: Default::default(), }; @@ -1225,6 +1266,7 @@ where pub fn parse_formula(&self, ts: TokenStream) -> Result> { let mut parser = Parser { resolver: &self.resolver, + self_type_name: self.self_type_name.clone(), cursor: ts.trees(), formula_existentials: Default::default(), }; diff --git a/tests/ui/fail/annot_preds_trait.rs b/tests/ui/fail/annot_preds_trait.rs new file mode 100644 index 0000000..bd0bdbc --- /dev/null +++ b/tests/ui/fail/annot_preds_trait.rs @@ -0,0 +1,43 @@ +//@error-in-other-file: Unsat +//@compile-flags: -Adead_code -C debug-assertions=off + +// A is represented as Tuple in SMT-LIB2 format. +struct A { + x: i64, +} + +trait Double { + // Support annotations in trait definitions + #[thrust::predicate] + fn is_double(self, doubled: Self) -> bool; + + // This annotations are applied to all implementors of the `Double` trait. + #[thrust::requires(true)] + #[thrust::ensures(Self::is_double(*self, ^self))] + fn double(&mut self); +} + +impl Double for A { + // Write concrete definitions for predicates in `impl` blocks + #[thrust::predicate] + fn is_double(self, doubled: Self) -> bool { + // (tuple_proj.0 self) is equivalent to self.x + // self.x * 3 == doubled.x (this isn't actually doubled!) is written as following: + "(= + (* (tuple_proj.0 self) 3) + (tuple_proj.0 doubled) + )"; true // This definition does not comply with annotations in trait! + } + + // Check if this method complies with annotations in + // trait definition. + fn double(&mut self) { + self.x += self.x; + } +} + +fn main() { + let mut a = A { x: 3 }; + a.double(); + assert!(a.x == 6); +} diff --git a/tests/ui/fail/annot_preds_trait_multi.rs b/tests/ui/fail/annot_preds_trait_multi.rs new file mode 100644 index 0000000..a277358 --- /dev/null +++ b/tests/ui/fail/annot_preds_trait_multi.rs @@ -0,0 +1,71 @@ +//@error-in-other-file: Unsat +//@compile-flags: -Adead_code -C debug-assertions=off + +trait Double { + // Support annotations in trait definitions + #[thrust::predicate] + fn is_double(self, doubled: Self) -> bool; + + // This annotations are applied to all implementors of the `Double` trait. + #[thrust::requires(true)] + #[thrust::ensures(Self::is_double(*self, ^self))] + fn double(&mut self); +} + +// A is represented as Tuple in SMT-LIB2 format. +struct A { + x: i64, +} + +impl Double for A { + #[thrust::predicate] + fn is_double(self, doubled: Self) -> bool { + // self.x * 2 == doubled.x + "(= + (* (tuple_proj.0 self) 2) + (tuple_proj.0 doubled) + )"; true + } + + fn double(&mut self) { + self.x += self.x; + } +} + +// B is represented as Tuple in SMT-LIB2 format. +struct B { + x: i64, + y: i64, +} + +impl Double for B { + #[thrust::predicate] + fn is_double(self, doubled: Self) -> bool { + // self.x * 3 == doubled.x && self.y * 2 == doubled.y (this isn't actually doubled!) + "(and + (= + (* (tuple_proj.0 self) 3) + (tuple_proj.0 doubled) + ) + (= + (* (tuple_proj.1 self) 2) + (tuple_proj.1 doubled) + ) + )"; true // This definition does not comply with annotations in trait! + } + + fn double(&mut self) { + self.x += self.x; + self.y += self.y; + } +} + +fn main() { + let mut a = A { x: 3 }; + a.double(); + assert!(a.x == 6); + + let mut b = B { x: 2, y: 5 }; + b.double(); + assert!(b.x == 4 && b.y == 10); +} diff --git a/tests/ui/pass/annot_preds_trait.rs b/tests/ui/pass/annot_preds_trait.rs new file mode 100644 index 0000000..bb2e37c --- /dev/null +++ b/tests/ui/pass/annot_preds_trait.rs @@ -0,0 +1,43 @@ +//@check-pass +//@compile-flags: -Adead_code -C debug-assertions=off + +// A is represented as Tuple in SMT-LIB2 format. +struct A { + x: i64, +} + +trait Double { + // Support annotations in trait definitions + #[thrust::predicate] + fn is_double(self, doubled: Self) -> bool; + + // This annotations are applied to all implementors of the `Double` trait. + #[thrust::requires(true)] + #[thrust::ensures(Self::is_double(*self, ^self))] + fn double(&mut self); +} + +impl Double for A { + // Write concrete definitions for predicates in `impl` blocks + #[thrust::predicate] + fn is_double(self, doubled: Self) -> bool { + // (tuple_proj.0 self) is equivalent to self.x + // self.x * 2 == doubled.x is written as following: + "(= + (* (tuple_proj.0 self) 2) + (tuple_proj.0 doubled) + )"; true + } + + // Check if this method complies with annotations in + // trait definition. + fn double(&mut self) { + self.x += self.x; + } +} + +fn main() { + let mut a = A { x: 3 }; + a.double(); + assert!(a.x == 6); +} diff --git a/tests/ui/pass/annot_preds_trait_multi.rs b/tests/ui/pass/annot_preds_trait_multi.rs new file mode 100644 index 0000000..a51ff84 --- /dev/null +++ b/tests/ui/pass/annot_preds_trait_multi.rs @@ -0,0 +1,71 @@ +//@check-pass +//@compile-flags: -Adead_code -C debug-assertions=off + +trait Double { + // Support annotations in trait definitions + #[thrust::predicate] + fn is_double(self, doubled: Self) -> bool; + + // This annotations are applied to all implementors of the `Double` trait. + #[thrust::requires(true)] + #[thrust::ensures(Self::is_double(*self, ^self))] + fn double(&mut self); +} + +// A is represented as Tuple in SMT-LIB2 format. +struct A { + x: i64, +} + +impl Double for A { + #[thrust::predicate] + fn is_double(self, doubled: Self) -> bool { + // self.x * 2 == doubled.x + "(= + (* (tuple_proj.0 self) 2) + (tuple_proj.0 doubled) + )"; true + } + + fn double(&mut self) { + self.x += self.x; + } +} + +// B is represented as Tuple in SMT-LIB2 format. +struct B { + x: i64, + y: i64, +} + +impl Double for B { + #[thrust::predicate] + fn is_double(self, doubled: Self) -> bool { + // self.x * 2 == doubled.x && self.y * 2 == doubled.y + "(and + (= + (* (tuple_proj.0 self) 2) + (tuple_proj.0 doubled) + ) + (= + (* (tuple_proj.1 self) 2) + (tuple_proj.1 doubled) + ) + )"; true + } + + fn double(&mut self) { + self.x += self.x; + self.y += self.y; + } +} + +fn main() { + let mut a = A { x: 3 }; + a.double(); + assert!(a.x == 6); + + let mut b = B { x: 2, y: 5 }; + b.double(); + assert!(b.x == 4 && b.y == 10); +}