fix webgpu split (#17258)

fix webgpu split for the case of split_sizes coming from input[1]
This commit is contained in:
Guenther Schmuelling 2023-08-22 16:49:22 -07:00 committed by GitHub
parent d76dbc4fc3
commit d3d3dde844
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 4 deletions

View file

@ -23,10 +23,12 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
const createSplitAttributesFromInputs =
(inputs: readonly TensorView[], attributes: SplitAttributes): SplitAttributes => {
const splitSizes: number[] = [];
let numOutputs: number = attributes.numOutputs;
if (inputs[1].dims[0] > 0) {
inputs[1].getBigInt64Array().forEach(v => splitSizes.push(Number(v)));
numOutputs = splitSizes.length;
}
return createAttributeWithCacheKey({numOutputs: attributes.numOutputs, axis: attributes.axis, splitSizes});
return createAttributeWithCacheKey({numOutputs, axis: attributes.axis, splitSizes});
};
const calculateOutputIndexImpl = (numberOfTensors: number): string => `
@ -114,7 +116,7 @@ const createSplitProgramInfoLoader =
const updatedAttributes = inputs.length === 1 ? attributes : createSplitAttributesFromInputs(inputs, attributes);
const metadata:
ProgramMetadata = {name: 'Split', inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey};
return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], attributes)};
return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], updatedAttributes)};
};
export const split = (context: ComputeContext, attributes: SplitAttributes): void => {

View file

@ -25,8 +25,9 @@ class Split : public JsKernel, public SplitBase {
if (num_outputs_ < 0) {
num_outputs_ = split_sizes.size();
}
} else if (split_sizes_.size() == 0) {
// Compute split_sizes from input shape and num_outputs
} else if (split_sizes_.size() == 0 && info.GetInputCount() < 2) {
// Compute split_sizes from input shape and num_outputs.
// TODO: Shape might not be known at this point, better to handle this in javascript
auto total_split_size = info.node().InputDefs()[0]->Shape()->dim(gsl::narrow_cast<int32_t>(axis_)).dim_value();
int64_t split_size_sum = 0;
if (num_outputs_ < 0) {
@ -44,6 +45,7 @@ class Split : public JsKernel, public SplitBase {
ORT_ENFORCE(split_size_sum == total_split_size,
"Sum of split sizes (", split_size_sum, ") does not match input size (", total_split_size, ")");
}
// else: let javascript handle all other cases, ie. split_sizes come as input[1]
JSEP_INIT_KERNEL_ATTRIBUTE(Split, ({"axis" : $1,
"numOutputs" : $2,