Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 85 additions & 21 deletions src/analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::collections::HashMap;
use std::rc::Rc;

use rustc_hir::lang_items::LangItem;
use rustc_index::IndexVec;
use rustc_middle::mir::{self, BasicBlock, Local};
use rustc_middle::ty::{self as mir_ty, TyCtxt};
use rustc_span::def_id::{DefId, LocalDefId};
Expand Down Expand Up @@ -114,6 +115,33 @@ enum DefTy<'tcx> {
Deferred(DeferredDefTy<'tcx>),
}

#[derive(Debug, Clone, Default)]
pub struct EnumDefs {
defs: HashMap<DefId, rty::EnumDatatypeDef>,
}

impl EnumDefs {
pub fn find_by_name(&self, name: &chc::DatatypeSymbol) -> Option<&rty::EnumDatatypeDef> {
self.defs.values().find(|def| &def.name == name)
}

pub fn get(&self, def_id: DefId) -> Option<&rty::EnumDatatypeDef> {
self.defs.get(&def_id)
}

pub fn insert(&mut self, def_id: DefId, def: rty::EnumDatatypeDef) {
self.defs.insert(def_id, def);
}
}

impl refine::EnumDefProvider for Rc<RefCell<EnumDefs>> {
fn enum_def(&self, name: &chc::DatatypeSymbol) -> rty::EnumDatatypeDef {
self.borrow().find_by_name(name).unwrap().clone()
}
}

pub type Env = refine::Env<Rc<RefCell<EnumDefs>>>;

#[derive(Clone)]
pub struct Analyzer<'tcx> {
tcx: TyCtxt<'tcx>,
Expand All @@ -131,7 +159,7 @@ pub struct Analyzer<'tcx> {
basic_blocks: HashMap<LocalDefId, HashMap<BasicBlock, BasicBlockType>>,
def_ids: did_cache::DefIdCache<'tcx>,

enum_defs: Rc<RefCell<HashMap<DefId, rty::EnumDatatypeDef>>>,
enum_defs: Rc<RefCell<EnumDefs>>,
}

impl<'tcx> crate::refine::TemplateRegistry for Analyzer<'tcx> {
Expand Down Expand Up @@ -174,7 +202,53 @@ impl<'tcx> Analyzer<'tcx> {
}
}

pub fn register_enum_def(&mut self, def_id: DefId, enum_def: rty::EnumDatatypeDef) {
fn build_enum_def(&self, def_id: DefId) -> rty::EnumDatatypeDef {
let adt = self.tcx.adt_def(def_id);

let name = refine::datatype_symbol(self.tcx, def_id);
let variants: IndexVec<_, _> = adt
.variants()
.iter()
.map(|variant| {
let name = refine::datatype_symbol(self.tcx, variant.def_id);
// TODO: consider using TyCtxt::tag_for_variant
let discr = resolve_discr(self.tcx, variant.discr);
let field_tys = variant
.fields
.iter()
.map(|field| {
let field_ty = self.tcx.type_of(field.did).instantiate_identity();
TypeBuilder::new(self.tcx, def_id).build(field_ty)
})
.collect();
rty::EnumVariantDef {
name,
discr,
field_tys,
}
})
.collect();

let generics = self.tcx.generics_of(def_id);
let ty_params = (0..generics.count())
.filter(|idx| {
matches!(
generics.param_at(*idx, self.tcx).kind,
mir_ty::GenericParamDefKind::Type { .. }
)
})
.count();
tracing::debug!(?def_id, ?name, ?ty_params, "ty_params count");

rty::EnumDatatypeDef {
name,
ty_params,
variants,
}
}

pub fn register_enum_def(&self, def_id: DefId) {
let enum_def = self.build_enum_def(def_id);
tracing::debug!(def_id = ?def_id, enum_def = ?enum_def, "register_enum_def");
let ctors = enum_def
.variants
Expand Down Expand Up @@ -203,17 +277,13 @@ impl<'tcx> Analyzer<'tcx> {
self.system.borrow_mut().datatypes.push(datatype);
}

pub fn find_enum_variant(
&self,
ty_sym: &chc::DatatypeSymbol,
v_sym: &chc::DatatypeSymbol,
) -> Option<rty::EnumVariantDef> {
self.enum_defs
.borrow()
.iter()
.find(|(_, d)| &d.name == ty_sym)
.and_then(|(_, d)| d.variants.iter().find(|v| &v.name == v_sym))
.cloned()
pub fn get_or_register_enum_def(&self, def_id: DefId) -> rty::EnumDatatypeDef {
if let Some(enum_def) = self.enum_defs.borrow().get(def_id) {
return enum_def.clone();
}

self.register_enum_def(def_id);
self.enum_defs.borrow().get(def_id).unwrap().clone()
}

pub fn register_def(&mut self, def_id: DefId, rty: rty::RefinedType) {
Expand Down Expand Up @@ -304,14 +374,8 @@ impl<'tcx> Analyzer<'tcx> {
self.register_def(panic_def_id, rty::RefinedType::unrefined(panic_ty.into()));
}

pub fn new_env(&self) -> refine::Env {
let defs = self
.enum_defs
.borrow()
.values()
.map(|def| (def.name.clone(), def.clone()))
.collect();
refine::Env::new(defs)
pub fn new_env(&self) -> Env {
refine::Env::new(Rc::clone(&self.enum_defs))
}

pub fn crate_analyzer(&mut self) -> crate_::Analyzer<'tcx, '_> {
Expand Down
30 changes: 13 additions & 17 deletions src/analyze/basic_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub struct Analyzer<'tcx, 'ctx> {
body: Cow<'tcx, Body<'tcx>>,

type_builder: TypeBuilder<'tcx>,
env: Env,
env: analyze::Env,
local_decls: IndexVec<Local, mir::LocalDecl<'tcx>>,
// TODO: remove this
prophecy_vars: HashMap<usize, TempVarIdx>,
Expand Down Expand Up @@ -350,16 +350,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
.map(|operand| self.operand_type(operand).boxed())
.collect();
match *kind {
mir::AggregateKind::Adt(did, variant_id, args, _, _)
mir::AggregateKind::Adt(did, variant_idx, args, _, _)
if self.tcx.def_kind(did) == DefKind::Enum =>
{
let adt = self.tcx.adt_def(did);
let ty_sym = refine::datatype_symbol(self.tcx, did);
let variant = adt.variant(variant_id);
let v_sym = refine::datatype_symbol(self.tcx, variant.def_id);

let enum_variant_def = self.ctx.find_enum_variant(&ty_sym, &v_sym).unwrap();
let variant_rtys = enum_variant_def
let enum_def = self.ctx.get_or_register_enum_def(did);
let variant_def = &enum_def.variants[variant_idx];
let variant_rtys = variant_def
.field_tys
.clone()
.into_iter()
Expand All @@ -386,7 +382,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {

let sort_args: Vec<_> =
rty_args.iter().map(|rty| rty.ty.to_sort()).collect();
let ty = rty::EnumType::new(ty_sym.clone(), rty_args).into();
let ty = rty::EnumType::new(enum_def.name.clone(), rty_args).into();

let mut builder = PlaceTypeBuilder::default();
let mut field_terms = Vec::new();
Expand All @@ -396,7 +392,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
}
builder.build(
ty,
chc::Term::datatype_ctor(ty_sym, sort_args, v_sym, field_terms),
chc::Term::datatype_ctor(
enum_def.name,
sort_args,
variant_def.name.clone(),
field_terms,
),
)
}
_ => PlaceType::tuple(field_tys),
Expand Down Expand Up @@ -967,7 +968,7 @@ impl<T> UnbindAtoms<T> {
self.existentials.extend(var_ty.existentials);
}

pub fn unbind(mut self, env: &Env, ty: rty::RefinedType<Var>) -> rty::RefinedType<T> {
pub fn unbind(mut self, env: &analyze::Env, ty: rty::RefinedType<Var>) -> rty::RefinedType<T> {
let rty::RefinedType {
ty: src_ty,
refinement,
Expand Down Expand Up @@ -1136,11 +1137,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
self
}

pub fn env(&mut self, env: Env) -> &mut Self {
self.env = env;
self
}

pub fn run(&mut self, expected: &BasicBlockType) {
let span = tracing::info_span!("bb", bb = ?self.basic_block);
let _guard = span.enter();
Expand Down
46 changes: 1 addition & 45 deletions src/analyze/crate_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
use std::collections::HashSet;

use rustc_hir::def::DefKind;
use rustc_index::IndexVec;
use rustc_middle::ty::{self as mir_ty, TyCtxt};
use rustc_span::def_id::{DefId, LocalDefId};

use crate::analyze;
use crate::chc;
use crate::refine::{self, TypeBuilder};
use crate::rty::{self, ClauseBuilderExt as _};

/// An implementation of local crate analysis.
Expand Down Expand Up @@ -173,49 +171,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
let DefKind::Enum = self.tcx.def_kind(local_def_id) else {
continue;
};
let adt = self.tcx.adt_def(local_def_id);

let name = refine::datatype_symbol(self.tcx, local_def_id.to_def_id());
let variants: IndexVec<_, _> = adt
.variants()
.iter()
.map(|variant| {
let name = refine::datatype_symbol(self.tcx, variant.def_id);
// TODO: consider using TyCtxt::tag_for_variant
let discr = analyze::resolve_discr(self.tcx, variant.discr);
let field_tys = variant
.fields
.iter()
.map(|field| {
let field_ty = self.tcx.type_of(field.did).instantiate_identity();
TypeBuilder::new(self.tcx, local_def_id.to_def_id()).build(field_ty)
})
.collect();
rty::EnumVariantDef {
name,
discr,
field_tys,
}
})
.collect();

let generics = self.tcx.generics_of(local_def_id);
let ty_params = (0..generics.count())
.filter(|idx| {
matches!(
generics.param_at(*idx, self.tcx).kind,
mir_ty::GenericParamDefKind::Type { .. }
)
})
.count();
tracing::debug!(?local_def_id, ?name, ?ty_params, "ty_params count");

let def = rty::EnumDatatypeDef {
name,
ty_params,
variants,
};
self.ctx.register_enum_def(local_def_id.to_def_id(), def);
self.ctx.register_enum_def(local_def_id.to_def_id());
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/refine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ mod basic_block;
pub use basic_block::BasicBlockType;

mod env;
pub use env::{Assumption, Env, PlaceType, PlaceTypeBuilder, PlaceTypeVar, TempVarIdx, Var};
pub use env::{
Assumption, EnumDefProvider, Env, PlaceType, PlaceTypeBuilder, PlaceTypeVar, TempVarIdx, Var,
};

use crate::chc::DatatypeSymbol;
use rustc_middle::ty as mir_ty;
Expand Down
Loading
Loading