[JS/Web] Add ConvTranspose implementation using MatMul (#17573)

### Description
Add ConvTranspose implementation using MatMul to increase perf.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
satyajandhyala 2023-09-29 11:00:44 -07:00 committed by GitHub
parent caf98128c1
commit b4fbc25b1f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 471 additions and 19 deletions

View file

@ -0,0 +1,243 @@
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts
//
// modified to fit the needs of the project
import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ShapeUtil} from '../../../util';
import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types';
import {ConvTransposeAttributes} from '../conv-transpose';
import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util';
import {utilFunctions} from './conv_util';
import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu';
const conv2dTransposeCommonSnippet =
(isChannelsLast: boolean, addBias = false, activation?: Activation, hasPreluActivationWeights = false,
innerElementSize = 4): string => {
const getWSnippet = (innerElementSize: number) => {
switch (innerElementSize) {
case 1:
return 'return W[getIndexFromCoords4D(coord, wShape)];';
case 4:
return `
let coord1 = vec4<i32>(coordX, coordY, col + 1, rowInner);
let coord2 = vec4<i32>(coordX, coordY, col + 2, rowInner);
let coord3 = vec4<i32>(coordX, coordY, col + 3, rowInner);
let v0 = W[getIndexFromCoords4D(coord, wShape)];
let v1 = W[getIndexFromCoords4D(coord1, wShape)];
let v2 = W[getIndexFromCoords4D(coord2, wShape)];
let v3 = W[getIndexFromCoords4D(coord3, wShape)];
return vec4<f32>(v0, v1, v2, v3);
`;
default:
throw new Error(`innerElementSize ${innerElementSize} is not supported.`);
}
};
const coordASnippet = isChannelsLast ? `
let coord = vec4<i32>(batch, iXR, iXC, xCh);
` :
`
let coord = vec4<i32>(batch, xCh, iXR, iXC);
`;
const coordResSnippet = isChannelsLast ? `
let coords = vec4<i32>(
batch,
row / outWidth,
row % outWidth,
col);
` :
`
let coords = vec4<i32>(
batch,
row,
col / outWidth,
col % outWidth);
`;
const xHeight = isChannelsLast ? 'outBackprop[1]' : 'outBackprop[2]';
const xWidth = isChannelsLast ? 'outBackprop[2]' : 'outBackprop[3]';
const row = isChannelsLast ? 'row' : 'col';
const col = isChannelsLast ? 'col' : 'row';
const readASnippet = `
let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'};
let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'};
let outRow = ${row} / outWidth;
let outCol = ${row} % outWidth;
let WRow = ${col} / (filterDims[1] * inChannels);
let WCol = ${col} / inChannels % filterDims[1];
let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]);
let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]);
if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) {
return ${typeSnippet(innerElementSize)}(0.0);
}
if (xC < 0.0 || xC >= f32(${xWidth}) || fract(xC) > 0.0) {
return ${typeSnippet(innerElementSize)}(0.0);
}
let iXR = i32(xR);
let iXC = i32(xC);
let xCh = ${col} % inChannels;
${coordASnippet}
return x[getIndexFromCoords4D(coord, xShape)/${innerElementSize}];`;
const sampleA = isChannelsLast ? `
let col = colIn * ${innerElementSize};
if (row < dimAOuter && col < dimInner) {
${readASnippet}
}
return ${typeSnippet(innerElementSize)}(0.0);` :
`
let col = colIn * ${innerElementSize};
if (row < dimInner && col < dimBOuter) {
${readASnippet}
}
return ${typeSnippet(innerElementSize)}(0.0);`;
const sampleW = `
let col = colIn * ${innerElementSize};
let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'};
let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels);
let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1];
if (${
isChannelsLast ? 'row < dimInner && col < dimBOuter' :
'row < dimInner && col < dimAOuter'} && coordX >= 0 && coordY >= 0) {
let rowInner = row % inChannels;
let coord = vec4<i32>(coordX, coordY, col, rowInner);
${getWSnippet(innerElementSize)}
}
return ${typeSnippet(innerElementSize)}(0.0);
`;
const userCode = `
${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)}
fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${typeSnippet(innerElementSize)} {
${isChannelsLast ? sampleA : sampleW}
}
fn mm_readB(batch: i32, row : i32, colIn : i32) -> ${typeSnippet(innerElementSize)} {
${isChannelsLast ? sampleW : sampleA}
}
fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${typeSnippet(innerElementSize)}) {
let col = colIn * ${innerElementSize};
if (row < dimAOuter && col < dimBOuter) {
var value = valueInput;
let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'};
${coordResSnippet}
${biasActivationSnippet(addBias, activation)}
result[getIndexFromCoords4D(coords, outShape)/${innerElementSize}] = value;
}
}`;
return userCode;
};
export const createConv2DTransposeMatMulProgramInfo =
(inputs: readonly TensorView[], metadata: ProgramMetadata, attributes: ConvTransposeAttributes,
outputShape: readonly number[], dimAOuter: number, dimBOuter: number, dimInner: number, hasBias: boolean,
sequentialAccessByThreads: boolean): ProgramInfo => {
const isChannelsLast = attributes.format === 'NHWC';
const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1];
const batchSize = outputShape[0];
const outWidth = isChannelsLast ? outputShape[2] : outputShape[3];
const outHeight = isChannelsLast ? outputShape[1] : outputShape[2];
const outChannels = isChannelsLast ? outputShape[3] : outputShape[1];
const isVec4 =
isChannelsLast ? inChannels % 4 === 0 && outChannels % 4 === 0 : outWidth % 4 === 0 && outChannels % 4 === 0;
// TODO: fine tune size
const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight;
const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels;
const workGroupSize: [number, number, number] = isVec4 ?
[8, 8, 1] :
[(dispatchX <= 4 || dispatchY <= 4) ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1];
const elementsPerThread =
isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 4, dispatchX > 4 && dispatchY <= 4 ? 1 : 4, 1];
const dispatch = [
Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]),
Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]),
Math.ceil(batchSize / workGroupSize[2] / elementsPerThread[2])
];
LOG_DEBUG('verbose', () => `[conv_backprop_mm_webgpu] dispatch = ${dispatch}`);
const innerElementSize = isVec4 ? 4 : 1;
const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]);
const declareInputs = [
`@group(0) @binding(0) var<storage, read> x: array<${isVec4 ? 'vec4<f32>' : 'f32'}>;`,
'@group(0) @binding(1) var<storage, read> W: array<f32>;'
];
let declareFunctions = '';
if (hasBias) {
declareInputs.push(`@group(0) @binding(2) var<storage, read> bias: array<${isVec4 ? 'vec4<f32>' : 'f32'}>;`);
declareFunctions += `
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? 'vec4<f32>' : 'f32'} {
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
}`;
}
return {
...metadata,
outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}],
dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}),
getShaderSource: () => `
${utilFunctions}
${declareInputs.join('\n')}
@group(0) @binding(${declareInputs.length}) var<storage, read_write> result: array<${
isVec4 ? 'vec4<f32>' : 'f32'}>;
const outBackprop : vec4<i32> = vec4<i32>(${inputs[0].dims.join(',')});
const xShape : vec4<i32> = vec4<i32>(${inputs[0].dims.join(',')});
const wShape : vec4<i32> = vec4<i32>(${inputs[1].dims.join(',')});
const outShape : vec4<i32> = vec4<i32>(${outputShape.join(',')});
const outShapeStrides : vec3<i32> = vec3<i32>(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')});
const filterDims : vec2<i32> = vec2<i32>(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${
attributes.kernelShape[isChannelsLast ? 2 : 3]});
const effectiveFilterDims : vec2<i32> = filterDims + vec2<i32>(
${
attributes.dilations[0] <= 1 ?
0 :
(attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)},
${
attributes.dilations[1] <= 1 ?
0 :
(attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)});
const pads : vec2<i32> = vec2<i32>(i32(effectiveFilterDims[0]) - 1 - (${
attributes.pads[0] + attributes.pads[2]})/2,
i32(effectiveFilterDims[1]) - 1 - (${
attributes.pads[1] + attributes.pads[3]})/2);
const strides : vec2<i32> = vec2<i32>(${attributes.strides[0]}, ${attributes.strides[1]});
const dilation : vec2<i32> = vec2<i32>(${attributes.dilations[0]}, ${attributes.dilations[1]});
const dimAOuter : i32 = ${dimAOuter};
const dimBOuter : i32 = ${dimBOuter};
const dimInner : i32 = ${dimInner};
${declareFunctions}
${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, undefined, false, innerElementSize)}
${
isVec4 ?
makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner) :
makeMatMulPackedSource(
elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner, false, undefined,
sequentialAccessByThreads)}`
};
};

