From 5e8d94cec8fc3554791243851e3cadf005ffa04c Mon Sep 17 00:00:00 2001 From: xhcao Date: Sat, 26 Aug 2023 03:11:25 +0800 Subject: [PATCH] [js/webgpu] support Greater and Less operators (#17296) ### Description ### Motivation and Context --- js/web/docs/webgpu-operators.md | 2 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 2 + js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts | 107 ++++++++++++------ js/web/test/suite-test-list.jsonc | 28 ++--- .../providers/js/js_execution_provider.cc | 12 ++ .../core/providers/js/operators/binary.cc | 10 ++ 6 files changed, 110 insertions(+), 51 deletions(-) diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index e33854819c..a210071bc1 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -41,9 +41,11 @@ Do not modify directly.* | Gemm | ai.onnx(7-8,9-10,11-12,13+) | | | GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | +| Greater | ai.onnx(7-8,9-12,13+) | | | InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | | | LayerNormalization | ai.onnx(17+) | | | LeakyRelu | ai.onnx(6-15,16+) | | +| Less | ai.onnx(7-8,9-12,13+) | | | Log | ai.onnx(6-12,13+) | | | MatMul | ai.onnx(1-12,13+) | | | MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(11,12+) | need perf optimization; need implementing activation | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 6fe4eb5d49..11e54545c4 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -61,9 +61,11 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Gemm', [gemm, parseGemmAttributes]], ['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]], ['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]], + ['Greater', [binaryOps.greater]], ['InstanceNormalization', [instanceNorm, parseInstanceNormAttributes]], ['LayerNormalization', [layerNorm, parseLayerNormAttributes]], ['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]], + ['Less', [binaryOps.less]], ['Log', [unaryOps.log]], ['MatMul', [matMul]], // TODO: support new attributes for MaxPool-8 and MaxPool-10 diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index 02b978a381..a16aed7ae4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor'; import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; @@ -50,26 +51,27 @@ const createBinaryOpProgramShader = }; broadcastImpl = ` - fn calcOffsetA(outputIndices: ${output.type.indices}) -> u32 { - return ${calcOffsetImpl(dimsA)}; - } + fn calcOffsetA(outputIndices: ${output.type.indices}) -> u32 { + return ${calcOffsetImpl(dimsA)}; + } - fn calcOffsetB(outputIndices: ${output.type.indices}) -> u32 { - return ${calcOffsetImpl(dimsB)}; - } - `; + fn calcOffsetB(outputIndices: ${output.type.indices}) -> u32 { + return ${calcOffsetImpl(dimsB)}; + } + `; } let assignment: string; if (vectorize) { if (doBroadcast) { assignment = ` - let outputIndices = ${output.offsetToIndices('global_idx * 4u')}; - let offsetA = calcOffsetA(outputIndices); - let offsetB = calcOffsetB(outputIndices); - ${ + let outputIndices = ${output.offsetToIndices('global_idx * 4u')}; + let offsetA = calcOffsetA(outputIndices); + let offsetB = calcOffsetB(outputIndices); + ${ output.setByOffset( - 'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))}`; + 'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))} + `; } else { assignment = output.setByOffset( 'global_idx', expressionVector(a.getByOffset('global_idx'), b.getByOffset('global_idx'))); @@ -78,37 +80,49 @@ const createBinaryOpProgramShader = if (!doBroadcast) { throw new Error('no necessary to use scalar implementation for element-wise binary op implementation.'); } - const singleAssignment = (x: number) => { + + const singleAssignment = (resStr: string, x: number, typeCast = '') => { const expressionA = `aData[indexA${x}][componentA${x}]`; const expressionB = `bData[indexB${x}][componentB${x}]`; return ` - let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; - let offsetA${x} = calcOffsetA(outputIndices${x}); - let offsetB${x} = calcOffsetB(outputIndices${x}); - let indexA${x} = offsetA${x} / 4u; - let indexB${x} = offsetB${x} / 4u; - let componentA${x} = offsetA${x} % 4u; - let componentB${x} = offsetB${x} % 4u; - outputData[global_idx][${x}] = ${expressionScalar(expressionA, expressionB)};`; + let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; + let offsetA${x} = calcOffsetA(outputIndices${x}); + let offsetB${x} = calcOffsetB(outputIndices${x}); + let indexA${x} = offsetA${x} / 4u; + let indexB${x} = offsetB${x} / 4u; + let componentA${x} = offsetA${x} % 4u; + let componentB${x} = offsetB${x} % 4u; + ${resStr}[${x}] = ${typeCast}(${expressionScalar(expressionA, expressionB)}); + `; }; - - assignment = ` - ${singleAssignment(0)} - ${singleAssignment(1)} - ${singleAssignment(2)} - ${singleAssignment(3)}`; + if (typeOutput === DataType.bool) { + assignment = ` + var data = vec4(0); + ${singleAssignment('data', 0, 'u32')} + ${singleAssignment('data', 1, 'u32')} + ${singleAssignment('data', 2, 'u32')} + ${singleAssignment('data', 3, 'u32')} + outputData[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));`; + } else { + assignment = ` + ${singleAssignment('outputData[global_idx]', 0)} + ${singleAssignment('outputData[global_idx]', 1)} + ${singleAssignment('outputData[global_idx]', 2)} + ${singleAssignment('outputData[global_idx]', 3)} + `; + } } return ` - ${shaderHelper.declareVariables(a, b, output)} + ${shaderHelper.declareVariables(a, b, output)} - ${additionalImplementation ?? ''} - ${broadcastImpl} + ${additionalImplementation ?? ''} + ${broadcastImpl} - ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)} - ${assignment} - }`; + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)} + ${assignment} + }`; }; const createBinaryOpProgramInfo = @@ -132,7 +146,7 @@ const createBinaryOpProgramInfo = // check whether vectorize can be enabled let sharedDimension = 1; - for (let i = 0; i < outputShape.length; i++) { + for (let i = 1; i < outputShape.length; i++) { const dimA = a.dims[a.dims.length - i] ?? 1; const dimB = b.dims[b.dims.length - i] ?? 1; if (dimA === dimB) { @@ -162,12 +176,13 @@ const createBinaryOpProgramInfo = const createBinaryOpProgramInfoLoader = (inputs: readonly TensorView[], name: string, funcCall: BinaryFunctionCall, additionalImplementation?: string, - cacheKey?: string): ProgramInfoLoader => { + cacheKey?: string, outputDataType?: number): ProgramInfoLoader => { const metadata: ProgramMetadata = {name, inputTypes: [GpuDataType.default, GpuDataType.default], cacheHint: cacheKey}; return { ...metadata, - get: () => createBinaryOpProgramInfo(metadata, inputs[0], inputs[1], funcCall, additionalImplementation) + get: () => createBinaryOpProgramInfo( + metadata, inputs[0], inputs[1], funcCall, additionalImplementation, outputDataType) }; }; @@ -209,3 +224,21 @@ export const pow = (context: ComputeContext): void => { export const sub = (context: ComputeContext): void => { context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Sub', (a, b) => `${a}-${b}`)); }; + +export const greater = (context: ComputeContext): void => { + context.compute(createBinaryOpProgramInfoLoader( + context.inputs, 'Greater', ({ + scalar: (a, b) => `select(0, 1, ${a}>${b})`, + vector: (a, b) => `select(vec4(0), vec4(1), ${a}>${b})` + }), + undefined, undefined, DataType.bool)); +}; + +export const less = (context: ComputeContext): void => { + context.compute(createBinaryOpProgramInfoLoader( + context.inputs, 'Less', ({ + scalar: (a, b) => `select(0, 1, ${a}<${b})`, + vector: (a, b) => `select(vec4(0), vec4(1), ${a}<${b})` + }), + undefined, undefined, DataType.bool)); +}; diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 97ffd4c20d..a7964d9ca1 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -563,12 +563,12 @@ "test_globalaveragepool", "test_globalmaxpool_precomputed", "test_globalmaxpool", - // "test_greater_bcast", - // "test_greater_equal_bcast_expanded", - // "test_greater_equal_bcast", - // "test_greater_equal_expanded", - // "test_greater_equal", - // "test_greater", + "test_greater_bcast", + "test_greater_equal_bcast_expanded", + "test_greater_equal_bcast", + "test_greater_equal_expanded", + "test_greater_equal", + "test_greater", // // "test_gridsample_aligncorners_true", // // "test_gridsample_bicubic", // // "test_gridsample_bilinear", @@ -648,12 +648,12 @@ "test_leakyrelu_default", "test_leakyrelu_example", "test_leakyrelu", - // "test_less_bcast", - // "test_less_equal_bcast_expanded", - // "test_less_equal_bcast", - // "test_less_equal_expanded", - // "test_less_equal", - // "test_less", + "test_less_bcast", + "test_less_equal_bcast_expanded", + "test_less_equal_bcast", + "test_less_equal_expanded", + "test_less_equal", + "test_less", "test_log_example", "test_log", // // "test_logsoftmax_axis_0_expanded", @@ -1341,8 +1341,8 @@ "floor.jsonc", "gemm.jsonc", "global-average-pool.jsonc", - //"greater.jsonc", - //"less.jsonc", + "greater.jsonc", + "less.jsonc", "log.jsonc", //"matmul.jsonc", // <--- some tests fail (when input is 3D/4D/5D) "mul.jsonc", diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 3fe1980141..2146d9a0c5 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -197,6 +197,12 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, Pow); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 14, Pow); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 15, Pow); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Greater); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 12, Greater); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Greater); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Less); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 12, Less); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Less); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Shape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 14, Shape); @@ -381,6 +387,12 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO_VERSIONED(12, 12, Pow), KERNEL_CREATE_INFO_VERSIONED(13, 14, Pow), KERNEL_CREATE_INFO(15, Pow), + KERNEL_CREATE_INFO_VERSIONED(7, 8, Greater), + KERNEL_CREATE_INFO_VERSIONED(9, 12, Greater), + KERNEL_CREATE_INFO(13, Greater), + KERNEL_CREATE_INFO_VERSIONED(7, 8, Less), + KERNEL_CREATE_INFO_VERSIONED(9, 12, Less), + KERNEL_CREATE_INFO(13, Less), BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/binary.cc b/onnxruntime/core/providers/js/operators/binary.cc index 7e0223a98b..e26bb0e49f 100644 --- a/onnxruntime/core/providers/js/operators/binary.cc +++ b/onnxruntime/core/providers/js/operators/binary.cc @@ -52,5 +52,15 @@ REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 12, 12, Pow); REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 13, 14, Pow); REG_ELEMENTWISE_KERNEL(Pow, 15, Pow); +JSEP_KERNEL_IMPL(Greater, Greater) +REG_ELEMENTWISE_VERSIONED_KERNEL(Greater, 7, 8, Greater); +REG_ELEMENTWISE_VERSIONED_KERNEL(Greater, 9, 12, Greater); +REG_ELEMENTWISE_KERNEL(Greater, 13, Greater); + +JSEP_KERNEL_IMPL(Less, Less) +REG_ELEMENTWISE_VERSIONED_KERNEL(Less, 7, 8, Less); +REG_ELEMENTWISE_VERSIONED_KERNEL(Less, 9, 12, Less); +REG_ELEMENTWISE_KERNEL(Less, 13, Less); + } // namespace js } // namespace onnxruntime