From ef7e6f4abfcd02ce0a9b66ad4db2f5a4266eb98f Mon Sep 17 00:00:00 2001 From: Artem Agvanian Date: Tue, 16 Jul 2024 14:27:15 -0700 Subject: [PATCH] Towards global transformations - Add edge annotations to the call graph - Implement global transformation pass infrastructure - Implement `dump_mir` as a global transformation pass --- .../compiler_interface.rs | 21 +- kani-compiler/src/kani_middle/mod.rs | 43 +-- kani-compiler/src/kani_middle/reachability.rs | 302 +++++++++++------- .../kani_middle/transform/dump_mir_pass.rs | 68 ++++ .../src/kani_middle/transform/mod.rs | 57 +++- 5 files changed, 324 insertions(+), 167 deletions(-) create mode 100644 kani-compiler/src/kani_middle/transform/dump_mir_pass.rs diff --git a/kani-compiler/src/codegen_cprover_gotoc/compiler_interface.rs b/kani-compiler/src/codegen_cprover_gotoc/compiler_interface.rs index abb69a655472..8685f5631b65 100644 --- a/kani-compiler/src/codegen_cprover_gotoc/compiler_interface.rs +++ b/kani-compiler/src/codegen_cprover_gotoc/compiler_interface.rs @@ -7,6 +7,7 @@ use crate::args::ReachabilityType; use crate::codegen_cprover_gotoc::GotocCtx; use crate::kani_middle::analysis; use crate::kani_middle::attributes::{is_test_harness_description, KaniAttributes}; +use crate::kani_middle::check_reachable_items; use crate::kani_middle::codegen_units::{CodegenUnit, CodegenUnits}; use crate::kani_middle::metadata::gen_test_metadata; use crate::kani_middle::provide; @@ -14,7 +15,6 @@ use crate::kani_middle::reachability::{ collect_reachable_items, filter_const_crate_items, filter_crate_items, }; use crate::kani_middle::transform::BodyTransformation; -use crate::kani_middle::{check_reachable_items, dump_mir_items}; use crate::kani_queries::QueryDb; use cbmc::goto_program::Location; use cbmc::irep::goto_binary_serde::write_goto_binary_file; @@ -89,11 +89,26 @@ impl GotocCodegenBackend { check_contract: Option, mut transformer: BodyTransformation, ) -> (GotocCtx<'tcx>, Vec, Option) { - let items = with_timer( + let (items, call_graph) = with_timer( || collect_reachable_items(tcx, &mut transformer, starting_items), "codegen reachability analysis", ); - dump_mir_items(tcx, &mut transformer, &items, &symtab_goto.with_extension("kani.mir")); + + // Retrieve all instances from the currently codegened items. + let instances = items + .iter() + .filter_map(|item| match item { + MonoItem::Fn(instance) => Some(*instance), + MonoItem::Static(static_def) => { + let instance: Instance = (*static_def).into(); + instance.has_body().then_some(instance) + } + MonoItem::GlobalAsm(_) => None, + }) + .collect(); + + // Apply all transformation passes, including global passes. + transformer.apply_passes(tcx, starting_items, instances, call_graph); // Follow rustc naming convention (cx is abbrev for context). // https://rustc-dev-guide.rust-lang.org/conventions.html#naming-conventions diff --git a/kani-compiler/src/kani_middle/mod.rs b/kani-compiler/src/kani_middle/mod.rs index 17b08b687e30..3c300e9da52c 100644 --- a/kani-compiler/src/kani_middle/mod.rs +++ b/kani-compiler/src/kani_middle/mod.rs @@ -4,9 +4,7 @@ //! and transformations. use std::collections::HashSet; -use std::path::Path; -use crate::kani_middle::transform::BodyTransformation; use crate::kani_queries::QueryDb; use rustc_hir::{def::DefKind, def_id::LOCAL_CRATE}; use rustc_middle::span_bug; @@ -15,18 +13,14 @@ use rustc_middle::ty::layout::{ LayoutOfHelpers, TyAndLayout, }; use rustc_middle::ty::{self, Instance as InstanceInternal, Ty as TyInternal, TyCtxt}; -use rustc_session::config::OutputType; use rustc_smir::rustc_internal; use rustc_span::source_map::respan; use rustc_span::Span; use rustc_target::abi::call::FnAbi; use rustc_target::abi::{HasDataLayout, TargetDataLayout}; -use stable_mir::mir::mono::{Instance, MonoItem}; +use stable_mir::mir::mono::MonoItem; use stable_mir::ty::{FnDef, RigidTy, Span as SpanStable, TyKind}; use stable_mir::CrateDef; -use std::fs::File; -use std::io::BufWriter; -use std::io::Write; use self::attributes::KaniAttributes; @@ -92,41 +86,6 @@ pub fn check_reachable_items(tcx: TyCtxt, queries: &QueryDb, items: &[MonoItem]) tcx.dcx().abort_if_errors(); } -/// Print MIR for the reachable items if the `--emit mir` option was provided to rustc. -pub fn dump_mir_items( - tcx: TyCtxt, - transformer: &mut BodyTransformation, - items: &[MonoItem], - output: &Path, -) { - /// Convert MonoItem into a DefId. - /// Skip stuff that we cannot generate the MIR items. - fn get_instance(item: &MonoItem) -> Option { - match item { - // Exclude FnShims and others that cannot be dumped. - MonoItem::Fn(instance) => Some(*instance), - MonoItem::Static(def) => { - let instance: Instance = (*def).into(); - instance.has_body().then_some(instance) - } - MonoItem::GlobalAsm(_) => None, - } - } - - if tcx.sess.opts.output_types.contains_key(&OutputType::Mir) { - // Create output buffer. - let out_file = File::create(output).unwrap(); - let mut writer = BufWriter::new(out_file); - - // For each def_id, dump their MIR - for instance in items.iter().filter_map(get_instance) { - writeln!(writer, "// Item: {} ({})", instance.name(), instance.mangled_name()).unwrap(); - let body = transformer.body(tcx, instance); - let _ = body.dump(&mut writer, &instance.name()); - } - } -} - /// Structure that represents the source location of a definition. /// TODO: Use `InternedString` once we move it out of the cprover_bindings. /// diff --git a/kani-compiler/src/kani_middle/reachability.rs b/kani-compiler/src/kani_middle/reachability.rs index 279dcf8cc107..87e6b542142b 100644 --- a/kani-compiler/src/kani_middle/reachability.rs +++ b/kani-compiler/src/kani_middle/reachability.rs @@ -22,6 +22,7 @@ use rustc_data_structures::fingerprint::Fingerprint; use rustc_data_structures::fx::FxHashSet; use rustc_data_structures::stable_hasher::{HashStable, StableHasher}; use rustc_middle::ty::{TyCtxt, VtblEntry}; +use rustc_session::config::OutputType; use rustc_smir::rustc_internal; use stable_mir::mir::alloc::{AllocId, GlobalAlloc}; use stable_mir::mir::mono::{Instance, InstanceKind, MonoItem, StaticDef}; @@ -32,6 +33,12 @@ use stable_mir::mir::{ use stable_mir::ty::{Allocation, ClosureKind, ConstantKind, RigidTy, Ty, TyKind}; use stable_mir::CrateItem; use stable_mir::{CrateDef, ItemKind}; +use std::fmt::{Display, Formatter}; +use std::{ + collections::{HashMap, HashSet}, + fs::File, + io::{BufWriter, Write}, +}; use crate::kani_middle::coercion; use crate::kani_middle::coercion::CoercionBase; @@ -42,7 +49,7 @@ pub fn collect_reachable_items( tcx: TyCtxt, transformer: &mut BodyTransformation, starting_points: &[MonoItem], -) -> Vec { +) -> (Vec, CallGraph) { // For each harness, collect items using the same collector. // I.e.: This will return any item that is reachable from one or more of the starting points. let mut collector = MonoItemsCollector::new(tcx, transformer); @@ -62,7 +69,7 @@ pub fn collect_reachable_items( // order of the errors and warnings is stable. let mut sorted_items: Vec<_> = collector.collected.into_iter().collect(); sorted_items.sort_by_cached_key(|item| to_fingerprint(tcx, item)); - sorted_items + (sorted_items, collector.call_graph) } /// Collect all (top-level) items in the crate that matches the given predicate. @@ -118,7 +125,24 @@ where } } } - roots + roots.into_iter().map(|root| root.item).collect() +} + +/// Reason for introducing an edge in the call graph. +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +enum CollectionReason { + DirectCall, + IndirectCall, + VTableMethod, + Static, + StaticDrop, +} + +/// A destination of the edge in the call graph. +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +struct CollectedItem { + item: MonoItem, + reason: CollectionReason, } struct MonoItemsCollector<'tcx, 'a> { @@ -130,8 +154,8 @@ struct MonoItemsCollector<'tcx, 'a> { collected: FxHashSet, /// Items enqueued for visiting. queue: Vec, - #[cfg(debug_assertions)] - call_graph: debug::CallGraph, + /// Call graph used for dataflow analysis. + call_graph: CallGraph, } impl<'tcx, 'a> MonoItemsCollector<'tcx, 'a> { @@ -140,8 +164,7 @@ impl<'tcx, 'a> MonoItemsCollector<'tcx, 'a> { tcx, collected: FxHashSet::default(), queue: vec![], - #[cfg(debug_assertions)] - call_graph: debug::CallGraph::default(), + call_graph: CallGraph::default(), transformer, } } @@ -167,17 +190,19 @@ impl<'tcx, 'a> MonoItemsCollector<'tcx, 'a> { vec![] } }; - #[cfg(debug_assertions)] self.call_graph.add_edges(to_visit, &next_items); - self.queue - .extend(next_items.into_iter().filter(|item| !self.collected.contains(item))); + self.queue.extend(next_items.into_iter().filter_map( + |CollectedItem { item, .. }| { + if !self.collected.contains(&item) { Some(item) } else { None } + }, + )); } } } /// Visit a function and collect all mono-items reachable from its instructions. - fn visit_fn(&mut self, instance: Instance) -> Vec { + fn visit_fn(&mut self, instance: Instance) -> Vec { let _guard = debug_span!("visit_fn", function=?instance).entered(); let body = self.transformer.body(self.tcx, instance); let mut collector = @@ -187,19 +212,24 @@ impl<'tcx, 'a> MonoItemsCollector<'tcx, 'a> { } /// Visit a static object and collect drop / initialization functions. - fn visit_static(&mut self, def: StaticDef) -> Vec { + fn visit_static(&mut self, def: StaticDef) -> Vec { let _guard = debug_span!("visit_static", ?def).entered(); let mut next_items = vec![]; // Collect drop function. let static_ty = def.ty(); let instance = Instance::resolve_drop_in_place(static_ty); - next_items.push(instance.into()); + next_items + .push(CollectedItem { item: instance.into(), reason: CollectionReason::StaticDrop }); // Collect initialization. let alloc = def.eval_initializer().unwrap(); for (_, prov) in alloc.provenance.ptrs { - next_items.extend(collect_alloc_items(prov.0).into_iter()); + next_items.extend( + collect_alloc_items(prov.0) + .into_iter() + .map(|item| CollectedItem { item, reason: CollectionReason::Static }), + ); } next_items @@ -213,7 +243,7 @@ impl<'tcx, 'a> MonoItemsCollector<'tcx, 'a> { struct MonoItemsFnCollector<'a, 'tcx> { tcx: TyCtxt<'tcx>, - collected: FxHashSet, + collected: FxHashSet, body: &'a Body, } @@ -251,7 +281,9 @@ impl<'a, 'tcx> MonoItemsFnCollector<'a, 'tcx> { } }); trace!(methods=?methods.clone().collect::>(), "collect_vtable_methods"); - self.collected.extend(methods); + self.collected.extend( + methods.map(|item| CollectedItem { item, reason: CollectionReason::VTableMethod }), + ); } // Add the destructor for the concrete type. @@ -282,7 +314,12 @@ impl<'a, 'tcx> MonoItemsFnCollector<'a, 'tcx> { }; if should_collect && should_codegen_locally(&instance) { trace!(?instance, "collect_instance"); - self.collected.insert(instance.into()); + let reason = if is_direct_call { + CollectionReason::DirectCall + } else { + CollectionReason::IndirectCall + }; + self.collected.insert(CollectedItem { item: instance.into(), reason }); } } @@ -290,7 +327,11 @@ impl<'a, 'tcx> MonoItemsFnCollector<'a, 'tcx> { fn collect_allocation(&mut self, alloc: &Allocation) { debug!(?alloc, "collect_allocation"); for (_, id) in &alloc.provenance.ptrs { - self.collected.extend(collect_alloc_items(id.0).into_iter()) + self.collected.extend( + collect_alloc_items(id.0) + .into_iter() + .map(|item| CollectedItem { item, reason: CollectionReason::Static }), + ) } } } @@ -366,7 +407,8 @@ impl<'a, 'tcx> MirVisitor for MonoItemsFnCollector<'a, 'tcx> { } Rvalue::ThreadLocalRef(item) => { trace!(?item, "visit_rvalue thread_local"); - self.collected.insert(MonoItem::Static(StaticDef::try_from(item).unwrap())); + let item = MonoItem::Static(StaticDef::try_from(item).unwrap()); + self.collected.insert(CollectedItem { item, reason: CollectionReason::Static }); } _ => { /* not interesting */ } } @@ -485,128 +527,148 @@ fn collect_alloc_items(alloc_id: AllocId) -> Vec { items } -#[cfg(debug_assertions)] -#[allow(dead_code)] -mod debug { - - use std::fmt::{Display, Formatter}; - use std::{ - collections::{HashMap, HashSet}, - fs::File, - io::{BufWriter, Write}, - }; - - use rustc_session::config::OutputType; - - use super::*; +/// Call graph with edges annotated with the reason why they were added to the graph. +#[derive(Debug, Default)] +pub struct CallGraph { + /// Nodes of the graph. + nodes: HashSet, + /// Edges of the graph. + edges: HashMap>, + /// Since the graph is directed, we also store back edges. + back_edges: HashMap>, +} - #[derive(Debug, Default)] - pub struct CallGraph { - // Nodes of the graph. - nodes: HashSet, - edges: HashMap>, - back_edges: HashMap>, +/// Newtype around MonoItem. +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +struct Node(pub MonoItem); + +/// Newtype around CollectedItem. +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +struct CollectedNode(pub CollectedItem); + +impl CallGraph { + /// Add a new node into a graph. + fn add_node(&mut self, item: MonoItem) { + let node = Node(item); + self.nodes.insert(node.clone()); + self.edges.entry(node.clone()).or_default(); + self.back_edges.entry(node).or_default(); } - #[derive(Clone, Debug, Eq, PartialEq, Hash)] - struct Node(pub MonoItem); - - impl CallGraph { - pub fn add_node(&mut self, item: MonoItem) { - let node = Node(item); - self.nodes.insert(node.clone()); - self.edges.entry(node.clone()).or_default(); - self.back_edges.entry(node).or_default(); - } + /// Add a new edge "from" -> "to". + fn add_edge(&mut self, from: MonoItem, to: MonoItem, collection_reason: CollectionReason) { + let from_node = Node(from.clone()); + let to_node = Node(to.clone()); + self.add_node(from.clone()); + self.add_node(to.clone()); + self.edges + .get_mut(&from_node) + .unwrap() + .push(CollectedNode(CollectedItem { item: to, reason: collection_reason })); + self.back_edges + .get_mut(&to_node) + .unwrap() + .push(CollectedNode(CollectedItem { item: from, reason: collection_reason })); + } - /// Add a new edge "from" -> "to". - pub fn add_edge(&mut self, from: MonoItem, to: MonoItem) { - let from_node = Node(from.clone()); - let to_node = Node(to.clone()); - self.add_node(from); - self.add_node(to); - self.edges.get_mut(&from_node).unwrap().push(to_node.clone()); - self.back_edges.get_mut(&to_node).unwrap().push(from_node); + /// Add multiple new edges for the "from" node. + fn add_edges(&mut self, from: MonoItem, to: &[CollectedItem]) { + self.add_node(from.clone()); + for CollectedItem { item, reason } in to { + self.add_edge(from.clone(), item.clone(), *reason); } + } - /// Add multiple new edges for the "from" node. - pub fn add_edges(&mut self, from: MonoItem, to: &[MonoItem]) { - self.add_node(from.clone()); - for item in to { - self.add_edge(from.clone(), item.clone()); + /// Print the graph in DOT format to a file. + /// See for more information. + fn dump_dot(&self, tcx: TyCtxt) -> std::io::Result<()> { + if let Ok(target) = std::env::var("KANI_REACH_DEBUG") { + debug!(?target, "dump_dot"); + let outputs = tcx.output_filenames(()); + let base_path = outputs.path(OutputType::Metadata); + let path = base_path.as_path().with_extension("dot"); + let out_file = File::create(path)?; + let mut writer = BufWriter::new(out_file); + writeln!(writer, "digraph ReachabilityGraph {{")?; + if target.is_empty() { + self.dump_all(&mut writer)?; + } else { + // Only dump nodes that led the reachability analysis to the target node. + self.dump_reason(&mut writer, &target)?; } + writeln!(writer, "}}")?; } - /// Print the graph in DOT format to a file. - /// See for more information. - pub fn dump_dot(&self, tcx: TyCtxt) -> std::io::Result<()> { - if let Ok(target) = std::env::var("KANI_REACH_DEBUG") { - debug!(?target, "dump_dot"); - let outputs = tcx.output_filenames(()); - let base_path = outputs.path(OutputType::Metadata); - let path = base_path.as_path().with_extension("dot"); - let out_file = File::create(path)?; - let mut writer = BufWriter::new(out_file); - writeln!(writer, "digraph ReachabilityGraph {{")?; - if target.is_empty() { - self.dump_all(&mut writer)?; - } else { - // Only dump nodes that led the reachability analysis to the target node. - self.dump_reason(&mut writer, &target)?; - } - writeln!(writer, "}}")?; - } + Ok(()) + } - Ok(()) + /// Write all notes to the given writer. + fn dump_all(&self, writer: &mut W) -> std::io::Result<()> { + tracing::info!(nodes=?self.nodes.len(), edges=?self.edges.len(), "dump_all"); + for node in &self.nodes { + writeln!(writer, r#""{node}""#)?; + for succ in self.edges.get(node).unwrap() { + let reason = succ.0.reason; + writeln!(writer, r#""{node}" -> "{succ}" [label={reason:?}] "#)?; + } } + Ok(()) + } - /// Write all notes to the given writer. - fn dump_all(&self, writer: &mut W) -> std::io::Result<()> { - tracing::info!(nodes=?self.nodes.len(), edges=?self.edges.len(), "dump_all"); - for node in &self.nodes { - writeln!(writer, r#""{node}""#)?; - for succ in self.edges.get(node).unwrap() { - writeln!(writer, r#""{node}" -> "{succ}" "#)?; - } + /// Write all notes that may have led to the discovery of the given target. + fn dump_reason(&self, writer: &mut W, target: &str) -> std::io::Result<()> { + let mut queue: Vec = + self.nodes.iter().filter(|item| item.to_string().contains(target)).cloned().collect(); + let mut visited: HashSet = HashSet::default(); + tracing::info!(target=?queue, nodes=?self.nodes.len(), edges=?self.edges.len(), "dump_reason"); + while let Some(to_visit) = queue.pop() { + if !visited.contains(&to_visit) { + visited.insert(to_visit.clone()); + queue.extend( + self.back_edges.get(&to_visit).unwrap().iter().map(|item| item.as_node()), + ); } - Ok(()) } - /// Write all notes that may have led to the discovery of the given target. - fn dump_reason(&self, writer: &mut W, target: &str) -> std::io::Result<()> { - let mut queue = self - .nodes + for node in &visited { + writeln!(writer, r#""{node}""#)?; + for succ in self + .edges + .get(node) + .unwrap() .iter() - .filter(|item| item.to_string().contains(target)) - .collect::>(); - let mut visited: HashSet<&Node> = HashSet::default(); - tracing::info!(target=?queue, nodes=?self.nodes.len(), edges=?self.edges.len(), "dump_reason"); - while let Some(to_visit) = queue.pop() { - if !visited.contains(to_visit) { - visited.insert(to_visit); - queue.extend(self.back_edges.get(to_visit).unwrap()); - } + .filter(|item| visited.contains(&item.as_node())) + { + let reason = succ.0.reason; + writeln!(writer, r#""{node}" -> "{succ}" [label={reason:?}] "#)?; } + } + Ok(()) + } +} - for node in &visited { - writeln!(writer, r#""{node}""#)?; - for succ in - self.edges.get(node).unwrap().iter().filter(|item| visited.contains(item)) - { - writeln!(writer, r#""{node}" -> "{succ}" "#)?; - } - } - Ok(()) +impl Display for Node { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match &self.0 { + MonoItem::Fn(instance) => write!(f, "{}", instance.name()), + MonoItem::Static(def) => write!(f, "{}", def.name()), + MonoItem::GlobalAsm(asm) => write!(f, "{asm:?}"), } } +} - impl Display for Node { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match &self.0 { - MonoItem::Fn(instance) => write!(f, "{}", instance.name()), - MonoItem::Static(def) => write!(f, "{}", def.name()), - MonoItem::GlobalAsm(asm) => write!(f, "{asm:?}"), - } +impl Display for CollectedNode { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match &self.0.item { + MonoItem::Fn(instance) => write!(f, "{}", instance.name()), + MonoItem::Static(def) => write!(f, "{}", def.name()), + MonoItem::GlobalAsm(asm) => write!(f, "{asm:?}"), } } } + +impl CollectedNode { + pub fn as_node(&self) -> Node { + Node(self.0.item.clone()) + } +} diff --git a/kani-compiler/src/kani_middle/transform/dump_mir_pass.rs b/kani-compiler/src/kani_middle/transform/dump_mir_pass.rs new file mode 100644 index 000000000000..1d7aec07ffc0 --- /dev/null +++ b/kani-compiler/src/kani_middle/transform/dump_mir_pass.rs @@ -0,0 +1,68 @@ +use crate::kani_middle::reachability::CallGraph; +use crate::kani_middle::transform::GlobalPass; +use crate::kani_queries::QueryDb; +use kani_metadata::ArtifactType; +use rustc_middle::ty::TyCtxt; +use rustc_session::config::OutputType; +use stable_mir::mir::mono::{Instance, MonoItem}; +use stable_mir::mir::Body; +use std::fmt::{Debug, Formatter, Result}; +use std::fs::File; +use std::io::BufWriter; +use std::io::Write; + +/// Dump all MIR bodies. +pub struct DumpMirPass { + enabled: bool, +} + +impl DumpMirPass { + pub fn new(tcx: TyCtxt) -> Self { + Self { enabled: tcx.sess.opts.output_types.contains_key(&OutputType::Mir) } + } +} + +impl Debug for DumpMirPass { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + f.debug_struct("DumpMirPass").field("enabled", &self.enabled).finish() + } +} + +impl GlobalPass for DumpMirPass { + fn is_enabled(&self, _query_db: &QueryDb) -> bool { + self.enabled + } + + fn transform( + &mut self, + tcx: TyCtxt, + starting_items: &[MonoItem], + bodies: Vec<(Body, Instance)>, + _call_graph: &CallGraph, + ) -> Vec<(bool, Body, Instance)> { + // Create output buffer. + let file_path = { + let base_path = tcx.output_filenames(()).path(OutputType::Object); + let base_name = base_path.as_path(); + let entry_point = (starting_items.len() == 1).then_some(starting_items[0].clone()); + if let Some(MonoItem::Fn(starting_instance)) = entry_point { + let mangled_name = starting_instance.mangled_name(); + let file_stem = + format!("{}_{mangled_name}", base_name.file_stem().unwrap().to_str().unwrap()); + base_name.with_file_name(file_stem).with_extension(ArtifactType::SymTabGoto) + } else { + base_name.with_extension(ArtifactType::SymTabGoto) + } + }; + let out_file = File::create(file_path.with_extension("kani.mir")).unwrap(); + let mut writer = BufWriter::new(out_file); + + // For each def_id, dump their MIR + for (body, instance) in bodies.iter() { + writeln!(writer, "// Item: {} ({})", instance.name(), instance.mangled_name()).unwrap(); + let _ = body.dump(&mut writer, &instance.name()); + } + + bodies.into_iter().map(|(body, instance)| (false, body, instance)).collect() + } +} diff --git a/kani-compiler/src/kani_middle/transform/mod.rs b/kani-compiler/src/kani_middle/transform/mod.rs index 26a95978fcaf..ae525134d5e8 100644 --- a/kani-compiler/src/kani_middle/transform/mod.rs +++ b/kani-compiler/src/kani_middle/transform/mod.rs @@ -17,6 +17,7 @@ //! For all instrumentation passes, always use exhaustive matches to ensure soundness in case a new //! case is added. use crate::kani_middle::codegen_units::CodegenUnit; +use crate::kani_middle::reachability::CallGraph; use crate::kani_middle::transform::body::CheckType; use crate::kani_middle::transform::check_uninit::UninitPass; use crate::kani_middle::transform::check_values::ValidValuePass; @@ -24,8 +25,9 @@ use crate::kani_middle::transform::contracts::AnyModifiesPass; use crate::kani_middle::transform::kani_intrinsics::IntrinsicGeneratorPass; use crate::kani_middle::transform::stubs::{ExternFnStubPass, FnStubPass}; use crate::kani_queries::QueryDb; +use dump_mir_pass::DumpMirPass; use rustc_middle::ty::TyCtxt; -use stable_mir::mir::mono::Instance; +use stable_mir::mir::mono::{Instance, MonoItem}; use stable_mir::mir::Body; use std::collections::HashMap; use std::fmt::Debug; @@ -34,6 +36,7 @@ pub(crate) mod body; mod check_uninit; mod check_values; mod contracts; +mod dump_mir_pass; mod kani_intrinsics; mod stubs; @@ -49,6 +52,9 @@ pub struct BodyTransformation { stub_passes: Vec>, /// The passes that may add safety checks to the function body. inst_passes: Vec>, + /// The passes that operate on the whole codegen unit, they run after all previous passes are + /// done. + global_passes: Vec>, /// Cache transformation results. cache: HashMap, } @@ -58,6 +64,7 @@ impl BodyTransformation { let mut transformer = BodyTransformation { stub_passes: vec![], inst_passes: vec![], + global_passes: vec![], cache: Default::default(), }; let check_type = CheckType::new_assert_assume(tcx); @@ -87,10 +94,13 @@ impl BodyTransformation { arguments: queries.args().clone(), }, ); + transformer.add_global_pass(queries, DumpMirPass::new(tcx)); + transformer } - /// Retrieve the body of an instance. + /// Retrieve the body of an instance. This does not apply global passes, but will retrieve the + /// body after global passes running if they were previously applied. /// /// Note that this assumes that the instance does have a body since existing consumers already /// assume that. Use `instance.has_body()` to check if an instance has a body. @@ -118,6 +128,27 @@ impl BodyTransformation { } } + /// Apply all per-body pass followed by the global passes and store the results in cache that + /// can later be queried by `body`. + pub fn apply_passes( + &mut self, + tcx: TyCtxt, + starting_items: &[MonoItem], + instances: Vec, + call_graph: CallGraph, + ) { + let bodies: Vec<_> = + instances.into_iter().map(|instance| (self.body(tcx, instance), instance)).collect(); + for global_pass in self.global_passes.iter_mut() { + let results = global_pass.transform(tcx, starting_items, bodies.clone(), &call_graph); + for (modified, body, instance) in results { + if modified { + self.cache.insert(instance, TransformationResult::Modified(body)); + } + } + } + } + fn add_pass(&mut self, query_db: &QueryDb, pass: P) { if pass.is_enabled(&query_db) { match P::transformation_type() { @@ -126,6 +157,12 @@ impl BodyTransformation { } } } + + fn add_global_pass(&mut self, query_db: &QueryDb, pass: P) { + if pass.is_enabled(&query_db) { + self.global_passes.push(Box::new(pass)) + } + } } /// The type of transformation that a pass may perform. @@ -152,6 +189,22 @@ pub(crate) trait TransformPass: Debug { fn transform(&mut self, tcx: TyCtxt, body: Body, instance: Instance) -> (bool, Body); } +/// A trait to represent transformation passes that operate on the whole codegen unit. +pub(crate) trait GlobalPass: Debug { + fn is_enabled(&self, query_db: &QueryDb) -> bool + where + Self: Sized; + + /// Run a transformation pass on the whole codegen unit. + fn transform( + &mut self, + tcx: TyCtxt, + starting_items: &[MonoItem], + bodies: Vec<(Body, Instance)>, + call_graph: &CallGraph, + ) -> Vec<(bool, Body, Instance)>; +} + /// The transformation result. /// We currently only cache the body of functions that were instrumented. #[derive(Clone, Debug)]