[js/webgpu] Support int32 type for binary (#16901)

### Description
Enable typed binary and support int32 type for binary.

Co-authored-by: Xing Xu <xing.xu@intel.com>

---------

Co-authored-by: Xing Xu <xing.xu@intel.com>
This commit is contained in:
xhcao 2023-08-19 03:19:01 +08:00 committed by GitHub
parent c0b6c6c94b
commit dd3b2cefd6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 113 additions and 41 deletions

View file

@ -146,8 +146,6 @@ const createBinaryOpProgramInfo =
if (sharedDimension % 4 === 0) {
vectorize = true;
}
} else {
// element-wise
vectorize = true;
@ -188,19 +186,24 @@ export const mul = (context: ComputeContext): void => {
};
export const pow = (context: ComputeContext): void => {
const type = inputVariable('input', context.inputs[0].dataType, context.inputs[0].dims).type.value;
const roundStr = type === 'i32' ? 'round' : '';
context.compute(createBinaryOpProgramInfoLoader(
context.inputs, 'Pow', ({scalar: (a, b) => `pow_f32(${a},${b})`, vector: (a, b) => `pow_vf32(${a},${b})`}), `
fn pow_f32(a : f32, b : f32) -> f32 {
if (b == 0.0) {
return 1.0;
} else if (a < 0.0 && b != floor(b)) {
return pow(a, b); // NaN
context.inputs, 'Pow',
({scalar: (a, b) => `pow_custom(${a},${b})`, vector: (a, b) => `pow_vector_custom(${a},${b})`}),
`
fn pow_custom(a : ${type}, b : ${type}) -> ${type} {
if (b == ${type}(0.0)) {
return ${type}(1.0);
} else if (a < ${type}(0.0) && f32(b) != floor(f32(b))) {
return ${type}(pow(f32(a), f32(b))); // NaN
}
return select(sign(a), 1.0, round(abs(b) % 2.0) != 1.0) * pow(abs(a), b);
return select(sign(a), ${type}(1.0), round(f32(abs(b) % ${type}(2.0))) != 1.0) * ${type}(${
roundStr}(pow(f32(abs(a)), f32(b))));
}
fn pow_vf32(a : vec4<f32>, b : vec4<f32>) -> vec4<f32> {
fn pow_vector_custom(a : vec4<${type}>, b : vec4<${type}>) -> vec4<${type}> {
// TODO: implement vectorized pow
return vec4<f32>(pow_f32(a.x, b.x), pow_f32(a.y, b.y), pow_f32(a.z, b.z), pow_f32(a.w, b.w));
return vec4<${type}>(pow_custom(a.x, b.x), pow_custom(a.y, b.y), pow_custom(a.z, b.z), pow_custom(a.w, b.w));
}
`));
};

View file

@ -0,0 +1,31 @@
[
{
"name": "Div with no attributes",
"operator": "Div",
"attributes": [],
"cases": [
{
"name": "T[2,4] T[2,4] (int32)",
"inputs": [
{
"data": [1, 2, 1, 8, 2, 12, 9, 2],
"dims": [2, 4],
"type": "int32"
},
{
"data": [2, 1, 1, 2, 2, 3, 1, 4],
"dims": [2, 4],
"type": "int32"
}
],
"outputs": [
{
"data": [0, 2, 1, 4, 1, 4, 9, 0],
"dims": [2, 4],
"type": "int32"
}
]
}
]
}
]

View file

@ -0,0 +1,31 @@
[
{
"name": "Pow with no attributes",
"operator": "Pow",
"attributes": [],
"cases": [
{
"name": "T[2,4] T[2,4] (int32)",
"inputs": [
{
"data": [1, 2, 1, 8, 2, 12, 9, 2],
"dims": [2, 4],
"type": "int32"
},
{
"data": [2, 1, 1, 2, 2, 3, 1, 4],
"dims": [2, 4],
"type": "int32"
}
],
"outputs": [
{
"data": [1, 2, 1, 64, 4, 1728, 9, 16],
"dims": [2, 4],
"type": "int32"
}
]
}
]
}
]

View file

@ -1324,6 +1324,7 @@
"abs.jsonc",
"acos.jsonc",
"add.jsonc",
"add_int32.jsonc",
//"and.jsonc",
"asin.jsonc",
"ceil.jsonc",
@ -1331,6 +1332,7 @@
"conv.jsonc",
"cos.jsonc",
"div.jsonc",
"div_int32.jsonc",
//"depth-to-space.jsonc",
//"equal.jsonc",
"exp.jsonc",
@ -1343,6 +1345,7 @@
"log.jsonc",
//"matmul.jsonc", // <--- some tests fail (when input is 3D/4D/5D)
"mul.jsonc",
"mul_int32.jsonc",
//"neg.jsonc",
//"not.jsonc",
//"or.jsonc",
@ -1354,6 +1357,7 @@
//"pad.jsonc",
//"pad-big.jsonc",
"pow.jsonc",
"pow_int32.jsonc",
"pow-big-number.jsonc",
"reshape.jsonc",
"skip-layer-norm.jsonc",
@ -1362,6 +1366,7 @@
//"split.jsonc",
"sqrt.jsonc",
"sub.jsonc",
"sub_int32.jsonc",
"tan.jsonc",
"tile.jsonc",
"transpose.jsonc",

View file

@ -6,49 +6,51 @@
namespace onnxruntime {
namespace js {
#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \
ONNX_OPERATOR_KERNEL_EX( \
OP_TYPE, \
kOnnxDomain, \
VERSION, \
kJsExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()), \
#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \
ONNX_OPERATOR_KERNEL_EX( \
OP_TYPE, \
kOnnxDomain, \
VERSION, \
kJsExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(), \
DataTypeImpl::GetTensorType<int32_t>()}), \
KERNEL_CLASS);
#define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, TYPE, KERNEL_CLASS) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
OP_TYPE, \
kOnnxDomain, \
VERSION_FROM, VERSION_TO, \
kJsExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()), \
#define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
OP_TYPE, \
kOnnxDomain, \
VERSION_FROM, VERSION_TO, \
kJsExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(), \
DataTypeImpl::GetTensorType<int32_t>()}), \
KERNEL_CLASS);
JSEP_KERNEL_IMPL(Add, Add)
REG_ELEMENTWISE_VERSIONED_KERNEL(Add, 7, 12, float, Add);
REG_ELEMENTWISE_VERSIONED_KERNEL(Add, 13, 13, float, Add);
REG_ELEMENTWISE_KERNEL(Add, 14, float, Add);
REG_ELEMENTWISE_VERSIONED_KERNEL(Add, 7, 12, Add);
REG_ELEMENTWISE_VERSIONED_KERNEL(Add, 13, 13, Add);
REG_ELEMENTWISE_KERNEL(Add, 14, Add);
JSEP_KERNEL_IMPL(Sub, Sub)
REG_ELEMENTWISE_VERSIONED_KERNEL(Sub, 7, 12, float, Sub);
REG_ELEMENTWISE_VERSIONED_KERNEL(Sub, 13, 13, float, Sub);
REG_ELEMENTWISE_KERNEL(Sub, 14, float, Sub);
REG_ELEMENTWISE_VERSIONED_KERNEL(Sub, 7, 12, Sub);
REG_ELEMENTWISE_VERSIONED_KERNEL(Sub, 13, 13, Sub);
REG_ELEMENTWISE_KERNEL(Sub, 14, Sub);
JSEP_KERNEL_IMPL(Mul, Mul)
REG_ELEMENTWISE_VERSIONED_KERNEL(Mul, 7, 12, float, Mul);
REG_ELEMENTWISE_VERSIONED_KERNEL(Mul, 13, 13, float, Mul);
REG_ELEMENTWISE_KERNEL(Mul, 14, float, Mul);
REG_ELEMENTWISE_VERSIONED_KERNEL(Mul, 7, 12, Mul);
REG_ELEMENTWISE_VERSIONED_KERNEL(Mul, 13, 13, Mul);
REG_ELEMENTWISE_KERNEL(Mul, 14, Mul);
JSEP_KERNEL_IMPL(Div, Div)
REG_ELEMENTWISE_VERSIONED_KERNEL(Div, 7, 12, float, Div);
REG_ELEMENTWISE_VERSIONED_KERNEL(Div, 13, 13, float, Div);
REG_ELEMENTWISE_KERNEL(Div, 14, float, Div);
REG_ELEMENTWISE_VERSIONED_KERNEL(Div, 7, 12, Div);
REG_ELEMENTWISE_VERSIONED_KERNEL(Div, 13, 13, Div);
REG_ELEMENTWISE_KERNEL(Div, 14, Div);
JSEP_KERNEL_IMPL(Pow, Pow)
REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 7, 11, float, Pow);
REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 12, 12, float, Pow);
REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 13, 14, float, Pow);
REG_ELEMENTWISE_KERNEL(Pow, 15, float, Pow);
REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 7, 11, Pow);
REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 12, 12, Pow);
REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 13, 14, Pow);
REG_ELEMENTWISE_KERNEL(Pow, 15, Pow);
} // namespace js
} // namespace onnxruntime