Fix shape-related issues in FuseConv (#12410)

* fix shape mismatch in FuseConv

* remove zeroed bias

* offset Z dim

* append UT

* add testing model

* remove output

* remove commented

* fix comments

* refactor output msg

* narrowly restrict the use of cudnn...ActFwd

* reset changes in cudnn_common

* add test cases covering all path

* move cases to conv test

* remove extra space

* fix build err

Co-authored-by: Randy Shuai <rashuai@microsoft.com>
This commit is contained in:
RandySheriffH 2022-08-29 10:47:19 -07:00 committed by GitHub
parent 233f8c210e
commit 17ccd6fa02
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 340 additions and 61 deletions

View file

@ -37,7 +37,7 @@ class FusedConv : public onnxruntime::cuda::Conv<T> {
Status ComputeInternal(OpKernelContext* context) const override {
CUDNN_RETURN_IF_ERROR(status_);
std::lock_guard<OrtMutex> lock(Base::s_.mutex);
ORT_RETURN_IF_ERROR(Base::UpdateState(context, true));
ORT_RETURN_IF_ERROR(Base::UpdateState(context));
if (Base::s_.Y->Shape().Size() == 0) {
return Status::OK();
}
@ -47,25 +47,27 @@ class FusedConv : public onnxruntime::cuda::Conv<T> {
const auto alpha = onnxruntime::cuda::Consts<CudaT>::One;
const auto beta = onnxruntime::cuda::Consts<CudaT>::Zero;
IAllocatorUniquePtr<void> workspace = Base::GetWorkSpace();
auto cudnn_status = cudnnConvolutionBiasActivationForward(Base::CudnnHandle(),
&alpha,
Base::s_.x_tensor,
Base::s_.x_data,
Base::s_.w_desc,
Base::s_.w_data,
Base::s_.conv_desc,
Base::s_.algo,
workspace.get(),
Base::s_.workspace_bytes,
has_z ? &alpha : &beta,
has_z ? Base::s_.z_tensor : Base::s_.y_tensor,
has_z ? Base::s_.z_data : Base::s_.y_data,
Base::s_.b_tensor,
has_b ? Base::s_.b_data : Base::s_.b_zero,
activation_desc_,
Base::s_.y_tensor,
Base::s_.y_data);
if (CUDNN_STATUS_SUCCESS != cudnn_status) {
if (has_b && has_z && !Base::s_.post_slicing_required) {
CUDNN_RETURN_IF_ERROR(cudnnConvolutionBiasActivationForward(Base::CudnnHandle(),
&alpha,
Base::s_.x_tensor,
Base::s_.x_data,
Base::s_.w_desc,
Base::s_.w_data,
Base::s_.conv_desc,
Base::s_.algo,
workspace.get(),
Base::s_.workspace_bytes,
&alpha,
Base::s_.z_tensor,
Base::s_.z_data,
Base::s_.b_tensor,
Base::s_.b_data,
activation_desc_,
Base::s_.y_tensor,
Base::s_.y_data));
} else {
CUDNN_RETURN_IF_ERROR(cudnnConvolutionForward(Base::CudnnHandle(),
&alpha,
Base::s_.x_tensor,
@ -79,21 +81,38 @@ class FusedConv : public onnxruntime::cuda::Conv<T> {
&beta,
Base::s_.y_tensor,
Base::s_.y_data));
if (has_b) {
CUDNN_RETURN_IF_ERROR(cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.b_tensor, Base::s_.b_data,
&alpha, Base::s_.y_tensor, Base::s_.y_data));
if (Base::s_.post_slicing_required) {
ORT_RETURN_IF_ERROR(onnxruntime::cuda::SliceOutUnwantedOutputSection(
this->Stream(), Base::s_.y_data, Base::s_.y_dims_with_adjusted_pads, Base::s_.Y->MutableDataRaw(),
Base::s_.y_dims.GetDims(), Base::s_.slice_starts, Base::s_.slice_ends, Base::s_.slice_axes, Base::s_.element_size));
onnxruntime::cuda::CudnnTensor sliced_y_tensor;
ORT_RETURN_IF_ERROR(sliced_y_tensor.Set(Base::s_.y_dims.GetDims(), onnxruntime::cuda::CudnnTensor::GetDataType<CudaT>()));
if (has_b) {
CUDNN_RETURN_IF_ERROR(cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.b_tensor, Base::s_.b_data,
&alpha, sliced_y_tensor, Base::s_.Y->MutableDataRaw()));
}
if (has_z) {
CUDNN_RETURN_IF_ERROR(cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.z_tensor, Base::s_.z_data,
&alpha, sliced_y_tensor, Base::s_.Y->MutableDataRaw()));
}
CUDNN_RETURN_IF_ERROR(cudnnActivationForward(Base::CudnnHandle(), activation_desc_, &alpha, sliced_y_tensor,
Base::s_.y_data, &beta, sliced_y_tensor, Base::s_.y_data));
} else {
if (has_b) {
CUDNN_RETURN_IF_ERROR(cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.b_tensor, Base::s_.b_data,
&alpha, Base::s_.y_tensor, Base::s_.y_data));
}
if (has_z) {
CUDNN_RETURN_IF_ERROR(cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.z_tensor, Base::s_.z_data,
&alpha, Base::s_.y_tensor, Base::s_.y_data));
}
CUDNN_RETURN_IF_ERROR(cudnnActivationForward(Base::CudnnHandle(), activation_desc_, &alpha, Base::s_.y_tensor,
Base::s_.y_data, &beta, Base::s_.y_tensor, Base::s_.y_data));
}
if (has_z) {
CUDNN_RETURN_IF_ERROR(cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.z_tensor, Base::s_.z_data,
&alpha, Base::s_.y_tensor, Base::s_.y_data));
}
CUDNN_RETURN_IF_ERROR(cudnnActivationForward(Base::CudnnHandle(), activation_desc_, &alpha, Base::s_.y_tensor,
Base::s_.y_data, &beta, Base::s_.y_tensor, Base::s_.y_data));
}
if (Base::s_.post_slicing_required) {
ORT_RETURN_IF_ERROR(onnxruntime::cuda::SliceOutUnwantedOutputSection(
this->Stream(), Base::s_.y_data, Base::s_.y_dims_with_adjusted_pads, Base::s_.Y->MutableDataRaw(),
Base::s_.y_dims.GetDims(), Base::s_.slice_starts, Base::s_.slice_ends, Base::s_.slice_axes, Base::s_.element_size));
}
return Status::OK();
}

View file

@ -87,7 +87,7 @@ Status SliceOutUnwantedOutputSection(cudaStream_t stream,
}
template <typename T>
Status Conv<T>::UpdateState(OpKernelContext* context, bool bias_expected) const {
Status Conv<T>::UpdateState(OpKernelContext* context) const {
//set X
const Tensor* X = context->Input<Tensor>(0);
const TensorShape& x_shape = X->Shape();
@ -109,8 +109,7 @@ Status Conv<T>::UpdateState(OpKernelContext* context, bool bias_expected) const
//set Z
if (context->InputCount() >= 4) {
const Tensor* Z = context->Input<Tensor>(3);
ORT_RETURN_IF_ERROR(s_.z_tensor.Set(Z->Shape().GetDims(), CudnnTensor::GetDataType<CudaT>()));
s_.z_data = reinterpret_cast<const CudaT*>(Z->Data<T>());
s_.z_data = reinterpret_cast<const CudaT*>(Z->template Data<T>());
} else {
s_.z_data = nullptr;
}
@ -237,22 +236,43 @@ Status Conv<T>::UpdateState(OpKernelContext* context, bool bias_expected) const
if (context->InputCount() >= 3) {
const Tensor* B = context->Input<Tensor>(2);
const auto& b_shape = B->Shape();
ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D");
TensorShapeVector b_dims(2 + kernel_shape.size(), 1);
b_dims[1] = b_shape[0];
ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType<CudaT>()));
//s_.b_data = reinterpret_cast<const CudaT*>(B->Data<T>());
} else if (bias_expected) {
TensorShapeVector b_dims(2 + kernel_shape.size(), 1);
b_dims[1] = w_dims[0];
auto malloc_size = b_dims[1] * sizeof(CudaT);
ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType<CudaT>()));
if (s_.b_zero) {
CUDA_CALL_THROW(cudaFree(s_.b_zero));
s_.b_zero = nullptr;
if (b_shape.NumDimensions() == 1) {
TensorShapeVector b_dims(2 + kernel_shape.size(), 1);
b_dims[1] = b_shape[0];
ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType<CudaT>()));
} else {
const auto& y_rank = y_dims_cudnn.size();
const auto& b_rank = b_shape.GetDims().size();
ORT_RETURN_IF_NOT(b_rank <= y_rank, "rank of B is ", b_rank, ", which is bigger than the rank of Y - ", y_rank);
if (b_rank == y_rank) {
ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_shape.GetDims(), CudnnTensor::GetDataType<CudaT>()));
} else {
TensorShapeVector b_extended_dims = b_shape.AsShapeVector();
for (auto i = b_rank; i < y_rank; ++i) {
ORT_RETURN_IF_NOT(y_dims_cudnn[i] == 1, "dim ", i, " of Y is ", y_dims_cudnn[i], ", cannot apply it to that dim of B");
b_extended_dims.push_back(1);
}
ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_extended_dims, CudnnTensor::GetDataType<CudaT>()));
}
}
}
if (context->InputCount() >= 4) {
const Tensor* Z = context->Input<Tensor>(3);
const auto& z_shape = Z->Shape();
const auto& z_rank = z_shape.GetDims().size();
const auto& y_rank = y_dims_cudnn.size();
ORT_RETURN_IF_NOT(z_rank <= y_rank, "rank of Z is ", z_rank, ", which is bigger than the rank of Y - ", y_rank);
if (z_rank == y_rank) {
ORT_RETURN_IF_ERROR(s_.z_tensor.Set(z_shape.GetDims(), CudnnTensor::GetDataType<CudaT>()));
} else {
TensorShapeVector z_extended_dims = z_shape.AsShapeVector();
for (auto i = z_rank; i < y_rank; ++i) {
ORT_RETURN_IF_NOT(y_dims_cudnn[i] == 1, "dim ", i, " of Y is ", y_dims_cudnn[i], ", cannot apply it to that dim of Z");
z_extended_dims.push_back(1);
}
ORT_RETURN_IF_ERROR(s_.z_tensor.Set(z_extended_dims, CudnnTensor::GetDataType<CudaT>()));
}
CUDA_CALL_THROW(cudaMalloc(&s_.b_zero, malloc_size));
CUDA_CALL_THROW(cudaMemsetAsync(s_.b_zero, 0, malloc_size, Stream()));
}
if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) {

View file

@ -141,7 +141,6 @@ struct CudnnConvState {
const void* w_data = nullptr;
CudnnTensor b_tensor;
const void* b_data = nullptr;
void* b_zero = nullptr;
CudnnTensor y_tensor;
Tensor* Y = nullptr;
void* y_data = nullptr;
@ -166,13 +165,6 @@ struct CudnnConvState {
// note that conv objects are shared between execution frames, and a lock is needed to avoid multi-thread racing
OrtMutex mutex;
IAllocatorUniquePtr<void> memory_for_cudnn_conv_results;
~CudnnConvState() {
if (b_zero) {
CUDA_CALL_THROW(cudaFree(b_zero));
b_zero = nullptr;
}
}
};
enum : size_t {
@ -197,7 +189,7 @@ class Conv : public CudaKernel {
return GetScratchBuffer<void>(s_.workspace_bytes);
}
Status UpdateState(OpKernelContext* context, bool bias_expected = false) const;
Status UpdateState(OpKernelContext* context) const;
ConvAttributes conv_attrs_;
mutable CudnnConvState<cudnnConvolutionFwdAlgoPerf_t> s_;
constexpr static auto kDefaultConvAlgo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;

View file

@ -3,6 +3,9 @@
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
#include "core/session/inference_session.h"
#include "test/framework/test_utils.h"
using namespace std;
namespace onnxruntime {
namespace test {
@ -725,5 +728,141 @@ TEST(ConvTest, Conv_AutoPad_with_non_default_strides) {
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true);
}
#ifdef USE_CUDA
TEST(ConvTest, Fuse_Conv_Bias) {
auto model_uri = ORT_TSTR("testdata/fuse_conv_bias.onnx");
SessionOptions so;
InferenceSession session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));
ASSERT_TRUE(session.Initialize().IsOK());
NameMLValMap feeds;
OrtValue ml_value;
size_t X_count = 1 * 3 * 32 * 32;
std::vector<float> X_data(X_count, 1.f);
std::vector<int64_t> X_shape{1, 3, 32, 32};
size_t W_count = 1 * 3 * 5 * 32;
std::vector<float> W_data(W_count, 2.f);
std::vector<int64_t> W_shape{1, 3, 5, 32};
size_t B_count = 1;
std::vector<float> B_data(B_count, 5.f);
std::vector<int64_t> B_shape{1};
size_t Z_count = 1 * 1 * 28;
std::vector<float> Z_data(Z_count, 1.f);
std::vector<int64_t> Z_shape{1, 1, 28};
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), X_shape, X_data, &ml_value);
feeds.insert(std::make_pair("X", ml_value));
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), W_shape, W_data, &ml_value);
feeds.insert(std::make_pair("W", ml_value));
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), B_shape, B_data, &ml_value);
feeds.insert(std::make_pair("B", ml_value));
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), Z_shape, Z_data, &ml_value);
feeds.insert(std::make_pair("Z", ml_value));
std::vector<std::string> output_names{"R"};
std::vector<OrtValue> fetches;
onnxruntime::RunOptions run_options;
auto st = session.Run(run_options, feeds, output_names, &fetches);
ASSERT_TRUE(st.IsOK()) << st;
ASSERT_EQ(1u, fetches.size());
}
TEST(ConvTest, Fuse_Conv_Bias_Slice) {
auto model_uri = ORT_TSTR("testdata/fuse_conv_bias_slice.onnx");
SessionOptions so;
InferenceSession session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));
ASSERT_TRUE(session.Initialize().IsOK());
NameMLValMap feeds;
OrtValue ml_value;
size_t X_count = 1 * 2 * 6 * 6;
std::vector<float> X_data(X_count, 1.f);
std::vector<int64_t> X_shape{1, 2, 6, 6};
size_t W_count = 1 * 2 * 4 * 4;
std::vector<float> W_data(W_count, 2.f);
std::vector<int64_t> W_shape{1, 2, 4, 4};
size_t B_count = 1;
std::vector<float> B_data(B_count, 5.f);
std::vector<int64_t> B_shape{1};
size_t Z_count = 1 * 1 * 4 * 2;
std::vector<float> Z_data(Z_count, 1.f);
std::vector<int64_t> Z_shape{1, 1, 4, 2};
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), X_shape, X_data, &ml_value);
feeds.insert(std::make_pair("X", ml_value));
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), W_shape, W_data, &ml_value);
feeds.insert(std::make_pair("W", ml_value));
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), B_shape, B_data, &ml_value);
feeds.insert(std::make_pair("B", ml_value));
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), Z_shape, Z_data, &ml_value);
feeds.insert(std::make_pair("Z", ml_value));
std::vector<std::string> output_names{"R"};
std::vector<OrtValue> fetches;
onnxruntime::RunOptions run_options;
auto st = session.Run(run_options, feeds, output_names, &fetches);
ASSERT_TRUE(st.IsOK()) << st;
ASSERT_EQ(1u, fetches.size());
}
TEST(ConvTest, Fuse_Conv_No_Bias) {
auto model_uri = ORT_TSTR("testdata/fuse_conv_no_bias.onnx");
SessionOptions so;
InferenceSession session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));
ASSERT_TRUE(session.Initialize().IsOK());
NameMLValMap feeds;
OrtValue ml_value;
size_t X_count = 1 * 3 * 32 * 32;
std::vector<float> X_data(X_count, 1.f);
std::vector<int64_t> X_shape{1, 3, 32, 32};
size_t W_count = 1 * 3 * 5 * 32;
std::vector<float> W_data(W_count, 2.f);
std::vector<int64_t> W_shape{1, 3, 5, 32};
size_t Z_count = 1 * 1 * 28;
std::vector<float> Z_data(Z_count, 1.f);
std::vector<int64_t> Z_shape{1, 1, 28};
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), X_shape, X_data, &ml_value);
feeds.insert(std::make_pair("X", ml_value));
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), W_shape, W_data, &ml_value);
feeds.insert(std::make_pair("W", ml_value));
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), Z_shape, Z_data, &ml_value);
feeds.insert(std::make_pair("Z", ml_value));
std::vector<std::string> output_names{"R"};
std::vector<OrtValue> fetches;
onnxruntime::RunOptions run_options;
auto st = session.Run(run_options, feeds, output_names, &fetches);
ASSERT_TRUE(st.IsOK()) << st;
ASSERT_EQ(1u, fetches.size());
}
#endif
} // namespace test
} // namespace onnxruntime

View file

@ -0,0 +1,37 @@


X
W
BY"Conv

Y
ZC"Add
CR"RelugraphZ
X




 Z
W




 Z
B

Z
Z



b?
R:
84
˙˙˙˙˙˙˙˙˙
˙˙˙˙˙˙˙˙˙
˙˙˙˙˙˙˙˙˙
˙˙˙˙˙˙˙˙˙B

View file

@ -0,0 +1,40 @@
:‡
7
X
W
BY"Conv*
pads@@@@ *
strides@@ 

Y
ZC"Add
CR"RelugraphZ
X




Z
W




Z
B

Z
Z




b?
R:
84
˙˙˙˙˙˙˙˙˙
˙˙˙˙˙˙˙˙˙
˙˙˙˙˙˙˙˙˙
˙˙˙˙˙˙˙˙˙B

View file

@ -0,0 +1,32 @@


X
WY"Conv

Y
ZC"Add
CR"RelugraphZ
X




 Z
W




 Z
Z



b?
R:
84
˙˙˙˙˙˙˙˙˙
˙˙˙˙˙˙˙˙˙
˙˙˙˙˙˙˙˙˙
˙˙˙˙˙˙˙˙˙B