[ROCm] add FP16 support for FusedConv Op (#15443)

Add FP16 support for FusedConv Op and update UT
This commit is contained in:
PeixuanZuo 2023-04-12 12:19:14 +08:00 committed by GitHub
parent ce1eb6d629
commit d49a8de9b1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 72 additions and 23 deletions

View file

@ -441,15 +441,18 @@ class FusedConv : public onnxruntime::rocm::Conv<T, false> {
template <typename T>
typename FusedConv<T>::FusionPlanCache FusedConv<T>::plan_cache_;
ONNX_OPERATOR_TYPED_KERNEL_EX(
FusedConv,
kMSDomain,
1,
float,
kRocmExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
FusedConv<float>);
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
FusedConv, \
kMSDomain, \
1, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
FusedConv<T>);
REGISTER_KERNEL_TYPED(float);
REGISTER_KERNEL_TYPED(MLFloat16);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -107,6 +107,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int8_t, QAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul);
@ -251,6 +252,8 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain,
1, float, FusedConv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain,
1, MLFloat16, FusedConv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GemmFastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GemmFastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GemmFastGelu)>,

View file

@ -3,6 +3,8 @@
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
#include "test/common/tensor_op_test_utils.h"
#include "test/util/include/default_providers.h"
namespace onnxruntime {
namespace test {
@ -10,6 +12,12 @@ namespace test {
#if !defined(DISABLE_CONTRIB_OPS)
using namespace std;
#if defined(USE_ROCM)
#define ROCM_GTEST_SKIP(message) GTEST_SKIP_(message)
#else
#define ROCM_GTEST_SKIP(message)
#endif
struct ConvOpAndTestAttributes {
string auto_pad;
vector<int64_t> dilations;
@ -49,7 +57,6 @@ static std::unordered_set<std::string> providers_except_cpu_gpu = {
kAclExecutionProvider,
kArmNNExecutionProvider};
void TestConvOp(const ConvOpAndTestAttributes& attributes,
const vector<vector<float>>& inputs,
const vector<vector<int64_t>>& input_shapes,
@ -58,7 +65,10 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes,
const std::unordered_set<std::string>& excluded_provider_types = providers_except_cpu_gpu,
bool weight_is_initializer = false,
OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess,
const std::string& err_str = "") {
const std::string& err_str = "",
bool use_float16 = false) {
bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get());
OpTester test("FusedConv", 1, onnxruntime::kMSDomain);
test.AddAttribute("group", attributes.group);
test.AddAttribute("kernel_shape", attributes.kernel_shape);
@ -86,17 +96,49 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes,
}
const char* szNames[] = {"X", "W", "B", "Z"};
test.AddInput<float>(szNames[0], input_shapes[0], inputs[0]);
test.AddInput<float>(szNames[1], input_shapes[1], inputs[1], weight_is_initializer);
if (inputs.size() >= 3)
test.AddInput<float>(szNames[2], input_shapes[2], inputs[2]);
if (inputs.size() >= 4)
test.AddInput<float>(szNames[3], input_shapes[3], inputs[3]);
test.AddOutput<float>("Y", expected_output_shape, expected_output);
test.Run(expect_result, err_str, excluded_provider_types);
if (use_float16 && enable_rocm) {
// Only ROCm EP supports float16.
test.AddInput<MLFloat16>(szNames[0], input_shapes[0], ToFloat16(inputs[0]));
test.AddInput<MLFloat16>(szNames[1], input_shapes[1], ToFloat16(inputs[1]), weight_is_initializer);
if (inputs.size() >= 3)
test.AddInput<MLFloat16>(szNames[2], input_shapes[2], ToFloat16(inputs[2]));
if (inputs.size() >= 4)
test.AddInput<MLFloat16>(szNames[3], input_shapes[3], ToFloat16(inputs[3]));
test.AddOutput<MLFloat16>("Y", expected_output_shape, ToFloat16(expected_output));
test.Run(expect_result, err_str, excluded_provider_types);
} else {
test.AddInput<float>(szNames[0], input_shapes[0], inputs[0]);
test.AddInput<float>(szNames[1], input_shapes[1], inputs[1], weight_is_initializer);
if (inputs.size() >= 3)
test.AddInput<float>(szNames[2], input_shapes[2], inputs[2]);
if (inputs.size() >= 4)
test.AddInput<float>(szNames[3], input_shapes[3], inputs[3]);
test.AddOutput<float>("Y", expected_output_shape, expected_output);
test.Run(expect_result, err_str, excluded_provider_types);
}
}
void RunConvOp(const ConvOpAndTestAttributes& attributes,
const vector<vector<float>>& inputs,
const vector<vector<int64_t>>& input_shapes,
const std::initializer_list<float>& expected_output,
const vector<int64_t>& expected_output_shape,
const std::unordered_set<std::string>& excluded_provider_types = providers_except_cpu_gpu,
bool weight_is_initializer = false,
OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess,
const std::string& err_str = "") {
bool use_float16 = true;
TestConvOp(attributes, inputs, input_shapes, expected_output, expected_output_shape, excluded_provider_types,
weight_is_initializer, expect_result, err_str, use_float16);
use_float16 = false;
TestConvOp(attributes, inputs, input_shapes, expected_output, expected_output_shape, excluded_provider_types,
weight_is_initializer, expect_result, err_str, use_float16);
}
TEST(FusedConvTest, Conv2D_HardSigmoid) {
ROCM_GTEST_SKIP("ROCm does not support Conv2D_HardSigmoid");
ConvOpAndTestAttributes attrs = {
"", // auto_pad
vector<int64_t>{1, 1}, // dilations
@ -114,7 +156,7 @@ TEST(FusedConvTest, Conv2D_HardSigmoid) {
vector<int64_t> W_shape = {2, 1, 2, 2};
vector<int64_t> Y_shape = {1, 2, 2, 2};
auto expected_vals = {0.8f, 0.9f, 1.0f, 1.0f, 0.2f, 0.1f, 0.0f, 0.0f};
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, providers_except_cpu);
RunConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, providers_except_cpu);
}
TEST(FusedConvTest, Conv2D_Relu) {
@ -134,7 +176,7 @@ TEST(FusedConvTest, Conv2D_Relu) {
vector<int64_t> W_shape = {2, 1, 2, 2};
vector<int64_t> Y_shape = {1, 2, 2, 2};
auto expected_vals = {12.0f, 16.0f, 24.0f, 28.0f, 0.0f, 0.0f, 0.0f, 0.0f};
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
RunConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
}
TEST(FusedConvTest, Conv2D_Bias_Relu) {
@ -156,7 +198,7 @@ TEST(FusedConvTest, Conv2D_Bias_Relu) {
vector<float> B = {1.0f, -1.0f};
vector<int64_t> B_shape = {2};
auto expected_vals = {13.0f, 17.0f, 25.0f, 29.0f, 11.0f, 15.0f, 23.0f, 27.0f};
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape);
RunConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape);
}
#if defined(USE_CUDA) || defined(USE_ROCM)
@ -196,12 +238,13 @@ TEST(FusedConvTest, Conv2D_Bias_Z_Relu) {
vector<float> Z = {-1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f};
vector<int64_t> Z_shape = {1, 2, 2, 2};
auto expected_vals = {12.0f, 17.0f, 25.0f, 29.0f, 11.0f, 15.0f, 23.0f, 28.0f};
TestConvOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, providers_except_gpu);
RunConvOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, providers_except_gpu);
}
#endif
TEST(FusedConvTest, Cpu_Conv2D_Bias_Z_Relu) {
ROCM_GTEST_SKIP("ROCm skip Cpu_Conv2D_Bias_Z_Relu");
ConvOpAndTestAttributes attrs = {
"", // auto_pad
vector<int64_t>{1, 1}, // dilations
@ -222,7 +265,7 @@ TEST(FusedConvTest, Cpu_Conv2D_Bias_Z_Relu) {
vector<float> Z = {-1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f};
vector<int64_t> Z_shape = {1, 2, 2, 2};
auto expected_vals = {12.0f, 17.0f, 25.0f, 29.0f, 11.0f, 15.0f, 23.0f, 28.0f};
TestConvOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, providers_except_cpu);
RunConvOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, providers_except_cpu);
}
#endif