[js/webgpu] Fix f16 errors in unary (#18839)

### Description
This PR fixes below errors:
```
no matching overload for operator > (vec4<f16>, vec4<f32>)
This commit is contained in:
Jiajia Qin 2023-12-16 03:25:12 +08:00 committed by GitHub
parent f52668cc68
commit 4bbed4c71a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -7,7 +7,7 @@ import {MAX_CLIP, MIN_CLIP, ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';
import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common';
import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglValueType} from './common';
type BuiltinFunctionName = string;
type ElementwiseCustomExpression = (expression: string) => string;
@ -132,7 +132,7 @@ const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAt
export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => {
const attributes = context.inputs.length === 1 ? clipAttributes : generateClipAttributesFromInputs(context.inputs);
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(
createElementwiseProgramInfo(
context.inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, `
@ -163,15 +163,16 @@ export const parseAlphaAttributes = (attributes: Record<string, unknown>): Alpha
createAttributeWithCacheKey(attributes as {alpha: number});
export const elu = (context: ComputeContext, attributes: AlphaAttributes): void => {
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'Elu', a => `elu_vf32(${a})`, `
const elu_alpha_: f32 = f32(${attributes.alpha});
const elu_alpha_ = ${dataType}(${attributes.alpha});
fn elu_f32(a: f32) -> f32 {
fn elu_f32(a: ${dataType}) -> ${dataType} {
return select((exp(a) - 1.0) * elu_alpha_, a, a >= 0.0);
}
fn elu_vf32(v: vec4<f32>) -> vec4<f32> {
fn elu_vf32(v: vec4<${dataType}>) -> vec4<${dataType}> {
return vec4(elu_f32(v.x), elu_f32(v.y), elu_f32(v.z), elu_f32(v.w));
}`,
attributes.cacheKey));
@ -192,7 +193,7 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} {
}`;
export const erf = (context: ComputeContext): void => {
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType)));
};
@ -206,16 +207,17 @@ export const floor = (context: ComputeContext): void => {
};
export const gelu = (context: ComputeContext): void => {
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`,
erfImpl(`vec4<${dataType}>`, dataType)));
};
export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'LeakyRelu', a => `select(leaky_relu_alpha_ * ${a}, ${a}, ${a} >= vec4<f32>(0.0))`,
`const leaky_relu_alpha_: f32 = f32(${attributes.alpha});`, attributes.cacheKey));
context.inputs[0], 'LeakyRelu', a => `select(leaky_relu_alpha_ * ${a}, ${a}, ${a} >= vec4<${dataType}>(0.0))`,
`const leaky_relu_alpha_ = ${dataType}(${attributes.alpha});`, attributes.cacheKey));
};
export const not = (context: ComputeContext): void => {
@ -231,8 +233,9 @@ export const reciprocal = (context: ComputeContext): void => {
};
export const relu = (context: ComputeContext): void => {
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'Relu', a => `select(vec4<f32>(0.0), ${a}, ${a} > vec4<f32>(0.0))`));
context.inputs[0], 'Relu', a => `select(vec4<${dataType}>(0.0), ${a}, ${a} > vec4<${dataType}>(0.0))`));
};
export const sigmoid = (context: ComputeContext): void => {
@ -260,9 +263,10 @@ export const tanh = (context: ComputeContext): void => {
};
export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => {
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'ThresholdedRelu', a => `select(vec4<f32>(0.0), ${a}, ${a} > thresholded_relu_alpha_)`,
`const thresholded_relu_alpha_: vec4<f32> = vec4<f32>(${attributes.alpha});`, attributes.cacheKey));
context.inputs[0], 'ThresholdedRelu', a => `select(vec4<${dataType}>(0.0), ${a}, ${a} > thresholded_relu_alpha_)`,
`const thresholded_relu_alpha_ = vec4<${dataType}>(${attributes.alpha});`, attributes.cacheKey));
return 0;
};