From c3618fd9df071f7d6416b08bca9ead9b1a0e39dc Mon Sep 17 00:00:00 2001 From: Rajeev Rao Date: Wed, 8 Jun 2022 17:16:10 -0700 Subject: [PATCH] DMHA plugin refactoring Signed-off-by: Rajeev Rao --- plugin/README.md | 2 +- .../multiscaleDeformableAttn.cu | 24 +- .../multiscaleDeformableAttn.h | 15 +- .../multiscaleDeformableAttnPlugin.cpp | 139 +-- .../multiscaleDeformableAttnPlugin.h | 53 +- .../multiscaleDeformableIm2ColCuda.cuh | 945 ++++++++---------- 6 files changed, 540 insertions(+), 638 deletions(-) diff --git a/plugin/README.md b/plugin/README.md index 1fb45ddd..acbc417b 100644 --- a/plugin/README.md +++ b/plugin/README.md @@ -27,7 +27,7 @@ | [leakyReluPlugin](leakyReluPlugin) | LReLU_TRT | 1 | | [multilevelCropAndResizePlugin](multilevelCropAndResizePlugin) | MultilevelCropAndResize_TRT | 1 | | [multilevelProposeROI](multilevelProposeROI) | MultilevelProposeROI_TRT | 1 | -| [multiscaleDeformableAttnPlugin](multiscaleDeformableAttnPlugin) | DMHA | 1 | +| [multiscaleDeformableAttnPlugin](multiscaleDeformableAttnPlugin) | MultiscaleDeformableAttnPlugin_TRT | 1 | | [nmsPlugin](nmsPlugin) | NMS_TRT | 1 | | [normalizePlugin](normalizePlugin) | Normalize_TRT | 1 | | [nvFasterRCNN](nvFasterRCNN) | RPROI_TRT | 1 | diff --git a/plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableAttn.cu b/plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableAttn.cu index 52bf1576..e52833d6 100644 --- a/plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableAttn.cu +++ b/plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableAttn.cu @@ -36,17 +36,17 @@ #include "multiscaleDeformableIm2ColCuda.cuh" -int ms_deform_attn_cuda_forward(cudaStream_t stream, const float* value, const int32_t* spatialShapes, - const int32_t* levelStartIndex, const float* samplingLoc, const float* attnWeight, float* output, int batch, - int mSpatialSize, int mNumHeads, int mChannels, int mNumLevels, int mNumQuery, int mNumPoint) +int32_t ms_deform_attn_cuda_forward(cudaStream_t stream, const float* value, const int32_t* spatialShapes, + const int32_t* levelStartIndex, const float* samplingLoc, const float* attnWeight, float* output, int32_t batch, + int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint) { auto perValueSize = mSpatialSize * mNumHeads * mChannels; auto perSampleLocSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint * 2; auto perAttnWeightSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint; - int mIm2colStep = batch; + int32_t mIm2colStep = batch; - for (int n = 0; n < batch / mIm2colStep; ++n) + for (int32_t n = 0; n < batch / mIm2colStep; ++n) { auto columns = output + perValueSize * n * mIm2colStep; ms_deformable_im2col_cuda(stream, value + n * mIm2colStep * perValueSize, spatialShapes, levelStartIndex, @@ -57,18 +57,16 @@ int ms_deform_attn_cuda_forward(cudaStream_t stream, const float* value, const i return 0; } -#if __CUDA_ARCH__ >= 530 - -int ms_deform_attn_cuda_forward(cudaStream_t stream, const __half* value, const int32_t* spatialShapes, - const int32_t* levelStartIndex, const __half* samplingLoc, const __half* attnWeight, __half* output, int batch, - int mSpatialSize, int mNumHeads, int mChannels, int mNumLevels, int mNumQuery, int mNumPoint) +int32_t ms_deform_attn_cuda_forward(cudaStream_t stream, const __half* value, const int32_t* spatialShapes, + const int32_t* levelStartIndex, const __half* samplingLoc, const __half* attnWeight, __half* output, int32_t batch, + int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint) { auto perValueSize = mSpatialSize * mNumHeads * mChannels; auto perSampleLocSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint * 2; auto perAttnWeightSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint; - int mIm2colStep = batch; - for (int n = 0; n < batch / mIm2colStep; ++n) + int32_t mIm2colStep = batch; + for (int32_t n = 0; n < batch / mIm2colStep; ++n) { auto columns = output + perValueSize * n * mIm2colStep; ms_deformable_im2col_cuda<__half>(stream, value + n * mIm2colStep * perValueSize, spatialShapes, @@ -79,5 +77,3 @@ int ms_deform_attn_cuda_forward(cudaStream_t stream, const __half* value, const return 0; } - -#endif diff --git a/plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableAttn.h b/plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableAttn.h index a219904a..b2203125 100644 --- a/plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableAttn.h +++ b/plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableAttn.h @@ -32,14 +32,13 @@ #ifndef TRT_MULTISCALE_DEFORMABLE_ATTN_H #define TRT_MULTISCALE_DEFORMABLE_ATTN_H -int ms_deform_attn_cuda_forward(cudaStream_t stream, const float* value, const int32_t* spatialShapes, - const int32_t* levelStartIndex, const float* samplingLoc, const float* attnWeight, float* output, int batch, - int mSpatialSize, int mNumHeads, int mChannels, int mNumLevels, int mNumQuery, int mNumPoint); +int32_t ms_deform_attn_cuda_forward(cudaStream_t stream, float const* value, int32_t const* spatialShapes, + int32_t const* levelStartIndex, float const* samplingLoc, float const* attnWeight, float* output, int32_t batch, + int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint); -#if __CUDA_ARCH__ >= 530 -int ms_deform_attn_cuda_forward(cudaStream_t stream, const __half* value, const int32_t* spatialShapes, - const int32_t* levelStartIndex, const __half* samplingLoc, const __half* attnWeight, __half* output, int batch, - int mSpatialSize, int mNumHeads, int mChannels, int mNumLevels, int mNumQuery, int mNumPoint); -#endif + +int32_t ms_deform_attn_cuda_forward(cudaStream_t stream, const __half* value, int32_t const* spatialShapes, + int32_t const* levelStartIndex, const __half* samplingLoc, const __half* attnWeight, __half* output, int32_t batch, + int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint); #endif diff --git a/plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableAttnPlugin.cpp b/plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableAttnPlugin.cpp index 3183c22e..36f045cb 100644 --- a/plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableAttnPlugin.cpp +++ b/plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableAttnPlugin.cpp @@ -28,29 +28,35 @@ namespace plugin namespace { -static const char* DMHA_VERSION{"1"}; -static const char* DMHA_NAME{"DMHA"}; +static char const* DMHA_VERSION{"1"}; +static char const* DMHA_NAME{"MultiscaleDeformableAttnPlugin_TRT"}; } // namespace -MultiscaleDeformableAttnPlugin::MultiscaleDeformableAttnPlugin(const std::string& name) - : mLayerName(name) +MultiscaleDeformableAttnPlugin::MultiscaleDeformableAttnPlugin() { } -MultiscaleDeformableAttnPlugin::MultiscaleDeformableAttnPlugin(const std::string& name, const void* data, size_t length) - : mLayerName(name) +MultiscaleDeformableAttnPlugin::MultiscaleDeformableAttnPlugin(void const* data, size_t length) { } nvinfer1::IPluginV2DynamicExt* MultiscaleDeformableAttnPlugin::clone() const PLUGIN_NOEXCEPT { - MultiscaleDeformableAttnPlugin* plugin = new MultiscaleDeformableAttnPlugin(mLayerName); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; + try + { + MultiscaleDeformableAttnPlugin* plugin = new MultiscaleDeformableAttnPlugin(); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + catch (const std::exception& e) + { + caughtError(e); + } + return nullptr; } -nvinfer1::DimsExprs MultiscaleDeformableAttnPlugin::getOutputDimensions(int outputIndex, - const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) PLUGIN_NOEXCEPT +nvinfer1::DimsExprs MultiscaleDeformableAttnPlugin::getOutputDimensions(int32_t outputIndex, + nvinfer1::DimsExprs const* inputs, int32_t nbInputs, nvinfer1::IExprBuilder& exprBuilder) PLUGIN_NOEXCEPT { nvinfer1::DimsExprs ret; ret.nbDims = 4; @@ -63,7 +69,7 @@ nvinfer1::DimsExprs MultiscaleDeformableAttnPlugin::getOutputDimensions(int outp } bool MultiscaleDeformableAttnPlugin::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) PLUGIN_NOEXCEPT + int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) PLUGIN_NOEXCEPT { ASSERT((nbInputs == 5)); ASSERT((nbOutputs == 1)); @@ -76,12 +82,8 @@ bool MultiscaleDeformableAttnPlugin::supportsFormatCombination( } else { -#if __CUDA_ARCH__ >= 530 - return ((inOut[pos].type == inOut[0].type) - && ((inOut[pos].type == nvinfer1::DataType::kFLOAT) || (inOut[pos].type == nvinfer1::DataType::kHALF))); -#else - return ((inOut[pos].type == inOut[0].type) && ((inOut[pos].type == nvinfer1::DataType::kFLOAT))); -#endif + return ((inOut[pos].type == inOut[0].type) && + ((inOut[pos].type == nvinfer1::DataType::kFLOAT) || (inOut[pos].type == nvinfer1::DataType::kHALF))); } } else @@ -90,8 +92,8 @@ bool MultiscaleDeformableAttnPlugin::supportsFormatCombination( } } -void MultiscaleDeformableAttnPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* outputs, int nbOutputs) PLUGIN_NOEXCEPT +void MultiscaleDeformableAttnPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* inputs, int32_t nbInputs, + nvinfer1::DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) PLUGIN_NOEXCEPT { // Check for valid input dimensions ASSERT(inputs[0].desc.dims.nbDims==4); @@ -116,42 +118,41 @@ void MultiscaleDeformableAttnPlugin::configurePlugin(const nvinfer1::DynamicPlug ASSERT(inputs[3].desc.dims.d[1] == inputs[4].desc.dims.d[1]); } -size_t MultiscaleDeformableAttnPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const PLUGIN_NOEXCEPT +size_t MultiscaleDeformableAttnPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const PLUGIN_NOEXCEPT { return 0; } -int MultiscaleDeformableAttnPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workSpace, +int32_t MultiscaleDeformableAttnPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workSpace, cudaStream_t stream) PLUGIN_NOEXCEPT { - const int batch = inputDesc[0].dims.d[0]; - int spatial_size = inputDesc[0].dims.d[1]; - int num_heads = inputDesc[0].dims.d[2]; - int channels = inputDesc[0].dims.d[3]; - int num_levels = inputDesc[1].dims.d[0]; - int num_query = inputDesc[3].dims.d[1]; - int num_point = inputDesc[3].dims.d[4]; - int rc = 0; + int32_t const batch = inputDesc[0].dims.d[0]; + int32_t spatial_size = inputDesc[0].dims.d[1]; + int32_t num_heads = inputDesc[0].dims.d[2]; + int32_t channels = inputDesc[0].dims.d[3]; + int32_t num_levels = inputDesc[1].dims.d[0]; + int32_t num_query = inputDesc[3].dims.d[1]; + int32_t num_point = inputDesc[3].dims.d[4]; + int32_t rc = 0; if (inputDesc[0].type == nvinfer1::DataType::kFLOAT) { - const float* value = static_cast(inputs[0]); - const int32_t* spatialShapes = static_cast(inputs[1]); - const int32_t* levelStartIndex = static_cast(inputs[2]); - const float* samplingLoc = static_cast(inputs[3]); - const float* attnWeight = static_cast(inputs[4]); + float const* value = static_cast(inputs[0]); + int32_t const* spatialShapes = static_cast(inputs[1]); + int32_t const* levelStartIndex = static_cast(inputs[2]); + float const* samplingLoc = static_cast(inputs[3]); + float const* attnWeight = static_cast(inputs[4]); float* output = static_cast(outputs[0]); rc = ms_deform_attn_cuda_forward(stream, value, spatialShapes, levelStartIndex, samplingLoc, attnWeight, output, batch, spatial_size, num_heads, channels, num_levels, num_query, num_point); } -#if __CUDA_ARCH__ >= 530 else if (inputDesc[0].type == nvinfer1::DataType::kHALF) { const __half* value = static_cast(inputs[0]); - const int32_t* spatialShapes = static_cast(inputs[1]); - const int32_t* levelStartIndex = static_cast(inputs[2]); + int32_t const* spatialShapes = static_cast(inputs[1]); + int32_t const* levelStartIndex = static_cast(inputs[2]); const __half* samplingLoc = static_cast(inputs[3]); const __half* attnWeight = static_cast(inputs[4]); __half* output = static_cast<__half*>(outputs[0]); @@ -159,7 +160,6 @@ int MultiscaleDeformableAttnPlugin::enqueue(const nvinfer1::PluginTensorDesc* in rc = ms_deform_attn_cuda_forward(stream, value, spatialShapes, levelStartIndex, samplingLoc, attnWeight, output, batch, spatial_size, num_heads, channels, num_levels, num_query, num_point); } -#endif return rc; } @@ -167,35 +167,34 @@ int MultiscaleDeformableAttnPlugin::enqueue(const nvinfer1::PluginTensorDesc* in void MultiscaleDeformableAttnPlugin::attachToContext( cudnnContext* cudnnContext, cublasContext* cublasContext, nvinfer1::IGpuAllocator* gpuAllocator) PLUGIN_NOEXCEPT { - mCublasHandle = cublasContext; } void MultiscaleDeformableAttnPlugin::detachFromContext() PLUGIN_NOEXCEPT {} // IPluginV2Ext Methods nvinfer1::DataType MultiscaleDeformableAttnPlugin::getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const PLUGIN_NOEXCEPT + int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const PLUGIN_NOEXCEPT { return inputTypes[0]; } // IPluginV2 Methods -const char* MultiscaleDeformableAttnPlugin::getPluginType() const PLUGIN_NOEXCEPT +char const* MultiscaleDeformableAttnPlugin::getPluginType() const PLUGIN_NOEXCEPT { return DMHA_NAME; } -const char* MultiscaleDeformableAttnPlugin::getPluginVersion() const PLUGIN_NOEXCEPT +char const* MultiscaleDeformableAttnPlugin::getPluginVersion() const PLUGIN_NOEXCEPT { return DMHA_VERSION; } -int MultiscaleDeformableAttnPlugin::getNbOutputs() const PLUGIN_NOEXCEPT +int32_t MultiscaleDeformableAttnPlugin::getNbOutputs() const PLUGIN_NOEXCEPT { return 1; } -int MultiscaleDeformableAttnPlugin::initialize() PLUGIN_NOEXCEPT +int32_t MultiscaleDeformableAttnPlugin::initialize() PLUGIN_NOEXCEPT { return 0; } @@ -216,11 +215,11 @@ void MultiscaleDeformableAttnPlugin::destroy() PLUGIN_NOEXCEPT delete this; } -void MultiscaleDeformableAttnPlugin::setPluginNamespace(const char* pluginNamespace) PLUGIN_NOEXCEPT +void MultiscaleDeformableAttnPlugin::setPluginNamespace(char const* pluginNamespace) PLUGIN_NOEXCEPT { mNamespace = pluginNamespace; } -const char* MultiscaleDeformableAttnPlugin::getPluginNamespace() const PLUGIN_NOEXCEPT +char const* MultiscaleDeformableAttnPlugin::getPluginNamespace() const PLUGIN_NOEXCEPT { return mNamespace.c_str(); } @@ -234,45 +233,61 @@ MultiscaleDeformableAttnPluginCreator::MultiscaleDeformableAttnPluginCreator() mFC.fields = mPluginAttributes.data(); } -const char* MultiscaleDeformableAttnPluginCreator::getPluginName() const PLUGIN_NOEXCEPT +char const* MultiscaleDeformableAttnPluginCreator::getPluginName() const PLUGIN_NOEXCEPT { return DMHA_NAME; } -const char* MultiscaleDeformableAttnPluginCreator::getPluginVersion() const PLUGIN_NOEXCEPT +char const* MultiscaleDeformableAttnPluginCreator::getPluginVersion() const PLUGIN_NOEXCEPT { return DMHA_VERSION; } -const nvinfer1::PluginFieldCollection* MultiscaleDeformableAttnPluginCreator::getFieldNames() PLUGIN_NOEXCEPT +nvinfer1::PluginFieldCollection const* MultiscaleDeformableAttnPluginCreator::getFieldNames() PLUGIN_NOEXCEPT { return &mFC; } IPluginV2* MultiscaleDeformableAttnPluginCreator::createPlugin( - const char* name, const PluginFieldCollection* fc) PLUGIN_NOEXCEPT + char const* name, PluginFieldCollection const* fc) PLUGIN_NOEXCEPT { - MultiscaleDeformableAttnPlugin* plugin = new MultiscaleDeformableAttnPlugin(name); - return plugin; + try + { + MultiscaleDeformableAttnPlugin* plugin = new MultiscaleDeformableAttnPlugin(); + return plugin; + } + catch (const std::exception& e) + { + caughtError(e); + } + return nullptr; } IPluginV2* MultiscaleDeformableAttnPluginCreator::deserializePlugin( - const char* name, const void* serialData, size_t serialLength) PLUGIN_NOEXCEPT + char const* name, void const* serialData, size_t serialLength) PLUGIN_NOEXCEPT { - auto plugin = new MultiscaleDeformableAttnPlugin(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; + try + { + auto plugin = new MultiscaleDeformableAttnPlugin(serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + catch (const std::exception& e) + { + caughtError(e); + } + return nullptr; } -void MultiscaleDeformableAttnPluginCreator::setPluginNamespace(const char* pluginNamespace) PLUGIN_NOEXCEPT +void MultiscaleDeformableAttnPluginCreator::setPluginNamespace(char const* pluginNamespace) PLUGIN_NOEXCEPT { mNamespace = pluginNamespace; } -const char* MultiscaleDeformableAttnPluginCreator::getPluginNamespace() const PLUGIN_NOEXCEPT +char const* MultiscaleDeformableAttnPluginCreator::getPluginNamespace() const PLUGIN_NOEXCEPT { return mNamespace.c_str(); } } // namespace plugin -} // namespace nvinfer1 +} // namespace nvinfer1 \ No newline at end of file diff --git a/plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableAttnPlugin.h b/plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableAttnPlugin.h index 30757bc4..1e75a916 100644 --- a/plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableAttnPlugin.h +++ b/plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableAttnPlugin.h @@ -57,49 +57,44 @@ namespace plugin class MultiscaleDeformableAttnPlugin : public nvinfer1::IPluginV2DynamicExt { public: - MultiscaleDeformableAttnPlugin() = delete; + MultiscaleDeformableAttnPlugin(); - MultiscaleDeformableAttnPlugin(const std::string& name); - - MultiscaleDeformableAttnPlugin(const std::string& name, const void* data, size_t length); + MultiscaleDeformableAttnPlugin(void const* data, size_t length); // IPluginV2DynamicExt methods nvinfer1::IPluginV2DynamicExt* clone() const PLUGIN_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs, nvinfer1::IExprBuilder& exprBuilder) PLUGIN_NOEXCEPT override; bool supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) PLUGIN_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) PLUGIN_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const PLUGIN_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) PLUGIN_NOEXCEPT override; + int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) PLUGIN_NOEXCEPT override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) PLUGIN_NOEXCEPT override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const PLUGIN_NOEXCEPT override; + int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) PLUGIN_NOEXCEPT override; void attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, nvinfer1::IGpuAllocator* gpuAllocator) PLUGIN_NOEXCEPT override; void detachFromContext() PLUGIN_NOEXCEPT override; // IPluginV2Ext Methods nvinfer1::DataType getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const PLUGIN_NOEXCEPT override; + int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const PLUGIN_NOEXCEPT override; // IPluginV2 Methods - const char* getPluginType() const PLUGIN_NOEXCEPT override; - const char* getPluginVersion() const PLUGIN_NOEXCEPT override; - int getNbOutputs() const PLUGIN_NOEXCEPT override; - int initialize() PLUGIN_NOEXCEPT override; + char const* getPluginType() const PLUGIN_NOEXCEPT override; + char const* getPluginVersion() const PLUGIN_NOEXCEPT override; + int32_t getNbOutputs() const PLUGIN_NOEXCEPT override; + int32_t initialize() PLUGIN_NOEXCEPT override; void terminate() PLUGIN_NOEXCEPT override; size_t getSerializationSize() const PLUGIN_NOEXCEPT override; void serialize(void* buffer) const PLUGIN_NOEXCEPT override; void destroy() PLUGIN_NOEXCEPT override; - void setPluginNamespace(const char* pluginNamespace) PLUGIN_NOEXCEPT override; - const char* getPluginNamespace() const PLUGIN_NOEXCEPT override; + void setPluginNamespace(char const* pluginNamespace) PLUGIN_NOEXCEPT override; + char const* getPluginNamespace() const PLUGIN_NOEXCEPT override; private: std::string mNamespace; - const std::string mLayerName; - - cublasHandle_t mCublasHandle; #if NV_TENSORRT_MAJOR < 8 using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch; @@ -116,15 +111,15 @@ class MultiscaleDeformableAttnPluginCreator : public nvinfer1::IPluginCreator { public: MultiscaleDeformableAttnPluginCreator(); - const char* getPluginName() const PLUGIN_NOEXCEPT override; - const char* getPluginVersion() const PLUGIN_NOEXCEPT override; - const nvinfer1::PluginFieldCollection* getFieldNames() PLUGIN_NOEXCEPT override; + char const* getPluginName() const PLUGIN_NOEXCEPT override; + char const* getPluginVersion() const PLUGIN_NOEXCEPT override; + nvinfer1::PluginFieldCollection const* getFieldNames() PLUGIN_NOEXCEPT override; nvinfer1::IPluginV2* createPlugin( - const char* name, const nvinfer1::PluginFieldCollection* fc) PLUGIN_NOEXCEPT override; + char const* name, nvinfer1::PluginFieldCollection const* fc) PLUGIN_NOEXCEPT override; nvinfer1::IPluginV2* deserializePlugin( - const char* name, const void* serialData, size_t serialLength) PLUGIN_NOEXCEPT override; - void setPluginNamespace(const char* pluginNamespace) PLUGIN_NOEXCEPT override; - const char* getPluginNamespace() const PLUGIN_NOEXCEPT override; + char const* name, void const* serialData, size_t serialLength) PLUGIN_NOEXCEPT override; + void setPluginNamespace(char const* pluginNamespace) PLUGIN_NOEXCEPT override; + char const* getPluginNamespace() const PLUGIN_NOEXCEPT override; private: nvinfer1::PluginFieldCollection mFC; diff --git a/plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableIm2ColCuda.cuh b/plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableIm2ColCuda.cuh index 85e81f36..7390f876 100644 --- a/plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableIm2ColCuda.cuh +++ b/plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableIm2ColCuda.cuh @@ -33,115 +33,124 @@ #include #include -#define CUDA_KERNEL_LOOP(i, n) for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) +#include "common/checkMacrosPlugin.h" -const int CUDA_NUM_THREADS = 768; -inline int GET_BLOCKS(const int N, const int numThreads) +#define CUDA_KERNEL_LOOP(i, n) for (int32_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) + +constexpr int32_t CUDA_NUM_THREADS = 768; +inline int32_t GET_BLOCKS(int32_t const N, int32_t const numThreads) { return (N + numThreads - 1) / numThreads; } template -__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t*& bottomData, const int& height, const int& width, - const int& nHeads, const int& channels, const scalar_t& h, const scalar_t& w, const int& m, const int& c) +__device__ scalar_t ms_deform_attn_im2col_bilinear(scalar_t const*& bottomData, int32_t const& height, int32_t const& width, + int32_t const& nHeads, int32_t const& channels, scalar_t const& h, scalar_t const& w, int32_t const& m, int32_t const& c) { - const int hLow = floor(h); - const int wLow = floor(w); - const int hHigh = hLow + 1; - const int wHigh = wLow + 1; - - const scalar_t lh = h - hLow; - const scalar_t lw = w - wLow; - const scalar_t hh = 1 - lh, hw = 1 - lw; - - const int wStride = nHeads * channels; - const int hStride = width * wStride; - const int hLowPtrOffset = hLow * hStride; - const int hHighPtrOffset = hLowPtrOffset + hStride; - const int wLowPtrOffset = wLow * wStride; - const int wHighPtrOffset = wLowPtrOffset + wStride; - const int basePtr = m * channels + c; + int32_t const hLow = floor(h); + int32_t const wLow = floor(w); + int32_t const hHigh = hLow + 1; + int32_t const wHigh = wLow + 1; + + scalar_t const lh = h - hLow; + scalar_t const lw = w - wLow; + scalar_t const hh = 1 - lh, hw = 1 - lw; + + int32_t const wStride = nHeads * channels; + int32_t const hStride = width * wStride; + int32_t const hLowPtrOffset = hLow * hStride; + int32_t const hHighPtrOffset = hLowPtrOffset + hStride; + int32_t const wLowPtrOffset = wLow * wStride; + int32_t const wHighPtrOffset = wLowPtrOffset + wStride; + int32_t const basePtr = m * channels + c; scalar_t v1 = 0; if (hLow >= 0 && wLow >= 0) { - const int ptr1 = hLowPtrOffset + wLowPtrOffset + basePtr; + int32_t const ptr1 = hLowPtrOffset + wLowPtrOffset + basePtr; v1 = bottomData[ptr1]; } scalar_t v2 = 0; if (hLow >= 0 && wHigh <= width - 1) { - const int ptr2 = hLowPtrOffset + wHighPtrOffset + basePtr; + int32_t const ptr2 = hLowPtrOffset + wHighPtrOffset + basePtr; v2 = bottomData[ptr2]; } scalar_t v3 = 0; if (hHigh <= height - 1 && wLow >= 0) { - const int ptr3 = hHighPtrOffset + wLowPtrOffset + basePtr; + int32_t const ptr3 = hHighPtrOffset + wLowPtrOffset + basePtr; v3 = bottomData[ptr3]; } scalar_t v4 = 0; if (hHigh <= height - 1 && wHigh <= width - 1) { - const int ptr4 = hHighPtrOffset + wHighPtrOffset + basePtr; + int32_t const ptr4 = hHighPtrOffset + wHighPtrOffset + basePtr; v4 = bottomData[ptr4]; } - const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + scalar_t const w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + scalar_t const val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); return val; } -#if __CUDA_ARCH__ >= 530 template <> -__device__ __half ms_deform_attn_im2col_bilinear<__half>(const __half*& bottomData, const int& height, const int& width, - const int& nHeads, const int& channels, const __half& h, const __half& w, const int& m, const int& c) +__device__ __half ms_deform_attn_im2col_bilinear<__half>(const __half*& bottomData, int32_t const& height, int32_t const& width, + int32_t const& nHeads, int32_t const& channels, const __half& h, const __half& w, int32_t const& m, int32_t const& c) { - const int hLow = __half2int_rd(h); - const int wLow = __half2int_rd(w); - const int hHigh = hLow + 1; - const int wHigh = wLow + 1; + int32_t const hLow = __half2int_rd(h); + int32_t const wLow = __half2int_rd(w); + int32_t const hHigh = hLow + 1; + int32_t const wHigh = wLow + 1; - const __half zero = __int2half_rz(0); + const __half kZERO = __int2half_rz(0); const __half one = __int2half_rz(1); + +#if __CUDA_ARCH__>=530 const __half lh = __hsub(h, __int2half_rd(hLow)); const __half lw = __hsub(w, __int2half_rd(wLow)); const __half hh = __hsub(one, lh), hw = __hsub(one, lw); - - const int wStride = nHeads * channels; - const int hStride = width * wStride; - const int hLowPtrOffset = hLow * hStride; - const int hHighPtrOffset = hLowPtrOffset + hStride; - const int wLowPtrOffset = wLow * wStride; - const int wHighPtrOffset = wLowPtrOffset + wStride; - const int basePtr = m * channels + c; - - __half v1 = zero; +#else + const __half lh = __float2half(__half2float(h) - hLow); + const __half lw = __float2half(__half2float(w) - wLow); + const __half hh = __float2half(__half2float(one) - __half2float(lh)); + const __half hw = __float2half(__half2float(one) - __half2float(lw)); +#endif + int32_t const wStride = nHeads * channels; + int32_t const hStride = width * wStride; + int32_t const hLowPtrOffset = hLow * hStride; + int32_t const hHighPtrOffset = hLowPtrOffset + hStride; + int32_t const wLowPtrOffset = wLow * wStride; + int32_t const wHighPtrOffset = wLowPtrOffset + wStride; + int32_t const basePtr = m * channels + c; + + __half v1 = kZERO; if (hLow >= 0 && wLow >= 0) { - const int ptr1 = hLowPtrOffset + wLowPtrOffset + basePtr; + int32_t const ptr1 = hLowPtrOffset + wLowPtrOffset + basePtr; v1 = bottomData[ptr1]; } - __half v2 = zero; + __half v2 = kZERO; if (hLow >= 0 && wHigh <= width - 1) { - const int ptr2 = hLowPtrOffset + wHighPtrOffset + basePtr; + int32_t const ptr2 = hLowPtrOffset + wHighPtrOffset + basePtr; v2 = bottomData[ptr2]; } - __half v3 = zero; + __half v3 = kZERO; if (hHigh <= height - 1 && wLow >= 0) { - const int ptr3 = hHighPtrOffset + wLowPtrOffset + basePtr; + int32_t const ptr3 = hHighPtrOffset + wLowPtrOffset + basePtr; v3 = bottomData[ptr3]; } - __half v4 = zero; + __half v4 = kZERO; if (hHigh <= height - 1 && wHigh <= width - 1) { - const int ptr4 = hHighPtrOffset + wHighPtrOffset + basePtr; + int32_t const ptr4 = hHighPtrOffset + wHighPtrOffset + basePtr; v4 = bottomData[ptr4]; } +#if __CUDA_ARCH__>=530 __half w1 = __hmul(__hmul(hh, hw), v1); __half w2 = __hmul(__hmul(hh, lw), v2); __half w3 = __hmul(__hmul(lh, hw), v3); @@ -151,41 +160,51 @@ __device__ __half ms_deform_attn_im2col_bilinear<__half>(const __half*& bottomDa w3 = __hadd(w3, w4); const __half val = __hadd(w1, w3); +#else + __half w1 = __float2half((__half2float(hh) * __half2float(hw)) * __half2float(v1)); + __half w2 = __float2half((__half2float(hh) * __half2float(lw)) * __half2float(v2)); + __half w3 = __float2half((__half2float(lh) * __half2float(hw)) * __half2float(v3)); + __half w4 = __float2half((__half2float(lh) * __half2float(lw)) * __half2float(v4)); + + w1 = __float2half(__half2float(w1) + __half2float(w2)); + w3 = __float2half(__half2float(w3) + __half2float(w4)); + + const __half val = __float2half(__half2float(w1) + __half2float(w3)); +#endif return val; } -#endif template -__device__ void ms_deform_attn_col2im_bilinear(const scalar_t*& bottomData, const int& height, const int& width, - const int& nHeads, const int& channels, const scalar_t& h, const scalar_t& w, const int& m, const int& c, - const scalar_t& topGrad, const scalar_t& attnWeight, scalar_t*& gradValue, scalar_t* gradSamplingLoc, +__device__ void ms_deform_attn_col2im_bilinear(scalar_t const*& bottomData, int32_t const& height, int32_t const& width, + int32_t const& nHeads, int32_t const& channels, scalar_t const& h, scalar_t const& w, int32_t const& m, int32_t const& c, + scalar_t const& topGrad, scalar_t const& attnWeight, scalar_t*& gradValue, scalar_t* gradSamplingLoc, scalar_t* gradAttnWeight) { - const int hLow = floor(h); - const int wLow = floor(w); - const int hHigh = hLow + 1; - const int wHigh = wLow + 1; - - const scalar_t lh = h - hLow; - const scalar_t lw = w - wLow; - const scalar_t hh = 1 - lh, hw = 1 - lw; - - const int wStride = nHeads * channels; - const int hStride = width * wStride; - const int hLowPtrOffset = hLow * hStride; - const int hHighPtrOffset = hLowPtrOffset + hStride; - const int wLowPtrOffset = wLow * wStride; - const int wHighPtrOffset = wLowPtrOffset + wStride; - const int basePtr = m * channels + c; - - const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - const scalar_t topGradvalue = topGrad * attnWeight; + int32_t const hLow = floor(h); + int32_t const wLow = floor(w); + int32_t const hHigh = hLow + 1; + int32_t const wHigh = wLow + 1; + + scalar_t const lh = h - hLow; + scalar_t const lw = w - wLow; + scalar_t const hh = 1 - lh, hw = 1 - lw; + + int32_t const wStride = nHeads * channels; + int32_t const hStride = width * wStride; + int32_t const hLowPtrOffset = hLow * hStride; + int32_t const hHighPtrOffset = hLowPtrOffset + hStride; + int32_t const wLowPtrOffset = wLow * wStride; + int32_t const wHighPtrOffset = wLowPtrOffset + wStride; + int32_t const basePtr = m * channels + c; + + scalar_t const w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + scalar_t const topGradvalue = topGrad * attnWeight; scalar_t gradHWeight = 0, gradWWeight = 0; scalar_t v1 = 0; if (hLow >= 0 && wLow >= 0) { - const int ptr1 = hLowPtrOffset + wLowPtrOffset + basePtr; + int32_t const ptr1 = hLowPtrOffset + wLowPtrOffset + basePtr; v1 = bottomData[ptr1]; gradHWeight -= hw * v1; gradWWeight -= hh * v1; @@ -194,7 +213,7 @@ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t*& bottomData, cons scalar_t v2 = 0; if (hLow >= 0 && wHigh <= width - 1) { - const int ptr2 = hLowPtrOffset + wHighPtrOffset + basePtr; + int32_t const ptr2 = hLowPtrOffset + wHighPtrOffset + basePtr; v2 = bottomData[ptr2]; gradHWeight -= lw * v2; gradWWeight += hh * v2; @@ -203,7 +222,7 @@ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t*& bottomData, cons scalar_t v3 = 0; if (hHigh <= height - 1 && wLow >= 0) { - const int ptr3 = hHighPtrOffset + wLowPtrOffset + basePtr; + int32_t const ptr3 = hHighPtrOffset + wLowPtrOffset + basePtr; v3 = bottomData[ptr3]; gradHWeight += hw * v3; gradWWeight -= lh * v3; @@ -212,50 +231,50 @@ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t*& bottomData, cons scalar_t v4 = 0; if (hHigh <= height - 1 && wHigh <= width - 1) { - const int ptr4 = hHighPtrOffset + wHighPtrOffset + basePtr; + int32_t const ptr4 = hHighPtrOffset + wHighPtrOffset + basePtr; v4 = bottomData[ptr4]; gradHWeight += lw * v4; gradWWeight += lh * v4; atomicAdd(gradValue + ptr4, w4 * topGradvalue); } - const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + scalar_t const val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); *gradAttnWeight = topGrad * val; *gradSamplingLoc = width * gradWWeight * topGradvalue; *(gradSamplingLoc + 1) = height * gradHWeight * topGradvalue; } template -__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t*& bottomData, const int& height, const int& width, - const int& nHeads, const int& channels, const scalar_t& h, const scalar_t& w, const int& m, const int& c, - const scalar_t& topGrad, const scalar_t& attnWeight, scalar_t*& gradValue, scalar_t* gradSamplingLoc, +__device__ void ms_deform_attn_col2im_bilinear_gm(scalar_t const*& bottomData, int32_t const& height, int32_t const& width, + int32_t const& nHeads, int32_t const& channels, scalar_t const& h, scalar_t const& w, int32_t const& m, int32_t const& c, + scalar_t const& topGrad, scalar_t const& attnWeight, scalar_t*& gradValue, scalar_t* gradSamplingLoc, scalar_t* gradAttnWeight) { - const int hLow = floor(h); - const int wLow = floor(w); - const int hHigh = hLow + 1; - const int wHigh = wLow + 1; - - const scalar_t lh = h - hLow; - const scalar_t lw = w - wLow; - const scalar_t hh = 1 - lh, hw = 1 - lw; - - const int wStride = nHeads * channels; - const int hStride = width * wStride; - const int hLowPtrOffset = hLow * hStride; - const int hHighPtrOffset = hLowPtrOffset + hStride; - const int wLowPtrOffset = wLow * wStride; - const int wHighPtrOffset = wLowPtrOffset + wStride; - const int basePtr = m * channels + c; - - const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - const scalar_t topGradvalue = topGrad * attnWeight; + int32_t const hLow = floor(h); + int32_t const wLow = floor(w); + int32_t const hHigh = hLow + 1; + int32_t const wHigh = wLow + 1; + + scalar_t const lh = h - hLow; + scalar_t const lw = w - wLow; + scalar_t const hh = 1 - lh, hw = 1 - lw; + + int32_t const wStride = nHeads * channels; + int32_t const hStride = width * wStride; + int32_t const hLowPtrOffset = hLow * hStride; + int32_t const hHighPtrOffset = hLowPtrOffset + hStride; + int32_t const wLowPtrOffset = wLow * wStride; + int32_t const wHighPtrOffset = wLowPtrOffset + wStride; + int32_t const basePtr = m * channels + c; + + scalar_t const w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + scalar_t const topGradvalue = topGrad * attnWeight; scalar_t gradHWeight = 0, gradWWeight = 0; scalar_t v1 = 0; if (hLow >= 0 && wLow >= 0) { - const int ptr1 = hLowPtrOffset + wLowPtrOffset + basePtr; + int32_t const ptr1 = hLowPtrOffset + wLowPtrOffset + basePtr; v1 = bottomData[ptr1]; gradHWeight -= hw * v1; gradWWeight -= hh * v1; @@ -264,7 +283,7 @@ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t*& bottomData, c scalar_t v2 = 0; if (hLow >= 0 && wHigh <= width - 1) { - const int ptr2 = hLowPtrOffset + wHighPtrOffset + basePtr; + int32_t const ptr2 = hLowPtrOffset + wHighPtrOffset + basePtr; v2 = bottomData[ptr2]; gradHWeight -= lw * v2; gradWWeight += hh * v2; @@ -273,7 +292,7 @@ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t*& bottomData, c scalar_t v3 = 0; if (hHigh <= height - 1 && wLow >= 0) { - const int ptr3 = hHighPtrOffset + wLowPtrOffset + basePtr; + int32_t const ptr3 = hHighPtrOffset + wLowPtrOffset + basePtr; v3 = bottomData[ptr3]; gradHWeight += hw * v3; gradWWeight -= lh * v3; @@ -282,190 +301,59 @@ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t*& bottomData, c scalar_t v4 = 0; if (hHigh <= height - 1 && wHigh <= width - 1) { - const int ptr4 = hHighPtrOffset + wHighPtrOffset + basePtr; + int32_t const ptr4 = hHighPtrOffset + wHighPtrOffset + basePtr; v4 = bottomData[ptr4]; gradHWeight += lw * v4; gradWWeight += lh * v4; atomicAdd(gradValue + ptr4, w4 * topGradvalue); } - const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + scalar_t const val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); atomicAdd(gradAttnWeight, topGrad * val); atomicAdd(gradSamplingLoc, width * gradWWeight * topGradvalue); atomicAdd(gradSamplingLoc + 1, height * gradHWeight * topGradvalue); } -#if 0 -template -__global__ void ms_deformable_im2col_gpu_kernel(const int n, - const scalar_t *dataValue, - const int32_t *dataSpatialShapes, - const int32_t *dataLevelStartIndex, - const scalar_t *dataSamplingLoc, - const scalar_t *dataAttnWeight, - const int batchSize, - const int spatialSize, - const int numHeads, - const int channels, - const int numLevels, - const int numQuery, - const int numPoint, - scalar_t *dataCol) -{ - CUDA_KERNEL_LOOP(index, n) - { - int _temp = index; - int cCol = _temp % channels; - _temp /= channels; - int samplingIndex = _temp; - int mCol = _temp % numHeads; - _temp /= numHeads; - _temp /= numQuery; - int bCol = _temp; - - scalar_t *dataColPtr = dataCol + index; - int dataWeightPtr = samplingIndex * numLevels * numPoint; - int dataLocWPtr = dataWeightPtr << 1; - int qidStride = numHeads * channels; - int dataValuePtrInitOffset = bCol * spatialSize * qidStride; - scalar_t col = 0; - - for (int lCol = 0; lCol < numLevels; ++lCol) - { - const int &spatialH = dataSpatialShapes[lCol << 1]; - const int &spatialW = dataSpatialShapes[lCol << 1 + 1]; - const scalar_t *dataValuePtr = dataValue + (dataValuePtrInitOffset + dataLevelStartIndex[lCol] * qidStride); - for (int pCol = 0; pCol < numPoint; ++pCol) - { - scalar_t locW = dataSamplingLoc[dataLocWPtr]; - scalar_t locH = dataSamplingLoc[dataLocWPtr + 1]; - scalar_t weight = dataAttnWeight[dataWeightPtr]; - - scalar_t hIm = locH * spatialH - 0.5; - scalar_t wIm = locW * spatialW - 0.5; - - if (hIm > -1 && hIm < spatialH && wIm > -1 && wIm < spatialW) - { - col += ms_deform_attn_im2col_bilinear(dataValuePtr, spatialH, spatialW, numHeads, channels, hIm, wIm, mCol, cCol) * weight; - } - - dataWeightPtr += 1; - dataLocWPtr += 2; - } - } - *dataColPtr = col; - } -} - -#if __CUDA_ARCH__ >= 530 -template <> -__global__ void ms_deformable_im2col_gpu_kernel<__half>(const int n, - const __half *dataValue, - const int32_t *dataSpatialShapes, - const int32_t *dataLevelStartIndex, - const __half *dataSamplingLoc, - const __half *dataAttnWeight, - const int batchSize, - const int spatialSize, - const int numHeads, - const int channels, - const int numLevels, - const int numQuery, - const int numPoint, - __half *dataCol) -{ - CUDA_KERNEL_LOOP(index, n) - { - int _temp = index; - int cCol = _temp % channels; - _temp /= channels; - int samplingIndex = _temp; - int mCol = _temp % numHeads; - _temp /= numHeads; - _temp /= numQuery; - int bCol = _temp; - - __half *dataColPtr = dataCol + index; - int dataWeightPtr = samplingIndex * numLevels * numPoint; - int dataLocWPtr = dataWeightPtr << 1; - int qidStride = numHeads * channels; - int dataValuePtrInitOffset = bCol * spatialSize * qidStride; - __half zeroPointFive = __float2half(0.5f); - __half minusOne = __float2half(-1.0f); - __half zero = __int2half_rz(0); - __half tpVal = zero; - __half col = zero; - - for (int lCol = 0; lCol < numLevels; ++lCol) - { - const int &spatialH = dataSpatialShapes[lCol << 1]; - const int &spatialW = dataSpatialShapes[lCol << 1 + 1]; - __half spatialHHalf = __int2half_rd(spatialH); - __half spatialWHalf = __int2half_rd(spatialW); - const __half *dataValuePtr = dataValue + (dataValuePtrInitOffset + dataLevelStartIndex[lCol] * qidStride); - for (int pCol = 0; pCol < numPoint; ++pCol) - { - __half locW = dataSamplingLoc[dataLocWPtr]; - __half locH = dataSamplingLoc[dataLocWPtr + 1]; - __half weight = dataAttnWeight[dataWeightPtr]; - - __half hIm = __hsub(__hmul(locH, spatialHHalf), zeroPointFive); - __half wIm = __hsub(__hmul(locW, spatialWHalf), zeroPointFive); - - if (__hgt(hIm, minusOne) && __hlt(hIm, spatialHHalf) && - __hgt(wIm, minusOne) && __hlt(wIm, spatialWHalf)) { - tpVal = ms_deform_attn_im2col_bilinear(dataValuePtr, spatialH, spatialW, numHeads, channels, hIm, wIm, mCol, cCol); - col = __hadd(col, __hmul(tpVal, weight)); - } - dataWeightPtr += 1; - dataLocWPtr += 2; - } - } - *dataColPtr = col; - } -} -#endif // CUDA_ARCH>=530 check -#endif #if 1 template -__global__ void ms_deformable_im2col_gpu_kernel(const int n, const scalar_t* dataValue, - const int32_t* dataSpatialShapes, const int32_t* dataLevelStartIndex, const scalar_t* dataSamplingLoc, - const scalar_t* dataAttnWeight, const int batchSize, const int spatialSize, const int numHeads, const int channels, - const int numLevels, const int numQuery, const int numPoint, scalar_t* dataCol) +__global__ void ms_deformable_im2col_gpu_kernel(int32_t const n, scalar_t const* dataValue, + int32_t const* dataSpatialShapes, int32_t const* dataLevelStartIndex, scalar_t const* dataSamplingLoc, + scalar_t const* dataAttnWeight, int32_t const batchSize, int32_t const spatialSize, int32_t const numHeads, int32_t const channels, + int32_t const numLevels, int32_t const numQuery, int32_t const numPoint, scalar_t* dataCol) { CUDA_KERNEL_LOOP(index, n) { - int _temp = index; - const int cCol = _temp % channels; + int32_t _temp = index; + int32_t const cCol = _temp % channels; _temp /= channels; - const int samplingIndex = _temp; - const int mCol = _temp % numHeads; + int32_t const samplingIndex = _temp; + int32_t const mCol = _temp % numHeads; _temp /= numHeads; _temp /= numQuery; - const int bCol = _temp; + int32_t const bCol = _temp; scalar_t* dataColPtr = dataCol + index; - int dataWeightPtr = samplingIndex * numLevels * numPoint; - int dataLocWPtr = dataWeightPtr << 1; - const int qidStride = numHeads * channels; - const int dataValuePtrInitOffset = bCol * spatialSize * qidStride; + int32_t dataWeightPtr = samplingIndex * numLevels * numPoint; + int32_t dataLocWPtr = dataWeightPtr << 1; + int32_t const qidStride = numHeads * channels; + int32_t const dataValuePtrInitOffset = bCol * spatialSize * qidStride; scalar_t col = 0; - for (int lCol = 0; lCol < numLevels; ++lCol) + for (int32_t lCol = 0; lCol < numLevels; ++lCol) { - const int levelStartId = dataLevelStartIndex[lCol]; - const int spatialHPtr = lCol << 1; - const int spatialH = dataSpatialShapes[spatialHPtr]; - const int spatialW = dataSpatialShapes[spatialHPtr + 1]; - const scalar_t* dataValuePtr = dataValue + (dataValuePtrInitOffset + levelStartId * qidStride); - for (int pCol = 0; pCol < numPoint; ++pCol) + int32_t const levelStartId = dataLevelStartIndex[lCol]; + int32_t const spatialHPtr = lCol << 1; + int32_t const spatialH = dataSpatialShapes[spatialHPtr]; + int32_t const spatialW = dataSpatialShapes[spatialHPtr + 1]; + scalar_t const* dataValuePtr = dataValue + (dataValuePtrInitOffset + levelStartId * qidStride); + for (int32_t pCol = 0; pCol < numPoint; ++pCol) { - const scalar_t locW = dataSamplingLoc[dataLocWPtr]; - const scalar_t locH = dataSamplingLoc[dataLocWPtr + 1]; - const scalar_t weight = dataAttnWeight[dataWeightPtr]; + scalar_t const locW = dataSamplingLoc[dataLocWPtr]; + scalar_t const locH = dataSamplingLoc[dataLocWPtr + 1]; + scalar_t const weight = dataAttnWeight[dataWeightPtr]; - const scalar_t hIm = locH * spatialH - 0.5; - const scalar_t wIm = locW * spatialW - 0.5; + scalar_t const hIm = locH * spatialH - 0.5; + scalar_t const wIm = locW * spatialW - 0.5; if (hIm > -1 && wIm > -1 && hIm < spatialH && wIm < spatialW) { @@ -482,61 +370,71 @@ __global__ void ms_deformable_im2col_gpu_kernel(const int n, const scalar_t* dat } } -#if __CUDA_ARCH__ >= 530 template <> -__global__ void ms_deformable_im2col_gpu_kernel<__half>(const int n, const __half* dataValue, - const int32_t* dataSpatialShapes, const int32_t* dataLevelStartIndex, const __half* dataSamplingLoc, - const __half* dataAttnWeight, const int batchSize, const int spatialSize, const int numHeads, const int channels, - const int numLevels, const int numQuery, const int numPoint, __half* dataCol) +__global__ void ms_deformable_im2col_gpu_kernel<__half>(int32_t const n, const __half* dataValue, + int32_t const* dataSpatialShapes, int32_t const* dataLevelStartIndex, const __half* dataSamplingLoc, + const __half* dataAttnWeight, int32_t const batchSize, int32_t const spatialSize, int32_t const numHeads, int32_t const channels, + int32_t const numLevels, int32_t const numQuery, int32_t const numPoint, __half* dataCol) { CUDA_KERNEL_LOOP(index, n) { - int _temp = index; - const int cCol = _temp % channels; + int32_t _temp = index; + int32_t const cCol = _temp % channels; _temp /= channels; - const int samplingIndex = _temp; - const int mCol = _temp % numHeads; + int32_t const samplingIndex = _temp; + int32_t const mCol = _temp % numHeads; _temp /= numHeads; _temp /= numQuery; - const int bCol = _temp; + int32_t const bCol = _temp; __half* dataColPtr = dataCol + index; - int dataWeightPtr = samplingIndex * numLevels * numPoint; - int dataLocWPtr = dataWeightPtr << 1; - const int qidStride = numHeads * channels; - const int dataValuePtrInitOffset = bCol * spatialSize * qidStride; - const __half zeroPointFive = __float2half(0.5f); - const __half minusOne = __float2half(-1.0f); - const __half zero = __int2half_rz(0); - __half tpVal = zero; - __half col = zero; - - for (int lCol = 0; lCol < numLevels; ++lCol) + int32_t dataWeightPtr = samplingIndex * numLevels * numPoint; + int32_t dataLocWPtr = dataWeightPtr << 1; + int32_t const qidStride = numHeads * channels; + int32_t const dataValuePtrInitOffset = bCol * spatialSize * qidStride; + const __half kZERO_POINT_FIVE = __float2half(0.5f); + const __half kMINUS_ONE = __float2half(-1.0f); + const __half kZERO = __int2half_rz(0); + __half tpVal = kZERO; + __half col = kZERO; + + for (int32_t lCol = 0; lCol < numLevels; ++lCol) { - const int levelStartId = dataLevelStartIndex[lCol]; - const int spatialHPtr = lCol << 1; - const int spatialH = dataSpatialShapes[spatialHPtr]; - const int spatialW = dataSpatialShapes[spatialHPtr + 1]; + int32_t const levelStartId = dataLevelStartIndex[lCol]; + int32_t const spatialHPtr = lCol << 1; + int32_t const spatialH = dataSpatialShapes[spatialHPtr]; + int32_t const spatialW = dataSpatialShapes[spatialHPtr + 1]; const __half spatialHHalf = __int2half_rd(spatialH); const __half spatialWHalf = __int2half_rd(spatialW); const __half* dataValuePtr = dataValue + (dataValuePtrInitOffset + levelStartId * qidStride); - for (int pCol = 0; pCol < numPoint; ++pCol) + for (int32_t pCol = 0; pCol < numPoint; ++pCol) { const __half locW = dataSamplingLoc[dataLocWPtr]; const __half locH = dataSamplingLoc[dataLocWPtr + 1]; const __half weight = dataAttnWeight[dataWeightPtr]; +#if __CUDA_ARCH__ >= 530 + const __half hIm = __hsub(__hmul(locH, spatialHHalf), kZERO_POINT_FIVE); + const __half wIm = __hsub(__hmul(locW, spatialWHalf), kZERO_POINT_FIVE); - const __half hIm = __hsub(__hmul(locH, spatialHHalf), zeroPointFive); - const __half wIm = __hsub(__hmul(locW, spatialWHalf), zeroPointFive); - - if (__hgt(hIm, minusOne) && __hgt(wIm, minusOne) && __hlt(hIm, spatialHHalf) + if (__hgt(hIm, kMINUS_ONE) && __hgt(wIm, kMINUS_ONE) && __hlt(hIm, spatialHHalf) && __hlt(wIm, spatialWHalf)) { tpVal = ms_deform_attn_im2col_bilinear( dataValuePtr, spatialH, spatialW, numHeads, channels, hIm, wIm, mCol, cCol); col = __hadd(col, __hmul(tpVal, weight)); } +#else + const __half hIm = __float2half(__half2float(locH) * __half2float(spatialHHalf) - __half2float(kZERO_POINT_FIVE)); + const __half wIm = __float2half(__half2float(locW) * __half2float(spatialWHalf) - __half2float(kZERO_POINT_FIVE)); + if((__half2float(hIm)>__half2float(kMINUS_ONE)) && (__half2float(wIm)>__half2float(kMINUS_ONE)) + && (__half2float(hIm)<__half2float(spatialHHalf)) && (__half2float(wIm)<__half2float(spatialWHalf))) + { + tpVal = ms_deform_attn_im2col_bilinear( + dataValuePtr, spatialH, spatialW, numHeads, channels, hIm, wIm, mCol, cCol); + col = __float2half(__half2float(col) + (__half2float(tpVal) * __half2float(weight))); + } +#endif dataWeightPtr += 1; dataLocWPtr += 2; } @@ -544,61 +442,60 @@ __global__ void ms_deformable_im2col_gpu_kernel<__half>(const int n, const __hal *dataColPtr = col; } } -#endif // CUDA_ARCH >=530 check #endif -template -__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, const scalar_t* grad_col, - const scalar_t* dataValue, const int32_t* dataSpatialShapes, const int32_t* dataLevelStartIndex, - const scalar_t* dataSamplingLoc, const scalar_t* dataAttnWeight, const int batchSize, const int spatialSize, - const int numHeads, const int channels, const int numLevels, const int numQuery, const int numPoint, +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(int32_t const n, scalar_t const* grad_col, + scalar_t const* dataValue, int32_t const* dataSpatialShapes, int32_t const* dataLevelStartIndex, + scalar_t const* dataSamplingLoc, scalar_t const* dataAttnWeight, int32_t const batchSize, int32_t const spatialSize, + int32_t const numHeads, int32_t const channels, int32_t const numLevels, int32_t const numQuery, int32_t const numPoint, scalar_t* gradValue, scalar_t* gradSamplingLoc, scalar_t* gradAttnWeight) { CUDA_KERNEL_LOOP(index, n) { __shared__ scalar_t cacheGradSamplingLoc[blockSize * 2]; __shared__ scalar_t cacheGradAttnWeight[blockSize]; - unsigned int tid = threadIdx.x; - int _temp = index; - const int cCol = _temp % channels; + uint32_t tid = threadIdx.x; + int32_t _temp = index; + int32_t const cCol = _temp % channels; _temp /= channels; - const int samplingIndex = _temp; - const int mCol = _temp % numHeads; + int32_t const samplingIndex = _temp; + int32_t const mCol = _temp % numHeads; _temp /= numHeads; - const int qCol = _temp % numQuery; + int32_t const qCol = _temp % numQuery; _temp /= numQuery; - const int bCol = _temp; + int32_t const bCol = _temp; - const scalar_t topGrad = grad_col[index]; + scalar_t const topGrad = grad_col[index]; - int dataWeightPtr = samplingIndex * numLevels * numPoint; - int dataLocWPtr = dataWeightPtr << 1; - const int gradSamplingPtr = dataWeightPtr; + int32_t dataWeightPtr = samplingIndex * numLevels * numPoint; + int32_t dataLocWPtr = dataWeightPtr << 1; + int32_t const gradSamplingPtr = dataWeightPtr; gradSamplingLoc += gradSamplingPtr << 1; gradAttnWeight += gradSamplingPtr; - const int gradWeightStride = 1; - const int gradLocStride = 2; - const int qidStride = numHeads * channels; - const int dataValuePtrInitOffset = bCol * spatialSize * qidStride; + int32_t const gradWeightStride = 1; + int32_t const gradLocStride = 2; + int32_t const qidStride = numHeads * channels; + int32_t const dataValuePtrInitOffset = bCol * spatialSize * qidStride; - for (int lCol = 0; lCol < numLevels; ++lCol) + for (int32_t lCol = 0; lCol < numLevels; ++lCol) { - const int levelStartId = dataLevelStartIndex[lCol]; - const int spatialHPtr = lCol << 1; - const int spatialH = dataSpatialShapes[spatialHPtr]; - const int spatialW = dataSpatialShapes[spatialHPtr + 1]; - const int valuePtrOffset = dataValuePtrInitOffset + levelStartId * qidStride; - const scalar_t* dataValuePtr = dataValue + valuePtrOffset; + int32_t const levelStartId = dataLevelStartIndex[lCol]; + int32_t const spatialHPtr = lCol << 1; + int32_t const spatialH = dataSpatialShapes[spatialHPtr]; + int32_t const spatialW = dataSpatialShapes[spatialHPtr + 1]; + int32_t const valuePtrOffset = dataValuePtrInitOffset + levelStartId * qidStride; + scalar_t const* dataValuePtr = dataValue + valuePtrOffset; scalar_t* gradValuePtr = gradValue + valuePtrOffset; - for (int pCol = 0; pCol < numPoint; ++pCol) + for (int32_t pCol = 0; pCol < numPoint; ++pCol) { - const scalar_t locW = dataSamplingLoc[dataLocWPtr]; - const scalar_t locH = dataSamplingLoc[dataLocWPtr + 1]; - const scalar_t weight = dataAttnWeight[dataWeightPtr]; + scalar_t const locW = dataSamplingLoc[dataLocWPtr]; + scalar_t const locH = dataSamplingLoc[dataLocWPtr + 1]; + scalar_t const weight = dataAttnWeight[dataWeightPtr]; - const scalar_t hIm = locH * spatialH - 0.5; - const scalar_t wIm = locW * spatialW - 0.5; + scalar_t const hIm = locH * spatialH - 0.5; + scalar_t const wIm = locW * spatialW - 0.5; *(cacheGradSamplingLoc + (threadIdx.x << 1)) = 0; *(cacheGradSamplingLoc + ((threadIdx.x << 1) + 1)) = 0; *(cacheGradAttnWeight + threadIdx.x) = 0; @@ -614,8 +511,8 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(co { scalar_t _gradW = cacheGradSamplingLoc[0], _gradH = cacheGradSamplingLoc[1], _gradA = cacheGradAttnWeight[0]; - int sid = 2; - for (unsigned int tid = 1; tid < blockSize; ++tid) + int32_t sid = 2; + for (uint32_t tid = 1; tid < blockSize; ++tid) { _gradW += cacheGradSamplingLoc[sid]; _gradH += cacheGradSamplingLoc[sid + 1]; @@ -638,58 +535,58 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(co } } -template -__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, const scalar_t* grad_col, - const scalar_t* dataValue, const int32_t* dataSpatialShapes, const int32_t* dataLevelStartIndex, - const scalar_t* dataSamplingLoc, const scalar_t* dataAttnWeight, const int batchSize, const int spatialSize, - const int numHeads, const int channels, const int numLevels, const int numQuery, const int numPoint, +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(int32_t const n, scalar_t const* grad_col, + scalar_t const* dataValue, int32_t const* dataSpatialShapes, int32_t const* dataLevelStartIndex, + scalar_t const* dataSamplingLoc, scalar_t const* dataAttnWeight, int32_t const batchSize, int32_t const spatialSize, + int32_t const numHeads, int32_t const channels, int32_t const numLevels, int32_t const numQuery, int32_t const numPoint, scalar_t* gradValue, scalar_t* gradSamplingLoc, scalar_t* gradAttnWeight) { CUDA_KERNEL_LOOP(index, n) { __shared__ scalar_t cacheGradSamplingLoc[blockSize * 2]; __shared__ scalar_t cacheGradAttnWeight[blockSize]; - unsigned int tid = threadIdx.x; - int _temp = index; - const int cCol = _temp % channels; + uint32_t tid = threadIdx.x; + int32_t _temp = index; + int32_t const cCol = _temp % channels; _temp /= channels; - const int samplingIndex = _temp; - const int mCol = _temp % numHeads; + int32_t const samplingIndex = _temp; + int32_t const mCol = _temp % numHeads; _temp /= numHeads; - const int qCol = _temp % numQuery; + int32_t const qCol = _temp % numQuery; _temp /= numQuery; - const int bCol = _temp; + int32_t const bCol = _temp; - const scalar_t topGrad = grad_col[index]; + scalar_t const topGrad = grad_col[index]; - int dataWeightPtr = samplingIndex * numLevels * numPoint; - int dataLocWPtr = dataWeightPtr << 1; - const int gradSamplingPtr = dataWeightPtr; + int32_t dataWeightPtr = samplingIndex * numLevels * numPoint; + int32_t dataLocWPtr = dataWeightPtr << 1; + int32_t const gradSamplingPtr = dataWeightPtr; gradSamplingLoc += gradSamplingPtr << 1; gradAttnWeight += gradSamplingPtr; - const int gradWeightStride = 1; - const int gradLocStride = 2; - const int qidStride = numHeads * channels; - const int dataValuePtrInitOffset = bCol * spatialSize * qidStride; + int32_t const gradWeightStride = 1; + int32_t const gradLocStride = 2; + int32_t const qidStride = numHeads * channels; + int32_t const dataValuePtrInitOffset = bCol * spatialSize * qidStride; - for (int lCol = 0; lCol < numLevels; ++lCol) + for (int32_t lCol = 0; lCol < numLevels; ++lCol) { - const int levelStartId = dataLevelStartIndex[lCol]; - const int spatialHPtr = lCol << 1; - const int spatialH = dataSpatialShapes[spatialHPtr]; - const int spatialW = dataSpatialShapes[spatialHPtr + 1]; - const int valuePtrOffset = dataValuePtrInitOffset + levelStartId * qidStride; - const scalar_t* dataValuePtr = dataValue + valuePtrOffset; + int32_t const levelStartId = dataLevelStartIndex[lCol]; + int32_t const spatialHPtr = lCol << 1; + int32_t const spatialH = dataSpatialShapes[spatialHPtr]; + int32_t const spatialW = dataSpatialShapes[spatialHPtr + 1]; + int32_t const valuePtrOffset = dataValuePtrInitOffset + levelStartId * qidStride; + scalar_t const* dataValuePtr = dataValue + valuePtrOffset; scalar_t* gradValuePtr = gradValue + valuePtrOffset; - for (int pCol = 0; pCol < numPoint; ++pCol) + for (int32_t pCol = 0; pCol < numPoint; ++pCol) { - const scalar_t locW = dataSamplingLoc[dataLocWPtr]; - const scalar_t locH = dataSamplingLoc[dataLocWPtr + 1]; - const scalar_t weight = dataAttnWeight[dataWeightPtr]; + scalar_t const locW = dataSamplingLoc[dataLocWPtr]; + scalar_t const locH = dataSamplingLoc[dataLocWPtr + 1]; + scalar_t const weight = dataAttnWeight[dataWeightPtr]; - const scalar_t hIm = locH * spatialH - 0.5; - const scalar_t wIm = locW * spatialW - 0.5; + scalar_t const hIm = locH * spatialH - 0.5; + scalar_t const wIm = locW * spatialW - 0.5; *(cacheGradSamplingLoc + (threadIdx.x << 1)) = 0; *(cacheGradSamplingLoc + ((threadIdx.x << 1) + 1)) = 0; *(cacheGradAttnWeight + threadIdx.x) = 0; @@ -702,12 +599,12 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(co __syncthreads(); - for (unsigned int s = blockSize / 2; s > 0; s >>= 1) + for (uint32_t s = blockSize / 2; s > 0; s >>= 1) { if (tid < s) { - const unsigned int xid1 = tid << 1; - const unsigned int xid2 = (tid + s) << 1; + uint32_t const xid1 = tid << 1; + uint32_t const xid2 = (tid + s) << 1; cacheGradAttnWeight[tid] += cacheGradAttnWeight[tid + s]; cacheGradSamplingLoc[xid1] += cacheGradSamplingLoc[xid2]; cacheGradSamplingLoc[xid1 + 1] += cacheGradSamplingLoc[xid2 + 1]; @@ -733,58 +630,58 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(co } template -__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, const scalar_t* grad_col, - const scalar_t* dataValue, const int32_t* dataSpatialShapes, const int32_t* dataLevelStartIndex, - const scalar_t* dataSamplingLoc, const scalar_t* dataAttnWeight, const int batchSize, const int spatialSize, - const int numHeads, const int channels, const int numLevels, const int numQuery, const int numPoint, +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(int32_t const n, scalar_t const* grad_col, + scalar_t const* dataValue, int32_t const* dataSpatialShapes, int32_t const* dataLevelStartIndex, + scalar_t const* dataSamplingLoc, scalar_t const* dataAttnWeight, int32_t const batchSize, int32_t const spatialSize, + int32_t const numHeads, int32_t const channels, int32_t const numLevels, int32_t const numQuery, int32_t const numPoint, scalar_t* gradValue, scalar_t* gradSamplingLoc, scalar_t* gradAttnWeight) { CUDA_KERNEL_LOOP(index, n) { - extern __shared__ int _s[]; + extern __shared__ int32_t _s[]; scalar_t* cacheGradSamplingLoc = (scalar_t*) _s; scalar_t* cacheGradAttnWeight = cacheGradSamplingLoc + 2 * blockDim.x; - unsigned int tid = threadIdx.x; - int _temp = index; - const int cCol = _temp % channels; + uint32_t tid = threadIdx.x; + int32_t _temp = index; + int32_t const cCol = _temp % channels; _temp /= channels; - const int samplingIndex = _temp; - const int mCol = _temp % numHeads; + int32_t const samplingIndex = _temp; + int32_t const mCol = _temp % numHeads; _temp /= numHeads; - const int qCol = _temp % numQuery; + int32_t const qCol = _temp % numQuery; _temp /= numQuery; - const int bCol = _temp; + int32_t const bCol = _temp; - const scalar_t topGrad = grad_col[index]; + scalar_t const topGrad = grad_col[index]; - int dataWeightPtr = samplingIndex * numLevels * numPoint; - int dataLocWPtr = dataWeightPtr << 1; - const int gradSamplingPtr = dataWeightPtr; + int32_t dataWeightPtr = samplingIndex * numLevels * numPoint; + int32_t dataLocWPtr = dataWeightPtr << 1; + int32_t const gradSamplingPtr = dataWeightPtr; gradSamplingLoc += gradSamplingPtr << 1; gradAttnWeight += gradSamplingPtr; - const int gradWeightStride = 1; - const int gradLocStride = 2; - const int qidStride = numHeads * channels; - const int dataValuePtrInitOffset = bCol * spatialSize * qidStride; + int32_t const gradWeightStride = 1; + int32_t const gradLocStride = 2; + int32_t const qidStride = numHeads * channels; + int32_t const dataValuePtrInitOffset = bCol * spatialSize * qidStride; - for (int lCol = 0; lCol < numLevels; ++lCol) + for (int32_t lCol = 0; lCol < numLevels; ++lCol) { - const int levelStartId = dataLevelStartIndex[lCol]; - const int spatialHPtr = lCol << 1; - const int spatialH = dataSpatialShapes[spatialHPtr]; - const int spatialW = dataSpatialShapes[spatialHPtr + 1]; - const int valuePtrOffset = dataValuePtrInitOffset + levelStartId * qidStride; - const scalar_t* dataValuePtr = dataValue + valuePtrOffset; + int32_t const levelStartId = dataLevelStartIndex[lCol]; + int32_t const spatialHPtr = lCol << 1; + int32_t const spatialH = dataSpatialShapes[spatialHPtr]; + int32_t const spatialW = dataSpatialShapes[spatialHPtr + 1]; + int32_t const valuePtrOffset = dataValuePtrInitOffset + levelStartId * qidStride; + scalar_t const* dataValuePtr = dataValue + valuePtrOffset; scalar_t* gradValuePtr = gradValue + valuePtrOffset; - for (int pCol = 0; pCol < numPoint; ++pCol) + for (int32_t pCol = 0; pCol < numPoint; ++pCol) { - const scalar_t locW = dataSamplingLoc[dataLocWPtr]; - const scalar_t locH = dataSamplingLoc[dataLocWPtr + 1]; - const scalar_t weight = dataAttnWeight[dataWeightPtr]; + scalar_t const locW = dataSamplingLoc[dataLocWPtr]; + scalar_t const locH = dataSamplingLoc[dataLocWPtr + 1]; + scalar_t const weight = dataAttnWeight[dataWeightPtr]; - const scalar_t hIm = locH * spatialH - 0.5; - const scalar_t wIm = locW * spatialW - 0.5; + scalar_t const hIm = locH * spatialH - 0.5; + scalar_t const wIm = locW * spatialW - 0.5; *(cacheGradSamplingLoc + (threadIdx.x << 1)) = 0; *(cacheGradSamplingLoc + ((threadIdx.x << 1) + 1)) = 0; *(cacheGradAttnWeight + threadIdx.x) = 0; @@ -800,8 +697,8 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, const { scalar_t _gradW = cacheGradSamplingLoc[0], _gradH = cacheGradSamplingLoc[1], _gradA = cacheGradAttnWeight[0]; - int sid = 2; - for (unsigned int tid = 1; tid < blockDim.x; ++tid) + int32_t sid = 2; + for (uint32_t tid = 1; tid < blockDim.x; ++tid) { _gradW += cacheGradSamplingLoc[sid]; _gradH += cacheGradSamplingLoc[sid + 1]; @@ -825,58 +722,58 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, const } template -__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, const scalar_t* grad_col, - const scalar_t* dataValue, const int32_t* dataSpatialShapes, const int32_t* dataLevelStartIndex, - const scalar_t* dataSamplingLoc, const scalar_t* dataAttnWeight, const int batchSize, const int spatialSize, - const int numHeads, const int channels, const int numLevels, const int numQuery, const int numPoint, +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(int32_t const n, scalar_t const* grad_col, + scalar_t const* dataValue, int32_t const* dataSpatialShapes, int32_t const* dataLevelStartIndex, + scalar_t const* dataSamplingLoc, scalar_t const* dataAttnWeight, int32_t const batchSize, int32_t const spatialSize, + int32_t const numHeads, int32_t const channels, int32_t const numLevels, int32_t const numQuery, int32_t const numPoint, scalar_t* gradValue, scalar_t* gradSamplingLoc, scalar_t* gradAttnWeight) { CUDA_KERNEL_LOOP(index, n) { - extern __shared__ int _s[]; + extern __shared__ int32_t _s[]; scalar_t* cacheGradSamplingLoc = (scalar_t*) _s; scalar_t* cacheGradAttnWeight = cacheGradSamplingLoc + 2 * blockDim.x; - unsigned int tid = threadIdx.x; - int _temp = index; - const int cCol = _temp % channels; + uint32_t tid = threadIdx.x; + int32_t _temp = index; + int32_t const cCol = _temp % channels; _temp /= channels; - const int samplingIndex = _temp; - const int mCol = _temp % numHeads; + int32_t const samplingIndex = _temp; + int32_t const mCol = _temp % numHeads; _temp /= numHeads; - const int qCol = _temp % numQuery; + int32_t const qCol = _temp % numQuery; _temp /= numQuery; - const int bCol = _temp; + int32_t const bCol = _temp; - const scalar_t topGrad = grad_col[index]; + scalar_t const topGrad = grad_col[index]; - int dataWeightPtr = samplingIndex * numLevels * numPoint; - int dataLocWPtr = dataWeightPtr << 1; - const int gradSamplingPtr = dataWeightPtr; + int32_t dataWeightPtr = samplingIndex * numLevels * numPoint; + int32_t dataLocWPtr = dataWeightPtr << 1; + int32_t const gradSamplingPtr = dataWeightPtr; gradSamplingLoc += gradSamplingPtr << 1; gradAttnWeight += gradSamplingPtr; - const int gradWeightStride = 1; - const int gradLocStride = 2; - const int qidStride = numHeads * channels; - const int dataValuePtrInitOffset = bCol * spatialSize * qidStride; + int32_t const gradWeightStride = 1; + int32_t const gradLocStride = 2; + int32_t const qidStride = numHeads * channels; + int32_t const dataValuePtrInitOffset = bCol * spatialSize * qidStride; - for (int lCol = 0; lCol < numLevels; ++lCol) + for (int32_t lCol = 0; lCol < numLevels; ++lCol) { - const int levelStartId = dataLevelStartIndex[lCol]; - const int spatialHPtr = lCol << 1; - const int spatialH = dataSpatialShapes[spatialHPtr]; - const int spatialW = dataSpatialShapes[spatialHPtr + 1]; - const int valuePtrOffset = dataValuePtrInitOffset + levelStartId * qidStride; - const scalar_t* dataValuePtr = dataValue + valuePtrOffset; + int32_t const levelStartId = dataLevelStartIndex[lCol]; + int32_t const spatialHPtr = lCol << 1; + int32_t const spatialH = dataSpatialShapes[spatialHPtr]; + int32_t const spatialW = dataSpatialShapes[spatialHPtr + 1]; + int32_t const valuePtrOffset = dataValuePtrInitOffset + levelStartId * qidStride; + scalar_t const* dataValuePtr = dataValue + valuePtrOffset; scalar_t* gradValuePtr = gradValue + valuePtrOffset; - for (int pCol = 0; pCol < numPoint; ++pCol) + for (int32_t pCol = 0; pCol < numPoint; ++pCol) { - const scalar_t locW = dataSamplingLoc[dataLocWPtr]; - const scalar_t locH = dataSamplingLoc[dataLocWPtr + 1]; - const scalar_t weight = dataAttnWeight[dataWeightPtr]; + scalar_t const locW = dataSamplingLoc[dataLocWPtr]; + scalar_t const locH = dataSamplingLoc[dataLocWPtr + 1]; + scalar_t const weight = dataAttnWeight[dataWeightPtr]; - const scalar_t hIm = locH * spatialH - 0.5; - const scalar_t wIm = locW * spatialW - 0.5; + scalar_t const hIm = locH * spatialH - 0.5; + scalar_t const wIm = locW * spatialW - 0.5; *(cacheGradSamplingLoc + (threadIdx.x << 1)) = 0; *(cacheGradSamplingLoc + ((threadIdx.x << 1) + 1)) = 0; *(cacheGradAttnWeight + threadIdx.x) = 0; @@ -889,12 +786,12 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, const __syncthreads(); - for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0; s >>= 1, spre >>= 1) + for (uint32_t s = blockDim.x / 2, spre = blockDim.x; s > 0; s >>= 1, spre >>= 1) { if (tid < s) { - const unsigned int xid1 = tid << 1; - const unsigned int xid2 = (tid + s) << 1; + uint32_t const xid1 = tid << 1; + uint32_t const xid2 = (tid + s) << 1; cacheGradAttnWeight[tid] += cacheGradAttnWeight[tid + s]; cacheGradSamplingLoc[xid1] += cacheGradSamplingLoc[xid2]; cacheGradSamplingLoc[xid1 + 1] += cacheGradSamplingLoc[xid2 + 1]; @@ -926,58 +823,58 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, const } template -__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, const scalar_t* grad_col, - const scalar_t* dataValue, const int32_t* dataSpatialShapes, const int32_t* dataLevelStartIndex, - const scalar_t* dataSamplingLoc, const scalar_t* dataAttnWeight, const int batchSize, const int spatialSize, - const int numHeads, const int channels, const int numLevels, const int numQuery, const int numPoint, +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(int32_t const n, scalar_t const* grad_col, + scalar_t const* dataValue, int32_t const* dataSpatialShapes, int32_t const* dataLevelStartIndex, + scalar_t const* dataSamplingLoc, scalar_t const* dataAttnWeight, int32_t const batchSize, int32_t const spatialSize, + int32_t const numHeads, int32_t const channels, int32_t const numLevels, int32_t const numQuery, int32_t const numPoint, scalar_t* gradValue, scalar_t* gradSamplingLoc, scalar_t* gradAttnWeight) { CUDA_KERNEL_LOOP(index, n) { - extern __shared__ int _s[]; + extern __shared__ int32_t _s[]; scalar_t* cacheGradSamplingLoc = (scalar_t*) _s; scalar_t* cacheGradAttnWeight = cacheGradSamplingLoc + 2 * blockDim.x; - unsigned int tid = threadIdx.x; - int _temp = index; - const int cCol = _temp % channels; + uint32_t tid = threadIdx.x; + int32_t _temp = index; + int32_t const cCol = _temp % channels; _temp /= channels; - const int samplingIndex = _temp; - const int mCol = _temp % numHeads; + int32_t const samplingIndex = _temp; + int32_t const mCol = _temp % numHeads; _temp /= numHeads; - const int qCol = _temp % numQuery; + int32_t const qCol = _temp % numQuery; _temp /= numQuery; - const int bCol = _temp; + int32_t const bCol = _temp; - const scalar_t topGrad = grad_col[index]; + scalar_t const topGrad = grad_col[index]; - int dataWeightPtr = samplingIndex * numLevels * numPoint; - int dataLocWPtr = dataWeightPtr << 1; - const int gradSamplingPtr = dataWeightPtr; + int32_t dataWeightPtr = samplingIndex * numLevels * numPoint; + int32_t dataLocWPtr = dataWeightPtr << 1; + int32_t const gradSamplingPtr = dataWeightPtr; gradSamplingLoc += gradSamplingPtr << 1; gradAttnWeight += gradSamplingPtr; - const int gradWeightStride = 1; - const int gradLocStride = 2; - const int qidStride = numHeads * channels; - const int dataValuePtrInitOffset = bCol * spatialSize * qidStride; + int32_t const gradWeightStride = 1; + int32_t const gradLocStride = 2; + int32_t const qidStride = numHeads * channels; + int32_t const dataValuePtrInitOffset = bCol * spatialSize * qidStride; - for (int lCol = 0; lCol < numLevels; ++lCol) + for (int32_t lCol = 0; lCol < numLevels; ++lCol) { - const int levelStartId = dataLevelStartIndex[lCol]; - const int spatialHPtr = lCol << 1; - const int spatialH = dataSpatialShapes[spatialHPtr]; - const int spatialW = dataSpatialShapes[spatialHPtr + 1]; - const int valuePtrOffset = dataValuePtrInitOffset + levelStartId * qidStride; - const scalar_t* dataValuePtr = dataValue + valuePtrOffset; + int32_t const levelStartId = dataLevelStartIndex[lCol]; + int32_t const spatialHPtr = lCol << 1; + int32_t const spatialH = dataSpatialShapes[spatialHPtr]; + int32_t const spatialW = dataSpatialShapes[spatialHPtr + 1]; + int32_t const valuePtrOffset = dataValuePtrInitOffset + levelStartId * qidStride; + scalar_t const* dataValuePtr = dataValue + valuePtrOffset; scalar_t* gradValuePtr = gradValue + valuePtrOffset; - for (int pCol = 0; pCol < numPoint; ++pCol) + for (int32_t pCol = 0; pCol < numPoint; ++pCol) { - const scalar_t locW = dataSamplingLoc[dataLocWPtr]; - const scalar_t locH = dataSamplingLoc[dataLocWPtr + 1]; - const scalar_t weight = dataAttnWeight[dataWeightPtr]; + scalar_t const locW = dataSamplingLoc[dataLocWPtr]; + scalar_t const locH = dataSamplingLoc[dataLocWPtr + 1]; + scalar_t const weight = dataAttnWeight[dataWeightPtr]; - const scalar_t hIm = locH * spatialH - 0.5; - const scalar_t wIm = locW * spatialW - 0.5; + scalar_t const hIm = locH * spatialH - 0.5; + scalar_t const wIm = locW * spatialW - 0.5; *(cacheGradSamplingLoc + (threadIdx.x << 1)) = 0; *(cacheGradSamplingLoc + ((threadIdx.x << 1) + 1)) = 0; *(cacheGradAttnWeight + threadIdx.x) = 0; @@ -990,12 +887,12 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const __syncthreads(); - for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0; s >>= 1, spre >>= 1) + for (uint32_t s = blockDim.x / 2, spre = blockDim.x; s > 0; s >>= 1, spre >>= 1) { if (tid < s) { - const unsigned int xid1 = tid << 1; - const unsigned int xid2 = (tid + s) << 1; + uint32_t const xid1 = tid << 1; + uint32_t const xid2 = (tid + s) << 1; cacheGradAttnWeight[tid] += cacheGradAttnWeight[tid + s]; cacheGradSamplingLoc[xid1] += cacheGradSamplingLoc[xid2]; cacheGradSamplingLoc[xid1 + 1] += cacheGradSamplingLoc[xid2 + 1]; @@ -1027,54 +924,54 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const } template -__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, const scalar_t* grad_col, const scalar_t* dataValue, - const int32_t* dataSpatialShapes, const int32_t* dataLevelStartIndex, const scalar_t* dataSamplingLoc, - const scalar_t* dataAttnWeight, const int batchSize, const int spatialSize, const int numHeads, const int channels, - const int numLevels, const int numQuery, const int numPoint, scalar_t* gradValue, scalar_t* gradSamplingLoc, +__global__ void ms_deformable_col2im_gpu_kernel_gm(int32_t const n, scalar_t const* grad_col, scalar_t const* dataValue, + int32_t const* dataSpatialShapes, int32_t const* dataLevelStartIndex, scalar_t const* dataSamplingLoc, + scalar_t const* dataAttnWeight, int32_t const batchSize, int32_t const spatialSize, int32_t const numHeads, int32_t const channels, + int32_t const numLevels, int32_t const numQuery, int32_t const numPoint, scalar_t* gradValue, scalar_t* gradSamplingLoc, scalar_t* gradAttnWeight) { CUDA_KERNEL_LOOP(index, n) { - int _temp = index; - const int cCol = _temp % channels; + int32_t _temp = index; + int32_t const cCol = _temp % channels; _temp /= channels; - const int samplingIndex = _temp; - const int mCol = _temp % numHeads; + int32_t const samplingIndex = _temp; + int32_t const mCol = _temp % numHeads; _temp /= numHeads; - const int qCol = _temp % numQuery; + int32_t const qCol = _temp % numQuery; _temp /= numQuery; - const int bCol = _temp; + int32_t const bCol = _temp; - const scalar_t topGrad = grad_col[index]; + scalar_t const topGrad = grad_col[index]; - int dataWeightPtr = samplingIndex * numLevels * numPoint; - int dataLocWPtr = dataWeightPtr << 1; - const int gradSamplingPtr = dataWeightPtr; + int32_t dataWeightPtr = samplingIndex * numLevels * numPoint; + int32_t dataLocWPtr = dataWeightPtr << 1; + int32_t const gradSamplingPtr = dataWeightPtr; gradSamplingLoc += gradSamplingPtr << 1; gradAttnWeight += gradSamplingPtr; - const int gradWeightStride = 1; - const int gradLocStride = 2; - const int qidStride = numHeads * channels; - const int dataValuePtrInitOffset = bCol * spatialSize * qidStride; + int32_t const gradWeightStride = 1; + int32_t const gradLocStride = 2; + int32_t const qidStride = numHeads * channels; + int32_t const dataValuePtrInitOffset = bCol * spatialSize * qidStride; - for (int lCol = 0; lCol < numLevels; ++lCol) + for (int32_t lCol = 0; lCol < numLevels; ++lCol) { - const int levelStartId = dataLevelStartIndex[lCol]; - const int spatialHPtr = lCol << 1; - const int spatialH = dataSpatialShapes[spatialHPtr]; - const int spatialW = dataSpatialShapes[spatialHPtr + 1]; - const int valuePtrOffset = dataValuePtrInitOffset + levelStartId * qidStride; - const scalar_t* dataValuePtr = dataValue + valuePtrOffset; + int32_t const levelStartId = dataLevelStartIndex[lCol]; + int32_t const spatialHPtr = lCol << 1; + int32_t const spatialH = dataSpatialShapes[spatialHPtr]; + int32_t const spatialW = dataSpatialShapes[spatialHPtr + 1]; + int32_t const valuePtrOffset = dataValuePtrInitOffset + levelStartId * qidStride; + scalar_t const* dataValuePtr = dataValue + valuePtrOffset; scalar_t* gradValuePtr = gradValue + valuePtrOffset; - for (int pCol = 0; pCol < numPoint; ++pCol) + for (int32_t pCol = 0; pCol < numPoint; ++pCol) { - const scalar_t locW = dataSamplingLoc[dataLocWPtr]; - const scalar_t locH = dataSamplingLoc[dataLocWPtr + 1]; - const scalar_t weight = dataAttnWeight[dataWeightPtr]; + scalar_t const locW = dataSamplingLoc[dataLocWPtr]; + scalar_t const locH = dataSamplingLoc[dataLocWPtr + 1]; + scalar_t const weight = dataAttnWeight[dataWeightPtr]; - const scalar_t hIm = locH * spatialH - 0.5; - const scalar_t wIm = locW * spatialW - 0.5; + scalar_t const hIm = locH * spatialH - 0.5; + scalar_t const wIm = locW * spatialW - 0.5; if (hIm > -1 && wIm > -1 && hIm < spatialH && wIm < spatialW) { ms_deform_attn_col2im_bilinear_gm(dataValuePtr, spatialH, spatialW, numHeads, channels, hIm, wIm, @@ -1090,14 +987,14 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, const scalar_t* } template -void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t* dataValue, const int32_t* dataSpatialShapes, - const int32_t* dataLevelStartIndex, const scalar_t* dataSamplingLoc, const scalar_t* dataAttnWeight, - const int batchSize, const int spatialSize, const int numHeads, const int channels, const int numLevels, - const int numQuery, const int numPoint, scalar_t* dataCol) +void ms_deformable_im2col_cuda(cudaStream_t stream, scalar_t const* dataValue, int32_t const* dataSpatialShapes, + int32_t const* dataLevelStartIndex, scalar_t const* dataSamplingLoc, scalar_t const* dataAttnWeight, + int32_t const batchSize, int32_t const spatialSize, int32_t const numHeads, int32_t const channels, int32_t const numLevels, + int32_t const numQuery, int32_t const numPoint, scalar_t* dataCol) { - const int numKernels = batchSize * numQuery * numHeads * channels; - const int numActualKernels = batchSize * numQuery * numHeads * channels; - const int numThreads = CUDA_NUM_THREADS; + int32_t const numKernels = batchSize * numQuery * numHeads * channels; + int32_t const numActualKernels = batchSize * numQuery * numHeads * channels; + int32_t const numThreads = CUDA_NUM_THREADS; cudaError_t err = cudaSuccess; ms_deformable_im2col_gpu_kernel<<>>( @@ -1106,20 +1003,20 @@ void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t* dataValue, c err = cudaGetLastError(); if (err != cudaSuccess) { - printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + nvinfer1::plugin::gLogError << "error in ms_deformable_im2col_cuda: " << cudaGetErrorString(err) << std::endl; } } template -void ms_deformable_col2im_cuda(cudaStream_t stream, const scalar_t* grad_col, const scalar_t* dataValue, - const int32_t* dataSpatialShapes, const int32_t* dataLevelStartIndex, const scalar_t* dataSamplingLoc, - const scalar_t* dataAttnWeight, const int batchSize, const int spatialSize, const int numHeads, const int channels, - const int numLevels, const int numQuery, const int numPoint, scalar_t* gradValue, scalar_t* gradSamplingLoc, +void ms_deformable_col2im_cuda(cudaStream_t stream, scalar_t const* grad_col, scalar_t const* dataValue, + int32_t const* dataSpatialShapes, int32_t const* dataLevelStartIndex, scalar_t const* dataSamplingLoc, + scalar_t const* dataAttnWeight, int32_t const batchSize, int32_t const spatialSize, int32_t const numHeads, int32_t const channels, + int32_t const numLevels, int32_t const numQuery, int32_t const numPoint, scalar_t* gradValue, scalar_t* gradSamplingLoc, scalar_t* gradAttnWeight) { - const int numThreads = (channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels; - const int numKernels = batchSize * numQuery * numHeads * channels; - const int numActualKernels = batchSize * numQuery * numHeads * channels; + int32_t const numThreads = (channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels; + int32_t const numKernels = batchSize * numQuery * numHeads * channels; + int32_t const numActualKernels = batchSize * numQuery * numHeads * channels; if (channels > 1024) { if ((channels & 1023) == 0) @@ -1228,16 +1125,16 @@ void ms_deformable_col2im_cuda(cudaStream_t stream, const scalar_t* grad_col, co cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { - printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + nvinfer1::plugin::gLogError << "error in ms_deformable_col2im_cuda: " << cudaGetErrorString(err) << std::endl; } } #define CUDA_KERNEL_LOOP_RANGE(tid, nDataMin, nDataMax) \ - for (int tid = blockIdx.x * blockDim.x + threadIdx.x; ((tid >= (nDataMin)) && (tid < (nDataMax))); \ + for (int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; ((tid >= (nDataMin)) && (tid < (nDataMax))); \ tid += blockDim.x * gridDim.x) -__global__ void float2half_input(const int nData1, const int nData2, const int nData3, const float* data1Float, - const float* data2Float, const float* data3Float, __half* data1Half, __half* data2Half, __half* data3Half) +__global__ void float2half_input(int32_t const nData1, int32_t const nData2, int32_t const nData3, float const* data1Float, + float const* data2Float, float const* data3Float, __half* data1Half, __half* data2Half, __half* data3Half) { CUDA_KERNEL_LOOP(index, nData1) { @@ -1258,7 +1155,7 @@ __global__ void float2half_input(const int nData1, const int nData2, const int n } } -__global__ void half2float_output(const int n_data, const __half* data_half, float* data_float) +__global__ void half2float_output(int32_t const n_data, const __half* data_half, float* data_float) { CUDA_KERNEL_LOOP(index, n_data) { @@ -1266,4 +1163,4 @@ __global__ void half2float_output(const int n_data, const __half* data_half, flo } } -#endif \ No newline at end of file +#endif