mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
[js/webgpu] Fix Conv2DTransposeMatMul f16 compilation failure (#19596)
This is used in sam-h-decoder-f16. ### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
3bdb10d5ca
commit
fe82fccf1a
1 changed files with 13 additions and 9 deletions
|
|
@ -23,17 +23,17 @@ import {DataType} from '../../../../wasm-common';
|
|||
import {LOG_DEBUG} from '../../../log';
|
||||
import {TensorView} from '../../../tensor-view';
|
||||
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
|
||||
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common';
|
||||
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
|
||||
import {ConvTransposeAttributes} from '../conv-transpose';
|
||||
import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils';
|
||||
|
||||
import {biasSnippet, typeSnippet} from './activation_util';
|
||||
import {biasSnippet} from './activation_util';
|
||||
import {utilFunctions} from './conv_util';
|
||||
import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu';
|
||||
|
||||
const conv2dTransposeCommonSnippet =
|
||||
(isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, innerElementSize = 4): string => {
|
||||
const type = typeSnippet(innerElementSize, 'f32');
|
||||
(isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, type: string,
|
||||
innerElementSize = 4): string => {
|
||||
const getWSnippet = (innerElementSize: number) => {
|
||||
switch (innerElementSize) {
|
||||
case 1:
|
||||
|
|
@ -47,7 +47,7 @@ const conv2dTransposeCommonSnippet =
|
|||
let v1 = w[getIndexFromCoords4D(coord1, vec4<i32>(uniforms.w_shape))];
|
||||
let v2 = w[getIndexFromCoords4D(coord2, vec4<i32>(uniforms.w_shape))];
|
||||
let v3 = w[getIndexFromCoords4D(coord3, vec4<i32>(uniforms.w_shape))];
|
||||
return vec4<f32>(v0, v1, v2, v3);
|
||||
return ${type}(v0, v1, v2, v3);
|
||||
`;
|
||||
default:
|
||||
throw new Error(`innerElementSize ${innerElementSize} is not supported.`);
|
||||
|
|
@ -224,7 +224,7 @@ export const createConv2DTransposeMatMulProgramInfo =
|
|||
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
|
||||
inputVariables.push(bias);
|
||||
declareFunctions += `
|
||||
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? 'vec4<f32>' : 'f32'} {
|
||||
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${bias.type.value} {
|
||||
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
|
||||
}`;
|
||||
}
|
||||
|
|
@ -236,16 +236,20 @@ export const createConv2DTransposeMatMulProgramInfo =
|
|||
{name: 'pads', type: 'i32', length: pads.length}
|
||||
];
|
||||
appendActivationUniforms(attributes, uniforms);
|
||||
const elemType = tensorTypeToWsglStorageType(inputs[0].dataType, 1);
|
||||
if (elemType !== 'f16' && elemType !== 'f32') {
|
||||
throw new Error(`elemType ${elemType} is not supported.`);
|
||||
}
|
||||
return `
|
||||
${utilFunctions('uniforms.result_strides')}
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)};
|
||||
${declareFunctions}
|
||||
${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)}
|
||||
${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, x.type.value, innerElementSize)}
|
||||
${
|
||||
isVec4 ? makeMatMulPackedVec4Source(
|
||||
elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) :
|
||||
elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner) :
|
||||
makeMatMulPackedSource(
|
||||
elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false,
|
||||
elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner, false,
|
||||
undefined, sequentialAccessByThreads)}`;
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue