Skip to content

Commit

Permalink
DMHA plugin refactoring
Browse files Browse the repository at this point in the history
Signed-off-by: Rajeev Rao <rajeevrao@nvidia.com>
  • Loading branch information
rajeevsrao committed Jun 9, 2022
1 parent 133dec2 commit c3618fd
Show file tree
Hide file tree
Showing 6 changed files with 540 additions and 638 deletions.
2 changes: 1 addition & 1 deletion plugin/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
| [leakyReluPlugin](leakyReluPlugin) | LReLU_TRT | 1 |
| [multilevelCropAndResizePlugin](multilevelCropAndResizePlugin) | MultilevelCropAndResize_TRT | 1 |
| [multilevelProposeROI](multilevelProposeROI) | MultilevelProposeROI_TRT | 1 |
| [multiscaleDeformableAttnPlugin](multiscaleDeformableAttnPlugin) | DMHA | 1 |
| [multiscaleDeformableAttnPlugin](multiscaleDeformableAttnPlugin) | MultiscaleDeformableAttnPlugin_TRT | 1 |
| [nmsPlugin](nmsPlugin) | NMS_TRT | 1 |
| [normalizePlugin](normalizePlugin) | Normalize_TRT | 1 |
| [nvFasterRCNN](nvFasterRCNN) | RPROI_TRT | 1 |
Expand Down
24 changes: 10 additions & 14 deletions plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableAttn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,17 @@

#include "multiscaleDeformableIm2ColCuda.cuh"

