From dd3b2cefd6abdd6a0fe6f6701d26d1299f87f1fa Mon Sep 17 00:00:00 2001 From: xhcao Date: Sat, 19 Aug 2023 03:19:01 +0800 Subject: [PATCH] [js/webgpu] Support int32 type for binary (#16901) ### Description Enable typed binary and support int32 type for binary. Co-authored-by: Xing Xu --------- Co-authored-by: Xing Xu --- js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts | 25 ++++---- js/web/test/data/ops/div_int32.jsonc | 31 ++++++++++ js/web/test/data/ops/pow_int32.jsonc | 31 ++++++++++ js/web/test/suite-test-list.jsonc | 5 ++ .../core/providers/js/operators/binary.cc | 62 ++++++++++--------- 5 files changed, 113 insertions(+), 41 deletions(-) create mode 100644 js/web/test/data/ops/div_int32.jsonc create mode 100644 js/web/test/data/ops/pow_int32.jsonc diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index 853c4229e4..5f3d156466 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -146,8 +146,6 @@ const createBinaryOpProgramInfo = if (sharedDimension % 4 === 0) { vectorize = true; } - - } else { // element-wise vectorize = true; @@ -188,19 +186,24 @@ export const mul = (context: ComputeContext): void => { }; export const pow = (context: ComputeContext): void => { + const type = inputVariable('input', context.inputs[0].dataType, context.inputs[0].dims).type.value; + const roundStr = type === 'i32' ? 'round' : ''; context.compute(createBinaryOpProgramInfoLoader( - context.inputs, 'Pow', ({scalar: (a, b) => `pow_f32(${a},${b})`, vector: (a, b) => `pow_vf32(${a},${b})`}), ` - fn pow_f32(a : f32, b : f32) -> f32 { - if (b == 0.0) { - return 1.0; - } else if (a < 0.0 && b != floor(b)) { - return pow(a, b); // NaN + context.inputs, 'Pow', + ({scalar: (a, b) => `pow_custom(${a},${b})`, vector: (a, b) => `pow_vector_custom(${a},${b})`}), + ` + fn pow_custom(a : ${type}, b : ${type}) -> ${type} { + if (b == ${type}(0.0)) { + return ${type}(1.0); + } else if (a < ${type}(0.0) && f32(b) != floor(f32(b))) { + return ${type}(pow(f32(a), f32(b))); // NaN } - return select(sign(a), 1.0, round(abs(b) % 2.0) != 1.0) * pow(abs(a), b); + return select(sign(a), ${type}(1.0), round(f32(abs(b) % ${type}(2.0))) != 1.0) * ${type}(${ + roundStr}(pow(f32(abs(a)), f32(b)))); } - fn pow_vf32(a : vec4, b : vec4) -> vec4 { + fn pow_vector_custom(a : vec4<${type}>, b : vec4<${type}>) -> vec4<${type}> { // TODO: implement vectorized pow - return vec4(pow_f32(a.x, b.x), pow_f32(a.y, b.y), pow_f32(a.z, b.z), pow_f32(a.w, b.w)); + return vec4<${type}>(pow_custom(a.x, b.x), pow_custom(a.y, b.y), pow_custom(a.z, b.z), pow_custom(a.w, b.w)); } `)); }; diff --git a/js/web/test/data/ops/div_int32.jsonc b/js/web/test/data/ops/div_int32.jsonc new file mode 100644 index 0000000000..f86ff82759 --- /dev/null +++ b/js/web/test/data/ops/div_int32.jsonc @@ -0,0 +1,31 @@ +[ + { + "name": "Div with no attributes", + "operator": "Div", + "attributes": [], + "cases": [ + { + "name": "T[2,4] T[2,4] (int32)", + "inputs": [ + { + "data": [1, 2, 1, 8, 2, 12, 9, 2], + "dims": [2, 4], + "type": "int32" + }, + { + "data": [2, 1, 1, 2, 2, 3, 1, 4], + "dims": [2, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [0, 2, 1, 4, 1, 4, 9, 0], + "dims": [2, 4], + "type": "int32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/pow_int32.jsonc b/js/web/test/data/ops/pow_int32.jsonc new file mode 100644 index 0000000000..5364561584 --- /dev/null +++ b/js/web/test/data/ops/pow_int32.jsonc @@ -0,0 +1,31 @@ +[ + { + "name": "Pow with no attributes", + "operator": "Pow", + "attributes": [], + "cases": [ + { + "name": "T[2,4] T[2,4] (int32)", + "inputs": [ + { + "data": [1, 2, 1, 8, 2, 12, 9, 2], + "dims": [2, 4], + "type": "int32" + }, + { + "data": [2, 1, 1, 2, 2, 3, 1, 4], + "dims": [2, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 2, 1, 64, 4, 1728, 9, 16], + "dims": [2, 4], + "type": "int32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 5ee62f9bd1..e0f9addc16 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1324,6 +1324,7 @@ "abs.jsonc", "acos.jsonc", "add.jsonc", + "add_int32.jsonc", //"and.jsonc", "asin.jsonc", "ceil.jsonc", @@ -1331,6 +1332,7 @@ "conv.jsonc", "cos.jsonc", "div.jsonc", + "div_int32.jsonc", //"depth-to-space.jsonc", //"equal.jsonc", "exp.jsonc", @@ -1343,6 +1345,7 @@ "log.jsonc", //"matmul.jsonc", // <--- some tests fail (when input is 3D/4D/5D) "mul.jsonc", + "mul_int32.jsonc", //"neg.jsonc", //"not.jsonc", //"or.jsonc", @@ -1354,6 +1357,7 @@ //"pad.jsonc", //"pad-big.jsonc", "pow.jsonc", + "pow_int32.jsonc", "pow-big-number.jsonc", "reshape.jsonc", "skip-layer-norm.jsonc", @@ -1362,6 +1366,7 @@ //"split.jsonc", "sqrt.jsonc", "sub.jsonc", + "sub_int32.jsonc", "tan.jsonc", "tile.jsonc", "transpose.jsonc", diff --git a/onnxruntime/core/providers/js/operators/binary.cc b/onnxruntime/core/providers/js/operators/binary.cc index ffad51f7e5..7e0223a98b 100644 --- a/onnxruntime/core/providers/js/operators/binary.cc +++ b/onnxruntime/core/providers/js/operators/binary.cc @@ -6,49 +6,51 @@ namespace onnxruntime { namespace js { -#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \ - ONNX_OPERATOR_KERNEL_EX( \ - OP_TYPE, \ - kOnnxDomain, \ - VERSION, \ - kJsExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ +#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION, \ + kJsExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType()}), \ KERNEL_CLASS); -#define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, TYPE, KERNEL_CLASS) \ - ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ - OP_TYPE, \ - kOnnxDomain, \ - VERSION_FROM, VERSION_TO, \ - kJsExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ +#define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION_FROM, VERSION_TO, \ + kJsExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType()}), \ KERNEL_CLASS); JSEP_KERNEL_IMPL(Add, Add) -REG_ELEMENTWISE_VERSIONED_KERNEL(Add, 7, 12, float, Add); -REG_ELEMENTWISE_VERSIONED_KERNEL(Add, 13, 13, float, Add); -REG_ELEMENTWISE_KERNEL(Add, 14, float, Add); +REG_ELEMENTWISE_VERSIONED_KERNEL(Add, 7, 12, Add); +REG_ELEMENTWISE_VERSIONED_KERNEL(Add, 13, 13, Add); +REG_ELEMENTWISE_KERNEL(Add, 14, Add); JSEP_KERNEL_IMPL(Sub, Sub) -REG_ELEMENTWISE_VERSIONED_KERNEL(Sub, 7, 12, float, Sub); -REG_ELEMENTWISE_VERSIONED_KERNEL(Sub, 13, 13, float, Sub); -REG_ELEMENTWISE_KERNEL(Sub, 14, float, Sub); +REG_ELEMENTWISE_VERSIONED_KERNEL(Sub, 7, 12, Sub); +REG_ELEMENTWISE_VERSIONED_KERNEL(Sub, 13, 13, Sub); +REG_ELEMENTWISE_KERNEL(Sub, 14, Sub); JSEP_KERNEL_IMPL(Mul, Mul) -REG_ELEMENTWISE_VERSIONED_KERNEL(Mul, 7, 12, float, Mul); -REG_ELEMENTWISE_VERSIONED_KERNEL(Mul, 13, 13, float, Mul); -REG_ELEMENTWISE_KERNEL(Mul, 14, float, Mul); +REG_ELEMENTWISE_VERSIONED_KERNEL(Mul, 7, 12, Mul); +REG_ELEMENTWISE_VERSIONED_KERNEL(Mul, 13, 13, Mul); +REG_ELEMENTWISE_KERNEL(Mul, 14, Mul); JSEP_KERNEL_IMPL(Div, Div) -REG_ELEMENTWISE_VERSIONED_KERNEL(Div, 7, 12, float, Div); -REG_ELEMENTWISE_VERSIONED_KERNEL(Div, 13, 13, float, Div); -REG_ELEMENTWISE_KERNEL(Div, 14, float, Div); +REG_ELEMENTWISE_VERSIONED_KERNEL(Div, 7, 12, Div); +REG_ELEMENTWISE_VERSIONED_KERNEL(Div, 13, 13, Div); +REG_ELEMENTWISE_KERNEL(Div, 14, Div); JSEP_KERNEL_IMPL(Pow, Pow) -REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 7, 11, float, Pow); -REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 12, 12, float, Pow); -REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 13, 14, float, Pow); -REG_ELEMENTWISE_KERNEL(Pow, 15, float, Pow); +REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 7, 11, Pow); +REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 12, 12, Pow); +REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 13, 14, Pow); +REG_ELEMENTWISE_KERNEL(Pow, 15, Pow); } // namespace js } // namespace onnxruntime