mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
[js/web] FP16 binary and unary ops (#17515)
### Description Binary and unary ops with fp16 support
This commit is contained in:
parent
dea425e7c1
commit
0f406ca1d3
3 changed files with 81 additions and 69 deletions
|
|
@ -7,7 +7,7 @@ import {MAX_CLIP, MIN_CLIP, ShapeUtil} from '../../util';
|
|||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
|
||||
|
||||
import {inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common';
|
||||
|
||||
type BuiltinFunctionName = string;
|
||||
type ElementwiseCustomExpression = (expression: string) => string;
|
||||
|
|
@ -101,6 +101,9 @@ export const parseCastAttributes = (attributes: Record<string, unknown>): CastAt
|
|||
export const cast = (context: ComputeContext, attributes: CastAttributes): void => {
|
||||
let func: ElementwiseFunctionCall;
|
||||
switch (attributes.to) {
|
||||
case DataType.float16:
|
||||
func = 'vec4<f16>';
|
||||
break;
|
||||
case DataType.float:
|
||||
func = 'vec4<f32>';
|
||||
break;
|
||||
|
|
@ -126,11 +129,12 @@ export interface ClipAttributes extends AttributeWithCacheKey {
|
|||
}
|
||||
|
||||
export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => {
|
||||
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
|
||||
context.compute(
|
||||
createElementwiseProgramInfoLoader(
|
||||
context.inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, `
|
||||
const clip_min_: vec4<f32> = vec4(f32(${attributes.min}));
|
||||
const clip_max_: vec4<f32> = vec4(f32(${attributes.max}));
|
||||
const clip_min_: vec4<${dataType}> = vec4(${dataType}(${attributes.min}));
|
||||
const clip_max_: vec4<${dataType}> = vec4(${dataType}(${attributes.max}));
|
||||
`,
|
||||
attributes.cacheKey),
|
||||
{inputs: [0]});
|
||||
|
|
@ -180,13 +184,13 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void
|
|||
attributes.cacheKey));
|
||||
};
|
||||
|
||||
export const erfImpl = (dataType: string) => `
|
||||
const r0: f32 = 0.3275911;
|
||||
const r1: f32 = 0.254829592;
|
||||
const r2: f32 = -0.284496736;
|
||||
const r3: f32 = 1.421413741;
|
||||
const r4: f32 = -1.453152027;
|
||||
const r5: f32 = 1.061405429;
|
||||
export const erfImpl = (dataType: string, varType = 'f32') => `
|
||||
const r0: ${varType} = 0.3275911;
|
||||
const r1: ${varType} = 0.254829592;
|
||||
const r2: ${varType} = -0.284496736;
|
||||
const r3: ${varType} = 1.421413741;
|
||||
const r4: ${varType} = -1.453152027;
|
||||
const r5: ${varType} = 1.061405429;
|
||||
|
||||
fn erf_vf32(v: ${dataType}) -> ${dataType} {
|
||||
let absv = abs(v);
|
||||
|
|
@ -195,8 +199,9 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} {
|
|||
}`;
|
||||
|
||||
export const erf = (context: ComputeContext): void => {
|
||||
context.compute(
|
||||
createElementwiseProgramInfoLoader(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl('vec4<f32>')));
|
||||
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
|
||||
context.compute(createElementwiseProgramInfoLoader(
|
||||
context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType)));
|
||||
};
|
||||
|
||||
export const exp = (context: ComputeContext): void => {
|
||||
|
|
@ -208,9 +213,10 @@ export const floor = (context: ComputeContext): void => {
|
|||
};
|
||||
|
||||
export const gelu = (context: ComputeContext): void => {
|
||||
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
|
||||
context.compute(createElementwiseProgramInfoLoader(
|
||||
context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`,
|
||||
erfImpl('vec4<f32>')));
|
||||
erfImpl(`vec4<${dataType}>`, dataType)));
|
||||
};
|
||||
|
||||
export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
|
||||
|
|
|
|||
|
|
@ -6,14 +6,13 @@
|
|||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \
|
||||
ONNX_OPERATOR_KERNEL_EX( \
|
||||
OP_TYPE, \
|
||||
kOnnxDomain, \
|
||||
VERSION, \
|
||||
kJsExecutionProvider, \
|
||||
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(), \
|
||||
DataTypeImpl::GetTensorType<int32_t>()}), \
|
||||
#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \
|
||||
ONNX_OPERATOR_KERNEL_EX( \
|
||||
OP_TYPE, \
|
||||
kOnnxDomain, \
|
||||
VERSION, \
|
||||
kJsExecutionProvider, \
|
||||
KernelDefBuilder().TypeConstraint("T", JsepSupportedDataTypes()), \
|
||||
KERNEL_CLASS);
|
||||
|
||||
#define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \
|
||||
|
|
@ -22,8 +21,7 @@ namespace js {
|
|||
kOnnxDomain, \
|
||||
VERSION_FROM, VERSION_TO, \
|
||||
kJsExecutionProvider, \
|
||||
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(), \
|
||||
DataTypeImpl::GetTensorType<int32_t>()}), \
|
||||
KernelDefBuilder().TypeConstraint("T", JsepSupportedDataTypes()), \
|
||||
KERNEL_CLASS);
|
||||
|
||||
JSEP_KERNEL_IMPL(Add, Add)
|
||||
|
|
|
|||
|
|
@ -6,22 +6,29 @@
|
|||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
#define JSEP_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \
|
||||
#define JSEP_ELEMENTWISE_TYPED_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \
|
||||
ONNX_OPERATOR_KERNEL_EX( \
|
||||
OP_TYPE, kOnnxDomain, VERSION, kJsExecutionProvider, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()), \
|
||||
KERNEL_CLASS);
|
||||
|
||||
#define JSEP_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, TYPE, KERNEL_CLASS) \
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
|
||||
OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kJsExecutionProvider, \
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()), \
|
||||
#define JSEP_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \
|
||||
ONNX_OPERATOR_KERNEL_EX( \
|
||||
OP_TYPE, kOnnxDomain, VERSION, kJsExecutionProvider, \
|
||||
KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), \
|
||||
KERNEL_CLASS);
|
||||
|
||||
#define JSEP_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
|
||||
OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kJsExecutionProvider, \
|
||||
KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), \
|
||||
KERNEL_CLASS);
|
||||
|
||||
#define JSEP_ELEMENTWISE_MULTI_TYPED_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \
|
||||
ONNX_OPERATOR_KERNEL_EX( \
|
||||
OP_TYPE, kOnnxDomain, VERSION, kJsExecutionProvider, \
|
||||
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(), \
|
||||
DataTypeImpl::GetTensorType<MLFloat16>(), \
|
||||
DataTypeImpl::GetTensorType<int32_t>()}), \
|
||||
KERNEL_CLASS);
|
||||
|
||||
|
|
@ -29,6 +36,7 @@ namespace js {
|
|||
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
|
||||
OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kJsExecutionProvider, \
|
||||
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(), \
|
||||
DataTypeImpl::GetTensorType<MLFloat16>(), \
|
||||
DataTypeImpl::GetTensorType<int32_t>()}), \
|
||||
KERNEL_CLASS);
|
||||
// math
|
||||
|
|
@ -42,115 +50,115 @@ JSEP_ELEMENTWISE_MULTI_TYPED_VERSIONED_KERNEL(Neg, 6, 12, Neg)
|
|||
JSEP_ELEMENTWISE_MULTI_TYPED_KERNEL(Neg, 13, Neg)
|
||||
|
||||
JSEP_KERNEL_IMPL(Floor, Floor)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Floor, 6, 12, float, Floor)
|
||||
JSEP_ELEMENTWISE_KERNEL(Floor, 13, float, Floor)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Floor, 6, 12, Floor)
|
||||
JSEP_ELEMENTWISE_KERNEL(Floor, 13, Floor)
|
||||
|
||||
JSEP_KERNEL_IMPL(Ceil, Ceil)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Ceil, 6, 12, float, Ceil)
|
||||
JSEP_ELEMENTWISE_KERNEL(Ceil, 13, float, Ceil)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Ceil, 6, 12, Ceil)
|
||||
JSEP_ELEMENTWISE_KERNEL(Ceil, 13, Ceil)
|
||||
|
||||
JSEP_KERNEL_IMPL(Reciprocal, Reciprocal)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Reciprocal, 6, 12, float, Reciprocal)
|
||||
JSEP_ELEMENTWISE_KERNEL(Reciprocal, 13, float, Reciprocal)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Reciprocal, 6, 12, Reciprocal)
|
||||
JSEP_ELEMENTWISE_KERNEL(Reciprocal, 13, Reciprocal)
|
||||
|
||||
JSEP_KERNEL_IMPL(Sqrt, Sqrt)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sqrt, 6, 12, float, Sqrt)
|
||||
JSEP_ELEMENTWISE_KERNEL(Sqrt, 13, float, Sqrt)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sqrt, 6, 12, Sqrt)
|
||||
JSEP_ELEMENTWISE_KERNEL(Sqrt, 13, Sqrt)
|
||||
|
||||
JSEP_KERNEL_IMPL(Exp, Exp)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, float, Exp)
|
||||
JSEP_ELEMENTWISE_KERNEL(Exp, 13, float, Exp)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, Exp)
|
||||
JSEP_ELEMENTWISE_KERNEL(Exp, 13, Exp)
|
||||
|
||||
JSEP_KERNEL_IMPL(Erf, Erf)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, float, Erf)
|
||||
JSEP_ELEMENTWISE_KERNEL(Erf, 13, float, Erf)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, Erf)
|
||||
JSEP_ELEMENTWISE_KERNEL(Erf, 13, Erf)
|
||||
|
||||
JSEP_KERNEL_IMPL(Sigmoid, Sigmoid)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, float, Sigmoid)
|
||||
JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, float, Sigmoid)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, Sigmoid)
|
||||
JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, Sigmoid)
|
||||
|
||||
JSEP_KERNEL_IMPL(Log, Log)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, float, Log)
|
||||
JSEP_ELEMENTWISE_KERNEL(Log, 13, float, Log)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, Log)
|
||||
JSEP_ELEMENTWISE_KERNEL(Log, 13, Log)
|
||||
|
||||
JSEP_KERNEL_IMPL(Sin, Sin)
|
||||
JSEP_ELEMENTWISE_KERNEL(Sin, 7, float, Sin)
|
||||
JSEP_ELEMENTWISE_KERNEL(Sin, 7, Sin)
|
||||
|
||||
JSEP_KERNEL_IMPL(Cos, Cos)
|
||||
JSEP_ELEMENTWISE_KERNEL(Cos, 7, float, Cos)
|
||||
JSEP_ELEMENTWISE_KERNEL(Cos, 7, Cos)
|
||||
|
||||
JSEP_KERNEL_IMPL(Tan, Tan)
|
||||
JSEP_ELEMENTWISE_KERNEL(Tan, 7, float, Tan)
|
||||
JSEP_ELEMENTWISE_KERNEL(Tan, 7, Tan)
|
||||
|
||||
JSEP_KERNEL_IMPL(Asin, Asin)
|
||||
JSEP_ELEMENTWISE_KERNEL(Asin, 7, float, Asin)
|
||||
JSEP_ELEMENTWISE_KERNEL(Asin, 7, Asin)
|
||||
|
||||
JSEP_KERNEL_IMPL(Acos, Acos)
|
||||
JSEP_ELEMENTWISE_KERNEL(Acos, 7, float, Acos)
|
||||
JSEP_ELEMENTWISE_KERNEL(Acos, 7, Acos)
|
||||
|
||||
JSEP_KERNEL_IMPL(Atan, Atan)
|
||||
JSEP_ELEMENTWISE_KERNEL(Atan, 7, float, Atan)
|
||||
JSEP_ELEMENTWISE_KERNEL(Atan, 7, Atan)
|
||||
|
||||
JSEP_KERNEL_IMPL(Sinh, Sinh)
|
||||
JSEP_ELEMENTWISE_KERNEL(Sinh, 9, float, Sinh)
|
||||
JSEP_ELEMENTWISE_KERNEL(Sinh, 9, Sinh)
|
||||
|
||||
JSEP_KERNEL_IMPL(Cosh, Cosh)
|
||||
JSEP_ELEMENTWISE_KERNEL(Cosh, 9, float, Cosh)
|
||||
JSEP_ELEMENTWISE_KERNEL(Cosh, 9, Cosh)
|
||||
|
||||
JSEP_KERNEL_IMPL(Asinh, Asinh)
|
||||
JSEP_ELEMENTWISE_KERNEL(Asinh, 9, float, Asinh)
|
||||
JSEP_ELEMENTWISE_KERNEL(Asinh, 9, Asinh)
|
||||
|
||||
JSEP_KERNEL_IMPL(Acosh, Acosh)
|
||||
JSEP_ELEMENTWISE_KERNEL(Acosh, 9, float, Acosh)
|
||||
JSEP_ELEMENTWISE_KERNEL(Acosh, 9, Acosh)
|
||||
|
||||
JSEP_KERNEL_IMPL(Atanh, Atanh)
|
||||
JSEP_ELEMENTWISE_KERNEL(Atanh, 9, float, Atanh)
|
||||
JSEP_ELEMENTWISE_KERNEL(Atanh, 9, Atanh)
|
||||
|
||||
JSEP_KERNEL_IMPL(Tanh, Tanh)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, float, Tanh)
|
||||
JSEP_ELEMENTWISE_KERNEL(Tanh, 13, float, Tanh)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, Tanh)
|
||||
JSEP_ELEMENTWISE_KERNEL(Tanh, 13, Tanh)
|
||||
|
||||
JSEP_KERNEL_IMPL(Not, Not)
|
||||
JSEP_ELEMENTWISE_KERNEL(Not, 1, bool, Not)
|
||||
JSEP_ELEMENTWISE_TYPED_KERNEL(Not, 1, bool, Not)
|
||||
|
||||
// activation
|
||||
|
||||
JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, ClipV10, min, 3.402823e+38f, max, -3.402823e+38f)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, float, ClipV10)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, ClipV10)
|
||||
JSEP_KERNEL_IMPL(Clip, Clip)
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, kJsExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
|
||||
.TypeConstraint("T", JsepSupportedFloatTypes())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1)
|
||||
.InputMemoryType(OrtMemTypeCPU, 2),
|
||||
Clip);
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 12, 12, kJsExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
|
||||
.TypeConstraint("T", JsepSupportedFloatTypes())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1)
|
||||
.InputMemoryType(OrtMemTypeCPU, 2),
|
||||
Clip);
|
||||
ONNX_OPERATOR_KERNEL_EX(Clip, kOnnxDomain, 13, kJsExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
|
||||
.TypeConstraint("T", JsepSupportedFloatTypes())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1)
|
||||
.InputMemoryType(OrtMemTypeCPU, 2),
|
||||
Clip);
|
||||
|
||||
JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(Elu, Elu, alpha, 1.0)
|
||||
JSEP_ELEMENTWISE_KERNEL(Elu, 6, float, Elu)
|
||||
JSEP_ELEMENTWISE_KERNEL(Elu, 6, Elu)
|
||||
|
||||
JSEP_KERNEL_IMPL(Relu, Relu)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, float, Relu)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, float, Relu)
|
||||
JSEP_ELEMENTWISE_KERNEL(Relu, 14, float, Relu)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, Relu)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, Relu)
|
||||
JSEP_ELEMENTWISE_KERNEL(Relu, 14, Relu)
|
||||
|
||||
JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(LeakyRelu, LeakyRelu, alpha, 0.01)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, float, LeakyRelu)
|
||||
JSEP_ELEMENTWISE_KERNEL(LeakyRelu, 16, float, LeakyRelu)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, LeakyRelu)
|
||||
JSEP_ELEMENTWISE_KERNEL(LeakyRelu, 16, LeakyRelu)
|
||||
|
||||
JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(ThresholdedRelu, ThresholdedRelu, alpha, 1.0)
|
||||
JSEP_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, float, ThresholdedRelu)
|
||||
JSEP_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, ThresholdedRelu)
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue