mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
[ROCm] add FP16 support for FusedConv Op (#15443)
Add FP16 support for FusedConv Op and update UT
This commit is contained in:
parent
ce1eb6d629
commit
d49a8de9b1
3 changed files with 72 additions and 23 deletions
|
|
@ -441,15 +441,18 @@ class FusedConv : public onnxruntime::rocm::Conv<T, false> {
|
|||
template <typename T>
|
||||
typename FusedConv<T>::FusionPlanCache FusedConv<T>::plan_cache_;
|
||||
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX(
|
||||
FusedConv,
|
||||
kMSDomain,
|
||||
1,
|
||||
float,
|
||||
kRocmExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
FusedConv<float>);
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
FusedConv, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T, \
|
||||
kRocmExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
FusedConv<T>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(float);
|
||||
REGISTER_KERNEL_TYPED(MLFloat16);
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -107,6 +107,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int8_t, QAttention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedConv);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedConv);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul);
|
||||
|
|
@ -251,6 +252,8 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain,
|
||||
1, float, FusedConv)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain,
|
||||
1, MLFloat16, FusedConv)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GemmFastGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GemmFastGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GemmFastGelu)>,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/common/tensor_op_test_utils.h"
|
||||
#include "test/util/include/default_providers.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
|
@ -10,6 +12,12 @@ namespace test {
|
|||
#if !defined(DISABLE_CONTRIB_OPS)
|
||||
using namespace std;
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define ROCM_GTEST_SKIP(message) GTEST_SKIP_(message)
|
||||
#else
|
||||
#define ROCM_GTEST_SKIP(message)
|
||||
#endif
|
||||
|
||||
struct ConvOpAndTestAttributes {
|
||||
string auto_pad;
|
||||
vector<int64_t> dilations;
|
||||
|
|
@ -49,7 +57,6 @@ static std::unordered_set<std::string> providers_except_cpu_gpu = {
|
|||
kAclExecutionProvider,
|
||||
kArmNNExecutionProvider};
|
||||
|
||||
|
||||
void TestConvOp(const ConvOpAndTestAttributes& attributes,
|
||||
const vector<vector<float>>& inputs,
|
||||
const vector<vector<int64_t>>& input_shapes,
|
||||
|
|
@ -58,7 +65,10 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes,
|
|||
const std::unordered_set<std::string>& excluded_provider_types = providers_except_cpu_gpu,
|
||||
bool weight_is_initializer = false,
|
||||
OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess,
|
||||
const std::string& err_str = "") {
|
||||
const std::string& err_str = "",
|
||||
bool use_float16 = false) {
|
||||
bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get());
|
||||
|
||||
OpTester test("FusedConv", 1, onnxruntime::kMSDomain);
|
||||
test.AddAttribute("group", attributes.group);
|
||||
test.AddAttribute("kernel_shape", attributes.kernel_shape);
|
||||
|
|
@ -86,17 +96,49 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes,
|
|||
}
|
||||
|
||||
const char* szNames[] = {"X", "W", "B", "Z"};
|
||||
test.AddInput<float>(szNames[0], input_shapes[0], inputs[0]);
|
||||
test.AddInput<float>(szNames[1], input_shapes[1], inputs[1], weight_is_initializer);
|
||||
if (inputs.size() >= 3)
|
||||
test.AddInput<float>(szNames[2], input_shapes[2], inputs[2]);
|
||||
if (inputs.size() >= 4)
|
||||
test.AddInput<float>(szNames[3], input_shapes[3], inputs[3]);
|
||||
test.AddOutput<float>("Y", expected_output_shape, expected_output);
|
||||
test.Run(expect_result, err_str, excluded_provider_types);
|
||||
|
||||
if (use_float16 && enable_rocm) {
|
||||
// Only ROCm EP supports float16.
|
||||
test.AddInput<MLFloat16>(szNames[0], input_shapes[0], ToFloat16(inputs[0]));
|
||||
test.AddInput<MLFloat16>(szNames[1], input_shapes[1], ToFloat16(inputs[1]), weight_is_initializer);
|
||||
if (inputs.size() >= 3)
|
||||
test.AddInput<MLFloat16>(szNames[2], input_shapes[2], ToFloat16(inputs[2]));
|
||||
if (inputs.size() >= 4)
|
||||
test.AddInput<MLFloat16>(szNames[3], input_shapes[3], ToFloat16(inputs[3]));
|
||||
test.AddOutput<MLFloat16>("Y", expected_output_shape, ToFloat16(expected_output));
|
||||
test.Run(expect_result, err_str, excluded_provider_types);
|
||||
} else {
|
||||
test.AddInput<float>(szNames[0], input_shapes[0], inputs[0]);
|
||||
test.AddInput<float>(szNames[1], input_shapes[1], inputs[1], weight_is_initializer);
|
||||
if (inputs.size() >= 3)
|
||||
test.AddInput<float>(szNames[2], input_shapes[2], inputs[2]);
|
||||
if (inputs.size() >= 4)
|
||||
test.AddInput<float>(szNames[3], input_shapes[3], inputs[3]);
|
||||
test.AddOutput<float>("Y", expected_output_shape, expected_output);
|
||||
test.Run(expect_result, err_str, excluded_provider_types);
|
||||
}
|
||||
}
|
||||
|
||||
void RunConvOp(const ConvOpAndTestAttributes& attributes,
|
||||
const vector<vector<float>>& inputs,
|
||||
const vector<vector<int64_t>>& input_shapes,
|
||||
const std::initializer_list<float>& expected_output,
|
||||
const vector<int64_t>& expected_output_shape,
|
||||
const std::unordered_set<std::string>& excluded_provider_types = providers_except_cpu_gpu,
|
||||
bool weight_is_initializer = false,
|
||||
OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess,
|
||||
const std::string& err_str = "") {
|
||||
bool use_float16 = true;
|
||||
TestConvOp(attributes, inputs, input_shapes, expected_output, expected_output_shape, excluded_provider_types,
|
||||
weight_is_initializer, expect_result, err_str, use_float16);
|
||||
|
||||
use_float16 = false;
|
||||
TestConvOp(attributes, inputs, input_shapes, expected_output, expected_output_shape, excluded_provider_types,
|
||||
weight_is_initializer, expect_result, err_str, use_float16);
|
||||
}
|
||||
|
||||
TEST(FusedConvTest, Conv2D_HardSigmoid) {
|
||||
ROCM_GTEST_SKIP("ROCm does not support Conv2D_HardSigmoid");
|
||||
ConvOpAndTestAttributes attrs = {
|
||||
"", // auto_pad
|
||||
vector<int64_t>{1, 1}, // dilations
|
||||
|
|
@ -114,7 +156,7 @@ TEST(FusedConvTest, Conv2D_HardSigmoid) {
|
|||
vector<int64_t> W_shape = {2, 1, 2, 2};
|
||||
vector<int64_t> Y_shape = {1, 2, 2, 2};
|
||||
auto expected_vals = {0.8f, 0.9f, 1.0f, 1.0f, 0.2f, 0.1f, 0.0f, 0.0f};
|
||||
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, providers_except_cpu);
|
||||
RunConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, providers_except_cpu);
|
||||
}
|
||||
|
||||
TEST(FusedConvTest, Conv2D_Relu) {
|
||||
|
|
@ -134,7 +176,7 @@ TEST(FusedConvTest, Conv2D_Relu) {
|
|||
vector<int64_t> W_shape = {2, 1, 2, 2};
|
||||
vector<int64_t> Y_shape = {1, 2, 2, 2};
|
||||
auto expected_vals = {12.0f, 16.0f, 24.0f, 28.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
|
||||
RunConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
|
||||
}
|
||||
|
||||
TEST(FusedConvTest, Conv2D_Bias_Relu) {
|
||||
|
|
@ -156,7 +198,7 @@ TEST(FusedConvTest, Conv2D_Bias_Relu) {
|
|||
vector<float> B = {1.0f, -1.0f};
|
||||
vector<int64_t> B_shape = {2};
|
||||
auto expected_vals = {13.0f, 17.0f, 25.0f, 29.0f, 11.0f, 15.0f, 23.0f, 27.0f};
|
||||
TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape);
|
||||
RunConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape);
|
||||
}
|
||||
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
|
|
@ -196,12 +238,13 @@ TEST(FusedConvTest, Conv2D_Bias_Z_Relu) {
|
|||
vector<float> Z = {-1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f};
|
||||
vector<int64_t> Z_shape = {1, 2, 2, 2};
|
||||
auto expected_vals = {12.0f, 17.0f, 25.0f, 29.0f, 11.0f, 15.0f, 23.0f, 28.0f};
|
||||
TestConvOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, providers_except_gpu);
|
||||
RunConvOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, providers_except_gpu);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
TEST(FusedConvTest, Cpu_Conv2D_Bias_Z_Relu) {
|
||||
ROCM_GTEST_SKIP("ROCm skip Cpu_Conv2D_Bias_Z_Relu");
|
||||
ConvOpAndTestAttributes attrs = {
|
||||
"", // auto_pad
|
||||
vector<int64_t>{1, 1}, // dilations
|
||||
|
|
@ -222,7 +265,7 @@ TEST(FusedConvTest, Cpu_Conv2D_Bias_Z_Relu) {
|
|||
vector<float> Z = {-1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f};
|
||||
vector<int64_t> Z_shape = {1, 2, 2, 2};
|
||||
auto expected_vals = {12.0f, 17.0f, 25.0f, 29.0f, 11.0f, 15.0f, 23.0f, 28.0f};
|
||||
TestConvOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, providers_except_cpu);
|
||||
RunConvOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, providers_except_cpu);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in a new issue