onnxruntime/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts
Xu Xing 0d7cf301a1
[js/webgpu] Add activation Tanh (#21540)
Bug:https://github.com/microsoft/onnxruntime/issues/21467

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2024-07-29 11:05:34 -07:00

82 lines
3.5 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
import {MAX_CLIP, MIN_CLIP} from '../../util';
import {ProgramUniform} from '../types';
import {UniformsArrayType} from './common';
export interface InternalActivationAttributes {
readonly activation: string;
readonly clipMin?: number;
readonly clipMax?: number;
readonly alpha?: number;
readonly beta?: number;
}
export const getActivationSnippet =
(attributes: InternalActivationAttributes, valueType: string, baseType = 'f32'): string => {
switch (attributes.activation) {
case 'Relu':
return `value = max(value, ${valueType}(0.0));`;
case 'Sigmoid':
return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`;
case 'Clip':
return `value = clamp(value, ${valueType}(${baseType}(uniforms.clip_min)), ${valueType}(${
baseType}(uniforms.clip_max)));`;
case 'HardSigmoid':
return `value = max(${valueType}(0.0), min(${valueType}(1.0), ${baseType}(uniforms.alpha) * value + ${
baseType}(uniforms.beta)));`;
case 'LeakyRelu':
return `value = select(${baseType}(uniforms.alpha) * value, value, value >= ${valueType}(0.0));`;
case 'Tanh':
return `let e2x = exp(-2.0 * abs(value));
value = sign(value) * (1.0 - e2x) / (1.0 + e2x);
`;
case '':
return '';
// TODO: adding other activations that can be fused.
default:
throw new Error(`Unsupported activation ${attributes.activation}`);
}
};
export const appendActivationUniformsData =
(attributes: InternalActivationAttributes, programUniform: ProgramUniform[]) => {
if (attributes.activation === 'Clip') {
programUniform.push(
{type: DataType.float, data: attributes.clipMax!}, {type: DataType.float, data: attributes.clipMin!});
} else if (attributes.activation === 'HardSigmoid') {
programUniform.push(
{type: DataType.float, data: attributes.alpha!}, {type: DataType.float, data: attributes.beta!});
} else if (attributes.activation === 'LeakyRelu') {
programUniform.push({type: DataType.float, data: attributes.alpha!});
}
};
export const appendActivationUniforms = (attributes: InternalActivationAttributes, uniforms: UniformsArrayType) => {
if (attributes.activation === 'Clip') {
uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'});
} else if (attributes.activation === 'HardSigmoid') {
uniforms.push({name: 'alpha', type: 'f32'}, {name: 'beta', type: 'f32'});
} else if (attributes.activation === 'LeakyRelu') {
uniforms.push({name: 'alpha', type: 'f32'});
}
};
export const parseInternalActivationAttributes =
(attributes: Record<string, unknown>|undefined): InternalActivationAttributes => {
const activation = attributes?.activation as string || '';
if (activation === 'HardSigmoid') {
const [alpha, beta] = attributes?.activation_params as [number, number] || [0.2, 0.5];
return {activation, alpha, beta};
} else if (activation === 'Clip') {
const [clipMin, clipMax] = attributes?.activation_params as [number, number] || [MIN_CLIP, MAX_CLIP];
return {activation, clipMax, clipMin};
} else if (activation === 'LeakyRelu') {
const [alpha] = attributes?.activation_params as [number] || [0.01];
return {activation, alpha};
}
return {activation};
};