mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-21 02:18:09 +00:00
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
177 lines
6.4 KiB
TypeScript
177 lines
6.4 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import { DataType } from '../../../wasm-common';
|
|
import { TensorView } from '../../tensor-view';
|
|
import { ShapeUtil } from '../../util';
|
|
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
|
|
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
|
|
|
|
import {
|
|
atomicOutputVariable,
|
|
createTensorShapeVariables,
|
|
inputVariable,
|
|
outputVariable,
|
|
ShaderHelper,
|
|
} from './common';
|
|
|
|
export interface ScatterNDAttributes extends AttributeWithCacheKey {
|
|
reduction: string;
|
|
}
|
|
|
|
type ReductionType = 'i32' | 'u32' | 'f32';
|
|
|
|
const atomicReductionSnippet = (reduction: string, ptr: string, v: string, type: ReductionType) => {
|
|
if (reduction !== 'none' && type !== 'i32' && type !== 'u32' && type !== 'f32') {
|
|
throw new Error(`Input ${type} is not supported with reduction ${reduction}.`);
|
|
}
|
|
|
|
const floatStart = `{
|
|
var oldValue = 0;
|
|
loop {
|
|
let newValueF32 =`;
|
|
const floatEnd = `;
|
|
let newValue = bitcast<i32>(newValueF32);
|
|
let res = atomicCompareExchangeWeak(&${ptr}, oldValue, newValue);
|
|
if res.exchanged {
|
|
break;
|
|
}
|
|
oldValue = res.old_value;
|
|
}
|
|
}`;
|
|
|
|
switch (reduction) {
|
|
case 'none':
|
|
return `${ptr}=${v};`;
|
|
case 'add':
|
|
if (type === 'i32' || type === 'u32') {
|
|
return `atomicAdd(&${ptr}, bitcast<${type}>(${v}));`;
|
|
} else {
|
|
// atomicAdd only supports uint/int type. For float, we use
|
|
// atomicCompareExchangeWeak to simulate.
|
|
return `
|
|
${floatStart}bitcast<${type}>(oldValue) + (${v})${floatEnd}`;
|
|
}
|
|
case 'max':
|
|
if (type === 'i32' || type === 'u32') {
|
|
return `atomicMax(&${ptr}, bitcast<${type}>(${v}));`;
|
|
} else {
|
|
// atomicMax only supports uint/int type. For float, we use
|
|
// atomicCompareExchangeWeak to simulate.
|
|
return `
|
|
${floatStart}max(bitcast<f32>(oldValue), (${v}))${floatEnd}`;
|
|
}
|
|
case 'min':
|
|
if (type === 'i32' || type === 'u32') {
|
|
return `atomicMin(&${ptr}, bitcast<${type}>(${v}));`;
|
|
} else {
|
|
// atomicMin only supports uint/int type. For float, we use
|
|
// atomicCompareExchangeWeak to simulate.
|
|
return `${floatStart}min(bitcast<${type}>(oldValue), (${v}))${floatEnd}`;
|
|
}
|
|
case 'mul':
|
|
// atomicMul is not supported, we use atomicCompareExchangeWeak to simulate.
|
|
return `${floatStart}(bitcast<${type}>(oldValue) * (${v}))${floatEnd}`;
|
|
|
|
default:
|
|
throw new Error(`Reduction ${reduction} is not supported.`);
|
|
}
|
|
};
|
|
|
|
const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: ScatterNDAttributes): ProgramInfo => {
|
|
const inputShape = inputs[0].dims;
|
|
const indicesShape = inputs[1].dims;
|
|
const outputShape = inputShape;
|
|
// TODO: support bool with components 4.
|
|
const components = 1;
|
|
const outputSize = Math.ceil(ShapeUtil.size(indicesShape) / components);
|
|
const lastIndexDimension = indicesShape[indicesShape.length - 1];
|
|
const numUpdatesElements = ShapeUtil.sizeFromDimension(inputShape, lastIndexDimension);
|
|
|
|
const programUniforms: ProgramUniform[] = [
|
|
{ type: DataType.uint32, data: outputSize },
|
|
{ type: DataType.uint32, data: lastIndexDimension },
|
|
{ type: DataType.uint32, data: numUpdatesElements },
|
|
...createTensorShapeVariables(inputs[1].dims, inputs[2].dims, outputShape),
|
|
];
|
|
|
|
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
|
const indices = inputVariable('indices', inputs[1].dataType, inputs[1].dims.length);
|
|
const updates = inputVariable('updates', inputs[2].dataType, inputs[2].dims.length, components);
|
|
const output =
|
|
attributes.reduction !== 'none' && attributes.reduction !== ''
|
|
? atomicOutputVariable('output', inputs[0].dataType, outputShape.length)
|
|
: outputVariable('output', inputs[0].dataType, outputShape.length, components);
|
|
|
|
return `
|
|
${shaderHelper
|
|
.registerUniform('output_size', 'u32')
|
|
.registerUniform('last_index_dimension', 'u32')
|
|
.registerUniform('num_updates_elements', 'u32')
|
|
.declareVariables(indices, updates, output)}
|
|
${shaderHelper.mainStart()}
|
|
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
|
|
var data_offset = 0u;
|
|
let indices_start = uniforms.last_index_dimension * global_idx;
|
|
let indices_end = indices_start + uniforms.last_index_dimension;
|
|
for (var i = indices_start; i < indices_end; i++) {
|
|
var index = i32(indices[i].x);
|
|
${
|
|
inputs[0].dims.length === 1
|
|
? `
|
|
let element_count_dim = uniforms.output_strides;
|
|
let dim_value = uniforms.output_shape;`
|
|
: `
|
|
let element_count_dim = uniforms.output_strides[i - indices_start];
|
|
let dim_value = uniforms.output_shape[i - indices_start + uniforms.last_index_dimension];`
|
|
}
|
|
if (index >= 0) {
|
|
if (index >= i32(dim_value)) {
|
|
index = i32(dim_value - 1);
|
|
}
|
|
} else {
|
|
if (index < -i32(dim_value)) {
|
|
index = 0;
|
|
} else {
|
|
index += i32(dim_value);
|
|
}
|
|
}
|
|
data_offset += u32((u32(index) * element_count_dim));
|
|
}
|
|
|
|
for (var i = 0u; i < uniforms.num_updates_elements; i++) {
|
|
let value = updates[uniforms.num_updates_elements * global_idx + i];
|
|
${atomicReductionSnippet(
|
|
attributes.reduction,
|
|
'output[data_offset + i]',
|
|
'value',
|
|
output.type.value as ReductionType,
|
|
)}
|
|
}
|
|
|
|
}`;
|
|
};
|
|
return {
|
|
name: 'ScatterND',
|
|
shaderCache: {
|
|
hint: `${attributes.cacheKey}_${attributes.reduction}`,
|
|
inputDependencies: ['rank', 'rank'],
|
|
},
|
|
getRunData: () => ({
|
|
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
|
|
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
|
|
programUniforms,
|
|
}),
|
|
getShaderSource,
|
|
};
|
|
};
|
|
|
|
export const parseScatterNDAttributes = (attributes: Record<string, unknown>): ScatterNDAttributes =>
|
|
createAttributeWithCacheKey({ reduction: attributes.reduction as string });
|
|
|
|
export const scatterND = (context: ComputeContext, attributes: ScatterNDAttributes): void => {
|
|
context.compute(createScatterNDProgramInfo(context.inputs, attributes), {
|
|
inputs: [context.inputs[1], context.inputs[2]],
|
|
outputs: [],
|
|
});
|
|
};
|