mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
[js/webgpu] Support int32 type for binary (#16901)
### Description Enable typed binary and support int32 type for binary. Co-authored-by: Xing Xu <xing.xu@intel.com> --------- Co-authored-by: Xing Xu <xing.xu@intel.com>
This commit is contained in:
parent
c0b6c6c94b
commit
dd3b2cefd6
5 changed files with 113 additions and 41 deletions
|
|
@ -146,8 +146,6 @@ const createBinaryOpProgramInfo =
|
|||
if (sharedDimension % 4 === 0) {
|
||||
vectorize = true;
|
||||
}
|
||||
|
||||
|
||||
} else {
|
||||
// element-wise
|
||||
vectorize = true;
|
||||
|
|
@ -188,19 +186,24 @@ export const mul = (context: ComputeContext): void => {
|
|||
};
|
||||
|
||||
export const pow = (context: ComputeContext): void => {
|
||||
const type = inputVariable('input', context.inputs[0].dataType, context.inputs[0].dims).type.value;
|
||||
const roundStr = type === 'i32' ? 'round' : '';
|
||||
context.compute(createBinaryOpProgramInfoLoader(
|
||||
context.inputs, 'Pow', ({scalar: (a, b) => `pow_f32(${a},${b})`, vector: (a, b) => `pow_vf32(${a},${b})`}), `
|
||||
fn pow_f32(a : f32, b : f32) -> f32 {
|
||||
if (b == 0.0) {
|
||||
return 1.0;
|
||||
} else if (a < 0.0 && b != floor(b)) {
|
||||
return pow(a, b); // NaN
|
||||
context.inputs, 'Pow',
|
||||
({scalar: (a, b) => `pow_custom(${a},${b})`, vector: (a, b) => `pow_vector_custom(${a},${b})`}),
|
||||
`
|
||||
fn pow_custom(a : ${type}, b : ${type}) -> ${type} {
|
||||
if (b == ${type}(0.0)) {
|
||||
return ${type}(1.0);
|
||||
} else if (a < ${type}(0.0) && f32(b) != floor(f32(b))) {
|
||||
return ${type}(pow(f32(a), f32(b))); // NaN
|
||||
}
|
||||
return select(sign(a), 1.0, round(abs(b) % 2.0) != 1.0) * pow(abs(a), b);
|
||||
return select(sign(a), ${type}(1.0), round(f32(abs(b) % ${type}(2.0))) != 1.0) * ${type}(${
|
||||
roundStr}(pow(f32(abs(a)), f32(b))));
|
||||
}
|
||||
fn pow_vf32(a : vec4<f32>, b : vec4<f32>) -> vec4<f32> {
|
||||
fn pow_vector_custom(a : vec4<${type}>, b : vec4<${type}>) -> vec4<${type}> {
|
||||
// TODO: implement vectorized pow
|
||||
return vec4<f32>(pow_f32(a.x, b.x), pow_f32(a.y, b.y), pow_f32(a.z, b.z), pow_f32(a.w, b.w));
|
||||
return vec4<${type}>(pow_custom(a.x, b.x), pow_custom(a.y, b.y), pow_custom(a.z, b.z), pow_custom(a.w, b.w));
|
||||
}
|
||||
`));
|
||||
};
|
||||
|
|
|
|||
31
js/web/test/data/ops/div_int32.jsonc
Normal file
31
js/web/test/data/ops/div_int32.jsonc
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
[
|
||||
{
|
||||
"name": "Div with no attributes",
|
||||
"operator": "Div",
|
||||
"attributes": [],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[2,4] T[2,4] (int32)",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 1, 8, 2, 12, 9, 2],
|
||||
"dims": [2, 4],
|
||||
"type": "int32"
|
||||
},
|
||||
{
|
||||
"data": [2, 1, 1, 2, 2, 3, 1, 4],
|
||||
"dims": [2, 4],
|
||||
"type": "int32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [0, 2, 1, 4, 1, 4, 9, 0],
|
||||
"dims": [2, 4],
|
||||
"type": "int32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
31
js/web/test/data/ops/pow_int32.jsonc
Normal file
31
js/web/test/data/ops/pow_int32.jsonc
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
[
|
||||
{
|
||||
"name": "Pow with no attributes",
|
||||
"operator": "Pow",
|
||||
"attributes": [],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[2,4] T[2,4] (int32)",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 1, 8, 2, 12, 9, 2],
|
||||
"dims": [2, 4],
|
||||
"type": "int32"
|
||||
},
|
||||
{
|
||||
"data": [2, 1, 1, 2, 2, 3, 1, 4],
|
||||
"dims": [2, 4],
|
||||
"type": "int32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1, 2, 1, 64, 4, 1728, 9, 16],
|
||||
"dims": [2, 4],
|
||||
"type": "int32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
@ -1324,6 +1324,7 @@
|
|||
"abs.jsonc",
|
||||
"acos.jsonc",
|
||||
"add.jsonc",
|
||||
"add_int32.jsonc",
|
||||
//"and.jsonc",
|
||||
"asin.jsonc",
|
||||
"ceil.jsonc",
|
||||
|
|
@ -1331,6 +1332,7 @@
|
|||
"conv.jsonc",
|
||||
"cos.jsonc",
|
||||
"div.jsonc",
|
||||
"div_int32.jsonc",
|
||||
//"depth-to-space.jsonc",
|
||||
//"equal.jsonc",
|
||||
"exp.jsonc",
|
||||
|
|
@ -1343,6 +1345,7 @@
|
|||
"log.jsonc",
|
||||
//"matmul.jsonc", // <--- some tests fail (when input is 3D/4D/5D)
|
||||
"mul.jsonc",
|
||||
"mul_int32.jsonc",
|
||||
//"neg.jsonc",
|
||||
//"not.jsonc",
|
||||
//"or.jsonc",
|
||||
|
|
@ -1354,6 +1357,7 @@
|
|||
//"pad.jsonc",
|
||||
//"pad-big.jsonc",
|
||||
"pow.jsonc",
|
||||
"pow_int32.jsonc",
|
||||
"pow-big-number.jsonc",
|
||||
"reshape.jsonc",
|
||||
"skip-layer-norm.jsonc",
|
||||
|
|
@ -1362,6 +1366,7 @@
|
|||
//"split.jsonc",
|
||||
"sqrt.jsonc",
|
||||
"sub.jsonc",
|
||||
"sub_int32.jsonc",
|
||||
"tan.jsonc",
|
||||
"tile.jsonc",
|
||||
"transpose.jsonc",
|
||||
|
|
|
|||
|
|
@ -6,49 +6,51 @@
|
|||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \
|
||||
ONNX_OPERATOR_KERNEL_EX( \
|
||||
OP_TYPE, \
|
||||
kOnnxDomain, \
|
||||
VERSION, \
|
||||
kJsExecutionProvider, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()), \
|
||||
#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \
|
||||
ONNX_OPERATOR_KERNEL_EX( \
|
||||
OP_TYPE, \
|
||||
kOnnxDomain, \
|
||||
VERSION, \
|
||||
kJsExecutionProvider, \
|
||||
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(), \
|
||||
DataTypeImpl::GetTensorType<int32_t>()}), \
|
||||
KERNEL_CLASS);
|
||||
|
||||
#define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, TYPE, KERNEL_CLASS) \
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
|
||||
OP_TYPE, \
|
||||
kOnnxDomain, \
|
||||
VERSION_FROM, VERSION_TO, \
|
||||
kJsExecutionProvider, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()), \
|
||||
#define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
|
||||
OP_TYPE, \
|
||||
kOnnxDomain, \
|
||||
VERSION_FROM, VERSION_TO, \
|
||||
kJsExecutionProvider, \
|
||||
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(), \
|
||||
DataTypeImpl::GetTensorType<int32_t>()}), \
|
||||
KERNEL_CLASS);
|
||||
|
||||
JSEP_KERNEL_IMPL(Add, Add)
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Add, 7, 12, float, Add);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Add, 13, 13, float, Add);
|
||||
REG_ELEMENTWISE_KERNEL(Add, 14, float, Add);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Add, 7, 12, Add);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Add, 13, 13, Add);
|
||||
REG_ELEMENTWISE_KERNEL(Add, 14, Add);
|
||||
|
||||
JSEP_KERNEL_IMPL(Sub, Sub)
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Sub, 7, 12, float, Sub);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Sub, 13, 13, float, Sub);
|
||||
REG_ELEMENTWISE_KERNEL(Sub, 14, float, Sub);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Sub, 7, 12, Sub);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Sub, 13, 13, Sub);
|
||||
REG_ELEMENTWISE_KERNEL(Sub, 14, Sub);
|
||||
|
||||
JSEP_KERNEL_IMPL(Mul, Mul)
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Mul, 7, 12, float, Mul);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Mul, 13, 13, float, Mul);
|
||||
REG_ELEMENTWISE_KERNEL(Mul, 14, float, Mul);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Mul, 7, 12, Mul);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Mul, 13, 13, Mul);
|
||||
REG_ELEMENTWISE_KERNEL(Mul, 14, Mul);
|
||||
|
||||
JSEP_KERNEL_IMPL(Div, Div)
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Div, 7, 12, float, Div);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Div, 13, 13, float, Div);
|
||||
REG_ELEMENTWISE_KERNEL(Div, 14, float, Div);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Div, 7, 12, Div);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Div, 13, 13, Div);
|
||||
REG_ELEMENTWISE_KERNEL(Div, 14, Div);
|
||||
|
||||
JSEP_KERNEL_IMPL(Pow, Pow)
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 7, 11, float, Pow);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 12, 12, float, Pow);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 13, 14, float, Pow);
|
||||
REG_ELEMENTWISE_KERNEL(Pow, 15, float, Pow);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 7, 11, Pow);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 12, 12, Pow);
|
||||
REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 13, 14, Pow);
|
||||
REG_ELEMENTWISE_KERNEL(Pow, 15, Pow);
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue