Skip to content

Commit

Permalink
semantic: model r/w access (#1102)
Browse files Browse the repository at this point in the history
* semantic: model r/w access
* chore: tests and review 🔧
* dev: handle module name error
  • Loading branch information
shramee authored Oct 27, 2023
1 parent 0a772af commit 640bfd3
Show file tree
Hide file tree
Showing 10 changed files with 383 additions and 15 deletions.
45 changes: 41 additions & 4 deletions crates/dojo-lang/src/compiler.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::collections::{BTreeMap, HashMap};
use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::iter::zip;
use std::ops::{Deref, DerefMut};

use anyhow::{anyhow, Context, Result};
use cairo_lang_compiler::db::RootDatabase;
use cairo_lang_defs::db::DefsGroup;
use cairo_lang_defs::ids::{ModuleId, ModuleItemId};
use cairo_lang_filesystem::db::FilesGroup;
use cairo_lang_filesystem::ids::{CrateId, CrateLongId};
Expand All @@ -27,7 +28,9 @@ use starknet::core::types::contract::SierraClass;
use starknet::core::types::FieldElement;
use tracing::{debug, trace, trace_span};

use crate::inline_macros::utils::{SYSTEM_READS, SYSTEM_WRITES};
use crate::plugin::DojoAuxData;
use crate::semantics::utils::find_module_rw;

const CAIRO_PATH_SEPARATOR: &str = "::";

Expand Down Expand Up @@ -202,7 +205,7 @@ pub fn collect_external_crate_ids(

fn update_manifest(
manifest: &mut dojo_world::manifest::Manifest,
db: &dyn SemanticGroup,
db: &RootDatabase,
crate_ids: &[CrateId],
compiled_artifacts: HashMap<SmolStr, (FieldElement, Option<abi::Contract>)>,
) -> anyhow::Result<()> {
Expand Down Expand Up @@ -254,7 +257,12 @@ fn update_manifest(
.filter_map(|aux_data| aux_data.as_ref().map(|aux_data| aux_data.0.as_any()))
{
if let Some(aux_data) = aux_data.downcast_ref::<StarkNetContractAuxData>() {
contracts.extend(get_dojo_contract_artifacts(aux_data, &compiled_artifacts)?);
contracts.extend(get_dojo_contract_artifacts(
db,
module_id,
aux_data,
&compiled_artifacts,
)?);
}

if let Some(dojo_aux_data) = aux_data.downcast_ref::<DojoAuxData>() {
Expand Down Expand Up @@ -315,6 +323,8 @@ fn get_dojo_model_artifacts(
}

fn get_dojo_contract_artifacts(
db: &RootDatabase,
module_id: &ModuleId,
aux_data: &StarkNetContractAuxData,
compiled_classes: &HashMap<SmolStr, (FieldElement, Option<abi::Contract>)>,
) -> anyhow::Result<HashMap<SmolStr, Contract>> {
Expand All @@ -323,11 +333,38 @@ fn get_dojo_contract_artifacts(
.iter()
.filter(|name| !matches!(name.as_ref(), "world" | "executor" | "base"))
.map(|name| {
let module_name = module_id.full_path(db);
let module_last_name = module_name.split("::").last().unwrap();

let reads = match SYSTEM_READS.lock().unwrap().get(module_last_name) {
Some(models) => {
models.clone().into_iter().collect::<BTreeSet<_>>().into_iter().collect()
}
None => vec![],
};

let write_entries = SYSTEM_WRITES.lock().unwrap();
let writes = match write_entries.get(module_last_name) {
Some(write_ops) => find_module_rw(db, module_id, write_ops),
None => vec![],
};

let (class_hash, abi) = compiled_classes
.get(name)
.cloned()
.ok_or(anyhow!("Contract {name} not found in target."))?;
Ok((name.clone(), Contract { name: name.clone(), class_hash, abi, address: None }))

Ok((
name.clone(),
Contract {
name: name.clone(),
class_hash,
abi,
writes,
reads,
..Default::default()
},
))
})
.collect::<anyhow::Result<_>>()
}
Expand Down
23 changes: 22 additions & 1 deletion crates/dojo-lang/src/inline_macros/get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ use cairo_lang_defs::plugin::{
InlineMacroExprPlugin, InlinePluginResult, PluginDiagnostic, PluginGeneratedFile,
};
use cairo_lang_semantic::inline_macros::unsupported_bracket_diagnostic;
use cairo_lang_syntax::node::ast::Expr;
use cairo_lang_syntax::node::ast::{Expr, ItemModule};
use cairo_lang_syntax::node::kind::SyntaxKind;
use cairo_lang_syntax::node::{ast, TypedSyntaxNode};
use itertools::Itertools;

use super::utils::{parent_of_kind, SYSTEM_READS};
use super::{extract_models, unsupported_arg_diagnostic, CAIRO_ERR_MSG_LEN};

#[derive(Debug)]
Expand Down Expand Up @@ -78,7 +80,26 @@ impl InlineMacroExprPlugin for GetMacro {
let __get_macro_keys__ = array::ArrayTrait::span(@__get_macro_keys__);\n"
));

let mut system_reads = SYSTEM_READS.lock().unwrap();

let module_syntax_node =
parent_of_kind(db, &syntax.as_syntax_node(), SyntaxKind::ItemModule);
let module_name = if let Some(module_syntax_node) = &module_syntax_node {
let mod_ast = ItemModule::from_syntax_node(db, module_syntax_node.clone());
mod_ast.name(db).as_syntax_node().get_text_without_trivia(db)
} else {
eprintln!("Error: Couldn't get the module name.");
"".into()
};

for model in &models {
if !module_name.is_empty() {
if system_reads.get(&module_name).is_none() {
system_reads.insert(module_name.clone(), vec![model.to_string()]);
} else {
system_reads.get_mut(&module_name).unwrap().push(model.to_string());
}
}
let mut lookup_err_msg = format!("{} not found", model.to_string());
lookup_err_msg.truncate(CAIRO_ERR_MSG_LEN);
let mut deser_err_msg = format!("{} failed to deserialize", model.to_string());
Expand Down
2 changes: 2 additions & 0 deletions crates/dojo-lang/src/inline_macros/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use smol_str::SmolStr;
pub mod emit;
pub mod get;
pub mod set;
pub mod utils;

const CAIRO_ERR_MSG_LEN: usize = 31;

Expand Down Expand Up @@ -71,6 +72,7 @@ pub fn extract_models(

Ok(models)
}

pub fn unsupported_arg_diagnostic(
db: &dyn SyntaxGroup,
macro_ast: &ast::ExprInlineMacro,
Expand Down
92 changes: 86 additions & 6 deletions crates/dojo-lang/src/inline_macros/set.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,35 @@
use std::collections::HashMap;

use cairo_lang_defs::patcher::PatchBuilder;
use cairo_lang_defs::plugin::{
InlineMacroExprPlugin, InlinePluginResult, PluginDiagnostic, PluginGeneratedFile,
};
use cairo_lang_semantic::inline_macros::unsupported_bracket_diagnostic;
use cairo_lang_syntax::node::ast::{ExprPath, ExprStructCtorCall, FunctionWithBody, ItemModule};
use cairo_lang_syntax::node::kind::SyntaxKind;
use cairo_lang_syntax::node::{ast, TypedSyntaxNode};

use super::unsupported_arg_diagnostic;
use super::utils::{parent_of_kind, SystemRWOpRecord, SYSTEM_WRITES};

#[derive(Debug)]
pub struct SetMacro;
impl SetMacro {
pub const NAME: &'static str = "set";
// Parents of set!()
// -----------------
// StatementExpr
// StatementList
// ExprBlock
// FunctionWithBody
// ImplItemList
// ImplBody
// ItemImpl
// ItemList
// ModuleBody
// ItemModule
// ItemList
// SyntaxFile
}
impl InlineMacroExprPlugin for SetMacro {
fn generate_code(
Expand Down Expand Up @@ -46,12 +65,19 @@ impl InlineMacroExprPlugin for SetMacro {

match models.value(db) {
ast::Expr::Parenthesized(parens) => {
bundle.push(parens.expr(db).as_syntax_node().get_text(db))
let syntax_node = parens.expr(db).as_syntax_node();
bundle.push((syntax_node.get_text(db), syntax_node));
}
ast::Expr::Tuple(list) => {
list.expressions(db).elements(db).into_iter().for_each(|expr| {
let syntax_node = expr.as_syntax_node();
bundle.push((syntax_node.get_text(db), syntax_node));
})
}
ast::Expr::StructCtorCall(ctor) => {
let syntax_node = ctor.as_syntax_node();
bundle.push((syntax_node.get_text(db), syntax_node));
}
ast::Expr::Tuple(list) => list.expressions(db).elements(db).iter().for_each(|expr| {
bundle.push(expr.as_syntax_node().get_text(db));
}),
ast::Expr::StructCtorCall(ctor) => bundle.push(ctor.as_syntax_node().get_text(db)),
_ => {
return InlinePluginResult {
code: None,
Expand All @@ -73,7 +99,61 @@ impl InlineMacroExprPlugin for SetMacro {
};
}

for entity in bundle {
let module_syntax_node =
parent_of_kind(db, &syntax.as_syntax_node(), SyntaxKind::ItemModule);
let module_name = if let Some(module_syntax_node) = &module_syntax_node {
let mod_ast = ItemModule::from_syntax_node(db, module_syntax_node.clone());
mod_ast.name(db).as_syntax_node().get_text_without_trivia(db)
} else {
eprintln!("Error: Couldn't get the module name.");
"".into()
};

let fn_syntax_node =
parent_of_kind(db, &syntax.as_syntax_node(), SyntaxKind::FunctionWithBody);
let fn_name = if let Some(fn_syntax_node) = &fn_syntax_node {
let fn_ast = FunctionWithBody::from_syntax_node(db, fn_syntax_node.clone());
fn_ast.declaration(db).name(db).as_syntax_node().get_text_without_trivia(db)
} else {
// Unlikely to get here, but if we do.
eprintln!("Error: Couldn't get the function name.");
"".into()
};

for (entity, syntax_node) in bundle {
// db.lookup_intern_file(key0);
if !module_name.is_empty() && !fn_name.is_empty() {
let mut system_writes = SYSTEM_WRITES.lock().unwrap();
// fn_syntax_node
if system_writes.get(&module_name).is_none() {
system_writes.insert(module_name.clone(), HashMap::new());
}
let fns = system_writes.get_mut(&module_name).unwrap();
if fns.get(&fn_name).is_none() {
fns.insert(fn_name.clone(), vec![]);
}

match syntax_node.kind(db) {
SyntaxKind::ExprPath => {
fns.get_mut(&fn_name).unwrap().push(SystemRWOpRecord::Path(
ExprPath::from_syntax_node(db, syntax_node),
));
}
// SyntaxKind::StatementExpr => {
// todo!()
// }
SyntaxKind::ExprStructCtorCall => {
fns.get_mut(&fn_name).unwrap().push(SystemRWOpRecord::StructCtor(
ExprStructCtorCall::from_syntax_node(db, syntax_node.clone()),
));
}
_ => eprintln!(
"Unsupport component value type {} for semantic writer analysis",
syntax_node.kind(db)
),
}
}

builder.add_str(&format!(
"
let __set_macro_value__ = {};
Expand Down
33 changes: 33 additions & 0 deletions crates/dojo-lang/src/inline_macros/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use std::collections::HashMap;
use std::sync::Mutex;

use cairo_lang_syntax::node::ast::{ExprPath, ExprStructCtorCall};
use cairo_lang_syntax::node::kind::SyntaxKind;
use cairo_lang_syntax::node::SyntaxNode;

type ModuleName = String;
type FunctionName = String;
lazy_static::lazy_static! {
pub static ref SYSTEM_WRITES: Mutex<HashMap<ModuleName, HashMap<FunctionName, Vec<SystemRWOpRecord>>>> = Default::default();
pub static ref SYSTEM_READS: Mutex<HashMap<ModuleName, Vec<String>>> = Default::default();
}

pub enum SystemRWOpRecord {
StructCtor(ExprStructCtorCall),
Path(ExprPath),
}

pub fn parent_of_kind(
db: &dyn cairo_lang_syntax::node::db::SyntaxGroup,
target: &SyntaxNode,
kind: SyntaxKind,
) -> Option<SyntaxNode> {
let mut new_target = target.clone();
while let Some(parent) = new_target.parent() {
if kind == parent.kind(db) {
return Some(parent);
}
new_target = parent;
}
None
}
16 changes: 14 additions & 2 deletions crates/dojo-lang/src/manifest_test_data/manifest
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,9 @@ test_manifest_file
}
]
}
]
],
"reads": [],
"writes": []
},
"executor": {
"name": "executor",
Expand Down Expand Up @@ -797,7 +799,9 @@ test_manifest_file
"kind": "enum",
"variants": []
}
]
],
"reads": [],
"writes": []
},
"base": {
"name": "base",
Expand Down Expand Up @@ -985,6 +989,14 @@ test_manifest_file
}
]
}
],
"reads": [
"Moves",
"Position"
],
"writes": [
"Moves",
"Position"
]
}
],
Expand Down
2 changes: 2 additions & 0 deletions crates/dojo-lang/src/semantics/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pub mod utils;

#[cfg(test)]
pub mod test_utils;

Expand Down
Loading

0 comments on commit 640bfd3

Please sign in to comment.