mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
* add web * add script and test * fix lint * add test/data/ops * add test/data/node/ to gitignore * modify scripts * add onnxjs * fix tests * fix test-runner * fix sourcemap * fix onnxjs profiling * update test list * update README * resolve comments * set wasm as default backend * rename package * update copyright header * do not use class "Buffer" in browser context * revise readme
62 lines
2.6 KiB
TypeScript
62 lines
2.6 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import {Split} from '../../../ops/split';
|
|
import {Tensor} from '../../../tensor';
|
|
import {ShapeUtil, SplitUtil} from '../../../util';
|
|
import {WebGLInferenceHandler} from '../inference-handler';
|
|
import {Artifact, ProgramInfo, RunData} from '../types';
|
|
|
|
export class WebGLSplit extends Split {
|
|
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
|
if (!this.artifacts) {
|
|
this.artifacts = [];
|
|
const axis = ShapeUtil.normalizeAxis(this.axis, inputs[0].dims.length);
|
|
const count = this.getProgramCount(inferenceHandler, inputs, axis);
|
|
for (let i = 0; i < count; ++i) {
|
|
const programInfo = this.createProgramInfo(inferenceHandler, inputs[0], axis, i);
|
|
const artifact = inferenceHandler.session.programManager.build(programInfo);
|
|
this.artifacts.push(artifact);
|
|
}
|
|
}
|
|
const results: Tensor[] = [];
|
|
|
|
this.artifacts.forEach(artifact => {
|
|
const rundata = this.createRunData(inferenceHandler, artifact.programInfo, inputs);
|
|
inferenceHandler.session.programManager.run(artifact, rundata);
|
|
results.push(rundata.outputTextureData.tensor);
|
|
});
|
|
return results;
|
|
}
|
|
getProgramCount(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], axis: number): number {
|
|
const [, offsets] = SplitUtil.splitShape(inputs[0].dims, axis, this.split, this.numOutputs);
|
|
return offsets.length;
|
|
}
|
|
createProgramInfo(inferenceHandler: WebGLInferenceHandler, input: Tensor, axis: number, index: number): ProgramInfo {
|
|
const [shapes, offsets] = SplitUtil.splitShape(input.dims, axis, this.split, this.numOutputs);
|
|
const offset = offsets[index];
|
|
const outputShape = shapes[index];
|
|
const rank = outputShape.length;
|
|
const shaderSource = `
|
|
float process(int indices[${rank}]) {
|
|
indices[${axis}] += ${offset};
|
|
return _A(indices);
|
|
}`;
|
|
return {
|
|
inputLayouts: [inferenceHandler.getOrCreateTextureLayout(input)],
|
|
outputLayout: inferenceHandler.createTextureLayoutFromShape(outputShape),
|
|
samplers: ['A'],
|
|
shaderSource,
|
|
};
|
|
}
|
|
createRunData(inferenceHandler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
|
const inputTDs = [inferenceHandler.getOrCreateTextureData(inputs[0], programInfo.inputLayouts[0])];
|
|
return {
|
|
inputTextureDatas: inputTDs,
|
|
outputTextureData:
|
|
inferenceHandler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
|
uniformData: {}
|
|
};
|
|
}
|
|
protected artifacts: Artifact[];
|
|
}
|