onnxruntime/js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts
Yulong Wang abdc31de40
[js] change default formatter for JavaScript/TypeScript from clang-format to Prettier (#21728)
### Description

See
454996d496
for manual changes (excluded auto-generated formatting changes)

### Why

Because the toolsets for old clang-format is out-of-date. This reduces
the development efficiency.

- The NPM package `clang-format` is already in maintenance mode. not
updated since 2 years ago.
- The VSCode extension for clang-format is not maintained for a while,
and a recent Node.js security update made it not working at all in
Windows.

No one in community seems interested in fixing those.

Choose Prettier as it is the most popular TS/JS formatter.

### How to merge

It's easy to break the build:
- Be careful of any new commits on main not included in this PR.
- Be careful that after this PR is merged, other PRs that already passed
CI can merge.

So, make sure there is no new commits before merging this one, and
invalidate js PRs that already passed CI, force them to merge to latest.
2024-08-14 16:51:22 -07:00

184 lines
7.9 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
import { ShapeUtil } from '../../util';
import { createAttributeWithCacheKey } from '../attribute-with-cache-key';
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, WORKGROUP_SIZE } from './common';
export interface RotaryEmbeddingAttributes {
readonly interleaved: boolean;
readonly numHeads: number;
readonly rotaryEmbeddingDim: number;
readonly scale: number;
}
const validateInputs = (inputs: readonly TensorView[], attributes: RotaryEmbeddingAttributes): void => {
const [input, positionIds, cosCache, sinCache] = inputs;
const { numHeads, rotaryEmbeddingDim } = attributes;
if (input.dims.length !== 3 && input.dims.length !== 4) {
throw new Error(`Input 'x' is expected to have 3 or 4 dimensions, got ${input.dims.length}`);
}
if (
!ShapeUtil.areEqual(positionIds.dims, []) &&
!ShapeUtil.areEqual(positionIds.dims, [1]) &&
positionIds.dims.length !== 2
) {
throw new Error(`Input 'position_ids' is expected to have 0, 1, or 2 dimensions, got ${positionIds.dims.length}`);
}
if (cosCache.dims.length !== 2) {
throw new Error(`Input 'cos_cache' is expected to have 2 dimensions, got ${cosCache.dims.length}`);
}
if (sinCache.dims.length !== 2) {
throw new Error(`Input 'sin_cache' is expected to have 2 dimensions, got ${sinCache.dims.length}`);
}
if (!ShapeUtil.areEqual(cosCache.dims, sinCache.dims)) {
throw new Error("Inputs 'cos_cache' and 'sin_cache' are expected to have the same shape");
}
if (rotaryEmbeddingDim > 0 && numHeads === 0) {
throw new Error('num_heads must be provided if rotary_embedding_dim is specified');
}
const batchSize = input.dims[0];
const sequenceLength = input.dims[input.dims.length - 2];
const maxSequenceLength = cosCache.dims[0];
const hiddenSize = ShapeUtil.sizeFromDimension(input.dims, 1) / sequenceLength;
const headSize = rotaryEmbeddingDim === 0 ? cosCache.dims[1] * 2 : hiddenSize / numHeads;
if (rotaryEmbeddingDim > headSize) {
throw new Error('rotary_embedding_dim must be less than or equal to head_size');
}
if (positionIds.dims.length === 2) {
if (batchSize !== positionIds.dims[0]) {
throw new Error(`Input 'position_ids' dimension 0 should be of size batch_size, got ${positionIds.dims[0]}`);
}
if (sequenceLength !== positionIds.dims[1]) {
throw new Error(`Input 'position_ids' dimension 1 should be of size sequence_length, got ${positionIds.dims[1]}`);
}
}
if (headSize / 2 !== cosCache.dims[1] && rotaryEmbeddingDim / 2 !== cosCache.dims[1]) {
throw new Error(
`Input 'cos_cache' dimension 1 should be same as head_size / 2 or rotary_embedding_dim / 2, got ${
cosCache.dims[1]
}`,
);
}
if (sequenceLength > maxSequenceLength) {
throw new Error('Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported');
}
};
const createRotaryEmbeddingProgramInfo = (
inputs: readonly TensorView[],
attributes: RotaryEmbeddingAttributes,
): ProgramInfo => {
const { interleaved, numHeads, rotaryEmbeddingDim, scale } = attributes;
const batchSize = inputs[0].dims[0];
const batchStride = ShapeUtil.sizeFromDimension(inputs[0].dims, 1);
const sequenceLength = inputs[0].dims[inputs[0].dims.length - 2];
const hiddenSize = batchStride / sequenceLength;
const halfRotaryEmbeddingDim = inputs[2].dims[1];
const headSize = rotaryEmbeddingDim === 0 ? halfRotaryEmbeddingDim * 2 : hiddenSize / numHeads;
// Rotary embeddings will be calculated in a pair-wise fashion. In accordance, use the shape
// [batch size, sequence length, num of heads, num of pairs to rotate + num of dims to copy]
// to unfold the global index in shader.
const globalShape = new Array<number>(
batchSize,
sequenceLength,
hiddenSize / headSize,
headSize - halfRotaryEmbeddingDim,
);
const globalStrides = ShapeUtil.computeStrides(globalShape);
const programUniforms: ProgramUniform[] = [
{ type: DataType.float, data: scale },
{ type: DataType.uint32, data: globalShape },
{ type: DataType.uint32, data: globalStrides },
// strides for addressing the input/output tensor, in permutated order to align with the unfolded global index,
// i.e. BSNH
...(inputs[0].dims.length === 3
? new Array<ProgramUniform>({ type: DataType.uint32, data: [batchStride, hiddenSize, headSize, 1] })
: []),
...(inputs[0].dims.length === 4
? new Array<ProgramUniform>({
type: DataType.uint32,
data: [batchStride, headSize, sequenceLength * headSize, 1],
})
: []),
...createTensorShapeVariables(inputs[0].dims, inputs[1].dims, inputs[2].dims, inputs[3].dims, inputs[0].dims),
];
const getShaderSource = (shaderHelper: ShaderHelper) => {
const input = inputVariable('input', inputs[0].dataType, inputs[0].dims.length);
const positionIds = inputVariable('position_ids', inputs[1].dataType, inputs[1].dims.length);
const cosCache = inputVariable('cos_cache', inputs[2].dataType, inputs[2].dims.length);
const sinCache = inputVariable('sin_cache', inputs[3].dataType, inputs[3].dims.length);
const output = outputVariable('output', inputs[0].dataType, inputs[0].dims.length);
shaderHelper.registerUniforms([
{ name: 'scale', type: 'f32' },
{ name: 'global_shape', type: 'u32', length: globalShape.length },
{ name: 'global_strides', type: 'u32', length: globalStrides.length },
{ name: 'input_output_strides', type: 'u32', length: globalStrides.length },
]);
return `
${shaderHelper.declareVariables(input, positionIds, cosCache, sinCache, output)}
${shaderHelper.mainStart(WORKGROUP_SIZE)}
let half_rotary_emb_dim = uniforms.${cosCache.name}_shape[1];
let bsnh = global_idx / uniforms.global_strides % uniforms.global_shape;
let size = uniforms.global_shape[0] * uniforms.global_strides[0];
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('size')}
if (bsnh[3] < half_rotary_emb_dim) {
let position_ids_idx =
${positionIds.broadcastedIndicesToOffset('bsnh.xy', outputVariable('', positionIds.type.tensor, 2))};
let position_id =
u32(${positionIds.getByOffset('position_ids_idx')}) + select(0, bsnh[1], position_ids_idx == 0);
let i = dot(bsnh, uniforms.input_output_strides) + select(0, bsnh[3], ${interleaved});
let j = i + select(half_rotary_emb_dim, 1, ${interleaved});
let re = ${input.getByOffset('i')} * ${cosCache.get('position_id', 'bsnh[3]')} -
${input.getByOffset('j')} * ${sinCache.get('position_id', 'bsnh[3]')};
${output.setByOffset('i', 're')}
let im = ${input.getByOffset('i')} * ${sinCache.get('position_id', 'bsnh[3]')} +
${input.getByOffset('j')} * ${cosCache.get('position_id', 'bsnh[3]')};
${output.setByOffset('j', 'im')}
} else {
let k = dot(bsnh, uniforms.input_output_strides) + half_rotary_emb_dim;
${output.setByOffset('k', input.getByOffset('k'))}
}
}`;
};
return {
name: 'RotaryEmbedding',
shaderCache: {
hint: createAttributeWithCacheKey({
interleaved,
}).cacheKey,
inputDependencies: ['rank', 'rank', 'rank', 'rank'],
},
getShaderSource,
getRunData: () => ({
outputs: [{ dims: inputs[0].dims, dataType: inputs[0].dataType }],
dispatchGroup: { x: Math.ceil(ShapeUtil.size(globalShape) / WORKGROUP_SIZE) },
programUniforms,
}),
};
};
export const rotaryEmbedding = (context: ComputeContext, attributes: RotaryEmbeddingAttributes): void => {
validateInputs(context.inputs, attributes);
context.compute(createRotaryEmbeddingProgramInfo(context.inputs, attributes));
};