2021-04-27 07:04:25 +00:00
|
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
|
// Licensed under the MIT License.
|
|
|
|
|
|
2024-08-14 23:51:22 +00:00
|
|
|
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
|
|
|
|
|
import { Graph } from '../../../graph';
|
|
|
|
|
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
|
|
|
|
|
import { Tensor } from '../../../tensor';
|
|
|
|
|
import { ShapeUtil, SplitUtil } from '../../../util';
|
|
|
|
|
import { WebGLInferenceHandler } from '../inference-handler';
|
|
|
|
|
import { ProgramInfo, TextureType } from '../types';
|
2021-08-12 19:30:49 +00:00
|
|
|
|
|
|
|
|
export interface SplitAttributes extends AttributeWithCacheKey {
|
|
|
|
|
readonly axis: number;
|
|
|
|
|
readonly split: number[];
|
|
|
|
|
readonly numOutputs: number;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const splitProgramMetadata = {
|
|
|
|
|
name: 'Split',
|
|
|
|
|
inputNames: ['A'],
|
|
|
|
|
inputTypes: [TextureType.unpacked],
|
|
|
|
|
};
|
|
|
|
|
|
2024-08-14 23:51:22 +00:00
|
|
|
export const split: OperatorImplementation<SplitAttributes> = (
|
|
|
|
|
inferenceHandler: WebGLInferenceHandler,
|
|
|
|
|
inputs: Tensor[],
|
|
|
|
|
attributes: SplitAttributes,
|
|
|
|
|
): Tensor[] => {
|
|
|
|
|
validateInputs(inputs);
|
2021-08-12 19:30:49 +00:00
|
|
|
|
2024-08-14 23:51:22 +00:00
|
|
|
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputs[0].dims.length);
|
|
|
|
|
const count = getProgramCount(inferenceHandler, inputs, axis, attributes);
|
|
|
|
|
const output: Tensor[] = [];
|
|
|
|
|
for (let i = 0; i < count; ++i) {
|
|
|
|
|
output.push(
|
|
|
|
|
inferenceHandler.run(
|
|
|
|
|
{
|
|
|
|
|
...splitProgramMetadata,
|
|
|
|
|
cacheHint: `${attributes.cacheKey};${i}`,
|
|
|
|
|
get: () => createSplitProgramInfo(inferenceHandler, inputs[0], attributes, axis, i),
|
|
|
|
|
},
|
|
|
|
|
inputs,
|
|
|
|
|
),
|
|
|
|
|
);
|
|
|
|
|
}
|
2021-08-12 19:30:49 +00:00
|
|
|
|
2024-08-14 23:51:22 +00:00
|
|
|
return output;
|
|
|
|
|
};
|
2021-08-12 19:30:49 +00:00
|
|
|
|
|
|
|
|
export const parseSplitAttributes: OperatorInitialization<SplitAttributes> = (node: Graph.Node): SplitAttributes => {
|
|
|
|
|
const axis = node.attributes.getInt('axis', 0);
|
|
|
|
|
const split = node.attributes.getInts('split', []);
|
|
|
|
|
const numOutputs = node.outputs.length;
|
2024-08-14 23:51:22 +00:00
|
|
|
return createAttributeWithCacheKey({ axis, split, numOutputs });
|
2021-08-12 19:30:49 +00:00
|
|
|
};
|
|
|
|
|
|
2024-08-14 23:51:22 +00:00
|
|
|
const getProgramCount = (
|
|
|
|
|
_inferenceHandler: WebGLInferenceHandler,
|
|
|
|
|
inputs: Tensor[],
|
|
|
|
|
axis: number,
|
|
|
|
|
attributes: SplitAttributes,
|
|
|
|
|
): number => {
|
|
|
|
|
const [, offsets] = SplitUtil.splitShape(inputs[0].dims, axis, attributes.split, attributes.numOutputs);
|
|
|
|
|
return offsets.length;
|
|
|
|
|
};
|
2021-08-12 19:30:49 +00:00
|
|
|
|
2024-08-14 23:51:22 +00:00
|
|
|
const createSplitProgramInfo = (
|
|
|
|
|
_inferenceHandler: WebGLInferenceHandler,
|
|
|
|
|
input: Tensor,
|
|
|
|
|
attributes: SplitAttributes,
|
|
|
|
|
axis: number,
|
|
|
|
|
index: number,
|
|
|
|
|
): ProgramInfo => {
|
|
|
|
|
const [shapes, offsets] = SplitUtil.splitShape(input.dims, axis, attributes.split, attributes.numOutputs);
|
|
|
|
|
const offset = offsets[index];
|
|
|
|
|
const outputShape = shapes[index];
|
|
|
|
|
const rank = outputShape.length;
|
|
|
|
|
const shaderSource = `
|
2021-04-27 07:04:25 +00:00
|
|
|
float process(int indices[${rank}]) {
|
|
|
|
|
indices[${axis}] += ${offset};
|
|
|
|
|
return _A(indices);
|
2021-08-12 19:30:49 +00:00
|
|
|
}
|
|
|
|
|
`;
|
2024-08-14 23:51:22 +00:00
|
|
|
return {
|
|
|
|
|
...splitProgramMetadata,
|
|
|
|
|
cacheHint: `${attributes.cacheKey}:${index}`,
|
|
|
|
|
output: { dims: outputShape, type: input.type, textureType: TextureType.unpacked },
|
|
|
|
|
shaderSource,
|
|
|
|
|
};
|
|
|
|
|
};
|
2021-08-12 19:30:49 +00:00
|
|
|
|
|
|
|
|
const validateInputs = (inputs: Tensor[]): void => {
|
|
|
|
|
if (!inputs || inputs.length !== 1) {
|
|
|
|
|
throw new Error('Split requires one input.');
|
2021-04-27 07:04:25 +00:00
|
|
|
}
|
2021-08-12 19:30:49 +00:00
|
|
|
|
2024-08-14 23:51:22 +00:00
|
|
|
if (
|
|
|
|
|
inputs[0].type !== 'int8' &&
|
|
|
|
|
inputs[0].type !== 'uint8' &&
|
|
|
|
|
inputs[0].type !== 'int16' &&
|
|
|
|
|
inputs[0].type !== 'uint16' &&
|
|
|
|
|
inputs[0].type !== 'int32' &&
|
|
|
|
|
inputs[0].type !== 'uint32' &&
|
|
|
|
|
inputs[0].type !== 'float32' &&
|
|
|
|
|
inputs[0].type !== 'float64' &&
|
|
|
|
|
inputs[0].type !== 'bool'
|
|
|
|
|
) {
|
2021-08-12 19:30:49 +00:00
|
|
|
throw new Error('Invalid input type.');
|
2021-04-27 07:04:25 +00:00
|
|
|
}
|
2021-08-12 19:30:49 +00:00
|
|
|
};
|