View file

@ -197,14 +197,14 @@ const createConvTranspose2DOpProgramShaderSource =
continue;
}
let idyC: u32 = u32(dyC);
var inputChannel = groupId * ${inputChannelsPerGroup};
for (var d2: u32 = 0; d2 < ${inputChannelsPerGroup}; d2 = d2 + 1) {
let inputChannel = groupId * ${inputChannelsPerGroup} + d2;
let xValue = ${
isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') :
dy.get('batch', 'inputChannel', 'idyR', 'idyC')};
let wValue = ${w.get('inputChannel', 'wOutChannel', 'u32(wRPerm)', 'u32(wCPerm)')};
dotProd = dotProd + xValue * wValue;
inputChannel = inputChannel + 1;
}
}
}

View file

@ -8,7 +8,9 @@ import {ComputeContext, GpuDataType, ProgramInfoLoader, ProgramMetadata} from '.
import {createConvTranspose2DProgramInfo} from './3rd-party/conv_backprop_webgpu';
import {ConvAttributes} from './conv';
import {createConv2DTransposeMatMulProgramInfoLoader} from './conv2dtranspose-mm';
import {parseInternalActivationAttributes} from './fuse-utils';
import {createTransposeProgramInfo, TransposeAttributes, transposeProgramMetadata} from './transpose';
const computeTotalPad =
(inDim: number, stride: number, adj: number, kernel: number, dilation: number, outSize: number) =>
@ -63,7 +65,7 @@ const getAdjustedConvTransposeAttributes =
<T extends ConvTransposeAttributes>(attributes: T, inputs: readonly TensorView[]): T => {
const kernelShape = attributes.kernelShape.slice();
// if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims
if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 0) === 0) {
if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 1) === 0) {
kernelShape.length = 0;
for (let i = 2; i < inputs[1].dims.length; ++i) {
kernelShape.push(inputs[1].dims[i]);
@ -95,9 +97,11 @@ const getAdjustedConvTransposeAttributes =
// always return a new object so does not modify the original attributes
const newAttributes: T = Object.assign({}, attributes);
Object.assign(
newAttributes,
{kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey: attributes.cacheKey});
const cacheKey = attributes.cacheKey + [
kernelShape.join('n,'), pads.join(','), strides.join(','), outputPadding.join(','), outputShape.join(','),
dilations.join(',')
].join('_');
Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey});
return newAttributes;
};
@ -226,12 +230,64 @@ const createConvTranspose2DProgramInfoLoader =
};
};
// for transposing weight tensor from [C, M/group, KH, KW] to [KH, KW, M/group, C]
const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [2, 3, 1, 0]});
const convTranspose2d =
(context: ComputeContext, inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => {
const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs);
const isChannelsLast = attributes.format === 'NHWC';
const hasBias = inputs.length === 3;
if (adjustedAttributes.group !== 1) {
context.compute(createConvTranspose2DProgramInfoLoader(inputs, adjustedAttributes));
return;
}
const outputShape = adjustedAttributes.outputShape;
const outHeight = outputShape[isChannelsLast ? 1 : 2];
const outWidth = outputShape[isChannelsLast ? 2 : 3];
const outChannels = outputShape[isChannelsLast ? 3 : 1];
const weightHeight = inputs[1].dims[2];
const weightWidth = inputs[1].dims[3];
const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1];
context.compute(createConvTranspose2DProgramInfoLoader(inputs, adjustedAttributes));
const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels;
const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth;
const dimInner = weightHeight * weightWidth * inputChannels;
const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true;
// STEP.1: transpose weight
const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ??
context.compute(
{
...transposeProgramMetadata,
cacheHint: weightTransposeAttribute.cacheKey,
get: () => createTransposeProgramInfo(inputs[1], weightTransposeAttribute.perm)
},
{inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0];
if (attributes.wIsConst && !context.kernelCustomData.wT) {
context.kernelCustomData.wT = transposedWeight;
}
// STEP.2: prepare reshaped inputs
const convTransposeInputs = [inputs[0], transposedWeight];
if (hasBias) {
if (!isChannelsLast && inputs[2].dims.length === 1) {
convTransposeInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1]));
} else {
convTransposeInputs.push(inputs[2]);
}
}
// STEP.3: compute matmul
context.compute(
createConv2DTransposeMatMulProgramInfoLoader(
convTransposeInputs, adjustedAttributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias,
sequentialAccessByThreads),
{inputs: convTransposeInputs});
};
const convTranspose1d = (context: ComputeContext, attributes: ConvTransposeAttributes): void => {
// extend the input to 2D by adding H dimension
const isChannelLast = attributes.format === 'NHWC';

View file

@ -0,0 +1,29 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {TensorView} from '../../tensor-view';
import {GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types';
import {createConv2DTransposeMatMulProgramInfo} from './3rd-party/conv_backprop_mm_webgpu';
import {ConvTransposeAttributes} from './conv-transpose';
const createConv2DTransposeMatMulProgramMetadata = (hasBias: boolean, cacheHint: string): ProgramMetadata => ({
name: 'Conv2DTransposeMatMul',
inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] :
[GpuDataType.default, GpuDataType.default],
cacheHint
});
export const createConv2DTransposeMatMulProgramInfoLoader =
(inputs: readonly TensorView[], attributes: ConvTransposeAttributes, outputShape: readonly number[],
dimAOuter: number, dimBOuter: number, dimInner: number, hasBias: boolean,
sequentialAccessByThreads: boolean): ProgramInfoLoader => {
const metadata = createConv2DTransposeMatMulProgramMetadata(hasBias, attributes.cacheKey);
return {
...metadata,
get: () => createConv2DTransposeMatMulProgramInfo(
inputs, metadata, attributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias,
sequentialAccessByThreads)
};
};

View file

@ -28,6 +28,37 @@
}
]
},
{
"name": "ConvTranspose without bias addition A - NHWC",
"inputShapeDefinitions": "rankOnly",
"opset": { "domain": "", "version": 17 },
"operator": "ConvTranspose",
"attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [10, 20, 30, 40],
"dims": [1, 1, 2, 2],
"type": "float32"
},
{
"data": [1, 2, 3, 4],
"dims": [1, 1, 2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [10, 40, 40, 60, 200, 160, 90, 240, 160],
"dims": [1, 1, 3, 3],
"type": "float32"
}
]
}
]
},
{
"name": "ConvTranspose without bias addition B",
"operator": "ConvTranspose",
@ -74,26 +105,22 @@
},
{
"data": [
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64
],
"dims": [4, 4, 2, 2],
"type": "float32"
},
{
"data": [0.1, 0.2, 0.3, 0.4],
"data": [65, 66, 67, 68],
"dims": [4],
"type": "float32"
}
],
"outputs": [
{
"data": [
100.0999984741211, 100.0999984741211, 100.0999984741211, 100.0999984741211, 100.19999694824219,
100.19999694824219, 100.19999694824219, 100.19999694824219, 100.30000305175781, 100.30000305175781,
100.30000305175781, 100.30000305175781, 100.4000015258789, 100.4000015258789, 100.4000015258789,
100.4000015258789
],
"data": [3365, 3465, 3565, 3665, 3766, 3866, 3966, 4066, 4167, 4267, 4367, 4467, 4568, 4668, 4768, 4868],
"dims": [1, 4, 2, 2],
"type": "float32"
}
@ -115,7 +142,7 @@
"type": "float32"
},
{
"data": [1, 1, 1, 1],
"data": [1, 2, 3, 4],
"dims": [1, 1, 2, 2],
"type": "float32"
},
@ -127,7 +154,43 @@
],
"outputs": [
{
"data": [11, 19, 20, 12, 20, 43, 46, 23, 22, 49, 52, 25, 13, 25, 26, 14],
"data": [11, 25, 28, 19, 32, 86, 99, 55, 40, 114, 131, 67, 29, 73, 80, 41],
"dims": [1, 1, 4, 4],
"type": "float32"
}
]
}
]
},
{
"name": "ConvTranspose with bias addition B - NHWC",
"operator": "ConvTranspose",
"inputShapeDefinitions": "rankOnly",
"opset": { "domain": "", "version": 17 },
"attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [6, 8, 7, 9, 15, 11, 8, 12, 9],
"dims": [1, 1, 3, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4],
"dims": [1, 1, 2, 2],
"type": "float32"
},
{
"data": [5],
"dims": [1],
"type": "float32"
}
],
"outputs": [
{
"data": [11, 25, 28, 19, 32, 86, 99, 55, 40, 114, 131, 67, 29, 73, 80, 41],
"dims": [1, 1, 4, 4],
"type": "float32"
}
@ -251,7 +314,6 @@
}
]
},
{
"name": "ConvTranspose- pointwise",
"operator": "ConvTranspose",
@ -285,5 +347,50 @@
]
}
]
},
{
"name": "ConvTranspose with bias addition C",
"operator": "ConvTranspose",
"inputShapeDefinitions": "rankOnly",
"opset": { "domain": "", "version": 17 },
"attributes": [{ "name": "kernel_shape", "data": [1, 1], "type": "ints" }],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64
],
"dims": [1, 4, 4, 4],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
"dims": [4, 4, 1, 1],
"type": "float32"
},
{
"data": [1, 2, 3, 4],
"dims": [4],
"type": "float32"
}
],
"outputs": [
{
"data": [
1021, 1049, 1077, 1105, 1133, 1161, 1189, 1217, 1245, 1273, 1301, 1329, 1357, 1385, 1413, 1441, 1122,
1154, 1186, 1218, 1250, 1282, 1314, 1346, 1378, 1410, 1442, 1474, 1506, 1538, 1570, 1602, 1223, 1259,
1295, 1331, 1367, 1403, 1439, 1475, 1511, 1547, 1583, 1619, 1655, 1691, 1727, 1763, 1324, 1364, 1404,
1444, 1484, 1524, 1564, 1604, 1644, 1684, 1724, 1764, 1804, 1844, 1884, 1924
],
"dims": [1, 4, 4, 4],
"type": "float32"
}
]
}
]
}
]

View file

@ -108,6 +108,23 @@ class ConvTranspose : public JsKernel {
}
}
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* /* prepacked_weights */) override {
is_packed = false;
if (input_idx == 1) {
// Only handle the common case of conv2D
if (tensor.Shape().NumDimensions() != 4 || tensor.SizeInBytes() == 0) {
return Status::OK();
}
w_is_const_ = true;
}
return Status::OK();
}
protected:
ConvTransposeAttributes conv_transpose_attrs_;
bool w_is_const_;