Skip to content

Commit

Permalink
refactor(cnnl): 基于 digit-layout 封装 Tensor
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Jun 21, 2024
1 parent ce05eaf commit 0a567e2
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 6 deletions.
1 change: 1 addition & 0 deletions cnnl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ categories = ["hardware-support"]

[dependencies]
cndrv = { version = "0.0", path = "../cndrv" }
digit-layout = "0.0"

[build-dependencies]
bindgen.workspace = true
Expand Down
44 changes: 38 additions & 6 deletions cnnl/src/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
use crate::bindings::cnnlTensorDescriptor_t;
use crate::bindings::{cnnlDataType_t, cnnlTensorDescriptor_t};
use digit_layout::DigitLayout;
use std::ptr::null_mut;

#[repr(transparent)]
pub struct Tensor(cnnlTensorDescriptor_t);

impl Tensor {
#[inline]
pub fn new() -> Self {
let mut ptr = null_mut();
cnnl!(cnnlCreateTensorDescriptor(&mut ptr));
Tensor(ptr)
pub fn new(dl: DigitLayout, shape: &[i64], strides: &[i64]) -> Self {
assert_eq!(shape.len(), strides.len());
let mut desc = null_mut();
cnnl!(cnnlCreateTensorDescriptor(&mut desc));
cnnl!(cnnlSetTensorDescriptorEx_v2(
desc,
cnnlTensorLayout_t::CNNL_LAYOUT_ARRAY,
convert_dt(dl),
shape.len() as _,
shape.as_ptr(),
strides.as_ptr(),
));
Tensor(desc)
}
}

Expand All @@ -20,8 +30,30 @@ impl Drop for Tensor {
}
}

fn convert_dt(dl: DigitLayout) -> cnnlDataType_t {
use cnnlDataType_t::*;
use digit_layout::types::*;
match dl {
F16 => CNNL_DTYPE_HALF,
BF16 => CNNL_DTYPE_BFLOAT16,
F32 => CNNL_DTYPE_FLOAT,
F64 => CNNL_DTYPE_DOUBLE,
I8 => CNNL_DTYPE_INT8,
I16 => CNNL_DTYPE_INT16,
I32 => CNNL_DTYPE_INT32,
I64 => CNNL_DTYPE_INT64,
U8 => CNNL_DTYPE_UINT8,
U16 => CNNL_DTYPE_UINT16,
U32 => CNNL_DTYPE_UINT32,
U64 => CNNL_DTYPE_UINT64,
BOOL => CNNL_DTYPE_BOOL,
_ => CNNL_DTYPE_INVALID,
}
}

#[test]
fn test() {
let _tensor = Tensor::new();
use digit_layout::types::F16;
let _tensor = Tensor::new(F16, &[2, 3, 4], &[12, 4, 1]);
println!("test passed");
}

0 comments on commit 0a567e2

Please sign in to comment.