Skip to content

Commit

Permalink
fix: add a derive for generate contract (#58)
Browse files Browse the repository at this point in the history
* feat: add derives specific to contract

* fix: add examples for contract derives

* docs: add example for derives in README
  • Loading branch information
glihm authored Sep 23, 2024
1 parent 9f7cd5c commit 4e3924f
Show file tree
Hide file tree
Showing 12 changed files with 81 additions and 6 deletions.
10 changes: 10 additions & 0 deletions crates/rs-macro/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ The `abigen!` macro takes 2 or 3 inputs:
3. Optional parameters:
- `output_path`: if provided, the content will be generated in the given file instead of being expanded at the location of the macro invocation.
- `type_aliases`: to avoid type name conflicts between components / contracts, you can rename some type by providing an alias for the full type path. It is important to give the **full** type path to ensure aliases are applied correctly.
- `derive`: to specify the derive for the generated structs/enums.
- `contract_derives`: to specify the derive for the generated contract type.

```rust
use cainome::rs::abigen;
Expand All @@ -66,6 +68,14 @@ abigen!(
},
);

// Example with custom derives:
abigen!(
MyContract,
"./contracts/abi/components.abi.json",
derive(Debug, Clone),
contract_derives(Debug, Clone)
);

fn main() {
// ... use the generated types here, which all of them
// implement CairoSerde trait.
Expand Down
2 changes: 2 additions & 0 deletions crates/rs-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ fn abigen_internal(input: TokenStream) -> TokenStream {
&abi_tokens,
contract_abi.execution_version,
&contract_abi.derives,
&contract_abi.contract_derives,
);

if let Some(out_path) = contract_abi.output_path {
Expand Down Expand Up @@ -66,6 +67,7 @@ fn abigen_internal_legacy(input: TokenStream) -> TokenStream {
&abi_tokens,
cainome_rs::ExecutionVersion::V1,
&contract_abi.derives,
&contract_abi.contract_derives,
);

if let Some(out_path) = contract_abi.output_path {
Expand Down
12 changes: 12 additions & 0 deletions crates/rs-macro/src/macro_inputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub(crate) struct ContractAbi {
pub type_aliases: HashMap<String, String>,
pub execution_version: ExecutionVersion,
pub derives: Vec<String>,
pub contract_derives: Vec<String>,
}

impl Parse for ContractAbi {
Expand Down Expand Up @@ -92,6 +93,7 @@ impl Parse for ContractAbi {
let mut execution_version = ExecutionVersion::V1;
let mut type_aliases = HashMap::new();
let mut derives = Vec::new();
let mut contract_derives = Vec::new();

loop {
if input.parse::<Token![,]>().is_err() {
Expand Down Expand Up @@ -153,6 +155,15 @@ impl Parse for ContractAbi {
derives.push(derive.to_token_stream().to_string());
}
}
"contract_derives" => {
let content;
parenthesized!(content in input);
let parsed = content.parse_terminated(Spanned::<Type>::parse, Token![,])?;

for derive in parsed {
contract_derives.push(derive.to_token_stream().to_string());
}
}
_ => emit_error!(name.span(), format!("unexpected named parameter `{name}`")),
}
}
Expand All @@ -164,6 +175,7 @@ impl Parse for ContractAbi {
type_aliases,
execution_version,
derives,
contract_derives,
})
}
}
Expand Down
12 changes: 12 additions & 0 deletions crates/rs-macro/src/macro_inputs_legacy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub(crate) struct ContractAbiLegacy {
pub output_path: Option<String>,
pub type_aliases: HashMap<String, String>,
pub derives: Vec<String>,
pub contract_derives: Vec<String>,
}

impl Parse for ContractAbiLegacy {
Expand Down Expand Up @@ -89,6 +90,7 @@ impl Parse for ContractAbiLegacy {
let mut output_path: Option<String> = None;
let mut type_aliases = HashMap::new();
let mut derives = Vec::new();
let mut contract_derives = Vec::new();

loop {
if input.parse::<Token![,]>().is_err() {
Expand Down Expand Up @@ -142,6 +144,15 @@ impl Parse for ContractAbiLegacy {
derives.push(derive.to_token_stream().to_string());
}
}
"contract_derives" => {
let content;
parenthesized!(content in input);
let parsed = content.parse_terminated(Spanned::<Type>::parse, Token![,])?;

for derive in parsed {
contract_derives.push(derive.to_token_stream().to_string());
}
}
_ => emit_error!(name.span(), format!("unexpected named parameter `{name}`")),
}
}
Expand All @@ -152,6 +163,7 @@ impl Parse for ContractAbiLegacy {
output_path,
type_aliases,
derives,
contract_derives,
})
}
}
Expand Down
12 changes: 9 additions & 3 deletions crates/rs/src/expand/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,22 @@ use super::utils;
pub struct CairoContract;

impl CairoContract {
pub fn expand(contract_name: Ident) -> TokenStream2 {
pub fn expand(contract_name: Ident, contract_derives: &[String]) -> TokenStream2 {
let reader = utils::str_to_ident(format!("{}Reader", contract_name).as_str());

let snrs_types = utils::snrs_types();
let snrs_accounts = utils::snrs_accounts();
let snrs_providers = utils::snrs_providers();

let mut internal_derives = vec![];

for d in contract_derives {
internal_derives.push(utils::str_to_type(d));
}

let q = quote! {

#[derive(Debug)]
#[derive(#(#internal_derives,)*)]
pub struct #contract_name<A: #snrs_accounts::ConnectedAccount + Sync> {
pub address: #snrs_types::Felt,
pub account: A,
Expand Down Expand Up @@ -45,7 +51,7 @@ impl CairoContract {
}
}

#[derive(Debug)]
#[derive(#(#internal_derives,)*)]
pub struct #reader<P: #snrs_providers::Provider + Sync> {
pub address: #snrs_types::Felt,
pub provider: P,
Expand Down
23 changes: 22 additions & 1 deletion crates/rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ pub struct Abigen {
pub execution_version: ExecutionVersion,
/// Derives to be added to the generated types.
pub derives: Vec<String>,
/// Derives to be added to the generated contract.
pub contract_derives: Vec<String>,
}

impl Abigen {
Expand All @@ -90,6 +92,7 @@ impl Abigen {
types_aliases: HashMap::new(),
execution_version: ExecutionVersion::V1,
derives: vec![],
contract_derives: vec![],
}
}

Expand Down Expand Up @@ -123,6 +126,16 @@ impl Abigen {
self
}

/// Sets the derives to be added to the generated contract.
///
/// # Arguments
///
/// * `derives` - Derives to be added to the generated contract.
pub fn with_contract_derives(mut self, derives: Vec<String>) -> Self {
self.contract_derives = derives;
self
}

/// Generates the contract bindings.
pub fn generate(&self) -> Result<ContractBindings> {
let file_content = std::fs::read_to_string(&self.abi_source)?;
Expand All @@ -134,6 +147,7 @@ impl Abigen {
&tokens,
self.execution_version,
&self.derives,
&self.contract_derives,
);

Ok(ContractBindings {
Expand All @@ -157,17 +171,24 @@ impl Abigen {
///
/// * `contract_name` - Name of the contract.
/// * `abi_tokens` - Tokenized ABI.
/// * `execution_version` - The version of transaction to be executed.
/// * `derives` - Derives to be added to the generated types.
/// * `contract_derives` - Derives to be added to the generated contract.
pub fn abi_to_tokenstream(
contract_name: &str,
abi_tokens: &TokenizedAbi,
execution_version: ExecutionVersion,
derives: &[String],
contract_derives: &[String],
) -> TokenStream2 {
let contract_name = utils::str_to_ident(contract_name);

let mut tokens: Vec<TokenStream2> = vec![];

tokens.push(CairoContract::expand(contract_name.clone()));
tokens.push(CairoContract::expand(
contract_name.clone(),
contract_derives,
));

let mut sorted_structs = abi_tokens.structs.clone();
sorted_structs.sort_by(|a, b| {
Expand Down
4 changes: 3 additions & 1 deletion examples/abigen_generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ async fn main() {
"MyContract",
"./contracts/target/dev/contracts_simple_get_set.contract_class.json",
)
.with_types_aliases(aliases);
.with_types_aliases(aliases)
.with_derives(vec!["Debug".to_string(), "PartialEq".to_string()])
.with_contract_derives(vec!["Debug".to_string(), "Clone".to_string()]);

abigen
.generate()
Expand Down
4 changes: 3 additions & 1 deletion examples/simple_get_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ const KATANA_CHAIN_ID: &str = "0x4b4154414e41";
// Or you can use the extracted abi entries with jq in contracts/abi/.
abigen!(
MyContract,
"./contracts/target/dev/contracts_simple_get_set.contract_class.json"
"./contracts/target/dev/contracts_simple_get_set.contract_class.json",
derives(Debug, PartialEq),
contract_derives(Debug, Clone)
);
//abigen!(MyContract, "./contracts/abi/simple_get_set.abi.json");

Expand Down
5 changes: 5 additions & 0 deletions src/bin/cli/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ pub struct CainomeArgs {
#[arg(value_name = "DERIVES")]
#[arg(help = "Derives to be added to the generated types.")]
pub derives: Option<Vec<String>>,

#[arg(long)]
#[arg(value_name = "CONTRACT_DERIVES")]
#[arg(help = "Derives to be added to the generated contract.")]
pub contract_derives: Option<Vec<String>>,
}

#[derive(Debug, Args, Clone)]
Expand Down
1 change: 1 addition & 0 deletions src/bin/cli/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ async fn main() -> CainomeCliResult<()> {
contracts,
execution_version: args.execution_version,
derives: args.derives.unwrap_or_default(),
contract_derives: args.contract_derives.unwrap_or_default(),
})
.await?;

Expand Down
1 change: 1 addition & 0 deletions src/bin/cli/plugins/builtins/rust.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ impl BuiltinPlugin for RustPlugin {
&contract.tokens,
input.execution_version,
&input.derives,
&input.contract_derives,
);
let filename = format!(
"{}.rs",
Expand Down
1 change: 1 addition & 0 deletions src/bin/cli/plugins/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub struct PluginInput {
pub contracts: Vec<ContractData>,
pub execution_version: ExecutionVersion,
pub derives: Vec<String>,
pub contract_derives: Vec<String>,
}

#[derive(Debug)]
Expand Down

0 comments on commit 4e3924f

Please sign in to comment.