Adding webgl shape kernel (#7971)

This commit is contained in:
Du Li 2021-06-08 06:22:45 -07:00 committed by GitHub
parent 0f01de3b0b
commit b50e9d9d74
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 42 additions and 1 deletions

View file

@ -139,7 +139,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [SequenceErase](https://github.com/onnx/onnx/blob/master/docs/Operators.md#SequenceErase) | |
| [SequenceInsert](https://github.com/onnx/onnx/blob/master/docs/Operators.md#SequenceInsert) | |
| [SequenceLength](https://github.com/onnx/onnx/blob/master/docs/Operators.md#SequenceLength) | |
| [Shape](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Shape) | |
| [Shape](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Shape) | [1-12](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Shape-1), [13+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Shape-13) |
| [Shrink](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Shrink) | |
| [Sigmoid](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Sigmoid) | [6-12](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Sigmoid-6), [13+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Sigmoid-13) |
| [Sign](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Sign) | |

View file

@ -24,6 +24,7 @@ import {WebGLAveragePool, WebGLGlobalAveragePool, WebGLGlobalMaxPool, WebGLMaxPo
import * as reduceOps from './ops/reduce';
import {WebGLReshape} from './ops/reshape';
import {WebGLResizePacked} from './ops/resize-packed';
import {WebGLShape} from './ops/shape';
import {WebGLSlice, WebGLSliceV10} from './ops/slice';
import {WebGLSoftmax} from './ops/softmax';
import {WebGLSplit} from './ops/split';
@ -89,6 +90,7 @@ export const WEBGL_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
['Reshape', '', '5+', () => new WebGLReshape()],
['Resize', '', '10', () => new WebGLResizePacked(10)],
['Resize', '', '11+', () => new WebGLResizePacked(11)],
['Shape', '', '1+', () => new WebGLShape()],
['Sigmoid', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslSigmoid())],
['Sin', '', '7+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslSin())],
['Slice', '', '10+', () => new WebGLSliceV10()], // TODO: support 'steps' for Slice-10

View file

@ -0,0 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {Shape} from '../../../ops/shape';
import {Tensor} from '../../../tensor';
import {WebGLInferenceHandler} from '../inference-handler';
export class WebGLShape extends Shape {
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
return [new Tensor([inputs[0].dims.length], 'int32', undefined, undefined, new Int32Array(inputs[0].dims))];
}
}

View file

@ -0,0 +1,26 @@
[
{
"name": "Shape op test",
"operator": "Shape",
"attributes": [],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [1, 1, 1, 1],
"dims": [2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [2, 2],
"dims": [2],
"type": "int32"
}
]
}
]
}
]