mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
[js/webgpu] Fix f16 errors in unary (#18839)
### Description This PR fixes below errors: ``` no matching overload for operator > (vec4<f16>, vec4<f32>)
This commit is contained in:
parent
f52668cc68
commit
4bbed4c71a
1 changed files with 16 additions and 12 deletions
|
|
@ -7,7 +7,7 @@ import {MAX_CLIP, MIN_CLIP, ShapeUtil} from '../../util';
|
|||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, ProgramInfo} from '../types';
|
||||
|
||||
import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common';
|
||||
import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglValueType} from './common';
|
||||
|
||||
type BuiltinFunctionName = string;
|
||||
type ElementwiseCustomExpression = (expression: string) => string;
|
||||
|
|
@ -132,7 +132,7 @@ const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAt
|
|||
|
||||
export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => {
|
||||
const attributes = context.inputs.length === 1 ? clipAttributes : generateClipAttributesFromInputs(context.inputs);
|
||||
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
|
||||
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
|
||||
context.compute(
|
||||
createElementwiseProgramInfo(
|
||||
context.inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, `
|
||||
|
|
@ -163,15 +163,16 @@ export const parseAlphaAttributes = (attributes: Record<string, unknown>): Alpha
|
|||
createAttributeWithCacheKey(attributes as {alpha: number});
|
||||
|
||||
export const elu = (context: ComputeContext, attributes: AlphaAttributes): void => {
|
||||
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
|
||||
context.compute(createElementwiseProgramInfo(
|
||||
context.inputs[0], 'Elu', a => `elu_vf32(${a})`, `
|
||||
const elu_alpha_: f32 = f32(${attributes.alpha});
|
||||
const elu_alpha_ = ${dataType}(${attributes.alpha});
|
||||
|
||||
fn elu_f32(a: f32) -> f32 {
|
||||
fn elu_f32(a: ${dataType}) -> ${dataType} {
|
||||
return select((exp(a) - 1.0) * elu_alpha_, a, a >= 0.0);
|
||||
}
|
||||
|
||||
fn elu_vf32(v: vec4<f32>) -> vec4<f32> {
|
||||
fn elu_vf32(v: vec4<${dataType}>) -> vec4<${dataType}> {
|
||||
return vec4(elu_f32(v.x), elu_f32(v.y), elu_f32(v.z), elu_f32(v.w));
|
||||
}`,
|
||||
attributes.cacheKey));
|
||||
|
|
@ -192,7 +193,7 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} {
|
|||
}`;
|
||||
|
||||
export const erf = (context: ComputeContext): void => {
|
||||
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
|
||||
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
|
||||
context.compute(createElementwiseProgramInfo(
|
||||
context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType)));
|
||||
};
|
||||
|
|
@ -206,16 +207,17 @@ export const floor = (context: ComputeContext): void => {
|
|||
};
|
||||
|
||||
export const gelu = (context: ComputeContext): void => {
|
||||
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
|
||||
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
|
||||
context.compute(createElementwiseProgramInfo(
|
||||
context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`,
|
||||
erfImpl(`vec4<${dataType}>`, dataType)));
|
||||
};
|
||||
|
||||
export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
|
||||
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
|
||||
context.compute(createElementwiseProgramInfo(
|
||||
context.inputs[0], 'LeakyRelu', a => `select(leaky_relu_alpha_ * ${a}, ${a}, ${a} >= vec4<f32>(0.0))`,
|
||||
`const leaky_relu_alpha_: f32 = f32(${attributes.alpha});`, attributes.cacheKey));
|
||||
context.inputs[0], 'LeakyRelu', a => `select(leaky_relu_alpha_ * ${a}, ${a}, ${a} >= vec4<${dataType}>(0.0))`,
|
||||
`const leaky_relu_alpha_ = ${dataType}(${attributes.alpha});`, attributes.cacheKey));
|
||||
};
|
||||
|
||||
export const not = (context: ComputeContext): void => {
|
||||
|
|
@ -231,8 +233,9 @@ export const reciprocal = (context: ComputeContext): void => {
|
|||
};
|
||||
|
||||
export const relu = (context: ComputeContext): void => {
|
||||
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
|
||||
context.compute(createElementwiseProgramInfo(
|
||||
context.inputs[0], 'Relu', a => `select(vec4<f32>(0.0), ${a}, ${a} > vec4<f32>(0.0))`));
|
||||
context.inputs[0], 'Relu', a => `select(vec4<${dataType}>(0.0), ${a}, ${a} > vec4<${dataType}>(0.0))`));
|
||||
};
|
||||
|
||||
export const sigmoid = (context: ComputeContext): void => {
|
||||
|
|
@ -260,9 +263,10 @@ export const tanh = (context: ComputeContext): void => {
|
|||
};
|
||||
|
||||
export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => {
|
||||
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
|
||||
context.compute(createElementwiseProgramInfo(
|
||||
context.inputs[0], 'ThresholdedRelu', a => `select(vec4<f32>(0.0), ${a}, ${a} > thresholded_relu_alpha_)`,
|
||||
`const thresholded_relu_alpha_: vec4<f32> = vec4<f32>(${attributes.alpha});`, attributes.cacheKey));
|
||||
context.inputs[0], 'ThresholdedRelu', a => `select(vec4<${dataType}>(0.0), ${a}, ${a} > thresholded_relu_alpha_)`,
|
||||
`const thresholded_relu_alpha_ = vec4<${dataType}>(${attributes.alpha});`, attributes.cacheKey));
|
||||
return 0;
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue