Skip to content

Commit

Permalink
feat(cndrv): 导出 kernel 属性探测接口
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Jul 9, 2024
1 parent 7aea68f commit 7143130
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 49 deletions.
13 changes: 12 additions & 1 deletion cndrv/src/cnrtc/binary.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
use crate::bindings::{cnrtcCompileCode, cnrtcStatus};
use crate::{
bindings::{cnrtcCompileCode, cnrtcStatus},
MemSize,
};
use std::{
ffi::{c_int, CString},
os::raw::c_void,
ptr::{null, null_mut},
};

#[repr(transparent)]
pub struct CnrtcBinary(Vec<u8>);

impl CnrtcBinary {
Expand Down Expand Up @@ -53,6 +57,13 @@ impl CnrtcBinary {
};
(ans, log)
}

#[inline]
pub fn memory_usage(&self) -> MemSize {
let mut bytes = 0;
cndrv!(cnModuleQueryFatBinaryMemoryUsage(self.as_ptr(), &mut bytes));
MemSize(bytes)
}
}

impl CnrtcBinary {
Expand Down
55 changes: 48 additions & 7 deletions cndrv/src/cnrtc/kernel_fn.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@
use crate::{bindings::CNkernel, AsRaw, Queue};
use super::module::Module;
use crate::{
bindings::{
CNkernel,
CNkernel_attribute_enum::{self, *},
},
AsRaw, MemSize, Queue,
};
use std::{
ffi::{c_void, CStr},
marker::PhantomData,
ptr::null_mut,
};

use super::module::Module;

pub struct KernelFn<'m>(CNkernel, #[allow(unused)] &'m Module<'m>);
#[repr(transparent)]
pub struct KernelFn<'m>(CNkernel, PhantomData<&'m ()>);

impl<'m> Module<'m> {
pub fn get_kernel(&'m self, name: impl AsRef<CStr>) -> KernelFn<'m> {
impl Module<'_> {
#[inline]
pub fn get_kernel(&self, name: impl AsRef<CStr>) -> KernelFn {
let name = name.as_ref();
let mut kernel = null_mut();
cndrv!(cnModuleGetKernel(self.as_raw(), name.as_ptr(), &mut kernel));
KernelFn(kernel, self)
KernelFn(kernel, PhantomData)
}
}

impl KernelFn<'_> {
#[inline]
pub fn launch(
&self,
dimz: u32,
Expand All @@ -38,4 +47,36 @@ impl KernelFn<'_> {
null_mut(),
));
}

#[inline]
pub fn nram_usage(&self) -> MemSize {
MemSize(self.get_attribute(CN_KERNEL_ATTRIBUTE_NRAM_SIZE_BYTES) as _)
}

#[inline]
pub fn wram_usage(&self) -> MemSize {
MemSize(self.get_attribute(CN_KERNEL_ATTRIBUTE_WEIGHT_RAM_SIZE_BYTES) as _)
}

#[inline]
pub fn smem_usage(&self) -> MemSize {
MemSize(self.get_attribute(CN_KERNEL_ATTRIBUTE_SHARED_SIZE_BYTES) as _)
}

#[inline]
pub fn const_usage(&self) -> MemSize {
MemSize(self.get_attribute(CN_KERNEL_ATTRIBUTE_CONST_SIZE_BYTES) as _)
}

#[inline]
pub fn binary_version(&self) -> usize {
self.get_attribute(CN_KERNEL_ATTRIBUTE_BINARY_VERSION) as _
}

#[inline]
fn get_attribute(&self, attr: CNkernel_attribute_enum) -> i64 {
let mut value = 0;
cndrv!(cnKernelGetAttribute(&mut value, attr, self.0));
value
}
}
2 changes: 1 addition & 1 deletion cndrv/src/cnrtc/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ impl_spore!(Module and ModuleSpore by CNmodule);

impl CurrentCtx {
#[inline]
pub fn load(&self, bin: CnrtcBinary) -> Module {
pub fn load(&self, bin: &CnrtcBinary) -> Module {
let mut module = null_mut();
cndrv!(cnModuleLoadFatBinary(bin.as_ptr(), &mut module));
Module(unsafe { self.wrap_raw(module) }, PhantomData)
Expand Down
41 changes: 1 addition & 40 deletions cndrv/src/device.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{bindings::CNdev, AsRaw};
use crate::{bindings::CNdev, AsRaw, MemSize};
use std::{ffi::c_int, fmt};

#[repr(transparent)]
Expand Down Expand Up @@ -96,45 +96,6 @@ impl fmt::Display for InfoFmt<'_> {
}
}

#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
#[repr(transparent)]
pub struct MemSize(pub usize);

impl fmt::Display for MemSize {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.0 == 0 {
write!(f, "0")
} else {
let zeros = self.0.trailing_zeros();
if zeros >= 40 {
write!(f, "{}TiB", self.0 >> 40)
} else if zeros >= 30 {
write!(f, "{}GiB", self.0 >> 30)
} else if zeros >= 20 {
write!(f, "{}MiB", self.0 >> 20)
} else if zeros >= 10 {
write!(f, "{}KiB", self.0 >> 10)
} else {
write!(f, "{}B", self.0)
}
}
}
}

impl From<c_int> for MemSize {
#[inline]
fn from(value: c_int) -> Self {
Self(value as _)
}
}

impl From<usize> for MemSize {
#[inline]
fn from(value: usize) -> Self {
Self(value)
}
}

#[test]
fn test() {
crate::init();
Expand Down
41 changes: 41 additions & 0 deletions cndrv/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,44 @@ struct Blob<P> {
ptr: P,
len: usize,
}

#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
#[repr(transparent)]
pub struct MemSize(pub usize);

use std::{ffi::c_int, fmt};

impl fmt::Display for MemSize {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.0 == 0 {
write!(f, "0")
} else {
let zeros = self.0.trailing_zeros();
if zeros >= 40 {
write!(f, "{}TiB", self.0 >> 40)
} else if zeros >= 30 {
write!(f, "{}GiB", self.0 >> 30)
} else if zeros >= 20 {
write!(f, "{}MiB", self.0 >> 20)
} else if zeros >= 10 {
write!(f, "{}KiB", self.0 >> 10)
} else {
write!(f, "{}B", self.0)
}
}
}
}

impl From<c_int> for MemSize {
#[inline]
fn from(value: c_int) -> Self {
Self(value as _)
}
}

impl From<usize> for MemSize {
#[inline]
fn from(value: usize) -> Self {
Self(value)
}
}

0 comments on commit 7143130

Please sign in to comment.