diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index c78ae0fc83..76575ef7b9 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -106,6 +106,12 @@ export declare namespace Env { * see comments on {@link GpuBufferType} for more details about why not use types defined in "@webgpu/types". */ readonly device: unknown; + /** + * Set or get whether validate input content. + * + * @defaultValue `false` + */ + validateInputContent?: boolean; } } diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index a87a894e3b..f8ac29e5f8 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -62,6 +62,7 @@ Do not modify directly.* | Not | ai.onnx(1+) | | | Pad | ai.onnx(2-10,11-12,13-17,18,19+) | | | Pow | ai.onnx(7-11,12,13-14,15+) | | +| Range | ai.onnx(11+) | | | Reciprocal | ai.onnx(6-12,13+) | | | ReduceL1 | ai.onnx(1-10,11-12,13-17,18+) | | | ReduceL2 | ai.onnx(1-10,11-12,13-17,18+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index e92e6696d9..cbe845b882 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -16,6 +16,7 @@ import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm'; import {matMul} from './ops/matmul'; import {pad, parsePadAttributes} from './ops/pad'; import * as pool from './ops/pool'; +import {range} from './ops/range'; import {parseReduceAttributes, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; import {parseResizeAttributes, resize} from './ops/resize'; import {parseSkipLayerNormAttributes, skipLayerNorm} from './ops/skip-layer-norm'; @@ -83,6 +84,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Not', [unaryOps.not]], ['Pad', [pad, parsePadAttributes]], ['Pow', [binaryOps.pow]], + ['Range', [range]], ['Reciprocal', [unaryOps.reciprocal]], ['ReduceMin', [reduceMin, parseReduceAttributes]], ['ReduceMean', [reduceMean, parseReduceAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/range.ts b/js/web/lib/wasm/jsep/webgpu/ops/range.ts new file mode 100644 index 0000000000..3ecb3308b1 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/range.ts @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {env} from 'onnxruntime-common'; + +import {DataType} from '../../../wasm-common'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; + +import {outputVariable, ShaderHelper} from './common'; + +const validateInputsContent = (start: number, limit: number, delta: number): void => { + const sameStartLimit = start === limit; + const increasingRangeNegativeStep = start < limit && delta < 0; + const decreasingRangePositiveStep = start > limit && delta > 0; + + if (sameStartLimit || increasingRangeNegativeStep || decreasingRangePositiveStep) { + throw new Error('Range these inputs\' contents are invalid.'); + } +}; + +const createRangeProgramInfo = + (metadata: ProgramMetadata, start: number, limit: number, delta: number, dataType: DataType): ProgramInfo => { + const numElements = Math.abs(Math.ceil((limit - start) / delta)); + const outputShape: number[] = [numElements]; + const outputSize = numElements; + + const output = outputVariable('output', dataType, outputShape); + const wgslType = output.type.storage; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${shaderHelper.declareVariables(output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + output[global_idx] = ${wgslType}(${start}) + ${wgslType}(global_idx) * ${wgslType}(${delta}); + }`; + return { + ...metadata, + getShaderSource, + outputs: [{dims: outputShape, dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) + }; + }; + +export const range = (context: ComputeContext): void => { + let start = 0; + let limit = 0; + let delta = 0; + if (context.inputs[0].dataType === DataType.int32) { + start = context.inputs[0].getInt32Array()[0]; + limit = context.inputs[1].getInt32Array()[0]; + delta = context.inputs[2].getInt32Array()[0]; + } else if (context.inputs[0].dataType === DataType.float) { + start = context.inputs[0].getFloat32Array()[0]; + limit = context.inputs[1].getFloat32Array()[0]; + delta = context.inputs[2].getFloat32Array()[0]; + } + if (env.webgpu.validateInputContent) { + validateInputsContent(start, limit, delta); + } + + const cacheHint = [start, limit, delta].map(x => x.toString()).join('_'); + const metadata: ProgramMetadata = {name: 'Range', inputTypes: [], cacheHint}; + context.compute( + {...metadata, get: () => createRangeProgramInfo(metadata, start, limit, delta, context.inputs[0].dataType)}, + {inputs: []}); +}; diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index 3f90351569..31bca8b943 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -333,7 +333,11 @@ function parseWebgpuFlags(args: minimist.ParsedArgs): Partial { if (profilingMode !== undefined && profilingMode !== 'off' && profilingMode !== 'default') { throw new Error('Flag "webgpu-profiling-mode" is invalid'); } - return {profilingMode}; + const validateInputContent = args['webgpu-validate-input-content']; + if (validateInputContent !== undefined && typeof validateInputContent !== 'boolean') { + throw new Error('Flag "webgpu-validate-input-content" is invalid'); + } + return {profilingMode, validateInputContent}; } function parseGlobalEnvFlags(args: minimist.ParsedArgs): NonNullable { diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 6e65645ef4..96ced2bdf9 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -885,10 +885,10 @@ // // "test_qlinearmatmul_3D", // // "test_quantizelinear_axis", // // "test_quantizelinear", - // "test_range_float_type_positive_delta_expanded", - // "test_range_float_type_positive_delta", - // "test_range_int32_type_negative_delta_expanded", - // "test_range_int32_type_negative_delta", + "test_range_float_type_positive_delta_expanded", + "test_range_float_type_positive_delta", + "test_range_int32_type_negative_delta_expanded", + "test_range_int32_type_negative_delta", "test_reciprocal_example", "test_reciprocal", "test_reduce_l1_default_axes_keepdims_example", diff --git a/js/web/test/test-main.ts b/js/web/test/test-main.ts index 49d0ac225b..d3592875bb 100644 --- a/js/web/test/test-main.ts +++ b/js/web/test/test-main.ts @@ -57,6 +57,9 @@ if (options.globalEnvFlags) { if (flags.webgpu?.profilingMode !== undefined) { ort.env.webgpu.profilingMode = flags.webgpu.profilingMode; } + if (flags.webgpu?.validateInputContent !== undefined) { + ort.env.webgpu.validateInputContent = flags.webgpu.validateInputContent; + } } // Set logging configuration diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 6ced8d4d4a..ae33fb752f 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -318,6 +318,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Til class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, Range); + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Pad); @@ -584,7 +587,11 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/range.cc b/onnxruntime/core/providers/js/operators/range.cc new file mode 100644 index 0000000000..e15861f7f2 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/range.cc @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +#include "range.h" + +namespace onnxruntime { +namespace js { +ONNX_OPERATOR_KERNEL_EX( + Range, + kOnnxDomain, + 11, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .InputMemoryType(OrtMemTypeCPU, 0) + .InputMemoryType(OrtMemTypeCPU, 1) + .InputMemoryType(OrtMemTypeCPU, 2), + Range); +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/range.h b/onnxruntime/core/providers/js/operators/range.h new file mode 100644 index 0000000000..8b32bfc3d9 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/range.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +JSEP_KERNEL_IMPL(Range, Range); + +} // namespace js +} // namespace onnxruntime