mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
[js/web] FP16 Gemm, Softmax & Transpose (#17494)
### Description First three OPs to support fp16. Will add more once this gets merged since others depend on changes in js_data_types
This commit is contained in:
parent
f20e475e67
commit
65249f42e4
12 changed files with 73 additions and 88 deletions
|
|
@ -1,13 +1,12 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor';
|
||||
import {GemmUtil, ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
|
||||
|
||||
import {ShaderHelper} from './common';
|
||||
import {ShaderHelper, tensorTypeToWsglStorageType} from './common';
|
||||
|
||||
const validateInputs = (inputs: readonly TensorView[]): void => {
|
||||
if (!inputs) {
|
||||
|
|
@ -22,11 +21,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
|
|||
throw new Error('Invalid input shape of C');
|
||||
}
|
||||
|
||||
if ((inputs[0].dataType !== DataType.float) || (inputs[1].dataType !== DataType.float) ||
|
||||
(inputs.length === 3 && inputs[2].dataType !== DataType.float)) {
|
||||
throw new Error('Invalid input type.');
|
||||
}
|
||||
|
||||
if ((inputs[0].dataType !== inputs[1].dataType) ||
|
||||
(inputs.length === 3 && inputs[0].dataType !== inputs[2].dataType)) {
|
||||
throw new Error('Input types are mismatched');
|
||||
|
|
@ -81,7 +75,7 @@ const createGemmProgramInfo =
|
|||
line = 'value += a[m * K + k] * b[k * N + n];';
|
||||
}
|
||||
|
||||
const dataType = 'f32'; // TODO: support other data type
|
||||
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
|
||||
const calculateAlpha = attributes.alpha === 1 ? '' : 'value *= alpha;';
|
||||
const calculateC = inputs.length === 3 ? `value += beta * c[${offsetC(M, N, inputs[2].dims)}];` : '';
|
||||
const inputStorageBuffersDeclarations = [
|
||||
|
|
|
|||
|
|
@ -5,21 +5,17 @@
|
|||
// performance limitations when the reduced axis is long. Need to add
|
||||
// a optimized codepath for this.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, GpuDataType, ProgramInfo} from '../types';
|
||||
|
||||
import {ShaderHelper} from './common';
|
||||
import {ShaderHelper, tensorTypeToWsglStorageType} from './common';
|
||||
|
||||
const validateInputs = (inputs: readonly TensorView[]): void => {
|
||||
if (!inputs || inputs.length !== 1) {
|
||||
throw new Error('Softmax op requires 1 input.');
|
||||
}
|
||||
if (inputs[0].dataType !== DataType.float) {
|
||||
throw new Error('Softmax input needs to be float.');
|
||||
}
|
||||
};
|
||||
|
||||
export interface SoftmaxAttributes extends AttributeWithCacheKey {
|
||||
|
|
@ -33,7 +29,7 @@ export const softmaxProgramMetadata = {
|
|||
|
||||
|
||||
const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttributes): ProgramInfo => {
|
||||
const dataType = 'f32';
|
||||
const dataType = tensorTypeToWsglStorageType(input.dataType);
|
||||
const shape = input.dims;
|
||||
const outputSize = ShapeUtil.size(shape);
|
||||
const WG = 64;
|
||||
|
|
@ -48,6 +44,8 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
|
|||
const cols = shape[axis];
|
||||
const rows = outputSize / cols;
|
||||
|
||||
// 6.2.4 in wgsl spec
|
||||
const threadMaxDecl = dataType === 'f32' ? 'var threadMax: f32 = -3.402823e+38f;' : 'var threadMax: f16 = -65504.0h;';
|
||||
const getShaderSource = (_shaderHelper: ShaderHelper) => `
|
||||
var<workgroup> rowMaxShared : ${dataType};
|
||||
var<workgroup> rowSumShared : ${dataType};
|
||||
|
|
@ -76,7 +74,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
|
|||
let row_stride : i32 = ${cols};
|
||||
|
||||
// find the rows max
|
||||
var threadMax = -3.402823e+38f; // 6.2.4 in wgsl spec
|
||||
${threadMaxDecl}
|
||||
for (var col = lindex; col < cols; col += wg) {
|
||||
let value = getValue(row, col, row_stride);
|
||||
threadMax = max(threadMax, value);
|
||||
|
|
@ -100,7 +98,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
|
|||
workgroupBarrier();
|
||||
|
||||
// find the rows sum
|
||||
var threadSum = 0.0;
|
||||
var threadSum: ${dataType} = 0.0;
|
||||
for (var col = lindex; col < cols; col += wg) {
|
||||
let subExp = exp(getValue(row, col, row_stride) - rowMaxShared);
|
||||
threadSum += subExp;
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
|
|
@ -22,11 +21,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
|
|||
if (!inputs || inputs.length !== 1) {
|
||||
throw new Error('Transpose requires 1 input.');
|
||||
}
|
||||
|
||||
if (inputs[0].dataType !== DataType.float && inputs[0].dataType !== DataType.int32 &&
|
||||
inputs[0].dataType !== DataType.uint32) {
|
||||
throw new Error('Transpose only support float, int32, and uint32 data types');
|
||||
}
|
||||
};
|
||||
|
||||
const getAdjustedPerm = (inputShape: readonly number[], perm: number[]): number[] =>
|
||||
|
|
|
|||
|
|
@ -9,12 +9,24 @@ namespace js {
|
|||
using SupportedTypes =
|
||||
TypeList<
|
||||
float,
|
||||
MLFloat16,
|
||||
int32_t,
|
||||
uint32_t>;
|
||||
|
||||
using SupportedFloats =
|
||||
TypeList<
|
||||
float,
|
||||
MLFloat16>;
|
||||
|
||||
const std::vector<MLDataType>& JsepSupportedDataTypes() {
|
||||
static const std::vector<MLDataType> supportedDataTypes = BuildKernelDefConstraintsFromTypeList<SupportedTypes>();
|
||||
return supportedDataTypes;
|
||||
}
|
||||
|
||||
const std::vector<MLDataType>& JsepSupportedFloatTypes() {
|
||||
static const std::vector<MLDataType> supportedDataTypes = BuildKernelDefConstraintsFromTypeList<SupportedFloats>();
|
||||
return supportedDataTypes;
|
||||
}
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -6,5 +6,6 @@
|
|||
namespace onnxruntime {
|
||||
namespace js {
|
||||
std::vector<MLDataType>& JsepSupportedDataTypes();
|
||||
}
|
||||
std::vector<MLDataType>& JsepSupportedFloatTypes();
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -244,10 +244,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnn
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Conv);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ConvTranspose);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, ConvTranspose);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, float, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, float, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Gemm);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gemm);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gemm);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul);
|
||||
|
||||
|
|
@ -269,9 +269,9 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnn
|
|||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMin);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, Softmax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Softmax);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Softmax);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Softmax);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Softmax);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Softmax);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 3, Concat);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 4, 10, Concat);
|
||||
|
|
@ -498,10 +498,10 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Conv)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ConvTranspose)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, ConvTranspose)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, float, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, float, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul)>,
|
||||
|
||||
|
|
@ -524,9 +524,9 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMin)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, Softmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Softmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Softmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Softmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Softmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Softmax)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 3, Concat)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 4, 10, Concat)>,
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/providers/js/js_execution_provider.h"
|
||||
#include "core/providers/js/js_data_types.h"
|
||||
|
||||
struct pthreadpool;
|
||||
|
||||
|
|
|
|||
|
|
@ -8,41 +8,34 @@
|
|||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
Gemm, \
|
||||
kOnnxDomain, \
|
||||
13, \
|
||||
T, \
|
||||
kJsExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
Gemm<T>); \
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
|
||||
Gemm, \
|
||||
kOnnxDomain, \
|
||||
11, 12, \
|
||||
T, \
|
||||
kJsExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
Gemm<T>); \
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
|
||||
Gemm, \
|
||||
kOnnxDomain, \
|
||||
9, 10, \
|
||||
T, \
|
||||
kJsExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
Gemm<T>); \
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
|
||||
Gemm, \
|
||||
kOnnxDomain, \
|
||||
7, 8, \
|
||||
T, \
|
||||
kJsExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
Gemm<T>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(float)
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Gemm,
|
||||
kOnnxDomain,
|
||||
13,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
|
||||
Gemm);
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Gemm,
|
||||
kOnnxDomain,
|
||||
11, 12,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
|
||||
Gemm);
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Gemm,
|
||||
kOnnxDomain,
|
||||
9, 10,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
|
||||
Gemm);
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Gemm,
|
||||
kOnnxDomain,
|
||||
7, 8,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
|
||||
Gemm);
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
template <typename T>
|
||||
class Gemm : public JsKernel {
|
||||
public:
|
||||
Gemm(const OpKernelInfo& info) : JsKernel(info) {
|
||||
|
|
|
|||
|
|
@ -7,27 +7,25 @@ namespace onnxruntime {
|
|||
namespace js {
|
||||
|
||||
#define REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(SoftmaxOp, sinceVersion, endVersion) \
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
|
||||
SoftmaxOp, \
|
||||
kOnnxDomain, \
|
||||
sinceVersion, endVersion, \
|
||||
float, \
|
||||
kJsExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), \
|
||||
SoftmaxOp<float>);
|
||||
.TypeConstraint("T", JsepSupportedFloatTypes()), \
|
||||
SoftmaxOp);
|
||||
|
||||
#define REGISTER_SOFTMAX_ELEMENTWISE_KERNEL(SoftmaxOp, sinceVersion) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
ONNX_OPERATOR_KERNEL_EX( \
|
||||
SoftmaxOp, \
|
||||
kOnnxDomain, \
|
||||
sinceVersion, \
|
||||
float, \
|
||||
kJsExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()) \
|
||||
.TypeConstraint("T", JsepSupportedFloatTypes()) \
|
||||
.InputMemoryType(OrtMemTypeCPU, 1), \
|
||||
SoftmaxOp<float>);
|
||||
SoftmaxOp);
|
||||
|
||||
REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(Softmax, 1, 10);
|
||||
REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(Softmax, 11, 12);
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
template <typename T>
|
||||
class Softmax : public JsKernel {
|
||||
public:
|
||||
Softmax(const OpKernelInfo& info) : JsKernel(info) {
|
||||
|
|
|
|||
|
|
@ -12,9 +12,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
1, 12,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<uint32_t>()}),
|
||||
.TypeConstraint("T", JsepSupportedDataTypes()),
|
||||
Transpose);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
|
|
@ -23,9 +21,7 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
13,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<uint32_t>()}),
|
||||
.TypeConstraint("T", JsepSupportedDataTypes()),
|
||||
Transpose);
|
||||
|
||||
} // namespace js
|
||||
|
|
|
|||
Loading…
Reference in a new issue