[js/web] FP16 Gemm, Softmax & Transpose (#17494)

### Description
First three OPs to support fp16. Will add more once this gets merged
since others depend on changes in js_data_types
This commit is contained in:
Arthur Islamov 2023-09-12 08:09:37 +04:00 committed by GitHub
parent f20e475e67
commit 65249f42e4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 73 additions and 88 deletions

View file

@ -1,13 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {GemmUtil, ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
import {ShaderHelper} from './common';
import {ShaderHelper, tensorTypeToWsglStorageType} from './common';
const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs) {
@ -22,11 +21,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
throw new Error('Invalid input shape of C');
}
if ((inputs[0].dataType !== DataType.float) || (inputs[1].dataType !== DataType.float) ||
(inputs.length === 3 && inputs[2].dataType !== DataType.float)) {
throw new Error('Invalid input type.');
}
if ((inputs[0].dataType !== inputs[1].dataType) ||
(inputs.length === 3 && inputs[0].dataType !== inputs[2].dataType)) {
throw new Error('Input types are mismatched');
@ -81,7 +75,7 @@ const createGemmProgramInfo =
line = 'value += a[m * K + k] * b[k * N + n];';
}
const dataType = 'f32'; // TODO: support other data type
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const calculateAlpha = attributes.alpha === 1 ? '' : 'value *= alpha;';
const calculateC = inputs.length === 3 ? `value += beta * c[${offsetC(M, N, inputs[2].dims)}];` : '';
const inputStorageBuffersDeclarations = [

View file

@ -5,21 +5,17 @@
// performance limitations when the reduced axis is long. Need to add
// a optimized codepath for this.
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo} from '../types';
import {ShaderHelper} from './common';
import {ShaderHelper, tensorTypeToWsglStorageType} from './common';
const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length !== 1) {
throw new Error('Softmax op requires 1 input.');
}
if (inputs[0].dataType !== DataType.float) {
throw new Error('Softmax input needs to be float.');
}
};
export interface SoftmaxAttributes extends AttributeWithCacheKey {
@ -33,7 +29,7 @@ export const softmaxProgramMetadata = {
const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttributes): ProgramInfo => {
const dataType = 'f32';
const dataType = tensorTypeToWsglStorageType(input.dataType);
const shape = input.dims;
const outputSize = ShapeUtil.size(shape);
const WG = 64;
@ -48,6 +44,8 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
const cols = shape[axis];
const rows = outputSize / cols;
// 6.2.4 in wgsl spec
const threadMaxDecl = dataType === 'f32' ? 'var threadMax: f32 = -3.402823e+38f;' : 'var threadMax: f16 = -65504.0h;';
const getShaderSource = (_shaderHelper: ShaderHelper) => `
var<workgroup> rowMaxShared : ${dataType};
var<workgroup> rowSumShared : ${dataType};
@ -76,7 +74,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
let row_stride : i32 = ${cols};
// find the rows max
var threadMax = -3.402823e+38f; // 6.2.4 in wgsl spec
${threadMaxDecl}
for (var col = lindex; col < cols; col += wg) {
let value = getValue(row, col, row_stride);
threadMax = max(threadMax, value);
@ -100,7 +98,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
workgroupBarrier();
// find the rows sum
var threadSum = 0.0;
var threadSum: ${dataType} = 0.0;
for (var col = lindex; col < cols; col += wg) {
let subExp = exp(getValue(row, col, row_stride) - rowMaxShared);
threadSum += subExp;

View file

@ -1,7 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
@ -22,11 +21,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length !== 1) {
throw new Error('Transpose requires 1 input.');
}
if (inputs[0].dataType !== DataType.float && inputs[0].dataType !== DataType.int32 &&
inputs[0].dataType !== DataType.uint32) {
throw new Error('Transpose only support float, int32, and uint32 data types');
}
};
const getAdjustedPerm = (inputShape: readonly number[], perm: number[]): number[] =>

View file

@ -9,12 +9,24 @@ namespace js {
using SupportedTypes =
TypeList<
float,
MLFloat16,
int32_t,
uint32_t>;
using SupportedFloats =
TypeList<
float,
MLFloat16>;
const std::vector<MLDataType>& JsepSupportedDataTypes() {
static const std::vector<MLDataType> supportedDataTypes = BuildKernelDefConstraintsFromTypeList<SupportedTypes>();
return supportedDataTypes;
}
const std::vector<MLDataType>& JsepSupportedFloatTypes() {
static const std::vector<MLDataType> supportedDataTypes = BuildKernelDefConstraintsFromTypeList<SupportedFloats>();
return supportedDataTypes;
}
} // namespace js
} // namespace onnxruntime

View file

@ -6,5 +6,6 @@
namespace onnxruntime {
namespace js {
std::vector<MLDataType>& JsepSupportedDataTypes();
}
std::vector<MLDataType>& JsepSupportedFloatTypes();
} // namespace js
} // namespace onnxruntime

View file

@ -244,10 +244,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnn
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Conv);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ConvTranspose);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, ConvTranspose);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, float, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, float, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Gemm);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Gemm);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Gemm);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gemm);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gemm);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul);
@ -269,9 +269,9 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnn
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, Softmax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Softmax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Softmax);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Softmax);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Softmax);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Softmax);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 3, Concat);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 4, 10, Concat);
@ -498,10 +498,10 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Conv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ConvTranspose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, ConvTranspose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul)>,
@ -524,9 +524,9 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 3, Concat)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 4, 10, Concat)>,

View file

@ -11,6 +11,7 @@
#include "core/framework/op_kernel.h"
#include "core/providers/js/js_execution_provider.h"
#include "core/providers/js/js_data_types.h"
struct pthreadpool;

View file

@ -8,41 +8,34 @@
namespace onnxruntime {
namespace js {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Gemm, \
kOnnxDomain, \
13, \
T, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Gemm<T>); \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
Gemm, \
kOnnxDomain, \
11, 12, \
T, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Gemm<T>); \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
Gemm, \
kOnnxDomain, \
9, 10, \
T, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Gemm<T>); \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
Gemm, \
kOnnxDomain, \
7, 8, \
T, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Gemm<T>);
REGISTER_KERNEL_TYPED(float)
ONNX_OPERATOR_KERNEL_EX(
Gemm,
kOnnxDomain,
13,
kJsExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
Gemm);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Gemm,
kOnnxDomain,
11, 12,
kJsExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
Gemm);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Gemm,
kOnnxDomain,
9, 10,
kJsExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
Gemm);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Gemm,
kOnnxDomain,
7, 8,
kJsExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
Gemm);
} // namespace js
} // namespace onnxruntime

View file

@ -8,7 +8,6 @@
namespace onnxruntime {
namespace js {
template <typename T>
class Gemm : public JsKernel {
public:
Gemm(const OpKernelInfo& info) : JsKernel(info) {

View file

@ -7,27 +7,25 @@ namespace onnxruntime {
namespace js {
#define REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(SoftmaxOp, sinceVersion, endVersion) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
SoftmaxOp, \
kOnnxDomain, \
sinceVersion, endVersion, \
float, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), \
SoftmaxOp<float>);
.TypeConstraint("T", JsepSupportedFloatTypes()), \
SoftmaxOp);
#define REGISTER_SOFTMAX_ELEMENTWISE_KERNEL(SoftmaxOp, sinceVersion) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
ONNX_OPERATOR_KERNEL_EX( \
SoftmaxOp, \
kOnnxDomain, \
sinceVersion, \
float, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()) \
.TypeConstraint("T", JsepSupportedFloatTypes()) \
.InputMemoryType(OrtMemTypeCPU, 1), \
SoftmaxOp<float>);
SoftmaxOp);
REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(Softmax, 1, 10);
REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(Softmax, 11, 12);

View file

@ -8,7 +8,6 @@
namespace onnxruntime {
namespace js {
template <typename T>
class Softmax : public JsKernel {
public:
Softmax(const OpKernelInfo& info) : JsKernel(info) {

View file

@ -12,9 +12,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
1, 12,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<uint32_t>()}),
.TypeConstraint("T", JsepSupportedDataTypes()),
Transpose);
ONNX_OPERATOR_KERNEL_EX(
@ -23,9 +21,7 @@ ONNX_OPERATOR_KERNEL_EX(
13,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<uint32_t>()}),
.TypeConstraint("T", JsepSupportedDataTypes()),
Transpose);
} // namespace js