2023-08-08 16:09:37 +00:00
|
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
|
// Licensed under the MIT License.
|
|
|
|
|
|
2023-10-18 17:47:41 +00:00
|
|
|
import {DataType} from '../../../wasm-common';
|
2023-09-15 04:14:44 +00:00
|
|
|
import {TensorView} from '../../tensor-view';
|
2023-08-08 16:09:37 +00:00
|
|
|
import {ShapeUtil} from '../../util';
|
2024-01-09 22:56:00 +00:00
|
|
|
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
|
2023-08-08 16:09:37 +00:00
|
|
|
|
2024-01-09 22:56:00 +00:00
|
|
|
import {createTensorShapeVariables, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
|
2023-08-08 16:09:37 +00:00
|
|
|
|
2024-01-09 22:56:00 +00:00
|
|
|
export interface InstanceNormAttributes {
|
2023-08-08 16:09:37 +00:00
|
|
|
epsilon: number;
|
|
|
|
|
format: 'NHWC'|'NCHW';
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const createInstanceNormProgramInfo =
|
2023-10-10 07:31:12 +00:00
|
|
|
(inputs: readonly TensorView[], attributes: InstanceNormAttributes): ProgramInfo => {
|
2023-08-08 16:09:37 +00:00
|
|
|
const xShape = inputs[0].dims;
|
|
|
|
|
const outputShape = xShape;
|
|
|
|
|
const axis = 2;
|
|
|
|
|
const normCount = ShapeUtil.sizeToDimension(xShape, axis);
|
|
|
|
|
const normSize = ShapeUtil.sizeFromDimension(xShape, axis);
|
2023-12-15 19:26:15 +00:00
|
|
|
const components = getMaxComponents(normSize);
|
|
|
|
|
const normPackedSize = normSize / components;
|
2024-01-09 22:56:00 +00:00
|
|
|
const inputShape = [xShape[0], xShape[1], normPackedSize];
|
|
|
|
|
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type'];
|
|
|
|
|
const programUniforms: ProgramUniform[] =
|
|
|
|
|
[{type: 'uint32', data: normSize}, {type: 'uint32', data: normPackedSize}];
|
|
|
|
|
programUniforms.push(...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape));
|
|
|
|
|
|
|
|
|
|
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
|
|
|
|
const x = inputVariable('x', inputs[0].dataType, inputShape.length, components);
|
|
|
|
|
const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims);
|
|
|
|
|
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims);
|
|
|
|
|
const output = outputVariable('output', inputs[0].dataType, inputShape.length, components);
|
|
|
|
|
const variables = [x, scale, bias, output];
|
|
|
|
|
const dataType = x.type.value;
|
|
|
|
|
const f32Type = components === 1 ? 'f32' : `vec${components}<f32>`;
|
|
|
|
|
const workgroupSize = 64;
|
|
|
|
|
|
|
|
|
|
const uniforms: UniformsArrayType = [{name: 'normSize', type: 'u32'}, {name: 'normPackedSize', type: 'u32'}];
|
|
|
|
|
return `
|
2023-12-15 19:26:15 +00:00
|
|
|
var<workgroup> meanShared : f32;
|
|
|
|
|
var<workgroup> squaredNormShared : f32;
|
|
|
|
|
var<workgroup> workgroupShared : array<${f32Type}, ${workgroupSize}>;
|
[js/webgpu] Optimize InstanceNormalization (#17491)
### Description
<!-- Describe your changes. -->
In previous implementation, there are two loops to iterate H * W
elements to calculate the `mean` and `squaredNorm` value in one thread,
meanwhile it outputs H * W elements in one thread. That results it's
very very slow when H * W is a large value. And usually, H * W does be a
large value in a model. For example, in the `candy-8` model, the shapes
of [H, W] are [224,224], [112,112], [56,56] for `InstanceNormalization`
op. And in my ADL, `[1,224,224,32]` consumes 17 ms. See below:
```
[profiling] kernel "23848328|[InstanceNormalization] 23848328" input[0]: [1,224,224,32] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,224,224,32] | float32, execution time: 17007914 ns
```
In this PR, it uses workgroup memory to optimize the original algorithm.
The advantage is that it can parallelly utilize the 64 (workgroupSize)
threads in one workgroup to calculate `mean` and `squaredNorm` value.
Meanwhile, it only outputs `H * W / workgroupSize` outputs for one
thread, which greatly reduces the overhead for one thread. With this
optimization, `[1,224,224,32]` becomes 3 ms and the main overhead is the
extra two `transpose`. The `createInstanceNormProgramInfo` only needs
`0.64` ms. See below:
```
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,224,224,32] | float32, output[0]: [1,32,224,224] | float32, execution time: 1543792 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,32,224,224] | float32, execution time: 642652 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, output[0]: [1,224,224,32] | float32, execution time: 991608 ns
```
This PR currently only applies the new algorithm to NCHW format. For
NHWC format, one way is to transpose the input so that it can use the
new algorithm. But the disadvantage is that 2 extra transpose are added.
@dakenf also gives another way to optimize NHWC. Details see
[here](https://github.com/microsoft/onnxruntime/blob/d45a96616da9843b037210f2d48d6b4e5bdae5c6/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts).
I checked @dakenf's method. The perf is similar with transpose +
optimized NCHW. But on different GPUs, one is a little better than
another or vice versa. So I prefer this PR only does the NCHW part.
@dakenf can submit his optimization on NHWC.
2023-09-15 00:03:18 +00:00
|
|
|
const workgroupSize = ${workgroupSize}u;
|
2024-01-09 22:56:00 +00:00
|
|
|
${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)}
|
[js/webgpu] Optimize InstanceNormalization (#17491)
### Description
<!-- Describe your changes. -->
In previous implementation, there are two loops to iterate H * W
elements to calculate the `mean` and `squaredNorm` value in one thread,
meanwhile it outputs H * W elements in one thread. That results it's
very very slow when H * W is a large value. And usually, H * W does be a
large value in a model. For example, in the `candy-8` model, the shapes
of [H, W] are [224,224], [112,112], [56,56] for `InstanceNormalization`
op. And in my ADL, `[1,224,224,32]` consumes 17 ms. See below:
```
[profiling] kernel "23848328|[InstanceNormalization] 23848328" input[0]: [1,224,224,32] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,224,224,32] | float32, execution time: 17007914 ns
```
In this PR, it uses workgroup memory to optimize the original algorithm.
The advantage is that it can parallelly utilize the 64 (workgroupSize)
threads in one workgroup to calculate `mean` and `squaredNorm` value.
Meanwhile, it only outputs `H * W / workgroupSize` outputs for one
thread, which greatly reduces the overhead for one thread. With this
optimization, `[1,224,224,32]` becomes 3 ms and the main overhead is the
extra two `transpose`. The `createInstanceNormProgramInfo` only needs
`0.64` ms. See below:
```
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,224,224,32] | float32, output[0]: [1,32,224,224] | float32, execution time: 1543792 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,32,224,224] | float32, execution time: 642652 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, output[0]: [1,224,224,32] | float32, execution time: 991608 ns
```
This PR currently only applies the new algorithm to NCHW format. For
NHWC format, one way is to transpose the input so that it can use the
new algorithm. But the disadvantage is that 2 extra transpose are added.
@dakenf also gives another way to optimize NHWC. Details see
[here](https://github.com/microsoft/onnxruntime/blob/d45a96616da9843b037210f2d48d6b4e5bdae5c6/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts).
I checked @dakenf's method. The perf is similar with transpose +
optimized NCHW. But on different GPUs, one is a little better than
another or vice versa. So I prefer this PR only does the NCHW part.
@dakenf can submit his optimization on NHWC.
2023-09-15 00:03:18 +00:00
|
|
|
${shaderHelper.mainStart(workgroupSize)}
|
|
|
|
|
let norm = global_idx / workgroupSize;
|
2024-01-09 22:56:00 +00:00
|
|
|
let batch = norm / uniforms.x_shape[1];
|
|
|
|
|
let channel = norm % uniforms.x_shape[1];
|
[js/webgpu] Optimize InstanceNormalization (#17491)
### Description
<!-- Describe your changes. -->
In previous implementation, there are two loops to iterate H * W
elements to calculate the `mean` and `squaredNorm` value in one thread,
meanwhile it outputs H * W elements in one thread. That results it's
very very slow when H * W is a large value. And usually, H * W does be a
large value in a model. For example, in the `candy-8` model, the shapes
of [H, W] are [224,224], [112,112], [56,56] for `InstanceNormalization`
op. And in my ADL, `[1,224,224,32]` consumes 17 ms. See below:
```
[profiling] kernel "23848328|[InstanceNormalization] 23848328" input[0]: [1,224,224,32] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,224,224,32] | float32, execution time: 17007914 ns
```
In this PR, it uses workgroup memory to optimize the original algorithm.
The advantage is that it can parallelly utilize the 64 (workgroupSize)
threads in one workgroup to calculate `mean` and `squaredNorm` value.
Meanwhile, it only outputs `H * W / workgroupSize` outputs for one
thread, which greatly reduces the overhead for one thread. With this
optimization, `[1,224,224,32]` becomes 3 ms and the main overhead is the
extra two `transpose`. The `createInstanceNormProgramInfo` only needs
`0.64` ms. See below:
```
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,224,224,32] | float32, output[0]: [1,32,224,224] | float32, execution time: 1543792 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,32,224,224] | float32, execution time: 642652 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, output[0]: [1,224,224,32] | float32, execution time: 991608 ns
```
This PR currently only applies the new algorithm to NCHW format. For
NHWC format, one way is to transpose the input so that it can use the
new algorithm. But the disadvantage is that 2 extra transpose are added.
@dakenf also gives another way to optimize NHWC. Details see
[here](https://github.com/microsoft/onnxruntime/blob/d45a96616da9843b037210f2d48d6b4e5bdae5c6/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts).
I checked @dakenf's method. The perf is similar with transpose +
optimized NCHW. But on different GPUs, one is a little better than
another or vice versa. So I prefer this PR only does the NCHW part.
@dakenf can submit his optimization on NHWC.
2023-09-15 00:03:18 +00:00
|
|
|
let localIndex = local_id.x;
|
|
|
|
|
|
|
|
|
|
// initialize workgroup memory
|
2023-12-15 19:26:15 +00:00
|
|
|
var initial = ${f32Type}(0);
|
2024-01-09 22:56:00 +00:00
|
|
|
for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) {
|
2023-12-15 19:26:15 +00:00
|
|
|
initial = initial + ${f32Type}(${x.get('batch', 'channel', 'h')});
|
[js/webgpu] Optimize InstanceNormalization (#17491)
### Description
<!-- Describe your changes. -->
In previous implementation, there are two loops to iterate H * W
elements to calculate the `mean` and `squaredNorm` value in one thread,
meanwhile it outputs H * W elements in one thread. That results it's
very very slow when H * W is a large value. And usually, H * W does be a
large value in a model. For example, in the `candy-8` model, the shapes
of [H, W] are [224,224], [112,112], [56,56] for `InstanceNormalization`
op. And in my ADL, `[1,224,224,32]` consumes 17 ms. See below:
```
[profiling] kernel "23848328|[InstanceNormalization] 23848328" input[0]: [1,224,224,32] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,224,224,32] | float32, execution time: 17007914 ns
```
In this PR, it uses workgroup memory to optimize the original algorithm.
The advantage is that it can parallelly utilize the 64 (workgroupSize)
threads in one workgroup to calculate `mean` and `squaredNorm` value.
Meanwhile, it only outputs `H * W / workgroupSize` outputs for one
thread, which greatly reduces the overhead for one thread. With this
optimization, `[1,224,224,32]` becomes 3 ms and the main overhead is the
extra two `transpose`. The `createInstanceNormProgramInfo` only needs
`0.64` ms. See below:
```
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,224,224,32] | float32, output[0]: [1,32,224,224] | float32, execution time: 1543792 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,32,224,224] | float32, execution time: 642652 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, output[0]: [1,224,224,32] | float32, execution time: 991608 ns
```
This PR currently only applies the new algorithm to NCHW format. For
NHWC format, one way is to transpose the input so that it can use the
new algorithm. But the disadvantage is that 2 extra transpose are added.
@dakenf also gives another way to optimize NHWC. Details see
[here](https://github.com/microsoft/onnxruntime/blob/d45a96616da9843b037210f2d48d6b4e5bdae5c6/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts).
I checked @dakenf's method. The perf is similar with transpose +
optimized NCHW. But on different GPUs, one is a little better than
another or vice versa. So I prefer this PR only does the NCHW part.
@dakenf can submit his optimization on NHWC.
2023-09-15 00:03:18 +00:00
|
|
|
}
|
|
|
|
|
workgroupShared[localIndex] = initial;
|
|
|
|
|
workgroupBarrier();
|
2023-08-08 16:09:37 +00:00
|
|
|
|
[js/webgpu] Optimize InstanceNormalization (#17491)
### Description
<!-- Describe your changes. -->
In previous implementation, there are two loops to iterate H * W
elements to calculate the `mean` and `squaredNorm` value in one thread,
meanwhile it outputs H * W elements in one thread. That results it's
very very slow when H * W is a large value. And usually, H * W does be a
large value in a model. For example, in the `candy-8` model, the shapes
of [H, W] are [224,224], [112,112], [56,56] for `InstanceNormalization`
op. And in my ADL, `[1,224,224,32]` consumes 17 ms. See below:
```
[profiling] kernel "23848328|[InstanceNormalization] 23848328" input[0]: [1,224,224,32] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,224,224,32] | float32, execution time: 17007914 ns
```
In this PR, it uses workgroup memory to optimize the original algorithm.
The advantage is that it can parallelly utilize the 64 (workgroupSize)
threads in one workgroup to calculate `mean` and `squaredNorm` value.
Meanwhile, it only outputs `H * W / workgroupSize` outputs for one
thread, which greatly reduces the overhead for one thread. With this
optimization, `[1,224,224,32]` becomes 3 ms and the main overhead is the
extra two `transpose`. The `createInstanceNormProgramInfo` only needs
`0.64` ms. See below:
```
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,224,224,32] | float32, output[0]: [1,32,224,224] | float32, execution time: 1543792 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,32,224,224] | float32, execution time: 642652 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, output[0]: [1,224,224,32] | float32, execution time: 991608 ns
```
This PR currently only applies the new algorithm to NCHW format. For
NHWC format, one way is to transpose the input so that it can use the
new algorithm. But the disadvantage is that 2 extra transpose are added.
@dakenf also gives another way to optimize NHWC. Details see
[here](https://github.com/microsoft/onnxruntime/blob/d45a96616da9843b037210f2d48d6b4e5bdae5c6/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts).
I checked @dakenf's method. The perf is similar with transpose +
optimized NCHW. But on different GPUs, one is a little better than
another or vice versa. So I prefer this PR only does the NCHW part.
@dakenf can submit his optimization on NHWC.
2023-09-15 00:03:18 +00:00
|
|
|
// Calculate the mean of current channel data.
|
|
|
|
|
for (var currSize = workgroupSize >> 1; currSize > 0; currSize = currSize >> 1) {
|
|
|
|
|
if (localIndex < currSize) {
|
|
|
|
|
workgroupShared[localIndex] = workgroupShared[localIndex] + workgroupShared[localIndex + currSize];
|
|
|
|
|
}
|
|
|
|
|
workgroupBarrier();
|
|
|
|
|
}
|
|
|
|
|
if (localIndex == 0) {
|
2024-01-09 22:56:00 +00:00
|
|
|
meanShared = ${sumVector('workgroupShared[0]', components)} / f32(uniforms.normSize);
|
[js/webgpu] Optimize InstanceNormalization (#17491)
### Description
<!-- Describe your changes. -->
In previous implementation, there are two loops to iterate H * W
elements to calculate the `mean` and `squaredNorm` value in one thread,
meanwhile it outputs H * W elements in one thread. That results it's
very very slow when H * W is a large value. And usually, H * W does be a
large value in a model. For example, in the `candy-8` model, the shapes
of [H, W] are [224,224], [112,112], [56,56] for `InstanceNormalization`
op. And in my ADL, `[1,224,224,32]` consumes 17 ms. See below:
```
[profiling] kernel "23848328|[InstanceNormalization] 23848328" input[0]: [1,224,224,32] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,224,224,32] | float32, execution time: 17007914 ns
```
In this PR, it uses workgroup memory to optimize the original algorithm.
The advantage is that it can parallelly utilize the 64 (workgroupSize)
threads in one workgroup to calculate `mean` and `squaredNorm` value.
Meanwhile, it only outputs `H * W / workgroupSize` outputs for one
thread, which greatly reduces the overhead for one thread. With this
optimization, `[1,224,224,32]` becomes 3 ms and the main overhead is the
extra two `transpose`. The `createInstanceNormProgramInfo` only needs
`0.64` ms. See below:
```
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,224,224,32] | float32, output[0]: [1,32,224,224] | float32, execution time: 1543792 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,32,224,224] | float32, execution time: 642652 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, output[0]: [1,224,224,32] | float32, execution time: 991608 ns
```
This PR currently only applies the new algorithm to NCHW format. For
NHWC format, one way is to transpose the input so that it can use the
new algorithm. But the disadvantage is that 2 extra transpose are added.
@dakenf also gives another way to optimize NHWC. Details see
[here](https://github.com/microsoft/onnxruntime/blob/d45a96616da9843b037210f2d48d6b4e5bdae5c6/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts).
I checked @dakenf's method. The perf is similar with transpose +
optimized NCHW. But on different GPUs, one is a little better than
another or vice versa. So I prefer this PR only does the NCHW part.
@dakenf can submit his optimization on NHWC.
2023-09-15 00:03:18 +00:00
|
|
|
}
|
|
|
|
|
workgroupBarrier();
|
2023-08-08 16:09:37 +00:00
|
|
|
|
[js/webgpu] Optimize InstanceNormalization (#17491)
### Description
<!-- Describe your changes. -->
In previous implementation, there are two loops to iterate H * W
elements to calculate the `mean` and `squaredNorm` value in one thread,
meanwhile it outputs H * W elements in one thread. That results it's
very very slow when H * W is a large value. And usually, H * W does be a
large value in a model. For example, in the `candy-8` model, the shapes
of [H, W] are [224,224], [112,112], [56,56] for `InstanceNormalization`
op. And in my ADL, `[1,224,224,32]` consumes 17 ms. See below:
```
[profiling] kernel "23848328|[InstanceNormalization] 23848328" input[0]: [1,224,224,32] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,224,224,32] | float32, execution time: 17007914 ns
```
In this PR, it uses workgroup memory to optimize the original algorithm.
The advantage is that it can parallelly utilize the 64 (workgroupSize)
threads in one workgroup to calculate `mean` and `squaredNorm` value.
Meanwhile, it only outputs `H * W / workgroupSize` outputs for one
thread, which greatly reduces the overhead for one thread. With this
optimization, `[1,224,224,32]` becomes 3 ms and the main overhead is the
extra two `transpose`. The `createInstanceNormProgramInfo` only needs
`0.64` ms. See below:
```
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,224,224,32] | float32, output[0]: [1,32,224,224] | float32, execution time: 1543792 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,32,224,224] | float32, execution time: 642652 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, output[0]: [1,224,224,32] | float32, execution time: 991608 ns
```
This PR currently only applies the new algorithm to NCHW format. For
NHWC format, one way is to transpose the input so that it can use the
new algorithm. But the disadvantage is that 2 extra transpose are added.
@dakenf also gives another way to optimize NHWC. Details see
[here](https://github.com/microsoft/onnxruntime/blob/d45a96616da9843b037210f2d48d6b4e5bdae5c6/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts).
I checked @dakenf's method. The perf is similar with transpose +
optimized NCHW. But on different GPUs, one is a little better than
another or vice versa. So I prefer this PR only does the NCHW part.
@dakenf can submit his optimization on NHWC.
2023-09-15 00:03:18 +00:00
|
|
|
// reinitialize workgroup memory.
|
2023-12-15 19:26:15 +00:00
|
|
|
initial = ${f32Type}(0);
|
2024-01-09 22:56:00 +00:00
|
|
|
for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) {
|
2023-12-15 19:26:15 +00:00
|
|
|
let deviation = ${f32Type}(${x.get('batch', 'channel', 'h')}) - ${f32Type}(meanShared);
|
[js/webgpu] Optimize InstanceNormalization (#17491)
### Description
<!-- Describe your changes. -->
In previous implementation, there are two loops to iterate H * W
elements to calculate the `mean` and `squaredNorm` value in one thread,
meanwhile it outputs H * W elements in one thread. That results it's
very very slow when H * W is a large value. And usually, H * W does be a
large value in a model. For example, in the `candy-8` model, the shapes
of [H, W] are [224,224], [112,112], [56,56] for `InstanceNormalization`
op. And in my ADL, `[1,224,224,32]` consumes 17 ms. See below:
```
[profiling] kernel "23848328|[InstanceNormalization] 23848328" input[0]: [1,224,224,32] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,224,224,32] | float32, execution time: 17007914 ns
```
In this PR, it uses workgroup memory to optimize the original algorithm.
The advantage is that it can parallelly utilize the 64 (workgroupSize)
threads in one workgroup to calculate `mean` and `squaredNorm` value.
Meanwhile, it only outputs `H * W / workgroupSize` outputs for one
thread, which greatly reduces the overhead for one thread. With this
optimization, `[1,224,224,32]` becomes 3 ms and the main overhead is the
extra two `transpose`. The `createInstanceNormProgramInfo` only needs
`0.64` ms. See below:
```
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,224,224,32] | float32, output[0]: [1,32,224,224] | float32, execution time: 1543792 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,32,224,224] | float32, execution time: 642652 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, output[0]: [1,224,224,32] | float32, execution time: 991608 ns
```
This PR currently only applies the new algorithm to NCHW format. For
NHWC format, one way is to transpose the input so that it can use the
new algorithm. But the disadvantage is that 2 extra transpose are added.
@dakenf also gives another way to optimize NHWC. Details see
[here](https://github.com/microsoft/onnxruntime/blob/d45a96616da9843b037210f2d48d6b4e5bdae5c6/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts).
I checked @dakenf's method. The perf is similar with transpose +
optimized NCHW. But on different GPUs, one is a little better than
another or vice versa. So I prefer this PR only does the NCHW part.
@dakenf can submit his optimization on NHWC.
2023-09-15 00:03:18 +00:00
|
|
|
initial = initial + deviation * deviation;
|
2023-08-08 16:09:37 +00:00
|
|
|
}
|
[js/webgpu] Optimize InstanceNormalization (#17491)
### Description
<!-- Describe your changes. -->
In previous implementation, there are two loops to iterate H * W
elements to calculate the `mean` and `squaredNorm` value in one thread,
meanwhile it outputs H * W elements in one thread. That results it's
very very slow when H * W is a large value. And usually, H * W does be a
large value in a model. For example, in the `candy-8` model, the shapes
of [H, W] are [224,224], [112,112], [56,56] for `InstanceNormalization`
op. And in my ADL, `[1,224,224,32]` consumes 17 ms. See below:
```
[profiling] kernel "23848328|[InstanceNormalization] 23848328" input[0]: [1,224,224,32] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,224,224,32] | float32, execution time: 17007914 ns
```
In this PR, it uses workgroup memory to optimize the original algorithm.
The advantage is that it can parallelly utilize the 64 (workgroupSize)
threads in one workgroup to calculate `mean` and `squaredNorm` value.
Meanwhile, it only outputs `H * W / workgroupSize` outputs for one
thread, which greatly reduces the overhead for one thread. With this
optimization, `[1,224,224,32]` becomes 3 ms and the main overhead is the
extra two `transpose`. The `createInstanceNormProgramInfo` only needs
`0.64` ms. See below:
```
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,224,224,32] | float32, output[0]: [1,32,224,224] | float32, execution time: 1543792 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,32,224,224] | float32, execution time: 642652 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, output[0]: [1,224,224,32] | float32, execution time: 991608 ns
```
This PR currently only applies the new algorithm to NCHW format. For
NHWC format, one way is to transpose the input so that it can use the
new algorithm. But the disadvantage is that 2 extra transpose are added.
@dakenf also gives another way to optimize NHWC. Details see
[here](https://github.com/microsoft/onnxruntime/blob/d45a96616da9843b037210f2d48d6b4e5bdae5c6/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts).
I checked @dakenf's method. The perf is similar with transpose +
optimized NCHW. But on different GPUs, one is a little better than
another or vice versa. So I prefer this PR only does the NCHW part.
@dakenf can submit his optimization on NHWC.
2023-09-15 00:03:18 +00:00
|
|
|
workgroupShared[localIndex] = initial;
|
|
|
|
|
workgroupBarrier();
|
2023-08-08 16:09:37 +00:00
|
|
|
|
[js/webgpu] Optimize InstanceNormalization (#17491)
### Description
<!-- Describe your changes. -->
In previous implementation, there are two loops to iterate H * W
elements to calculate the `mean` and `squaredNorm` value in one thread,
meanwhile it outputs H * W elements in one thread. That results it's
very very slow when H * W is a large value. And usually, H * W does be a
large value in a model. For example, in the `candy-8` model, the shapes
of [H, W] are [224,224], [112,112], [56,56] for `InstanceNormalization`
op. And in my ADL, `[1,224,224,32]` consumes 17 ms. See below:
```
[profiling] kernel "23848328|[InstanceNormalization] 23848328" input[0]: [1,224,224,32] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,224,224,32] | float32, execution time: 17007914 ns
```
In this PR, it uses workgroup memory to optimize the original algorithm.
The advantage is that it can parallelly utilize the 64 (workgroupSize)
threads in one workgroup to calculate `mean` and `squaredNorm` value.
Meanwhile, it only outputs `H * W / workgroupSize` outputs for one
thread, which greatly reduces the overhead for one thread. With this
optimization, `[1,224,224,32]` becomes 3 ms and the main overhead is the
extra two `transpose`. The `createInstanceNormProgramInfo` only needs
`0.64` ms. See below:
```
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,224,224,32] | float32, output[0]: [1,32,224,224] | float32, execution time: 1543792 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,32,224,224] | float32, execution time: 642652 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, output[0]: [1,224,224,32] | float32, execution time: 991608 ns
```
This PR currently only applies the new algorithm to NCHW format. For
NHWC format, one way is to transpose the input so that it can use the
new algorithm. But the disadvantage is that 2 extra transpose are added.
@dakenf also gives another way to optimize NHWC. Details see
[here](https://github.com/microsoft/onnxruntime/blob/d45a96616da9843b037210f2d48d6b4e5bdae5c6/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts).
I checked @dakenf's method. The perf is similar with transpose +
optimized NCHW. But on different GPUs, one is a little better than
another or vice versa. So I prefer this PR only does the NCHW part.
@dakenf can submit his optimization on NHWC.
2023-09-15 00:03:18 +00:00
|
|
|
// Calculate the sum of square of deviation of current channel data.
|
|
|
|
|
for (var currSize = workgroupSize >> 1; currSize > 0; currSize = currSize >> 1) {
|
|
|
|
|
if (localIndex < currSize) {
|
|
|
|
|
workgroupShared[localIndex] = workgroupShared[localIndex] + workgroupShared[localIndex + currSize];
|
|
|
|
|
}
|
|
|
|
|
workgroupBarrier();
|
2023-08-08 16:09:37 +00:00
|
|
|
}
|
[js/webgpu] Optimize InstanceNormalization (#17491)
### Description
<!-- Describe your changes. -->
In previous implementation, there are two loops to iterate H * W
elements to calculate the `mean` and `squaredNorm` value in one thread,
meanwhile it outputs H * W elements in one thread. That results it's
very very slow when H * W is a large value. And usually, H * W does be a
large value in a model. For example, in the `candy-8` model, the shapes
of [H, W] are [224,224], [112,112], [56,56] for `InstanceNormalization`
op. And in my ADL, `[1,224,224,32]` consumes 17 ms. See below:
```
[profiling] kernel "23848328|[InstanceNormalization] 23848328" input[0]: [1,224,224,32] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,224,224,32] | float32, execution time: 17007914 ns
```
In this PR, it uses workgroup memory to optimize the original algorithm.
The advantage is that it can parallelly utilize the 64 (workgroupSize)
threads in one workgroup to calculate `mean` and `squaredNorm` value.
Meanwhile, it only outputs `H * W / workgroupSize` outputs for one
thread, which greatly reduces the overhead for one thread. With this
optimization, `[1,224,224,32]` becomes 3 ms and the main overhead is the
extra two `transpose`. The `createInstanceNormProgramInfo` only needs
`0.64` ms. See below:
```
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,224,224,32] | float32, output[0]: [1,32,224,224] | float32, execution time: 1543792 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,32,224,224] | float32, execution time: 642652 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, output[0]: [1,224,224,32] | float32, execution time: 991608 ns
```
This PR currently only applies the new algorithm to NCHW format. For
NHWC format, one way is to transpose the input so that it can use the
new algorithm. But the disadvantage is that 2 extra transpose are added.
@dakenf also gives another way to optimize NHWC. Details see
[here](https://github.com/microsoft/onnxruntime/blob/d45a96616da9843b037210f2d48d6b4e5bdae5c6/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts).
I checked @dakenf's method. The perf is similar with transpose +
optimized NCHW. But on different GPUs, one is a little better than
another or vice versa. So I prefer this PR only does the NCHW part.
@dakenf can submit his optimization on NHWC.
2023-09-15 00:03:18 +00:00
|
|
|
if (localIndex == 0) {
|
2023-12-15 19:26:15 +00:00
|
|
|
squaredNormShared = ${sumVector('workgroupShared[0]', components)};
|
[js/webgpu] Optimize InstanceNormalization (#17491)
### Description
<!-- Describe your changes. -->
In previous implementation, there are two loops to iterate H * W
elements to calculate the `mean` and `squaredNorm` value in one thread,
meanwhile it outputs H * W elements in one thread. That results it's
very very slow when H * W is a large value. And usually, H * W does be a
large value in a model. For example, in the `candy-8` model, the shapes
of [H, W] are [224,224], [112,112], [56,56] for `InstanceNormalization`
op. And in my ADL, `[1,224,224,32]` consumes 17 ms. See below:
```
[profiling] kernel "23848328|[InstanceNormalization] 23848328" input[0]: [1,224,224,32] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,224,224,32] | float32, execution time: 17007914 ns
```
In this PR, it uses workgroup memory to optimize the original algorithm.
The advantage is that it can parallelly utilize the 64 (workgroupSize)
threads in one workgroup to calculate `mean` and `squaredNorm` value.
Meanwhile, it only outputs `H * W / workgroupSize` outputs for one
thread, which greatly reduces the overhead for one thread. With this
optimization, `[1,224,224,32]` becomes 3 ms and the main overhead is the
extra two `transpose`. The `createInstanceNormProgramInfo` only needs
`0.64` ms. See below:
```
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,224,224,32] | float32, output[0]: [1,32,224,224] | float32, execution time: 1543792 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,32,224,224] | float32, execution time: 642652 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, output[0]: [1,224,224,32] | float32, execution time: 991608 ns
```
This PR currently only applies the new algorithm to NCHW format. For
NHWC format, one way is to transpose the input so that it can use the
new algorithm. But the disadvantage is that 2 extra transpose are added.
@dakenf also gives another way to optimize NHWC. Details see
[here](https://github.com/microsoft/onnxruntime/blob/d45a96616da9843b037210f2d48d6b4e5bdae5c6/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts).
I checked @dakenf's method. The perf is similar with transpose +
optimized NCHW. But on different GPUs, one is a little better than
another or vice versa. So I prefer this PR only does the NCHW part.
@dakenf can submit his optimization on NHWC.
2023-09-15 00:03:18 +00:00
|
|
|
}
|
|
|
|
|
workgroupBarrier();
|
|
|
|
|
|
2024-01-09 22:56:00 +00:00
|
|
|
let invStdDev = 1 / sqrt(squaredNormShared / f32(uniforms.normSize) + f32(${attributes.epsilon}));
|
2023-12-15 19:26:15 +00:00
|
|
|
let channelScale = invStdDev * f32(${scale.getByOffset('channel')});
|
|
|
|
|
let channelShift = f32(${bias.getByOffset('channel')}) - meanShared * channelScale;
|
2024-01-09 22:56:00 +00:00
|
|
|
for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) {
|
2023-12-15 19:26:15 +00:00
|
|
|
let value = ${x.get('batch', 'channel', 'h')} * ${dataType}(${f32Type}(channelScale)) + ${dataType}(${
|
2024-01-09 22:56:00 +00:00
|
|
|
f32Type}(channelShift));
|
[js/webgpu] Optimize InstanceNormalization (#17491)
### Description
<!-- Describe your changes. -->
In previous implementation, there are two loops to iterate H * W
elements to calculate the `mean` and `squaredNorm` value in one thread,
meanwhile it outputs H * W elements in one thread. That results it's
very very slow when H * W is a large value. And usually, H * W does be a
large value in a model. For example, in the `candy-8` model, the shapes
of [H, W] are [224,224], [112,112], [56,56] for `InstanceNormalization`
op. And in my ADL, `[1,224,224,32]` consumes 17 ms. See below:
```
[profiling] kernel "23848328|[InstanceNormalization] 23848328" input[0]: [1,224,224,32] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,224,224,32] | float32, execution time: 17007914 ns
```
In this PR, it uses workgroup memory to optimize the original algorithm.
The advantage is that it can parallelly utilize the 64 (workgroupSize)
threads in one workgroup to calculate `mean` and `squaredNorm` value.
Meanwhile, it only outputs `H * W / workgroupSize` outputs for one
thread, which greatly reduces the overhead for one thread. With this
optimization, `[1,224,224,32]` becomes 3 ms and the main overhead is the
extra two `transpose`. The `createInstanceNormProgramInfo` only needs
`0.64` ms. See below:
```
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,224,224,32] | float32, output[0]: [1,32,224,224] | float32, execution time: 1543792 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,32,224,224] | float32, execution time: 642652 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, output[0]: [1,224,224,32] | float32, execution time: 991608 ns
```
This PR currently only applies the new algorithm to NCHW format. For
NHWC format, one way is to transpose the input so that it can use the
new algorithm. But the disadvantage is that 2 extra transpose are added.
@dakenf also gives another way to optimize NHWC. Details see
[here](https://github.com/microsoft/onnxruntime/blob/d45a96616da9843b037210f2d48d6b4e5bdae5c6/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts).
I checked @dakenf's method. The perf is similar with transpose +
optimized NCHW. But on different GPUs, one is a little better than
another or vice versa. So I prefer this PR only does the NCHW part.
@dakenf can submit his optimization on NHWC.
2023-09-15 00:03:18 +00:00
|
|
|
${output.set('batch', 'channel', 'h', 'value')};
|
2023-08-08 16:09:37 +00:00
|
|
|
}
|
|
|
|
|
}`;
|
2024-01-09 22:56:00 +00:00
|
|
|
};
|
2023-08-08 16:09:37 +00:00
|
|
|
return {
|
2024-01-09 22:56:00 +00:00
|
|
|
...{name: 'InstanceNormalization'},
|
|
|
|
|
// TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon.
|
|
|
|
|
shaderCache: {hint: `${attributes.epsilon};${components}`, inputDependencies},
|
2023-10-10 07:31:12 +00:00
|
|
|
getRunData: () => ({
|
|
|
|
|
outputs: [
|
2023-10-11 23:41:46 +00:00
|
|
|
{dims: outputShape, dataType: inputs[0].dataType},
|
2023-10-10 07:31:12 +00:00
|
|
|
],
|
2024-01-09 22:56:00 +00:00
|
|
|
dispatchGroup: {x: normCount},
|
|
|
|
|
programUniforms
|
2023-10-10 07:31:12 +00:00
|
|
|
}),
|
2023-08-08 16:09:37 +00:00
|
|
|
getShaderSource,
|
|
|
|
|
};
|
|
|
|
|
};
|
|
|
|
|
|
2023-10-18 17:47:41 +00:00
|
|
|
const computeMean =
|
|
|
|
|
(context: ComputeContext, input: TensorView, scale: TensorView, bias: TensorView, n: number, h: number, c: number,
|
|
|
|
|
epsilon: number) => {
|
|
|
|
|
const components = getMaxComponents(c);
|
|
|
|
|
const WG = 64;
|
|
|
|
|
// we will store channel scale and channel shift in [2, components] matrix
|
|
|
|
|
// or in vec2 when components == 1
|
|
|
|
|
const outputType = components === 1 ? 'vec2f' : `mat2x${components}f`;
|
|
|
|
|
const sumCastType = components === 1 ? 'f32' : `vec${components}f`;
|
|
|
|
|
const setOutputValue = (var1: string, var2: string) => `${outputType}(${var1}, ${var2})`;
|
|
|
|
|
const unitsOfWork = n * c / components;
|
|
|
|
|
const wgSize = Math.ceil(h / WG);
|
|
|
|
|
|
2024-01-09 22:56:00 +00:00
|
|
|
const meanInputDependencies: ProgramInputTensorInfoDependency[] = ['type'];
|
|
|
|
|
const meanProgramUniforms: ProgramUniform[] = [
|
|
|
|
|
{type: 'uint32', data: wgSize}, {type: 'uint32', data: h}, {type: 'uint32', data: Math.floor(c / components)},
|
|
|
|
|
{type: 'uint32', data: Math.floor(h * c / components)}
|
|
|
|
|
];
|
2023-10-18 17:47:41 +00:00
|
|
|
|
2024-01-09 22:56:00 +00:00
|
|
|
const getMeanShaderSource = (shaderHelper: ShaderHelper) => {
|
|
|
|
|
const inputHelper = inputVariable('input', input.dataType, input.dims, components);
|
|
|
|
|
return `
|
2023-10-18 17:47:41 +00:00
|
|
|
${shaderHelper.declareVariables(inputHelper)}
|
|
|
|
|
@group(0) @binding(1) var<storage, read_write> output : array<${outputType}>;
|
2024-01-09 22:56:00 +00:00
|
|
|
struct Uniforms {wg_size:u32, H:u32, C:u32, image_size:u32};
|
|
|
|
|
@group(0) @binding(2) var<uniform> uniforms: Uniforms;
|
2023-10-18 17:47:41 +00:00
|
|
|
|
|
|
|
|
${shaderHelper.mainStart(WG)}
|
2024-01-09 22:56:00 +00:00
|
|
|
let currentImageNumber = global_idx / ${WG} / uniforms.C;
|
|
|
|
|
let currentChannelNumber = (global_idx / ${WG}) % uniforms.C;
|
2023-10-18 17:47:41 +00:00
|
|
|
let wgId = global_idx % ${WG};
|
2024-01-09 22:56:00 +00:00
|
|
|
let wgOffset = wgId * uniforms.wg_size;
|
|
|
|
|
if (wgOffset >= uniforms.H) {
|
2023-10-18 17:47:41 +00:00
|
|
|
return;
|
|
|
|
|
}
|
2024-01-09 22:56:00 +00:00
|
|
|
let wgMax = min(wgOffset + uniforms.wg_size, uniforms.H);
|
2023-10-18 17:47:41 +00:00
|
|
|
|
2024-01-09 22:56:00 +00:00
|
|
|
let offset = currentImageNumber * uniforms.image_size + currentChannelNumber;
|
2023-10-18 17:47:41 +00:00
|
|
|
var sum = ${fillVector('f32', components)};
|
|
|
|
|
var squaredSum = ${fillVector('f32', components)};
|
|
|
|
|
for (var i: u32 = wgOffset; i < wgMax; i++) {
|
2024-01-09 22:56:00 +00:00
|
|
|
let value = ${sumCastType}(input[offset + i * uniforms.C]);
|
2023-10-18 17:47:41 +00:00
|
|
|
sum += value;
|
|
|
|
|
squaredSum += value * value;
|
|
|
|
|
}
|
|
|
|
|
output[global_idx] = ${setOutputValue('sum', 'squaredSum')};
|
|
|
|
|
}`;
|
2024-01-09 22:56:00 +00:00
|
|
|
};
|
2023-10-18 17:47:41 +00:00
|
|
|
|
|
|
|
|
const meanValues = context.compute(
|
|
|
|
|
{
|
|
|
|
|
name: 'InstanceNormComputeMean',
|
2024-01-09 22:56:00 +00:00
|
|
|
shaderCache: {hint: `${components}`, inputDependencies: meanInputDependencies},
|
2023-10-18 17:47:41 +00:00
|
|
|
getRunData: () => ({
|
|
|
|
|
outputs: [
|
|
|
|
|
{dims: [n, c, WG, 2], dataType: DataType.float},
|
|
|
|
|
],
|
|
|
|
|
dispatchGroup: {x: n * c / components},
|
2024-01-09 22:56:00 +00:00
|
|
|
programUniforms: meanProgramUniforms
|
2023-10-18 17:47:41 +00:00
|
|
|
}),
|
|
|
|
|
getShaderSource: getMeanShaderSource,
|
|
|
|
|
},
|
|
|
|
|
{inputs: [input], outputs: [-1]})[0];
|
|
|
|
|
|
2024-01-09 22:56:00 +00:00
|
|
|
const programUniforms: ProgramUniform[] = [
|
|
|
|
|
{type: 'uint32', data: unitsOfWork}, {type: 'uint32', data: h},
|
|
|
|
|
{type: 'uint32', data: Math.floor(c / components)}, {type: 'uint32', data: Math.floor(WG * c / components)}
|
|
|
|
|
];
|
|
|
|
|
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type', 'type'];
|
|
|
|
|
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
|
|
|
|
const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components);
|
|
|
|
|
const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components);
|
|
|
|
|
return `
|
2023-10-18 17:47:41 +00:00
|
|
|
@group(0) @binding(0) var<storage, read> input : array<${outputType}>;
|
|
|
|
|
@group(0) @binding(1) var<storage, read> scale : array<${scaleHelper.type.storage}>;
|
|
|
|
|
@group(0) @binding(2) var<storage, read> bias : array<${biasHelper.type.storage}>;
|
|
|
|
|
@group(0) @binding(3) var<storage, read_write> output : array<${outputType}>;
|
2024-01-09 22:56:00 +00:00
|
|
|
struct Uniforms {units_of_work : u32, H: u32, C : u32, image_size : u32};
|
|
|
|
|
@group(0) @binding(4) var<uniform> uniforms: Uniforms;
|
2023-10-18 17:47:41 +00:00
|
|
|
|
|
|
|
|
${shaderHelper.mainStart()}
|
2024-01-09 22:56:00 +00:00
|
|
|
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.units_of_work')}
|
|
|
|
|
let currentImageNumber = global_idx / uniforms.C;
|
|
|
|
|
let currentChannelNumber = global_idx % uniforms.C;
|
2023-10-18 17:47:41 +00:00
|
|
|
|
2024-01-09 22:56:00 +00:00
|
|
|
let offset = currentImageNumber * uniforms.image_size;
|
2023-10-18 17:47:41 +00:00
|
|
|
var sum = ${fillVector('f32', components)};
|
|
|
|
|
var squaredSum = ${fillVector('f32', components)};
|
|
|
|
|
for (var i: u32 = 0; i < ${WG}; i++) {
|
|
|
|
|
let value = input[offset + i + currentChannelNumber * ${WG}];
|
|
|
|
|
sum += value[0];
|
|
|
|
|
squaredSum += value[1];
|
|
|
|
|
}
|
2024-01-09 22:56:00 +00:00
|
|
|
sum = sum / f32(uniforms.H);
|
|
|
|
|
squaredSum = squaredSum / f32(uniforms.H);
|
|
|
|
|
let invStdDev = 1 / sqrt(squaredSum - sum * sum + f32(${epsilon}));
|
2023-10-18 17:47:41 +00:00
|
|
|
let channelScale = invStdDev * ${sumCastType}(scale[currentChannelNumber]);
|
|
|
|
|
let channelShift = ${sumCastType}(bias[currentChannelNumber]) - sum * channelScale;
|
|
|
|
|
|
|
|
|
|
output[global_idx] = ${setOutputValue('channelScale', 'channelShift')};
|
|
|
|
|
}`;
|
2024-01-09 22:56:00 +00:00
|
|
|
};
|
2023-10-18 17:47:41 +00:00
|
|
|
return context.compute(
|
|
|
|
|
{
|
|
|
|
|
name: 'InstanceNormComputeChannelScaleShift',
|
2024-01-09 22:56:00 +00:00
|
|
|
// TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon.
|
|
|
|
|
shaderCache: {hint: `${components};${epsilon}`, inputDependencies},
|
2023-10-18 17:47:41 +00:00
|
|
|
getRunData: () => ({
|
|
|
|
|
outputs: [
|
|
|
|
|
{dims: [n, c, 2], dataType: DataType.float},
|
|
|
|
|
],
|
|
|
|
|
dispatchGroup: {x: Math.ceil(unitsOfWork / 64 /* workgroup size */)},
|
2024-01-09 22:56:00 +00:00
|
|
|
programUniforms
|
2023-10-18 17:47:41 +00:00
|
|
|
}),
|
|
|
|
|
getShaderSource,
|
|
|
|
|
},
|
|
|
|
|
{inputs: [meanValues, scale, bias], outputs: [-1]})[0];
|
|
|
|
|
};
|
|
|
|
|
|
2023-08-08 16:09:37 +00:00
|
|
|
const createInstanceNormNHWCProgramInfo =
|
2023-10-18 17:47:41 +00:00
|
|
|
(context: ComputeContext, inputs: readonly TensorView[], attributes: InstanceNormAttributes) => {
|
2023-08-08 16:09:37 +00:00
|
|
|
const xShape = inputs[0].dims;
|
|
|
|
|
const outputShape = xShape;
|
|
|
|
|
const N = xShape[0];
|
|
|
|
|
const C = xShape[xShape.length - 1];
|
|
|
|
|
const H = ShapeUtil.sizeFromDimension(xShape, 1) / C;
|
2023-10-18 17:47:41 +00:00
|
|
|
const components = getMaxComponents(C);
|
|
|
|
|
const outputSize = ShapeUtil.size(outputShape) / components;
|
2024-01-09 22:56:00 +00:00
|
|
|
const programUniforms: ProgramUniform[] =
|
|
|
|
|
[{type: 'uint32', data: H}, {type: 'uint32', data: Math.floor(C / components)}];
|
|
|
|
|
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
|
2023-10-18 17:47:41 +00:00
|
|
|
// first compute mean
|
|
|
|
|
const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon);
|
2024-01-09 22:56:00 +00:00
|
|
|
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
|
|
|
|
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
|
|
|
|
|
const scaleType = components === 1 ? 'vec2f' : `mat2x${components}f`;
|
|
|
|
|
const scaleCastType = components === 1 ? dataType : `vec${components}<${dataType}>`;
|
2023-08-08 16:09:37 +00:00
|
|
|
|
2024-01-09 22:56:00 +00:00
|
|
|
const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components);
|
|
|
|
|
const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components);
|
2023-08-08 16:09:37 +00:00
|
|
|
|
2024-01-09 22:56:00 +00:00
|
|
|
return `
|
2023-10-18 17:47:41 +00:00
|
|
|
@group(0) @binding(0) var<storage, read> input : array<${inputHelper.type.storage}>;
|
|
|
|
|
@group(0) @binding(1) var<storage, read> scaleInput : array<${scaleType}>;
|
|
|
|
|
@group(0) @binding(2) var<storage, read_write> output : array<${outputHelper.type.storage}>;
|
2024-01-09 22:56:00 +00:00
|
|
|
struct Uniforms {H: u32, C : u32};
|
|
|
|
|
@group(0) @binding(3) var<uniform> uniforms: Uniforms;
|
2023-08-08 16:09:37 +00:00
|
|
|
|
|
|
|
|
${shaderHelper.mainStart()}
|
2024-01-09 22:56:00 +00:00
|
|
|
let currentImageNumber = global_idx / (uniforms.C * uniforms.H);
|
|
|
|
|
let currentChannelNumber = global_idx % uniforms.C;
|
[js/webgpu] Optimize InstanceNormalization (#17491)
### Description
<!-- Describe your changes. -->
In previous implementation, there are two loops to iterate H * W
elements to calculate the `mean` and `squaredNorm` value in one thread,
meanwhile it outputs H * W elements in one thread. That results it's
very very slow when H * W is a large value. And usually, H * W does be a
large value in a model. For example, in the `candy-8` model, the shapes
of [H, W] are [224,224], [112,112], [56,56] for `InstanceNormalization`
op. And in my ADL, `[1,224,224,32]` consumes 17 ms. See below:
```
[profiling] kernel "23848328|[InstanceNormalization] 23848328" input[0]: [1,224,224,32] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,224,224,32] | float32, execution time: 17007914 ns
```
In this PR, it uses workgroup memory to optimize the original algorithm.
The advantage is that it can parallelly utilize the 64 (workgroupSize)
threads in one workgroup to calculate `mean` and `squaredNorm` value.
Meanwhile, it only outputs `H * W / workgroupSize` outputs for one
thread, which greatly reduces the overhead for one thread. With this
optimization, `[1,224,224,32]` becomes 3 ms and the main overhead is the
extra two `transpose`. The `createInstanceNormProgramInfo` only needs
`0.64` ms. See below:
```
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,224,224,32] | float32, output[0]: [1,32,224,224] | float32, execution time: 1543792 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, input[1]: [32] | float32, input[2]: [32] | float32, output[0]: [1,32,224,224] | float32, execution time: 642652 ns
program-manager.ts:115
[profiling] kernel "23003600|[InstanceNormalization] 23003600" input[0]: [1,32,224,224] | float32, output[0]: [1,224,224,32] | float32, execution time: 991608 ns
```
This PR currently only applies the new algorithm to NCHW format. For
NHWC format, one way is to transpose the input so that it can use the
new algorithm. But the disadvantage is that 2 extra transpose are added.
@dakenf also gives another way to optimize NHWC. Details see
[here](https://github.com/microsoft/onnxruntime/blob/d45a96616da9843b037210f2d48d6b4e5bdae5c6/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts).
I checked @dakenf's method. The perf is similar with transpose +
optimized NCHW. But on different GPUs, one is a little better than
another or vice versa. So I prefer this PR only does the NCHW part.
@dakenf can submit his optimization on NHWC.
2023-09-15 00:03:18 +00:00
|
|
|
|
2024-01-09 22:56:00 +00:00
|
|
|
let scaleOffset = currentImageNumber * uniforms.C + currentChannelNumber;
|
2023-10-18 17:47:41 +00:00
|
|
|
let scale = scaleInput[scaleOffset];
|
|
|
|
|
output[global_idx] = fma(input[global_idx], ${scaleCastType}(scale[0]), ${scaleCastType}(scale[1]));
|
2023-08-08 16:09:37 +00:00
|
|
|
}`;
|
2024-01-09 22:56:00 +00:00
|
|
|
};
|
2023-10-18 17:47:41 +00:00
|
|
|
context.compute(
|
|
|
|
|
{
|
2024-01-09 22:56:00 +00:00
|
|
|
name: 'InstanceNormalizationNHWC',
|
|
|
|
|
shaderCache: {hint: `${components}`, inputDependencies},
|
2023-10-18 17:47:41 +00:00
|
|
|
getRunData: () => ({
|
|
|
|
|
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
|
2024-01-09 22:56:00 +00:00
|
|
|
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
|
|
|
|
|
programUniforms
|
2023-10-18 17:47:41 +00:00
|
|
|
}),
|
|
|
|
|
getShaderSource,
|
|
|
|
|
},
|
|
|
|
|
{inputs: [inputs[0], channelScaleShift]});
|
2023-08-08 16:09:37 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
|
|
export const instanceNorm = (context: ComputeContext, attributes: InstanceNormAttributes): void => {
|
|
|
|
|
if (attributes.format === 'NHWC') {
|
2023-10-18 17:47:41 +00:00
|
|
|
createInstanceNormNHWCProgramInfo(context, context.inputs, attributes);
|
2023-08-08 16:09:37 +00:00
|
|
|
} else {
|
2023-10-10 07:31:12 +00:00
|
|
|
context.compute(createInstanceNormProgramInfo(context.inputs, attributes));
|
2023-08-08 16:09:37 +00:00
|
|
|
}
|
|
|
|
|
};
|