From 32018f160d582094e5cc3fb6bf660586591b4700 Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 28 Dec 2025 19:00:41 +0900 Subject: [PATCH] Fix unsoundness when expanding variant of polymorphic enum --- src/analyze/basic_block.rs | 6 +- src/chc.rs | 12 ++++ src/refine/env.rs | 82 +++++++++++++++-------- tests/ui/fail/unused_variant_predicate.rs | 13 ++++ tests/ui/pass/unused_variant_predicate.rs | 13 ++++ 5 files changed, 96 insertions(+), 30 deletions(-) create mode 100644 tests/ui/fail/unused_variant_predicate.rs create mode 100644 tests/ui/pass/unused_variant_predicate.rs diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index 372e481..0c71333 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -511,7 +511,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { F: FnOnce(&mut Self) -> T, { let old_env = self.env.clone(); - self.env.assume(assumption); + self.env.assume(assumption, None); let result = callback(self); self.env = old_env; result @@ -682,7 +682,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let (_, rvalue_term) = builder.subsume(rvalue_ty); builder.push_formula(local_term.mut_final().equal_to(rvalue_term)); let assumption = builder.build_assumption(); - self.env.assume(assumption); + self.env.assume(assumption, None); } } @@ -1070,7 +1070,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { assumption.existentials.extend(existentials); } - self.env.assume(assumption); + self.env.assume(assumption, None); } fn unbind_atoms(&self) -> UnbindAtoms { diff --git a/src/chc.rs b/src/chc.rs index a5f046e..f0ac8b8 100644 --- a/src/chc.rs +++ b/src/chc.rs @@ -371,6 +371,8 @@ impl Function { Self::LE => Sort::bool(), Self::LT => Sort::bool(), Self::NOT => Sort::bool(), + Self::AND => Sort::bool(), + Self::OR => Sort::bool(), Self::NEG => Sort::int(), _ => unimplemented!(), } @@ -385,6 +387,8 @@ impl Function { pub const LE: Function = Function::infix("<="); pub const LT: Function = Function::infix("<"); pub const NOT: Function = Function::new("not"); + pub const AND: Function = Function::new("and"); + pub const OR: Function = Function::new("or"); pub const NEG: Function = Function::new("-"); } @@ -682,6 +686,14 @@ impl Term { Term::App(Function::NOT, vec![self]) } + pub fn bool_and(self, other: Self) -> Self { + Term::App(Function::AND, vec![self, other]) + } + + pub fn bool_or(self, other: Self) -> Self { + Term::App(Function::OR, vec![self, other]) + } + pub fn neg(self) -> Self { Term::App(Function::NEG, vec![self]) } diff --git a/src/refine/env.rs b/src/refine/env.rs index f18dadb..086219e 100644 --- a/src/refine/env.rs +++ b/src/refine/env.rs @@ -603,6 +603,7 @@ impl Env { ty: rty::PointerType, refinement: rty::Refinement, depth: usize, + guard: Option>, ) { // note that the given var is unbound here, so be careful of using indices around temp_vars let current_refinement = refinement @@ -623,7 +624,7 @@ impl Env { }; let mut inner_ty = *ty.elem; inner_ty.extend_refinement(current_refinement); - self.bind_impl(current.into(), inner_ty, depth); + self.bind_impl(current.into(), inner_ty, depth, guard); } fn bind_mut( @@ -632,6 +633,7 @@ impl Env { ty: rty::PointerType, refinement: rty::Refinement, depth: usize, + guard: Option>, ) { // note that the given var is unbound here, so be careful of using indices around temp_vars let next_index = self.temp_vars.next_index(); @@ -661,7 +663,7 @@ impl Env { ); let mut inner_ty = *ty.elem; inner_ty.extend_refinement(current_refinement); - self.bind_impl(current.into(), inner_ty, depth); + self.bind_impl(current.into(), inner_ty, depth, guard); } fn bind_tuple( @@ -670,6 +672,7 @@ impl Env { ty: rty::TupleType, refinement: rty::Refinement, depth: usize, + guard: Option>, ) { if let Var::Temp(temp) = var { // XXX: allocate `temp` once to invoke bind_var recursively @@ -681,7 +684,7 @@ impl Env { for elem in &ty.elems { let x = self.temp_vars.next_index(); xs.push(x); - self.bind_impl(x.into(), elem.clone(), depth); + self.bind_impl(x.into(), elem.clone(), depth, guard.clone()); } let assumption = { let tuple_ty = PlaceType::tuple( @@ -702,7 +705,7 @@ impl Env { existentials.extend(refinement.existentials); Assumption::new(existentials, formula) }; - self.assume(assumption); + self.assume(assumption, guard); let binding = FlowBinding::Tuple(xs.clone()); match var { Var::Local(local) => { @@ -720,6 +723,7 @@ impl Env { ty: rty::EnumType, refinement: rty::Refinement, depth: usize, + guard: Option>, ) { if let Var::Temp(temp) = var { // XXX: allocate `temp` once to invoke bind_var recursively @@ -730,15 +734,29 @@ impl Env { let def = self.enum_defs[&ty.symbol].clone(); let matcher_pred = chc::MatcherPred::new(ty.symbol.clone(), ty.arg_sorts()); + let discr_var = self + .temp_vars + .push(TempVarBinding::Type(rty::RefinedType::unrefined( + rty::Type::int(), + ))); + let mut variants = IndexVec::new(); - for variant_def in &def.variants { + for (variant_idx, variant_def) in def.variants.iter_enumerated() { let mut fields = IndexVec::new(); + let variant_guard = { + let discr_term = chc::Term::var(discr_var.into()); + let condition = discr_term.eq(chc::Term::int(variant_def.discr as i64)); + match guard.clone() { + Some(g) => g.bool_and(condition), + None => condition, + } + }; for field_ty in &variant_def.field_tys { let x = self.temp_vars.next_index(); fields.push(x); let mut field_ty = rty::RefinedType::unrefined(field_ty.clone().vacuous()); field_ty.instantiate_ty_params(ty.args.clone()); - self.bind_impl(x.into(), field_ty.boxed(), depth); + self.bind_impl(x.into(), field_ty.boxed(), depth, Some(variant_guard.clone())); } variants.push(FlowBindingVariant { fields }); } @@ -773,11 +791,6 @@ impl Env { assumption .body .push_conj(chc::Atom::new(matcher_pred.into(), pred_args)); - let discr_var = self - .temp_vars - .push(TempVarBinding::Type(rty::RefinedType::unrefined( - rty::Type::int(), - ))); assumption .body .push_conj( @@ -786,7 +799,7 @@ impl Env { chc::Term::var(value_var_ev.into()), )), ); - self.assume(assumption); + self.assume(assumption, guard); let binding = FlowBinding::Enum { discr: discr_var, @@ -803,7 +816,14 @@ impl Env { } } - fn bind_var(&mut self, var: Var, rty: rty::RefinedType) { + fn bind_var(&mut self, var: Var, mut rty: rty::RefinedType, guard: Option>) { + if let Some(guard) = guard { + let guard_false = guard + .equal_to(chc::Term::bool(false)) + .map_var(rty::RefinedTypeVar::Free); + let body = std::mem::take(&mut rty.refinement.body); + rty.refinement.body = chc::Formula::Or(vec![chc::Formula::Atom(guard_false), body.formula]).into(); + } match var { Var::Local(local) => { self.locals.insert(local, rty); @@ -814,43 +834,51 @@ impl Env { } } - fn bind_impl(&mut self, var: Var, rty: rty::RefinedType, depth: usize) { + fn bind_impl(&mut self, var: Var, rty: rty::RefinedType, depth: usize, guard: Option>) { if depth >= self.enum_expansion_depth_limit { - self.bind_var(var, rty); + self.bind_var(var, rty, guard); return; } match rty.ty { - rty::Type::Pointer(ty) if ty.is_own() => self.bind_own(var, ty, rty.refinement, depth), - rty::Type::Pointer(ty) if ty.is_mut() => self.bind_mut(var, ty, rty.refinement, depth), + rty::Type::Pointer(ty) if ty.is_own() => self.bind_own(var, ty, rty.refinement, depth, guard), + rty::Type::Pointer(ty) if ty.is_mut() => self.bind_mut(var, ty, rty.refinement, depth, guard), rty::Type::Tuple(ty) if !ty.is_unit() => { - self.bind_tuple(var, ty, rty.refinement, depth) + self.bind_tuple(var, ty, rty.refinement, depth, guard) } - rty::Type::Enum(ty) => self.bind_enum(var, ty, rty.refinement, depth + 1), - _ => self.bind_var(var, rty), + rty::Type::Enum(ty) => self.bind_enum(var, ty, rty.refinement, depth + 1, guard), + _ => self.bind_var(var, rty, guard), } } pub fn mut_bind(&mut self, local: Local, rty: rty::RefinedType) { let rty_disp = rty.clone(); - self.bind_impl(local.into(), rty, 0); + self.bind_impl(local.into(), rty, 0, None); tracing::debug!(local = ?local, rty = %rty_disp.display(), place_type = %self.local_type(local).display(), "mut_bind"); } pub fn immut_bind(&mut self, local: Local, rty: rty::RefinedType) { let rty_disp = rty.clone(); - self.bind_var(local.into(), rty); + self.bind_var(local.into(), rty, None); tracing::debug!(local = ?local, rty = %rty_disp.display(), place_type = %self.local_type(local).display(), "immut_bind"); } - pub fn assume(&mut self, assumption: impl Into) { - let assumption = assumption.into(); + pub fn assume(&mut self, assumption: impl Into, guard: Option>) { + let mut assumption = assumption.into(); + if let Some(guard) = guard { + let guard_false = guard + .equal_to(chc::Term::bool(false)) + .map_var(PlaceTypeVar::Var); + let body = std::mem::take(&mut assumption.body); + assumption.body = chc::Formula::Or(vec![chc::Formula::Atom(guard_false), body.formula]).into(); + } tracing::debug!(assumption = %assumption.display(), "assume"); self.assumptions.push(assumption); } pub fn extend_assumptions(&mut self, assumptions: Vec>) { - self.assumptions - .extend(assumptions.into_iter().map(Into::into)); + for assumption in assumptions { + self.assume(assumption, None); + } } pub fn dependencies(&self) -> impl Iterator + '_ { @@ -1146,7 +1174,7 @@ impl Env { pub fn drop_local(&mut self, local: Local) { let assumption = self.dropping_assumption(&Path::Local(local)); if !assumption.is_top() { - self.assume(assumption); + self.assume(assumption, None); } } } diff --git a/tests/ui/fail/unused_variant_predicate.rs b/tests/ui/fail/unused_variant_predicate.rs new file mode 100644 index 0000000..bd8c500 --- /dev/null +++ b/tests/ui/fail/unused_variant_predicate.rs @@ -0,0 +1,13 @@ +//@error-in-other-file: Unsat + +enum X { + None1, + None2, + Some(T), +} + +fn main() { + let mut opt: X = X::None1; + opt = X::None2; + assert!(matches!(opt, X::None1)); +} diff --git a/tests/ui/pass/unused_variant_predicate.rs b/tests/ui/pass/unused_variant_predicate.rs new file mode 100644 index 0000000..15bd1e4 --- /dev/null +++ b/tests/ui/pass/unused_variant_predicate.rs @@ -0,0 +1,13 @@ +//@check-pass + +enum X { + None1, + None2, + Some(T), +} + +fn main() { + let mut opt: X = X::None1; + opt = X::None2; + assert!(matches!(opt, X::None2)); +}