onnxruntime/js/web/lib/onnxjs/backends/webgl/ops/fuse-utils.ts

49 lines
1.6 KiB
TypeScript
Raw Normal View History

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { Attribute } from '../../../attribute';
import { MAX_CLIP, MIN_CLIP } from '../../../util';
import { GlslValueFunction } from '../glsl-definitions';
import { glslClip, glslRelu, glslSigmoid } from './unary-op';
export interface InternalActivationAttributes {
readonly activation: string;
readonly clipMin?: number;
readonly clipMax?: number;
readonly activationCacheKey: string;
}
export function getActivationSnippet(attributes: InternalActivationAttributes) {
let func: GlslValueFunction;
switch (attributes.activation) {
case 'Relu':
func = glslRelu();
break;
case 'Sigmoid':
func = glslSigmoid();
break;
2021-06-21 23:30:12 +00:00
case 'Clip':
func = glslClip(attributes.clipMin!, attributes.clipMax!);
2021-06-21 23:30:12 +00:00
break;
// TODO: adding other activations that can be fused.
default:
return { activationFunction: '', applyActivation: '' };
}
2021-06-21 23:30:12 +00:00
const activationName = func.name;
const activationFunction = func.body;
const applyActivation = `value = ${activationName}_(value);`;
return { activationFunction, applyActivation };
}
export const parseInternalActivationAttributes = (attributes: Attribute): InternalActivationAttributes => {
const activation = attributes.getString('activation', '');
if (activation === 'Clip') {
const [clipMin, clipMax] = attributes.getFloats('activation_params', [MIN_CLIP, MAX_CLIP]);
return { activation, clipMax, clipMin, activationCacheKey: `${activation}:${clipMin},${clipMax}` };
}
return { activation, activationCacheKey: activation };
};