[Web/JS] Added Expand operator support. (#16577)

### Description
Added Expand operator support.



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
satyajandhyala 2023-07-11 09:38:16 -07:00 committed by GitHub
parent 1b07bbceaa
commit d41bbac7b9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 158 additions and 2 deletions

View file

@ -30,6 +30,7 @@ Do not modify directly.*
| Elu | ai.onnx(6+) | |
| Erf | ai.onnx(9-12,13+) | |
| Exp | ai.onnx(6-12,13+) | |
| Expand | ai.onnx(8-12,13+) | |
| Floor | ai.onnx(6-12,13+) | |
| Gemm | ai.onnx(7-8,9-10,11+) | |
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |

View file

@ -5,6 +5,7 @@ import * as binaryOps from './ops/binary-op';
import {concat, parseConcatAttributes} from './ops/concat';
import {conv, parseConvAttributes} from './ops/conv';
import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose';
import {expand} from './ops/expand';
import {gemm, parseGemmAttributes} from './ops/gemm';
import {matMul} from './ops/matmul';
import * as pool from './ops/pool';
@ -41,6 +42,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Elu', [unaryOps.elu, unaryOps.parseAlphaAttributes]],
['Erf', [unaryOps.erf]],
['Exp', [unaryOps.exp]],
['Expand', [expand]],
['Floor', [unaryOps.floor]],
['Gemm', [gemm, parseGemmAttributes]],
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],

View file

@ -0,0 +1,105 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';
import {createIndicesHelper, ShaderHelper} from './common';
export const expandProgramMetadata = {
name: 'Expand',
inputTypes: [GpuDataType.default]
};
const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length !== 2) {
throw new Error('Expand requires 2 input.');
}
const inputShape = inputs[0].dims;
const shape: number[] = [];
if (inputs[1].dims[0] > 0) {
inputs[1].getBigInt64Array().forEach(v => shape.push(Number(v)));
}
let shapeIndex = shape.length < inputShape.length ? 0 : shape.length - inputShape.length;
let inputShapeIndex = inputShape.length < shape.length ? 0 : inputShape.length - shape.length;
for (; shapeIndex < shape.length && inputShapeIndex < inputShape.length; ++shapeIndex, ++inputShapeIndex) {
if (shape[shapeIndex] !== inputShape[inputShapeIndex] && shape[shapeIndex] !== 1 &&
inputShape[inputShapeIndex] !== 1) {
throw new Error('Expand requires shape to be broadcastable to input');
}
}
};
const getAdjustedShape = (shape1: readonly number[], shape2: readonly number[]): number[] => {
const diff = shape1.length - shape2.length;
const shape: number[] = [];
for (let i = 0; i < diff; ++i) {
shape.push(shape1[i]);
}
for (let i = 0; i < shape2.length; ++i) {
shape.push(shape2[i] === 1 ? shape1[i + diff] : shape2[i]);
}
return shape;
};
const calculateOutputShape = (inputShape: readonly number[], shape: readonly number[]): number[] =>
(inputShape.length > shape.length) ? getAdjustedShape(inputShape, shape) : getAdjustedShape(shape, inputShape);
const createExpandProgramInfo = (metadata: ProgramMetadata, inputs: readonly TensorView[]): ProgramInfo => {
const inputShape = inputs[0].dims;
const shape: number[] = [];
if (inputs[1].dims[0] > 0) {
inputs[1].getBigInt64Array().forEach(v => shape.push(Number(v)));
}
const outputShape: number[] = calculateOutputShape(inputShape, shape);
const outputSize = ShapeUtil.size(outputShape);
const inputIndicesHelper = createIndicesHelper('input', inputShape);
const outputIndicesHelper = createIndicesHelper('output', outputShape);
const dataType = 'f32';
const calculateInputIndexImpl = (): string => `
fn calculateInputIndex(outputIndices: array<u32, ${outputShape.length}>) -> array<u32,${inputShape.length}> {
${inputIndicesHelper.indicesVariableDeclaration('inputIndices')}
for (var i = 0; i < ${inputShape.length}; i++) {
if (inputShape[i] == 1) {
inputIndices[i] = 0;
} else {
inputIndices[i] = outputIndices[i + ${outputShape.length - inputShape.length}];
}
}
return inputIndices;
}`;
const getShaderSource = (shaderHelper: ShaderHelper) => `
const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});
${calculateInputIndexImpl()};
@group(0) @binding(0) var<storage, read> input : array<${dataType}>;
@group(0) @binding(1) var<storage, read_write> output : array<${dataType}>;
${outputIndicesHelper.o2iImpl}
${inputIndicesHelper.i2oImpl}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
${inputIndicesHelper.indicesVariableDeclaration('inputIndices')}
${outputIndicesHelper.indicesVariableDeclaration('outputIndices')}
${outputIndicesHelper.o2iCall('global_idx', 'outputIndices')}
inputIndices = calculateInputIndex(outputIndices);
output[global_idx] = input[${inputIndicesHelper.i2oExpression('inputIndices')}];
}`;
return {
...metadata,
getShaderSource,
outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}],
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)})
};
};
export const expand = (context: ComputeContext): void => {
validateInputs(context.inputs);
context.compute(
{...expandProgramMetadata, get: () => createExpandProgramInfo(expandProgramMetadata, context.inputs)},
{inputs: [0]});
};

View file

@ -519,8 +519,8 @@
"test_erf",
"test_exp_example",
"test_exp",
// "test_expand_dim_changed",
// "test_expand_dim_unchanged",
"test_expand_dim_changed",
"test_expand_dim_unchanged",
// "test_eyelike_populate_off_main_diagonal",
// "test_eyelike_with_dtype",
// "test_eyelike_without_dtype",

View file

@ -243,6 +243,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Split);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Split);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, Split);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 12, Expand);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Expand);
std::unique_ptr<KernelRegistry> RegisterKernels() {
auto kernel_registry = std::make_unique<onnxruntime::KernelRegistry>();
@ -424,6 +426,8 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 12, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Expand)>,
};
for (auto& function_table_entry : function_table) {

View file

@ -0,0 +1,30 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/js/js_kernel.h"
#include "expand.h"
namespace onnxruntime {
namespace js {
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Expand,
kOnnxDomain,
8,
12,
kJsExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.InputMemoryType(OrtMemTypeCPU, 1),
Expand);
ONNX_OPERATOR_KERNEL_EX(
Expand,
kOnnxDomain,
13,
kJsExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.InputMemoryType(OrtMemTypeCPU, 1),
Expand);
} // namespace js
} // namespace onnxruntime

View file

@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/js/js_kernel.h"
namespace onnxruntime {
namespace js {
JSEP_KERNEL_IMPL(Expand, Expand);
} // namespace js
} // namespace onnxruntime