mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-10 00:38:54 +00:00
fix webgpu split (#17258)
fix webgpu split for the case of split_sizes coming from input[1]
This commit is contained in:
parent
d76dbc4fc3
commit
d3d3dde844
2 changed files with 8 additions and 4 deletions
|
|
@ -23,10 +23,12 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
|
|||
const createSplitAttributesFromInputs =
|
||||
(inputs: readonly TensorView[], attributes: SplitAttributes): SplitAttributes => {
|
||||
const splitSizes: number[] = [];
|
||||
let numOutputs: number = attributes.numOutputs;
|
||||
if (inputs[1].dims[0] > 0) {
|
||||
inputs[1].getBigInt64Array().forEach(v => splitSizes.push(Number(v)));
|
||||
numOutputs = splitSizes.length;
|
||||
}
|
||||
return createAttributeWithCacheKey({numOutputs: attributes.numOutputs, axis: attributes.axis, splitSizes});
|
||||
return createAttributeWithCacheKey({numOutputs, axis: attributes.axis, splitSizes});
|
||||
};
|
||||
|
||||
const calculateOutputIndexImpl = (numberOfTensors: number): string => `
|
||||
|
|
@ -114,7 +116,7 @@ const createSplitProgramInfoLoader =
|
|||
const updatedAttributes = inputs.length === 1 ? attributes : createSplitAttributesFromInputs(inputs, attributes);
|
||||
const metadata:
|
||||
ProgramMetadata = {name: 'Split', inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey};
|
||||
return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], attributes)};
|
||||
return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], updatedAttributes)};
|
||||
};
|
||||
|
||||
export const split = (context: ComputeContext, attributes: SplitAttributes): void => {
|
||||
|
|
|
|||
|
|
@ -25,8 +25,9 @@ class Split : public JsKernel, public SplitBase {
|
|||
if (num_outputs_ < 0) {
|
||||
num_outputs_ = split_sizes.size();
|
||||
}
|
||||
} else if (split_sizes_.size() == 0) {
|
||||
// Compute split_sizes from input shape and num_outputs
|
||||
} else if (split_sizes_.size() == 0 && info.GetInputCount() < 2) {
|
||||
// Compute split_sizes from input shape and num_outputs.
|
||||
// TODO: Shape might not be known at this point, better to handle this in javascript
|
||||
auto total_split_size = info.node().InputDefs()[0]->Shape()->dim(gsl::narrow_cast<int32_t>(axis_)).dim_value();
|
||||
int64_t split_size_sum = 0;
|
||||
if (num_outputs_ < 0) {
|
||||
|
|
@ -44,6 +45,7 @@ class Split : public JsKernel, public SplitBase {
|
|||
ORT_ENFORCE(split_size_sum == total_split_size,
|
||||
"Sum of split sizes (", split_size_sum, ") does not match input size (", total_split_size, ")");
|
||||
}
|
||||
// else: let javascript handle all other cases, ie. split_sizes come as input[1]
|
||||
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(Split, ({"axis" : $1,
|
||||
"numOutputs" : $2,
|
||||
|
|
|
|||
Loading…
Reference in a new issue