From 176851038a677aafe7851de873451038d9165bcd Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 28 Dec 2025 15:03:38 +0900 Subject: [PATCH 1/3] register enum def on-demand --- src/analyze.rs | 67 +++++++++++++++++++++++++++++++------- src/analyze/basic_block.rs | 21 ++++++------ src/analyze/crate_.rs | 46 +------------------------- 3 files changed, 67 insertions(+), 67 deletions(-) diff --git a/src/analyze.rs b/src/analyze.rs index e237641..651e518 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -11,6 +11,7 @@ use std::collections::HashMap; use std::rc::Rc; use rustc_hir::lang_items::LangItem; +use rustc_index::IndexVec; use rustc_middle::mir::{self, BasicBlock, Local}; use rustc_middle::ty::{self as mir_ty, TyCtxt}; use rustc_span::def_id::{DefId, LocalDefId}; @@ -174,7 +175,53 @@ impl<'tcx> Analyzer<'tcx> { } } - pub fn register_enum_def(&mut self, def_id: DefId, enum_def: rty::EnumDatatypeDef) { + fn build_enum_def(&self, def_id: DefId) -> rty::EnumDatatypeDef { + let adt = self.tcx.adt_def(def_id); + + let name = refine::datatype_symbol(self.tcx, def_id); + let variants: IndexVec<_, _> = adt + .variants() + .iter() + .map(|variant| { + let name = refine::datatype_symbol(self.tcx, variant.def_id); + // TODO: consider using TyCtxt::tag_for_variant + let discr = resolve_discr(self.tcx, variant.discr); + let field_tys = variant + .fields + .iter() + .map(|field| { + let field_ty = self.tcx.type_of(field.did).instantiate_identity(); + TypeBuilder::new(self.tcx, def_id).build(field_ty) + }) + .collect(); + rty::EnumVariantDef { + name, + discr, + field_tys, + } + }) + .collect(); + + let generics = self.tcx.generics_of(def_id); + let ty_params = (0..generics.count()) + .filter(|idx| { + matches!( + generics.param_at(*idx, self.tcx).kind, + mir_ty::GenericParamDefKind::Type { .. } + ) + }) + .count(); + tracing::debug!(?def_id, ?name, ?ty_params, "ty_params count"); + + rty::EnumDatatypeDef { + name, + ty_params, + variants, + } + } + + pub fn register_enum_def(&self, def_id: DefId) { + let enum_def = self.build_enum_def(def_id); tracing::debug!(def_id = ?def_id, enum_def = ?enum_def, "register_enum_def"); let ctors = enum_def .variants @@ -203,17 +250,13 @@ impl<'tcx> Analyzer<'tcx> { self.system.borrow_mut().datatypes.push(datatype); } - pub fn find_enum_variant( - &self, - ty_sym: &chc::DatatypeSymbol, - v_sym: &chc::DatatypeSymbol, - ) -> Option { - self.enum_defs - .borrow() - .iter() - .find(|(_, d)| &d.name == ty_sym) - .and_then(|(_, d)| d.variants.iter().find(|v| &v.name == v_sym)) - .cloned() + pub fn get_or_register_enum_def(&self, def_id: DefId) -> rty::EnumDatatypeDef { + if let Some(enum_def) = self.enum_defs.borrow().get(&def_id) { + return enum_def.clone(); + } + + self.register_enum_def(def_id); + self.enum_defs.borrow().get(&def_id).unwrap().clone() } pub fn register_def(&mut self, def_id: DefId, rty: rty::RefinedType) { diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index 372e481..c381e1d 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -350,16 +350,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .map(|operand| self.operand_type(operand).boxed()) .collect(); match *kind { - mir::AggregateKind::Adt(did, variant_id, args, _, _) + mir::AggregateKind::Adt(did, variant_idx, args, _, _) if self.tcx.def_kind(did) == DefKind::Enum => { - let adt = self.tcx.adt_def(did); - let ty_sym = refine::datatype_symbol(self.tcx, did); - let variant = adt.variant(variant_id); - let v_sym = refine::datatype_symbol(self.tcx, variant.def_id); - - let enum_variant_def = self.ctx.find_enum_variant(&ty_sym, &v_sym).unwrap(); - let variant_rtys = enum_variant_def + let enum_def = self.ctx.get_or_register_enum_def(did); + let variant_def = &enum_def.variants[variant_idx]; + let variant_rtys = variant_def .field_tys .clone() .into_iter() @@ -386,7 +382,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let sort_args: Vec<_> = rty_args.iter().map(|rty| rty.ty.to_sort()).collect(); - let ty = rty::EnumType::new(ty_sym.clone(), rty_args).into(); + let ty = rty::EnumType::new(enum_def.name.clone(), rty_args).into(); let mut builder = PlaceTypeBuilder::default(); let mut field_terms = Vec::new(); @@ -396,7 +392,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } builder.build( ty, - chc::Term::datatype_ctor(ty_sym, sort_args, v_sym, field_terms), + chc::Term::datatype_ctor( + enum_def.name, + sort_args, + variant_def.name.clone(), + field_terms, + ), ) } _ => PlaceType::tuple(field_tys), diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index 2a17b11..5467630 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -3,13 +3,11 @@ use std::collections::HashSet; use rustc_hir::def::DefKind; -use rustc_index::IndexVec; use rustc_middle::ty::{self as mir_ty, TyCtxt}; use rustc_span::def_id::{DefId, LocalDefId}; use crate::analyze; use crate::chc; -use crate::refine::{self, TypeBuilder}; use crate::rty::{self, ClauseBuilderExt as _}; /// An implementation of local crate analysis. @@ -173,49 +171,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let DefKind::Enum = self.tcx.def_kind(local_def_id) else { continue; }; - let adt = self.tcx.adt_def(local_def_id); - - let name = refine::datatype_symbol(self.tcx, local_def_id.to_def_id()); - let variants: IndexVec<_, _> = adt - .variants() - .iter() - .map(|variant| { - let name = refine::datatype_symbol(self.tcx, variant.def_id); - // TODO: consider using TyCtxt::tag_for_variant - let discr = analyze::resolve_discr(self.tcx, variant.discr); - let field_tys = variant - .fields - .iter() - .map(|field| { - let field_ty = self.tcx.type_of(field.did).instantiate_identity(); - TypeBuilder::new(self.tcx, local_def_id.to_def_id()).build(field_ty) - }) - .collect(); - rty::EnumVariantDef { - name, - discr, - field_tys, - } - }) - .collect(); - - let generics = self.tcx.generics_of(local_def_id); - let ty_params = (0..generics.count()) - .filter(|idx| { - matches!( - generics.param_at(*idx, self.tcx).kind, - mir_ty::GenericParamDefKind::Type { .. } - ) - }) - .count(); - tracing::debug!(?local_def_id, ?name, ?ty_params, "ty_params count"); - - let def = rty::EnumDatatypeDef { - name, - ty_params, - variants, - }; - self.ctx.register_enum_def(local_def_id.to_def_id(), def); + self.ctx.register_enum_def(local_def_id.to_def_id()); } } } From bedcf79b7539567707a2c80ecf9b684dd1ce1382 Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 28 Dec 2025 15:40:44 +0900 Subject: [PATCH 2/3] Use latest enum_defs from Env --- src/analyze.rs | 43 ++++++++++++++++++++++-------- src/analyze/basic_block.rs | 9 ++----- src/refine.rs | 4 ++- src/refine/env.rs | 53 +++++++++++++++++++++++++++---------- tests/ui/fail/option_mut.rs | 10 +++++++ tests/ui/pass/option_mut.rs | 10 +++++++ 6 files changed, 96 insertions(+), 33 deletions(-) create mode 100644 tests/ui/fail/option_mut.rs create mode 100644 tests/ui/pass/option_mut.rs diff --git a/src/analyze.rs b/src/analyze.rs index 651e518..3b972de 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -115,6 +115,33 @@ enum DefTy<'tcx> { Deferred(DeferredDefTy<'tcx>), } +#[derive(Debug, Clone, Default)] +pub struct EnumDefs { + defs: HashMap, +} + +impl EnumDefs { + pub fn find_by_name(&self, name: &chc::DatatypeSymbol) -> Option<&rty::EnumDatatypeDef> { + self.defs.values().find(|def| &def.name == name) + } + + pub fn get(&self, def_id: DefId) -> Option<&rty::EnumDatatypeDef> { + self.defs.get(&def_id) + } + + pub fn insert(&mut self, def_id: DefId, def: rty::EnumDatatypeDef) { + self.defs.insert(def_id, def); + } +} + +impl refine::EnumDefProvider for Rc> { + fn enum_def(&self, name: &chc::DatatypeSymbol) -> rty::EnumDatatypeDef { + self.borrow().find_by_name(name).unwrap().clone() + } +} + +pub type Env = refine::Env>>; + #[derive(Clone)] pub struct Analyzer<'tcx> { tcx: TyCtxt<'tcx>, @@ -132,7 +159,7 @@ pub struct Analyzer<'tcx> { basic_blocks: HashMap>, def_ids: did_cache::DefIdCache<'tcx>, - enum_defs: Rc>>, + enum_defs: Rc>, } impl<'tcx> crate::refine::TemplateRegistry for Analyzer<'tcx> { @@ -251,12 +278,12 @@ impl<'tcx> Analyzer<'tcx> { } pub fn get_or_register_enum_def(&self, def_id: DefId) -> rty::EnumDatatypeDef { - if let Some(enum_def) = self.enum_defs.borrow().get(&def_id) { + if let Some(enum_def) = self.enum_defs.borrow().get(def_id) { return enum_def.clone(); } self.register_enum_def(def_id); - self.enum_defs.borrow().get(&def_id).unwrap().clone() + self.enum_defs.borrow().get(def_id).unwrap().clone() } pub fn register_def(&mut self, def_id: DefId, rty: rty::RefinedType) { @@ -347,14 +374,8 @@ impl<'tcx> Analyzer<'tcx> { self.register_def(panic_def_id, rty::RefinedType::unrefined(panic_ty.into())); } - pub fn new_env(&self) -> refine::Env { - let defs = self - .enum_defs - .borrow() - .values() - .map(|def| (def.name.clone(), def.clone())) - .collect(); - refine::Env::new(defs) + pub fn new_env(&self) -> Env { + refine::Env::new(Rc::clone(&self.enum_defs)) } pub fn crate_analyzer(&mut self) -> crate_::Analyzer<'tcx, '_> { diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index c381e1d..6053489 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -34,7 +34,7 @@ pub struct Analyzer<'tcx, 'ctx> { body: Cow<'tcx, Body<'tcx>>, type_builder: TypeBuilder<'tcx>, - env: Env, + env: analyze::Env, local_decls: IndexVec>, // TODO: remove this prophecy_vars: HashMap, @@ -968,7 +968,7 @@ impl UnbindAtoms { self.existentials.extend(var_ty.existentials); } - pub fn unbind(mut self, env: &Env, ty: rty::RefinedType) -> rty::RefinedType { + pub fn unbind(mut self, env: &analyze::Env, ty: rty::RefinedType) -> rty::RefinedType { let rty::RefinedType { ty: src_ty, refinement, @@ -1137,11 +1137,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self } - pub fn env(&mut self, env: Env) -> &mut Self { - self.env = env; - self - } - pub fn run(&mut self, expected: &BasicBlockType) { let span = tracing::info_span!("bb", bb = ?self.basic_block); let _guard = span.enter(); diff --git a/src/refine.rs b/src/refine.rs index 4736b39..7ef3886 100644 --- a/src/refine.rs +++ b/src/refine.rs @@ -14,7 +14,9 @@ mod basic_block; pub use basic_block::BasicBlockType; mod env; -pub use env::{Assumption, Env, PlaceType, PlaceTypeBuilder, PlaceTypeVar, TempVarIdx, Var}; +pub use env::{ + Assumption, EnumDefProvider, Env, PlaceType, PlaceTypeBuilder, PlaceTypeVar, TempVarIdx, Var, +}; use crate::chc::DatatypeSymbol; use rustc_middle::ty as mir_ty; diff --git a/src/refine/env.rs b/src/refine/env.rs index f18dadb..387db93 100644 --- a/src/refine/env.rs +++ b/src/refine/env.rs @@ -290,6 +290,16 @@ impl PlaceTypeBuilder { } } +pub trait EnumDefProvider { + fn enum_def(&self, name: &chc::DatatypeSymbol) -> rty::EnumDatatypeDef; +} + +impl<'a, T: EnumDefProvider> EnumDefProvider for &'a T { + fn enum_def(&self, name: &chc::DatatypeSymbol) -> rty::EnumDatatypeDef { + ::enum_def(self, name) + } +} + #[derive(Debug, Clone)] pub struct PlaceType { pub ty: rty::Type, @@ -392,16 +402,19 @@ impl PlaceType { builder.build(ty, term) } - pub fn downcast( + pub fn downcast( self, variant_idx: VariantIdx, field_idx: FieldIdx, - enum_defs: &HashMap, - ) -> PlaceType { + enum_defs: T, + ) -> PlaceType + where + T: EnumDefProvider, + { let mut builder = PlaceTypeBuilder::default(); let (inner_ty, inner_term) = builder.subsume(self); let inner_ty = inner_ty.into_enum().unwrap(); - let def = &enum_defs[&inner_ty.symbol]; + let def = enum_defs.enum_def(&inner_ty.symbol); let variant = &def.variants[variant_idx]; let mut field_terms = Vec::new(); @@ -510,18 +523,21 @@ impl PlaceType { pub type Assumption = rty::Formula; #[derive(Debug, Clone)] -pub struct Env { +pub struct Env { locals: BTreeMap>, flow_locals: BTreeMap, temp_vars: IndexVec, assumptions: Vec, - enum_defs: HashMap, + enum_defs: T, enum_expansion_depth_limit: usize, } -impl rty::ClauseScope for Env { +impl rty::ClauseScope for Env +where + T: EnumDefProvider, +{ fn build_clause(&self) -> chc::ClauseBuilder { let mut builder = chc::ClauseBuilder::default(); for (v, sort) in self.dependencies() { @@ -565,7 +581,10 @@ impl rty::ClauseScope for Env { } } -impl refine::TemplateScope for Env { +impl refine::TemplateScope for Env +where + T: EnumDefProvider, +{ type Var = Var; fn build_template(&self) -> rty::TemplateBuilder { let mut builder = rty::TemplateBuilder::default(); @@ -576,8 +595,11 @@ impl refine::TemplateScope for Env { } } -impl Env { - pub fn new(enum_defs: HashMap) -> Self { +impl Env +where + T: EnumDefProvider, +{ + pub fn new(enum_defs: T) -> Self { Env { locals: Default::default(), flow_locals: Default::default(), @@ -727,7 +749,7 @@ impl Env { assert_eq!(temp, self.temp_vars.push(TempVarBinding::Flow(dummy))); } - let def = self.enum_defs[&ty.symbol].clone(); + let def = self.enum_defs.enum_def(&ty.symbol); let matcher_pred = chc::MatcherPred::new(ty.symbol.clone(), ty.arg_sorts()); let mut variants = IndexVec::new(); @@ -933,7 +955,7 @@ impl Env { .collect(); let arg_rtys = { - let def = &self.enum_defs[sym]; + let def = self.enum_defs.enum_def(sym); let expected_tys = def .field_tys() .map(|ty| rty::RefinedType::unrefined(ty.clone().vacuous()).boxed()); @@ -1046,7 +1068,10 @@ impl Path { } } -impl Env { +impl Env +where + T: EnumDefProvider, +{ fn path_type(&self, path: &Path) -> PlaceType { match path { Path::PlaceTy(pty) => pty.clone(), @@ -1078,7 +1103,7 @@ impl Env { .map(|i| self.dropping_assumption(&path.clone().tuple_proj(i))) .collect() } else if let Some(ety) = ty.ty.as_enum() { - let enum_def = self.enum_defs[&ety.symbol].clone(); + let enum_def = self.enum_defs.enum_def(&ety.symbol); let matcher_pred = chc::MatcherPred::new(ety.symbol.clone(), ety.arg_sorts()); let PlaceType { diff --git a/tests/ui/fail/option_mut.rs b/tests/ui/fail/option_mut.rs new file mode 100644 index 0000000..827bc9c --- /dev/null +++ b/tests/ui/fail/option_mut.rs @@ -0,0 +1,10 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +fn main() { + let mut m: Option = Some(1); + if let Some(i) = &mut m { + *i += 2; + } + assert!(matches!(m, Some(1))); +} diff --git a/tests/ui/pass/option_mut.rs b/tests/ui/pass/option_mut.rs new file mode 100644 index 0000000..37f7121 --- /dev/null +++ b/tests/ui/pass/option_mut.rs @@ -0,0 +1,10 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +fn main() { + let mut m: Option = Some(1); + if let Some(i) = &mut m { + *i += 2; + } + assert!(matches!(m, Some(3))); +} From 86bbefa8b640419457996f1d10c1328d58eb7034 Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 28 Dec 2025 19:04:41 +0900 Subject: [PATCH 3/3] Add more test cases --- tests/ui/fail/option_inc.rs | 17 +++++++++++++++++ tests/ui/fail/option_loop.rs | 14 ++++++++++++++ tests/ui/fail/result_mut.rs | 18 ++++++++++++++++++ tests/ui/fail/result_struct.rs | 22 ++++++++++++++++++++++ tests/ui/pass/option_inc.rs | 17 +++++++++++++++++ tests/ui/pass/option_loop.rs | 14 ++++++++++++++ tests/ui/pass/result_mut.rs | 18 ++++++++++++++++++ tests/ui/pass/result_struct.rs | 23 +++++++++++++++++++++++ 8 files changed, 143 insertions(+) create mode 100644 tests/ui/fail/option_inc.rs create mode 100644 tests/ui/fail/option_loop.rs create mode 100644 tests/ui/fail/result_mut.rs create mode 100644 tests/ui/fail/result_struct.rs create mode 100644 tests/ui/pass/option_inc.rs create mode 100644 tests/ui/pass/option_loop.rs create mode 100644 tests/ui/pass/result_mut.rs create mode 100644 tests/ui/pass/result_struct.rs diff --git a/tests/ui/fail/option_inc.rs b/tests/ui/fail/option_inc.rs new file mode 100644 index 0000000..4e98b06 --- /dev/null +++ b/tests/ui/fail/option_inc.rs @@ -0,0 +1,17 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +fn maybe_inc(x: i32, do_it: bool) -> Option { + if do_it { + Some(x + 1) + } else { + None + } +} + +fn main() { + let res = maybe_inc(10, true); + if let Some(v) = res { + assert!(v == 12); + } +} diff --git a/tests/ui/fail/option_loop.rs b/tests/ui/fail/option_loop.rs new file mode 100644 index 0000000..0b91fe9 --- /dev/null +++ b/tests/ui/fail/option_loop.rs @@ -0,0 +1,14 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +fn main() { + let mut opt = Some(5); + while let Some(x) = opt { + if x > 0 { + opt = Some(x - 1); + } else { + opt = None; + } + } + assert!(false); +} diff --git a/tests/ui/fail/result_mut.rs b/tests/ui/fail/result_mut.rs new file mode 100644 index 0000000..7baa46b --- /dev/null +++ b/tests/ui/fail/result_mut.rs @@ -0,0 +1,18 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +fn mutate_res(r: &mut Result) { + match r { + Ok(v) => *v += 1, + Err(e) => *e -= 1, + } +} + +fn main() { + let mut r = Ok(10); + mutate_res(&mut r); + match r { + Ok(v) => assert!(v == 10), + Err(_) => unreachable!(), + } +} diff --git a/tests/ui/fail/result_struct.rs b/tests/ui/fail/result_struct.rs new file mode 100644 index 0000000..cb16420 --- /dev/null +++ b/tests/ui/fail/result_struct.rs @@ -0,0 +1,22 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +struct Point { + x: i32, + y: i32, +} + +fn make_point(x: i32, y: i32) -> Result { + if x >= 0 && y >= 0 { + Ok(Point { x, y }) + } else { + Err(()) + } +} + +fn main() { + let p = make_point(1, 2); + if let Ok(pt) = p { + assert!(pt.x > 1); + } +} diff --git a/tests/ui/pass/option_inc.rs b/tests/ui/pass/option_inc.rs new file mode 100644 index 0000000..aa0530b --- /dev/null +++ b/tests/ui/pass/option_inc.rs @@ -0,0 +1,17 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +fn maybe_inc(x: i32, do_it: bool) -> Option { + if do_it { + Some(x + 1) + } else { + None + } +} + +fn main() { + let res = maybe_inc(10, true); + if let Some(v) = res { + assert!(v == 11); + } +} diff --git a/tests/ui/pass/option_loop.rs b/tests/ui/pass/option_loop.rs new file mode 100644 index 0000000..ea3697f --- /dev/null +++ b/tests/ui/pass/option_loop.rs @@ -0,0 +1,14 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +fn main() { + let mut opt = Some(5); + while let Some(x) = opt { + if x > 0 { + opt = Some(x - 1); + } else { + opt = None; + } + } + assert!(matches!(opt, None)); +} diff --git a/tests/ui/pass/result_mut.rs b/tests/ui/pass/result_mut.rs new file mode 100644 index 0000000..1a5e218 --- /dev/null +++ b/tests/ui/pass/result_mut.rs @@ -0,0 +1,18 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +fn mutate_res(r: &mut Result) { + match r { + Ok(v) => *v += 1, + Err(e) => *e -= 1, + } +} + +fn main() { + let mut r = Ok(10); + mutate_res(&mut r); + match r { + Ok(v) => assert!(v == 11), + Err(_) => unreachable!(), + } +} diff --git a/tests/ui/pass/result_struct.rs b/tests/ui/pass/result_struct.rs new file mode 100644 index 0000000..6e80f99 --- /dev/null +++ b/tests/ui/pass/result_struct.rs @@ -0,0 +1,23 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +struct Point { + x: i32, + y: i32, +} + +fn make_point(x: i32, y: i32) -> Result { + if x >= 0 && y >= 0 { + Ok(Point { x, y }) + } else { + Err(()) + } +} + +fn main() { + let p = make_point(1, 2); + if let Ok(pt) = p { + assert!(pt.x >= 0); + assert!(pt.y >= 0); + } +}