int ms_deform_attn_cuda_forward(cudaStream_t stream, const float* value, const int32_t* spatialShapes,
const int32_t* levelStartIndex, const float* samplingLoc, const float* attnWeight, float* output, int batch,
int mSpatialSize, int mNumHeads, int mChannels, int mNumLevels, int mNumQuery, int mNumPoint)
int32_t ms_deform_attn_cuda_forward(cudaStream_t stream, const float* value, const int32_t* spatialShapes,
const int32_t* levelStartIndex, const float* samplingLoc, const float* attnWeight, float* output, int32_t batch,
int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint)
{
auto perValueSize = mSpatialSize * mNumHeads * mChannels;
auto perSampleLocSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint * 2;
auto perAttnWeightSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint;

int mIm2colStep = batch;
int32_t mIm2colStep = batch;

for (int n = 0; n < batch / mIm2colStep; ++n)
for (int32_t n = 0; n < batch / mIm2colStep; ++n)
{
auto columns = output + perValueSize * n * mIm2colStep;
ms_deformable_im2col_cuda<float>(stream, value + n * mIm2colStep * perValueSize, spatialShapes, levelStartIndex,
Expand All @@ -57,18 +57,16 @@ int ms_deform_attn_cuda_forward(cudaStream_t stream, const float* value, const i
return 0;
}

#if __CUDA_ARCH__ >= 530

int ms_deform_attn_cuda_forward(cudaStream_t stream, const __half* value, const int32_t* spatialShapes,
const int32_t* levelStartIndex, const __half* samplingLoc, const __half* attnWeight, __half* output, int batch,
int mSpatialSize, int mNumHeads, int mChannels, int mNumLevels, int mNumQuery, int mNumPoint)
int32_t ms_deform_attn_cuda_forward(cudaStream_t stream, const __half* value, const int32_t* spatialShapes,
const int32_t* levelStartIndex, const __half* samplingLoc, const __half* attnWeight, __half* output, int32_t batch,
int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint)
{
auto perValueSize = mSpatialSize * mNumHeads * mChannels;
auto perSampleLocSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint * 2;
auto perAttnWeightSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint;

int mIm2colStep = batch;
for (int n = 0; n < batch / mIm2colStep; ++n)
int32_t mIm2colStep = batch;
for (int32_t n = 0; n < batch / mIm2colStep; ++n)
{
auto columns = output + perValueSize * n * mIm2colStep;
ms_deformable_im2col_cuda<__half>(stream, value + n * mIm2colStep * perValueSize, spatialShapes,
Expand All @@ -79,5 +77,3 @@ int ms_deform_attn_cuda_forward(cudaStream_t stream, const __half* value, const

return 0;
}

#endif
15 changes: 7 additions & 8 deletions plugin/multiscaleDeformableAttnPlugin/multiscaleDeformableAttn.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,13 @@
#ifndef TRT_MULTISCALE_DEFORMABLE_ATTN_H
#define TRT_MULTISCALE_DEFORMABLE_ATTN_H

int ms_deform_attn_cuda_forward(cudaStream_t stream, const float* value, const int32_t* spatialShapes,
const int32_t* levelStartIndex, const float* samplingLoc, const float* attnWeight, float* output, int batch,
int mSpatialSize, int mNumHeads, int mChannels, int mNumLevels, int mNumQuery, int mNumPoint);
int32_t ms_deform_attn_cuda_forward(cudaStream_t stream, float const* value, int32_t const* spatialShapes,
int32_t const* levelStartIndex, float const* samplingLoc, float const* attnWeight, float* output, int32_t batch,
int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint);

#if __CUDA_ARCH__ >= 530
int ms_deform_attn_cuda_forward(cudaStream_t stream, const __half* value, const int32_t* spatialShapes,
const int32_t* levelStartIndex, const __half* samplingLoc, const __half* attnWeight, __half* output, int batch,
int mSpatialSize, int mNumHeads, int mChannels, int mNumLevels, int mNumQuery, int mNumPoint);
#endif

int32_t ms_deform_attn_cuda_forward(cudaStream_t stream, const __half* value, int32_t const* spatialShapes,
int32_t const* levelStartIndex, const __half* samplingLoc, const __half* attnWeight, __half* output, int32_t batch,
int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint);

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,35 @@ namespace plugin

namespace
{
static const char* DMHA_VERSION{"1"};
static const char* DMHA_NAME{"DMHA"};
static char const* DMHA_VERSION{"1"};
static char const* DMHA_NAME{"MultiscaleDeformableAttnPlugin_TRT"};
} // namespace

MultiscaleDeformableAttnPlugin::MultiscaleDeformableAttnPlugin(const std::string& name)
: mLayerName(name)
MultiscaleDeformableAttnPlugin::MultiscaleDeformableAttnPlugin()
{
}

MultiscaleDeformableAttnPlugin::MultiscaleDeformableAttnPlugin(const std::string& name, const void* data, size_t length)
: mLayerName(name)
MultiscaleDeformableAttnPlugin::MultiscaleDeformableAttnPlugin(void const* data, size_t length)
{
}

nvinfer1::IPluginV2DynamicExt* MultiscaleDeformableAttnPlugin::clone() const PLUGIN_NOEXCEPT
{
MultiscaleDeformableAttnPlugin* plugin = new MultiscaleDeformableAttnPlugin(mLayerName);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
try
{
MultiscaleDeformableAttnPlugin* plugin = new MultiscaleDeformableAttnPlugin();
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
catch (const std::exception& e)
{
caughtError(e);
}
return nullptr;
}

nvinfer1::DimsExprs MultiscaleDeformableAttnPlugin::getOutputDimensions(int outputIndex,
const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) PLUGIN_NOEXCEPT
nvinfer1::DimsExprs MultiscaleDeformableAttnPlugin::getOutputDimensions(int32_t outputIndex,
nvinfer1::DimsExprs const* inputs, int32_t nbInputs, nvinfer1::IExprBuilder& exprBuilder) PLUGIN_NOEXCEPT
{
nvinfer1::DimsExprs ret;
ret.nbDims = 4;
Expand All @@ -63,7 +69,7 @@ nvinfer1::DimsExprs MultiscaleDeformableAttnPlugin::getOutputDimensions(int outp
}

bool MultiscaleDeformableAttnPlugin::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) PLUGIN_NOEXCEPT
int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) PLUGIN_NOEXCEPT
{
ASSERT((nbInputs == 5));
ASSERT((nbOutputs == 1));
Expand All @@ -76,12 +82,8 @@ bool MultiscaleDeformableAttnPlugin::supportsFormatCombination(
}
else
{
#if __CUDA_ARCH__ >= 530
return ((inOut[pos].type == inOut[0].type)
&& ((inOut[pos].type == nvinfer1::DataType::kFLOAT) || (inOut[pos].type == nvinfer1::DataType::kHALF)));
#else
return ((inOut[pos].type == inOut[0].type) && ((inOut[pos].type == nvinfer1::DataType::kFLOAT)));
#endif
return ((inOut[pos].type == inOut[0].type) &&
((inOut[pos].type == nvinfer1::DataType::kFLOAT) || (inOut[pos].type == nvinfer1::DataType::kHALF)));
}
}
else
Expand All @@ -90,8 +92,8 @@ bool MultiscaleDeformableAttnPlugin::supportsFormatCombination(
}
}

void MultiscaleDeformableAttnPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* outputs, int nbOutputs) PLUGIN_NOEXCEPT
void MultiscaleDeformableAttnPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) PLUGIN_NOEXCEPT
{
// Check for valid input dimensions
ASSERT(inputs[0].desc.dims.nbDims==4);
Expand All @@ -116,86 +118,83 @@ void MultiscaleDeformableAttnPlugin::configurePlugin(const nvinfer1::DynamicPlug
ASSERT(inputs[3].desc.dims.d[1] == inputs[4].desc.dims.d[1]);
}

size_t MultiscaleDeformableAttnPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const PLUGIN_NOEXCEPT
size_t MultiscaleDeformableAttnPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs,
nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const PLUGIN_NOEXCEPT
{
return 0;
}

int MultiscaleDeformableAttnPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workSpace,
int32_t MultiscaleDeformableAttnPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workSpace,
cudaStream_t stream) PLUGIN_NOEXCEPT
{
const int batch = inputDesc[0].dims.d[0];
int spatial_size = inputDesc[0].dims.d[1];
int num_heads = inputDesc[0].dims.d[2];
int channels = inputDesc[0].dims.d[3];
int num_levels = inputDesc[1].dims.d[0];
int num_query = inputDesc[3].dims.d[1];
int num_point = inputDesc[3].dims.d[4];
int rc = 0;
int32_t const batch = inputDesc[0].dims.d[0];
int32_t spatial_size = inputDesc[0].dims.d[1];
int32_t num_heads = inputDesc[0].dims.d[2];
int32_t channels = inputDesc[0].dims.d[3];
int32_t num_levels = inputDesc[1].dims.d[0];
int32_t num_query = inputDesc[3].dims.d[1];
int32_t num_point = inputDesc[3].dims.d[4];
int32_t rc = 0;
if (inputDesc[0].type == nvinfer1::DataType::kFLOAT)
{
const float* value = static_cast<const float*>(inputs[0]);
const int32_t* spatialShapes = static_cast<const int32_t*>(inputs[1]);
const int32_t* levelStartIndex = static_cast<const int32_t*>(inputs[2]);
const float* samplingLoc = static_cast<const float*>(inputs[3]);
const float* attnWeight = static_cast<const float*>(inputs[4]);
float const* value = static_cast<float const*>(inputs[0]);
int32_t const* spatialShapes = static_cast<int32_t const*>(inputs[1]);
int32_t const* levelStartIndex = static_cast<int32_t const*>(inputs[2]);
float const* samplingLoc = static_cast<float const*>(inputs[3]);
float const* attnWeight = static_cast<float const*>(inputs[4]);
float* output = static_cast<float*>(outputs[0]);

rc = ms_deform_attn_cuda_forward(stream, value, spatialShapes, levelStartIndex, samplingLoc, attnWeight, output,
batch, spatial_size, num_heads, channels, num_levels, num_query, num_point);
}
#if __CUDA_ARCH__ >= 530
else if (inputDesc[0].type == nvinfer1::DataType::kHALF)
{
const __half* value = static_cast<const __half*>(inputs[0]);
const int32_t* spatialShapes = static_cast<const int32_t*>(inputs[1]);
const int32_t* levelStartIndex = static_cast<const int32_t*>(inputs[2]);
int32_t const* spatialShapes = static_cast<int32_t const*>(inputs[1]);
int32_t const* levelStartIndex = static_cast<int32_t const*>(inputs[2]);
const __half* samplingLoc = static_cast<const __half*>(inputs[3]);
const __half* attnWeight = static_cast<const __half*>(inputs[4]);
__half* output = static_cast<__half*>(outputs[0]);

rc = ms_deform_attn_cuda_forward(stream, value, spatialShapes, levelStartIndex, samplingLoc, attnWeight, output,
batch, spatial_size, num_heads, channels, num_levels, num_query, num_point);
}
#endif

return rc;
}

void MultiscaleDeformableAttnPlugin::attachToContext(
cudnnContext* cudnnContext, cublasContext* cublasContext, nvinfer1::IGpuAllocator* gpuAllocator) PLUGIN_NOEXCEPT
{
mCublasHandle = cublasContext;
}

void MultiscaleDeformableAttnPlugin::detachFromContext() PLUGIN_NOEXCEPT {}

// IPluginV2Ext Methods
nvinfer1::DataType MultiscaleDeformableAttnPlugin::getOutputDataType(
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const PLUGIN_NOEXCEPT
int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const PLUGIN_NOEXCEPT
{
return inputTypes[0];
}

// IPluginV2 Methods
const char* MultiscaleDeformableAttnPlugin::getPluginType() const PLUGIN_NOEXCEPT
char const* MultiscaleDeformableAttnPlugin::getPluginType() const PLUGIN_NOEXCEPT
{
return DMHA_NAME;
}

const char* MultiscaleDeformableAttnPlugin::getPluginVersion() const PLUGIN_NOEXCEPT
char const* MultiscaleDeformableAttnPlugin::getPluginVersion() const PLUGIN_NOEXCEPT
{
return DMHA_VERSION;
}

int MultiscaleDeformableAttnPlugin::getNbOutputs() const PLUGIN_NOEXCEPT
int32_t MultiscaleDeformableAttnPlugin::getNbOutputs() const PLUGIN_NOEXCEPT
{
return 1;
}

int MultiscaleDeformableAttnPlugin::initialize() PLUGIN_NOEXCEPT
int32_t MultiscaleDeformableAttnPlugin::initialize() PLUGIN_NOEXCEPT
{
return 0;
}
Expand All @@ -216,11 +215,11 @@ void MultiscaleDeformableAttnPlugin::destroy() PLUGIN_NOEXCEPT
delete this;
}

void MultiscaleDeformableAttnPlugin::setPluginNamespace(const char* pluginNamespace) PLUGIN_NOEXCEPT
void MultiscaleDeformableAttnPlugin::setPluginNamespace(char const* pluginNamespace) PLUGIN_NOEXCEPT
{
mNamespace = pluginNamespace;
}
const char* MultiscaleDeformableAttnPlugin::getPluginNamespace() const PLUGIN_NOEXCEPT
char const* MultiscaleDeformableAttnPlugin::getPluginNamespace() const PLUGIN_NOEXCEPT
{
return mNamespace.c_str();
}
Expand All @@ -234,45 +233,61 @@ MultiscaleDeformableAttnPluginCreator::MultiscaleDeformableAttnPluginCreator()
mFC.fields = mPluginAttributes.data();
}

const char* MultiscaleDeformableAttnPluginCreator::getPluginName() const PLUGIN_NOEXCEPT
char const* MultiscaleDeformableAttnPluginCreator::getPluginName() const PLUGIN_NOEXCEPT
{
return DMHA_NAME;
}

const char* MultiscaleDeformableAttnPluginCreator::getPluginVersion() const PLUGIN_NOEXCEPT
char const* MultiscaleDeformableAttnPluginCreator::getPluginVersion() const PLUGIN_NOEXCEPT
{
return DMHA_VERSION;
}

const nvinfer1::PluginFieldCollection* MultiscaleDeformableAttnPluginCreator::getFieldNames() PLUGIN_NOEXCEPT
nvinfer1::PluginFieldCollection const* MultiscaleDeformableAttnPluginCreator::getFieldNames() PLUGIN_NOEXCEPT
{
return &mFC;
}

IPluginV2* MultiscaleDeformableAttnPluginCreator::createPlugin(
const char* name, const PluginFieldCollection* fc) PLUGIN_NOEXCEPT
char const* name, PluginFieldCollection const* fc) PLUGIN_NOEXCEPT
{
MultiscaleDeformableAttnPlugin* plugin = new MultiscaleDeformableAttnPlugin(name);
return plugin;
try
{
MultiscaleDeformableAttnPlugin* plugin = new MultiscaleDeformableAttnPlugin();
return plugin;
}
catch (const std::exception& e)
{
caughtError(e);
}
return nullptr;
}

IPluginV2* MultiscaleDeformableAttnPluginCreator::deserializePlugin(
const char* name, const void* serialData, size_t serialLength) PLUGIN_NOEXCEPT
char const* name, void const* serialData, size_t serialLength) PLUGIN_NOEXCEPT
{
auto plugin = new MultiscaleDeformableAttnPlugin(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
try
{
auto plugin = new MultiscaleDeformableAttnPlugin(serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
catch (const std::exception& e)
{
caughtError(e);
}
return nullptr;
}

void MultiscaleDeformableAttnPluginCreator::setPluginNamespace(const char* pluginNamespace) PLUGIN_NOEXCEPT
void MultiscaleDeformableAttnPluginCreator::setPluginNamespace(char const* pluginNamespace) PLUGIN_NOEXCEPT
{
mNamespace = pluginNamespace;
}

const char* MultiscaleDeformableAttnPluginCreator::getPluginNamespace() const PLUGIN_NOEXCEPT
char const* MultiscaleDeformableAttnPluginCreator::getPluginNamespace() const PLUGIN_NOEXCEPT
{
return mNamespace.c_str();
}

} // namespace plugin
} // namespace nvinfer1
} // namespace nvinfer1
Loading

0 comments on commit c3618fd

Please sign in to comment.