diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp index 835d43037e..ab8ddbfe91 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp @@ -558,7 +558,11 @@ public: { ML_CHECK_VALID_ARGUMENT(axis < outputShapeDimCount); uint32_t broadcastAxisLength = outputShape[axis]; - ML_CHECK_VALID_ARGUMENT(inputTensorShape[0] == broadcastAxisLength); + ML_CHECK_VALID_ARGUMENT( + (inputTensorShape[0] == broadcastAxisLength) || + // Treat as broadcast dimension to match CPU behavior. + (inputTensorShape[0] == 1) + ); inputTensorShape.insert(inputTensorShape.begin(), axis, 1); inputTensorShape.insert(inputTensorShape.end(), outputShapeDimCount - 1 - axis, 1); } diff --git a/onnxruntime/test/contrib_ops/quantize_ops_test.cc b/onnxruntime/test/contrib_ops/quantize_ops_test.cc index 64a97ed4f9..db685967ae 100644 --- a/onnxruntime/test/contrib_ops/quantize_ops_test.cc +++ b/onnxruntime/test/contrib_ops/quantize_ops_test.cc @@ -76,6 +76,16 @@ TEST(DequantizeLinearOpTest, DequantizeLinear_per_tensor_float_int32_cpu) { test.Run(); } +TEST(DequantizeLinearOpTest, DequantizeLinearOpTest_BroadcastTensorOfOne) { + OpTester test("DequantizeLinear", 1, onnxruntime::kMSDomain); + + test.AddInput("x", {4}, {-30, -3, 100, 127}); + test.AddInput("x_scale", {1}, {2.0f}, true); + test.AddInput("zero_point", {1}, {0}, true); + test.AddOutput("y", {4}, {-60.f, -6.f, 200.f, 254.f}); + test.Run(); +} + #ifdef USE_CUDA TEST(DequantizeLinearOpTest, DequantizeLinear_per_tensor_half_uint8) { OpTester test("DequantizeLinear", 1, onnxruntime::kMSDomain); diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index f4b21823a4..026bb07edf 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -47,6 +47,16 @@ TEST(DequantizeLinearOpTest, Int32) { test.Run(); } +TEST(DequantizeLinearOpTest_BroadcastTensor, Int32) { + OpTester test("DequantizeLinear", 13); + test.AddInput("x", {4}, {-30, -3, 100, 127}); + test.AddAttribute("axis", 0); + test.AddInput("x_scale", {1}, {2.0f}); + test.AddInput("x_zero_point", {1}, {0}); + test.AddOutput("y", {4}, {-60.f, -6.f, 200.f, 254.f}); + test.Run(); +} + // 2d inputs TEST(DequantizeLinearOpTest, 2D) { OpTester test("DequantizeLinear", 10);