2021-05-10 18:41:50 +00:00
|
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
|
// Licensed under the MIT License.
|
|
|
|
|
|
2024-08-14 23:51:22 +00:00
|
|
|
import { Attribute } from '../../../attribute';
|
|
|
|
|
import { MAX_CLIP, MIN_CLIP } from '../../../util';
|
|
|
|
|
import { GlslValueFunction } from '../glsl-definitions';
|
2022-05-04 06:41:36 +00:00
|
|
|
|
2024-08-14 23:51:22 +00:00
|
|
|
import { glslClip, glslRelu, glslSigmoid } from './unary-op';
|
2021-05-10 18:41:50 +00:00
|
|
|
|
2021-08-12 19:30:49 +00:00
|
|
|
export interface InternalActivationAttributes {
|
|
|
|
|
readonly activation: string;
|
|
|
|
|
readonly clipMin?: number;
|
|
|
|
|
readonly clipMax?: number;
|
|
|
|
|
readonly activationCacheKey: string;
|
|
|
|
|
}
|
|
|
|
|
|
2022-07-27 20:57:12 +00:00
|
|
|
export function getActivationSnippet(attributes: InternalActivationAttributes) {
|
2021-08-12 19:30:49 +00:00
|
|
|
let func: GlslValueFunction;
|
|
|
|
|
switch (attributes.activation) {
|
2021-05-10 18:41:50 +00:00
|
|
|
case 'Relu':
|
2021-08-12 19:30:49 +00:00
|
|
|
func = glslRelu();
|
2021-05-10 18:41:50 +00:00
|
|
|
break;
|
|
|
|
|
case 'Sigmoid':
|
2021-08-12 19:30:49 +00:00
|
|
|
func = glslSigmoid();
|
2021-05-10 18:41:50 +00:00
|
|
|
break;
|
2021-06-21 23:30:12 +00:00
|
|
|
case 'Clip':
|
2021-08-12 19:30:49 +00:00
|
|
|
func = glslClip(attributes.clipMin!, attributes.clipMax!);
|
2021-06-21 23:30:12 +00:00
|
|
|
break;
|
2021-08-12 19:30:49 +00:00
|
|
|
// TODO: adding other activations that can be fused.
|
2021-05-10 18:41:50 +00:00
|
|
|
default:
|
2024-08-14 23:51:22 +00:00
|
|
|
return { activationFunction: '', applyActivation: '' };
|
2021-05-10 18:41:50 +00:00
|
|
|
}
|
2021-06-21 23:30:12 +00:00
|
|
|
|
2021-08-12 19:30:49 +00:00
|
|
|
const activationName = func.name;
|
|
|
|
|
const activationFunction = func.body;
|
|
|
|
|
const applyActivation = `value = ${activationName}_(value);`;
|
2024-08-14 23:51:22 +00:00
|
|
|
return { activationFunction, applyActivation };
|
2021-05-10 18:41:50 +00:00
|
|
|
}
|
2021-08-12 19:30:49 +00:00
|
|
|
|
|
|
|
|
export const parseInternalActivationAttributes = (attributes: Attribute): InternalActivationAttributes => {
|
2021-11-09 19:58:47 +00:00
|
|
|
const activation = attributes.getString('activation', '');
|
2021-08-12 19:30:49 +00:00
|
|
|
|
|
|
|
|
if (activation === 'Clip') {
|
2021-11-09 19:58:47 +00:00
|
|
|
const [clipMin, clipMax] = attributes.getFloats('activation_params', [MIN_CLIP, MAX_CLIP]);
|
2024-08-14 23:51:22 +00:00
|
|
|
return { activation, clipMax, clipMin, activationCacheKey: `${activation}:${clipMin},${clipMax}` };
|
2021-08-12 19:30:49 +00:00
|
|
|
}
|
2024-08-14 23:51:22 +00:00
|
|
|
return { activation, activationCacheKey: activation };
|
2021-08-12 19:30:49 +00:00
|
|
|
};
|