mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Adding webgl shape kernel (#7971)
This commit is contained in:
parent
0f01de3b0b
commit
b50e9d9d74
4 changed files with 42 additions and 1 deletions
|
|
@ -139,7 +139,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
|
|||
| [SequenceErase](https://github.com/onnx/onnx/blob/master/docs/Operators.md#SequenceErase) | |
|
||||
| [SequenceInsert](https://github.com/onnx/onnx/blob/master/docs/Operators.md#SequenceInsert) | |
|
||||
| [SequenceLength](https://github.com/onnx/onnx/blob/master/docs/Operators.md#SequenceLength) | |
|
||||
| [Shape](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Shape) | |
|
||||
| [Shape](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Shape) | [1-12](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Shape-1), [13+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Shape-13) |
|
||||
| [Shrink](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Shrink) | |
|
||||
| [Sigmoid](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Sigmoid) | [6-12](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Sigmoid-6), [13+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Sigmoid-13) |
|
||||
| [Sign](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Sign) | |
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import {WebGLAveragePool, WebGLGlobalAveragePool, WebGLGlobalMaxPool, WebGLMaxPo
|
|||
import * as reduceOps from './ops/reduce';
|
||||
import {WebGLReshape} from './ops/reshape';
|
||||
import {WebGLResizePacked} from './ops/resize-packed';
|
||||
import {WebGLShape} from './ops/shape';
|
||||
import {WebGLSlice, WebGLSliceV10} from './ops/slice';
|
||||
import {WebGLSoftmax} from './ops/softmax';
|
||||
import {WebGLSplit} from './ops/split';
|
||||
|
|
@ -89,6 +90,7 @@ export const WEBGL_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
|
|||
['Reshape', '', '5+', () => new WebGLReshape()],
|
||||
['Resize', '', '10', () => new WebGLResizePacked(10)],
|
||||
['Resize', '', '11+', () => new WebGLResizePacked(11)],
|
||||
['Shape', '', '1+', () => new WebGLShape()],
|
||||
['Sigmoid', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslSigmoid())],
|
||||
['Sin', '', '7+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslSin())],
|
||||
['Slice', '', '10+', () => new WebGLSliceV10()], // TODO: support 'steps' for Slice-10
|
||||
|
|
|
|||
13
js/web/lib/onnxjs/backends/webgl/ops/shape.ts
Normal file
13
js/web/lib/onnxjs/backends/webgl/ops/shape.ts
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Shape} from '../../../ops/shape';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
|
||||
|
||||
export class WebGLShape extends Shape {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return [new Tensor([inputs[0].dims.length], 'int32', undefined, undefined, new Int32Array(inputs[0].dims))];
|
||||
}
|
||||
}
|
||||
26
js/web/test/data/ops/shape.jsonc
Normal file
26
js/web/test/data/ops/shape.jsonc
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
[
|
||||
{
|
||||
"name": "Shape op test",
|
||||
"operator": "Shape",
|
||||
"attributes": [],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 1, 1, 1],
|
||||
"dims": [2, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [2, 2],
|
||||
"dims": [2],
|
||||
"type": "int32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
Loading…
Reference in a new issue