fix fused relu activation (#18303)

This commit is contained in:
Guenther Schmuelling 2023-11-09 08:18:21 -08:00 committed by GitHub
parent 2c22b49876
commit 25fbc2b0ab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 2 deletions

View file

@ -15,7 +15,10 @@ export const getActivationSnippet = (attributes: InternalActivationAttributes, i
} => {
switch (attributes.activation) {
case 'Relu':
return {activationFunction: '', applyActivation: 'value = max(value, 0.0);'};
return {
activationFunction: '',
applyActivation: isVec4 ? 'value = max(value, vec4(0.0));' : 'value = max(value, 0.0);'
};
case 'Sigmoid':
return {activationFunction: '', applyActivation: 'value = (1.0 / (1.0 + exp(-value)));'};
case 'Clip':

View file

@ -127,7 +127,7 @@ export class ProgramManager {
const userCode = programInfo.getShaderSource(shaderHelper);
const code = `${extensions.join('\n')}\n${shaderHelper.additionalImplementations}\n${userCode}`;
const shaderModule = device.createShaderModule({code, label: programInfo.name});
LOG_DEBUG('verbose', () => `[WebGPU] shader code: ${code}`);
LOG_DEBUG('verbose', () => `[WebGPU] ${programInfo.name} shader code: ${code}`);
const computePipeline = device.createComputePipeline(
{compute: {module: shaderModule, entryPoint: 'main'}, layout: 'auto', label: programInfo.name});