mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
fix fused relu activation (#18303)
This commit is contained in:
parent
2c22b49876
commit
25fbc2b0ab
2 changed files with 5 additions and 2 deletions
|
|
@ -15,7 +15,10 @@ export const getActivationSnippet = (attributes: InternalActivationAttributes, i
|
|||
} => {
|
||||
switch (attributes.activation) {
|
||||
case 'Relu':
|
||||
return {activationFunction: '', applyActivation: 'value = max(value, 0.0);'};
|
||||
return {
|
||||
activationFunction: '',
|
||||
applyActivation: isVec4 ? 'value = max(value, vec4(0.0));' : 'value = max(value, 0.0);'
|
||||
};
|
||||
case 'Sigmoid':
|
||||
return {activationFunction: '', applyActivation: 'value = (1.0 / (1.0 + exp(-value)));'};
|
||||
case 'Clip':
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ export class ProgramManager {
|
|||
const userCode = programInfo.getShaderSource(shaderHelper);
|
||||
const code = `${extensions.join('\n')}\n${shaderHelper.additionalImplementations}\n${userCode}`;
|
||||
const shaderModule = device.createShaderModule({code, label: programInfo.name});
|
||||
LOG_DEBUG('verbose', () => `[WebGPU] shader code: ${code}`);
|
||||
LOG_DEBUG('verbose', () => `[WebGPU] ${programInfo.name} shader code: ${code}`);
|
||||
|
||||
const computePipeline = device.createComputePipeline(
|
||||
{compute: {module: shaderModule, entryPoint: 'main'}, layout: 'auto', label: programInfo.name});
|
||||
|
|
|
|||
Loading…
Reference in a new issue