diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 4a8c92bb97..61b68a777b 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -87,6 +87,7 @@ Do not modify directly.* | Relu | ai.onnx(6-12,13,14+) | | | Reshape | ai.onnx(5-12,13,14+) | no GPU kernel | | Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(10,11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling | +| RotaryEmbedding | com.microsoft(1+) | | | Shape | ai.onnx(1-12,13-14,15+) | no GPU kernel; an ORT warning is generated - need to fix | | Sigmoid | ai.onnx(6-12,13+) | | | Sin | ai.onnx(7+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index ba874c8dd0..575cf296aa 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -27,6 +27,7 @@ import * as pool from './ops/pool'; import {range} from './ops/range'; import {reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; import {parseResizeAttributes, resize} from './ops/resize'; +import {rotaryEmbedding} from './ops/rotary-embedding'; import {skipLayerNorm} from './ops/skip-layer-norm'; import {parseSliceAttributes, slice} from './ops/slice'; import {parseSoftmaxAttributes, softmax} from './ops/softmax'; @@ -116,6 +117,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['ReduceSumSquare', [reduceSumSquare]], ['Relu', [unaryOps.relu]], ['Resize', [resize, parseResizeAttributes]], + ['RotaryEmbedding', [rotaryEmbedding]], ['Sigmoid', [unaryOps.sigmoid]], ['Sin', [unaryOps.sin]], ['Sinh', [unaryOps.sinh]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts b/js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts new file mode 100644 index 0000000000..a58087072e --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts @@ -0,0 +1,170 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; + +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, WORKGROUP_SIZE} from './common'; + +export interface RotaryEmbeddingAttributes { + readonly interleaved: boolean; + readonly numHeads: number; + readonly rotaryEmbeddingDim: number; + readonly scale: number; +} + +const validateInputs = (inputs: readonly TensorView[], attributes: RotaryEmbeddingAttributes): void => { + const [input, positionIds, cosCache, sinCache] = inputs; + const {numHeads, rotaryEmbeddingDim} = attributes; + + if (input.dims.length !== 3 && input.dims.length !== 4) { + throw new Error(`Input 'x' is expected to have 3 or 4 dimensions, got ${input.dims.length}`); + } + if (!ShapeUtil.areEqual(positionIds.dims, []) && !ShapeUtil.areEqual(positionIds.dims, [1]) && + positionIds.dims.length !== 2) { + throw new Error(`Input 'position_ids' is expected to have 0, 1, or 2 dimensions, got ${positionIds.dims.length}`); + } + if (cosCache.dims.length !== 2) { + throw new Error(`Input 'cos_cache' is expected to have 2 dimensions, got ${cosCache.dims.length}`); + } + if (sinCache.dims.length !== 2) { + throw new Error(`Input 'sin_cache' is expected to have 2 dimensions, got ${sinCache.dims.length}`); + } + if (!ShapeUtil.areEqual(cosCache.dims, sinCache.dims)) { + throw new Error('Inputs \'cos_cache\' and \'sin_cache\' are expected to have the same shape'); + } + + if (rotaryEmbeddingDim > 0 && numHeads === 0) { + throw new Error('num_heads must be provided if rotary_embedding_dim is specified'); + } + + const batchSize = input.dims[0]; + const sequenceLength = input.dims[input.dims.length - 2]; + const maxSequenceLength = cosCache.dims[0]; + const hiddenSize = ShapeUtil.sizeFromDimension(input.dims, 1) / sequenceLength; + const headSize = rotaryEmbeddingDim === 0 ? cosCache.dims[1] * 2 : hiddenSize / numHeads; + if (rotaryEmbeddingDim > headSize) { + throw new Error('rotary_embedding_dim must be less than or equal to head_size'); + } + + if (positionIds.dims.length === 2) { + if (batchSize !== positionIds.dims[0]) { + throw new Error(`Input 'position_ids' dimension 0 should be of size batch_size, got ${positionIds.dims[0]}`); + } + if (sequenceLength !== positionIds.dims[1]) { + throw new Error(`Input 'position_ids' dimension 1 should be of size sequence_length, got ${positionIds.dims[1]}`); + } + } + + if (headSize / 2 !== cosCache.dims[1] && rotaryEmbeddingDim / 2 !== cosCache.dims[1]) { + throw new Error(`Input 'cos_cache' dimension 1 should be same as head_size / 2 or rotary_embedding_dim / 2, got ${ + cosCache.dims[1]}`); + } + + if (sequenceLength > maxSequenceLength) { + throw new Error('Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported'); + } +}; + +const createRotaryEmbeddingProgramInfo = + (inputs: readonly TensorView[], attributes: RotaryEmbeddingAttributes): ProgramInfo => { + const {interleaved, numHeads, rotaryEmbeddingDim, scale} = attributes; + const batchSize = inputs[0].dims[0]; + const batchStride = ShapeUtil.sizeFromDimension(inputs[0].dims, 1); + const sequenceLength = inputs[0].dims[inputs[0].dims.length - 2]; + const hiddenSize = batchStride / sequenceLength; + const halfRotaryEmbeddingDim = inputs[2].dims[1]; + const headSize = rotaryEmbeddingDim === 0 ? halfRotaryEmbeddingDim * 2 : hiddenSize / numHeads; + + // Rotary embeddings will be calculated in a pair-wise fashion. In accordance, use the shape + // [batch size, sequence length, num of heads, num of pairs to rotate + num of dims to copy] + // to unfold the global index in shader. + const globalShape = + new Array(batchSize, sequenceLength, hiddenSize / headSize, headSize - halfRotaryEmbeddingDim); + const globalStrides = ShapeUtil.computeStrides(globalShape); + + const programUniforms: ProgramUniform[] = [ + {type: DataType.float, data: scale}, + {type: DataType.uint32, data: globalShape}, + {type: DataType.uint32, data: globalStrides}, + + // strides for addressing the input/output tensor, in permutated order to align with the unfolded global index, + // i.e. BSNH + ...(inputs[0].dims.length === 3 ? + new Array({type: DataType.uint32, data: [batchStride, hiddenSize, headSize, 1]}) : + []), + ...(inputs[0].dims.length === 4 ? + new Array( + {type: DataType.uint32, data: [batchStride, headSize, sequenceLength * headSize, 1]}) : + []), + + ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims, inputs[2].dims, inputs[3].dims, inputs[0].dims), + ]; + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const input = inputVariable('input', inputs[0].dataType, inputs[0].dims.length); + const positionIds = inputVariable('position_ids', inputs[1].dataType, inputs[1].dims.length); + const cosCache = inputVariable('cos_cache', inputs[2].dataType, inputs[2].dims.length); + const sinCache = inputVariable('sin_cache', inputs[3].dataType, inputs[3].dims.length); + const output = outputVariable('output', inputs[0].dataType, inputs[0].dims.length); + + shaderHelper.registerUniforms([ + {name: 'scale', type: 'f32'}, + {name: 'global_shape', type: 'u32', length: globalShape.length}, + {name: 'global_strides', type: 'u32', length: globalStrides.length}, + {name: 'input_output_strides', type: 'u32', length: globalStrides.length}, + ]); + + return ` + ${shaderHelper.declareVariables(input, positionIds, cosCache, sinCache, output)} + + ${shaderHelper.mainStart(WORKGROUP_SIZE)} + let half_rotary_emb_dim = uniforms.${cosCache.name}_shape[1]; + let bsnh = global_idx / uniforms.global_strides % uniforms.global_shape; + let size = uniforms.global_shape[0] * uniforms.global_strides[0]; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('size')} + + if (bsnh[3] < half_rotary_emb_dim) { + let position_ids_idx = + ${positionIds.broadcastedIndicesToOffset('bsnh.xy', outputVariable('', positionIds.type.tensor, 2))}; + let position_id = + u32(${positionIds.getByOffset('position_ids_idx')}) + select(0, bsnh[1], position_ids_idx == 0); + let i = dot(bsnh, uniforms.input_output_strides) + select(0, bsnh[3], ${interleaved}); + let j = i + select(half_rotary_emb_dim, 1, ${interleaved}); + let re = ${input.getByOffset('i')} * ${cosCache.get('position_id', 'bsnh[3]')} - + ${input.getByOffset('j')} * ${sinCache.get('position_id', 'bsnh[3]')}; + ${output.setByOffset('i', 're')} + let im = ${input.getByOffset('i')} * ${sinCache.get('position_id', 'bsnh[3]')} + + ${input.getByOffset('j')} * ${cosCache.get('position_id', 'bsnh[3]')}; + ${output.setByOffset('j', 'im')} + } else { + let k = dot(bsnh, uniforms.input_output_strides) + half_rotary_emb_dim; + ${output.setByOffset('k', input.getByOffset('k'))} + } + }`; + }; + + return { + name: 'RotaryEmbedding', + shaderCache: { + hint: createAttributeWithCacheKey({ + interleaved, + }).cacheKey, + inputDependencies: ['rank', 'rank', 'rank', 'rank'], + }, + getShaderSource, + getRunData: () => ({ + outputs: [{dims: inputs[0].dims, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(ShapeUtil.size(globalShape) / WORKGROUP_SIZE)}, + programUniforms, + }), + }; + }; + +export const rotaryEmbedding = (context: ComputeContext, attributes: RotaryEmbeddingAttributes): void => { + validateInputs(context.inputs, attributes); + context.compute(createRotaryEmbeddingProgramInfo(context.inputs, attributes)); +}; diff --git a/js/web/test/data/ops/rotary-embedding.jsonc b/js/web/test/data/ops/rotary-embedding.jsonc new file mode 100644 index 0000000000..1b564ecc77 --- /dev/null +++ b/js/web/test/data/ops/rotary-embedding.jsonc @@ -0,0 +1,925 @@ +[ + { + "name": "RotaryEmbedding with no attributes", + "operator": "RotaryEmbedding", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [], + "cases": [ + { + "name": "T[2,8,24] T[1] T[16,3] T[16,3]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.019, 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, + 0.1036, -0.3514, 0.2421, 0.6463, 0.873, -0.9276, 1.0311, -1.9557, -0.1482, 1.7376, 2.2039, -0.6589, + -1.0574, -0.1188, -0.9078, 0.3452, -0.5713, -0.2351, -0.5912, 1.1312, 0.7562, -1.2023, -0.5833, -0.4407, + 0.1766, 1.0224, -0.4826, -0.5421, -0.5342, -0.6413, 1.3314, -0.4498, 0.5493, 0.0539, 0.2601, 0.857, + 1.0076, -0.7529, -0.225, -0.4327, -1.5071, -0.4586, -1.9791, 0.7787, -0.7749, -0.1398, 1.1414, -0.6354, + 0.0352, -0.4765, -0.0409, 1.1993, 0.5374, -0.193, 2.5211, -0.0452, -0.3105, -0.9407, -0.0034, 1.5199, + -0.848, 0.5266, 0.0299, -0.0498, 1.0651, 0.886, -1.4702, -0.2134, -0.8707, 1.6159, -0.2356, 0.9444, + 0.5937, 0.7203, 0.5061, 1.5192, -0.4897, 0.9231, 0.2654, -0.1441, 0.5407, -1.5476, 0.6455, -1.1382, 0.464, + -0.4986, 0.1289, 2.7631, 0.1405, 1.1191, 2.1134, -0.9754, 0.1757, -0.1319, -0.2735, 0.3355, -0.6008, + -1.1164, 0.2577, -0.7226, -0.9244, 1.8737, 0.6052, 1.1904, 1.2195, -0.047, -1.0914, 1.0223, 0.3152, + 1.7528, -0.765, 1.8299, -0.2784, -0.2719, 0.1885, 2.1432, 0.8527, 0.0965, -0.0625, 0.8269, 1.0122, + -1.4482, -0.0644, 0.3215, 0.5908, -1.4197, 0.2113, 0.0306, 0.3604, 0.3166, -0.8975, -0.6393, -1.2944, + -0.0243, -0.2354, -0.7087, 1.1566, 0.4296, 0.5599, -0.7776, 0.3339, 0.1759, 2.1108, 1.0702, 0.8279, + -0.2969, 0.712, -0.2068, -0.1548, 0.1553, 0.6207, -0.169, -0.5816, 1.2632, 0.0695, 1.1862, -1.1874, + -0.7468, -0.932, -0.8579, -0.9647, -0.0991, 0.0195, 1.1213, -1.4873, -0.2043, -1.0466, -1.5772, -0.0489, + 0.343, 0.1264, 0.1519, -1.3639, -1.6593, 1.8127, -1.4459, -0.2158, -0.9792, -1.4392, 0.6508, 0.8964, + 0.5717, -0.239, 0.6983, -1.3416, 0.2715, -0.2852, 0.6051, 0.2167, -0.2181, -1.6306, 1.4788, 0.2754, + -0.0261, -0.4618, -0.5646, -1.0389, 0.5819, 1.3697, 0.0002, 1.5333, -1.0556, -0.1254, 0.1527, -0.5996, + -1.0962, 1.6327, 1.3951, 0.8784, 0.3389, 1.2907, 0.3124, 0.7299, 1.422, 0.3375, 0.0438, 1.8698, -0.2635, + -2.0799, -0.6313, 0.409, -1.1458, 0.0784, -1.8848, -1.6165, 0.6179, 0.9905, -0.0729, 0.5054, -0.6681, + -1.4382, 1.7547, -0.9605, -0.4558, -1.6105, 0.2979, 1.1537, -1.5604, 1.2779, -1.2514, 0.6056, 0.5763, + -3.3558, 0.2836, 0.6909, -0.7631, 2.4451, -0.35, 1.3289, -0.6494, 0.3478, 1.0038, -0.2937, 0.9238, + -1.2185, 0.4138, 0.5033, 0.9174, 1.8131, 1.4436, -0.4207, 0.022, -0.6807, -1.3306, 1.5646, 0.3338, 0.7105, + 0.4683, -0.6179, 0.0818, -0.0488, -0.981, -1.3632, 0.0929, -1.7926, -0.2921, -0.4792, 0.6756, -0.3413, + -0.2242, -0.2111, 0.6282, 0.1667, -1.4055, 1.5895, 1.0838, -0.9077, -0.806, 0.7967, -2.9351, 2.4179, + -0.4026, 0.6451, 1.6845, -0.0901, 0.6106, 2.3603, 1.3908, -0.7917, -0.6734, -0.1213, -1.1116, -0.7401, + -0.7879, 0.0606, -2.3337, -1.2603, -1.7245, -0.3533, -0.9421, -0.1776, 0.3992, -1.7142, -0.5319, -0.8848, + 0.6513, 1.0002, -1.4699, -1.4254, 0.7013, 0.2414, 0.2551, -0.7457, 0.3133, -1.0941, -0.3682, -0.0163, + -0.0645, -0.8101, 0.1415, 0.0551, 0.5873, -0.5887, -1.4733, -0.8565, 0.74, -0.5033, 0.0553, 0.9265, + -0.8652, -0.0288, -0.2209, 0.061, 0.6776, 0.4361, -0.8052, 0.3955, 0.8988, 0.8238, 0.2262, 1.2912, 0.6488, + 1.2114, 1.3569, 0.2983, 0.4718, -1.1936, 0.7928, -0.8665, 0.9468, 1.1629, 0.0616, -1.3136, -0.2764, + 0.0277, -0.1126, 0.2342, -0.5866, -1.8219, 1.1079, 0.5795, -1.4249 + ], + "dims": [2, 8, 24], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "int64" + }, + { + "data": [ + 1.0, 1.0, 1.0, 0.5403, 0.9989, 1.0, -0.4161, 0.9957, 1.0, -0.99, 0.9903, 1.0, -0.6536, 0.9828, 1.0, + 0.2837, 0.9732, 0.9999, 0.9602, 0.9615, 0.9999, 0.7539, 0.9477, 0.9999, -0.1455, 0.9318, 0.9999, -0.9111, + 0.914, 0.9998, -0.8391, 0.8942, 0.9998, 0.0044, 0.8725, 0.9997, 0.8439, 0.8488, 0.9997, 0.9074, 0.8234, + 0.9996, 0.1367, 0.7962, 0.9995, -0.7597, 0.7673, 0.9995 + ], + "dims": [16, 3], + "type": "float32" + }, + { + "data": [ + 0.0, 0.0, 0.0, 0.8415, 0.0464, 0.0022, 0.9093, 0.0927, 0.0043, 0.1411, 0.1388, 0.0065, -0.7568, 0.1846, + 0.0086, -0.9589, 0.23, 0.0108, -0.2794, 0.2749, 0.0129, 0.657, 0.3192, 0.0151, 0.9894, 0.3629, 0.0172, + 0.4121, 0.4057, 0.0194, -0.544, 0.4477, 0.0215, -1.0, 0.4887, 0.0237, -0.5366, 0.5286, 0.0259, 0.4202, + 0.5675, 0.028, 0.9906, 0.605, 0.0302, 0.6503, 0.6413, 0.0323 + ], + "dims": [16, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.019, 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, + 0.1036, -0.3514, 0.2421, 0.6463, 0.873, -0.9276, 1.0311, -1.9557, -0.1482, 1.7376, 2.2039, -0.6589, + -0.8618, -0.0922, -0.9073, -0.7032, -0.5762, -0.2371, 0.6923, 1.1571, 0.7572, -1.1471, -0.5302, -0.4391, + 0.5516, 1.0461, -0.4812, -0.1443, -0.4862, -0.6423, 0.674, -0.4614, 0.5475, 1.1495, 0.2389, 0.8582, + -0.0259, -0.6099, -0.223, 1.0963, -1.5704, -0.4595, 0.9507, 0.6696, -0.7721, -1.7415, 1.2087, -0.6387, + -1.1052, -0.5243, -0.04, -0.4671, 0.4909, -0.1931, -0.1937, -0.0447, -0.3171, 2.6839, -0.0076, 1.5185, + 0.8465, 0.3737, 0.0242, -0.0703, 1.1279, 0.8862, 1.2275, -0.1786, -0.8767, -1.8072, -0.263, 0.9387, + -0.8021, 0.7813, 0.5001, -1.4202, -0.385, 0.9263, -0.0443, -0.2323, 0.548, 1.5696, 0.6193, -1.1346, + 1.7878, -0.516, 0.1192, -2.1572, 0.046, 1.1202, -1.4812, -0.9082, 0.1728, -1.5132, -0.4489, 0.337, + -0.1541, -0.9266, 0.2416, 0.927, -1.1146, 1.8758, -0.4312, 1.3714, 1.2106, -0.4272, -0.8529, 1.0328, + 1.8441, 1.7698, -0.762, 0.2168, 0.1322, -0.2802, 0.146, 2.1002, 0.8437, -0.1534, 0.4321, 0.836, 0.5955, + -1.5452, -0.0491, -0.8794, 0.2418, -1.4203, 0.3635, 0.2362, 0.3672, -0.1128, -0.8664, -0.6354, -1.4409, + -0.3413, -0.2409, -0.3188, 1.1054, 0.4265, 0.5867, -1.3279, 0.3201, 0.0125, 1.8157, 1.0745, 0.7372, + -0.2429, 0.71, -0.4299, -0.2304, 0.1645, 0.9489, -0.1816, -0.5968, 1.0394, 0.0204, 1.1786, -0.3315, + -0.3997, -0.9304, -1.4268, -1.1526, -0.1132, 0.149, 1.3967, -1.4634, -0.1412, -0.6339, -1.5995, -0.1366, + 0.7604, 0.1514, 0.0824, -1.183, -1.6572, 2.0099, -0.9108, -0.2256, 0.4527, -1.8254, 0.6475, 0.8964, + 0.5717, -0.239, 0.6983, -1.3416, 0.2715, -0.2852, 0.6051, 0.2167, -0.2181, -1.6306, 1.4788, 0.2754, + -0.0261, -0.4618, -0.5646, -1.0389, 0.5819, 1.3697, 0.0002, 1.5333, -1.0556, -0.1254, 0.1527, -1.4979, + -1.1358, 1.632, 0.2493, 0.8266, 0.3424, -0.4992, 0.2964, 0.7298, 1.8544, 0.3516, 0.0454, 1.5415, -0.2822, + -2.0774, 1.2323, 0.3963, -1.1503, -0.4775, -1.9287, -1.6164, 0.3998, 0.902, -0.0764, -1.8059, -0.5762, + -1.4362, -0.2706, -1.0183, -0.462, 2.0891, 0.1782, 1.1591, -0.8151, 1.3, -1.2464, -0.5099, 0.5098, + -3.3525, 0.4326, 0.7414, -0.7775, -0.4271, -0.3807, 1.3245, 2.4936, 0.3139, 1.0095, 0.2323, 0.845, + -1.2244, -0.4511, 0.6266, 0.9095, -1.7981, 1.5241, -0.4121, 0.2341, -0.4737, -1.3333, -1.615, 0.4164, + 0.71, -0.2429, -0.5656, 0.0863, 0.0352, -0.7227, -1.3613, -0.0988, -1.9114, -0.3009, 0.1435, 0.7029, + -0.3467, 0.5092, -0.0828, 0.6253, 0.7113, -1.2138, 1.5964, -0.8346, -1.1515, -0.7923, -0.8254, -3.0038, + 2.4033, -0.3398, 0.0922, 1.7053, 1.1114, 0.7462, 2.366, -0.8409, -0.6654, -0.653, -0.7899, -1.0957, + -0.7149, -0.1072, -0.1967, -2.3416, -1.2609, -1.6375, -0.3576, 0.9413, -0.5694, 0.3954, 0.1383, -0.7477, + -0.8689, 1.8286, 0.851, -1.4793, -0.1597, 0.8541, 0.238, 1.4392, -0.5644, 0.3158, -1.0686, -0.1313, + -0.0181, 0.2438, -0.8801, 0.1413, -0.3587, 0.8002, -0.5982, -1.4301, -0.662, 0.7324, -0.725, 0.061, + 0.9293, -0.6902, -0.0125, -0.2089, -0.1664, 0.5428, 0.4245, -0.7901, 0.5665, 0.9044, 0.1948, -0.1723, + 1.2705, 1.0303, 1.2202, 1.3762, -0.2959, 0.7237, -1.2077, 0.7937, -0.6705, 0.9287, 1.0583, 0.0496, + -1.3118, 0.5556, 0.0459, -0.1324, -0.5513, -0.7409, -1.8002, 0.9892, 0.3619, -1.4522 + ], + "dims": [2, 8, 24], + "type": "float32" + } + ] + }, + { + "name": "T[2,8,24] Scalar T[16,3] T[16,3]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.019, 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, + 0.1036, -0.3514, 0.2421, 0.6463, 0.873, -0.9276, 1.0311, -1.9557, -0.1482, 1.7376, 2.2039, -0.6589, + -1.0574, -0.1188, -0.9078, 0.3452, -0.5713, -0.2351, -0.5912, 1.1312, 0.7562, -1.2023, -0.5833, -0.4407, + 0.1766, 1.0224, -0.4826, -0.5421, -0.5342, -0.6413, 1.3314, -0.4498, 0.5493, 0.0539, 0.2601, 0.857, + 1.0076, -0.7529, -0.225, -0.4327, -1.5071, -0.4586, -1.9791, 0.7787, -0.7749, -0.1398, 1.1414, -0.6354, + 0.0352, -0.4765, -0.0409, 1.1993, 0.5374, -0.193, 2.5211, -0.0452, -0.3105, -0.9407, -0.0034, 1.5199, + -0.848, 0.5266, 0.0299, -0.0498, 1.0651, 0.886, -1.4702, -0.2134, -0.8707, 1.6159, -0.2356, 0.9444, + 0.5937, 0.7203, 0.5061, 1.5192, -0.4897, 0.9231, 0.2654, -0.1441, 0.5407, -1.5476, 0.6455, -1.1382, 0.464, + -0.4986, 0.1289, 2.7631, 0.1405, 1.1191, 2.1134, -0.9754, 0.1757, -0.1319, -0.2735, 0.3355, -0.6008, + -1.1164, 0.2577, -0.7226, -0.9244, 1.8737, 0.6052, 1.1904, 1.2195, -0.047, -1.0914, 1.0223, 0.3152, + 1.7528, -0.765, 1.8299, -0.2784, -0.2719, 0.1885, 2.1432, 0.8527, 0.0965, -0.0625, 0.8269, 1.0122, + -1.4482, -0.0644, 0.3215, 0.5908, -1.4197, 0.2113, 0.0306, 0.3604, 0.3166, -0.8975, -0.6393, -1.2944, + -0.0243, -0.2354, -0.7087, 1.1566, 0.4296, 0.5599, -0.7776, 0.3339, 0.1759, 2.1108, 1.0702, 0.8279, + -0.2969, 0.712, -0.2068, -0.1548, 0.1553, 0.6207, -0.169, -0.5816, 1.2632, 0.0695, 1.1862, -1.1874, + -0.7468, -0.932, -0.8579, -0.9647, -0.0991, 0.0195, 1.1213, -1.4873, -0.2043, -1.0466, -1.5772, -0.0489, + 0.343, 0.1264, 0.1519, -1.3639, -1.6593, 1.8127, -1.4459, -0.2158, -0.9792, -1.4392, 0.6508, 0.8964, + 0.5717, -0.239, 0.6983, -1.3416, 0.2715, -0.2852, 0.6051, 0.2167, -0.2181, -1.6306, 1.4788, 0.2754, + -0.0261, -0.4618, -0.5646, -1.0389, 0.5819, 1.3697, 0.0002, 1.5333, -1.0556, -0.1254, 0.1527, -0.5996, + -1.0962, 1.6327, 1.3951, 0.8784, 0.3389, 1.2907, 0.3124, 0.7299, 1.422, 0.3375, 0.0438, 1.8698, -0.2635, + -2.0799, -0.6313, 0.409, -1.1458, 0.0784, -1.8848, -1.6165, 0.6179, 0.9905, -0.0729, 0.5054, -0.6681, + -1.4382, 1.7547, -0.9605, -0.4558, -1.6105, 0.2979, 1.1537, -1.5604, 1.2779, -1.2514, 0.6056, 0.5763, + -3.3558, 0.2836, 0.6909, -0.7631, 2.4451, -0.35, 1.3289, -0.6494, 0.3478, 1.0038, -0.2937, 0.9238, + -1.2185, 0.4138, 0.5033, 0.9174, 1.8131, 1.4436, -0.4207, 0.022, -0.6807, -1.3306, 1.5646, 0.3338, 0.7105, + 0.4683, -0.6179, 0.0818, -0.0488, -0.981, -1.3632, 0.0929, -1.7926, -0.2921, -0.4792, 0.6756, -0.3413, + -0.2242, -0.2111, 0.6282, 0.1667, -1.4055, 1.5895, 1.0838, -0.9077, -0.806, 0.7967, -2.9351, 2.4179, + -0.4026, 0.6451, 1.6845, -0.0901, 0.6106, 2.3603, 1.3908, -0.7917, -0.6734, -0.1213, -1.1116, -0.7401, + -0.7879, 0.0606, -2.3337, -1.2603, -1.7245, -0.3533, -0.9421, -0.1776, 0.3992, -1.7142, -0.5319, -0.8848, + 0.6513, 1.0002, -1.4699, -1.4254, 0.7013, 0.2414, 0.2551, -0.7457, 0.3133, -1.0941, -0.3682, -0.0163, + -0.0645, -0.8101, 0.1415, 0.0551, 0.5873, -0.5887, -1.4733, -0.8565, 0.74, -0.5033, 0.0553, 0.9265, + -0.8652, -0.0288, -0.2209, 0.061, 0.6776, 0.4361, -0.8052, 0.3955, 0.8988, 0.8238, 0.2262, 1.2912, 0.6488, + 1.2114, 1.3569, 0.2983, 0.4718, -1.1936, 0.7928, -0.8665, 0.9468, 1.1629, 0.0616, -1.3136, -0.2764, + 0.0277, -0.1126, 0.2342, -0.5866, -1.8219, 1.1079, 0.5795, -1.4249 + ], + "dims": [2, 8, 24], + "type": "float32" + }, + { + "data": [0], + "dims": [], + "type": "int64" + }, + { + "data": [ + 1.0, 1.0, 1.0, 0.5403, 0.9989, 1.0, -0.4161, 0.9957, 1.0, -0.99, 0.9903, 1.0, -0.6536, 0.9828, 1.0, + 0.2837, 0.9732, 0.9999, 0.9602, 0.9615, 0.9999, 0.7539, 0.9477, 0.9999, -0.1455, 0.9318, 0.9999, -0.9111, + 0.914, 0.9998, -0.8391, 0.8942, 0.9998, 0.0044, 0.8725, 0.9997, 0.8439, 0.8488, 0.9997, 0.9074, 0.8234, + 0.9996, 0.1367, 0.7962, 0.9995, -0.7597, 0.7673, 0.9995 + ], + "dims": [16, 3], + "type": "float32" + }, + { + "data": [ + 0.0, 0.0, 0.0, 0.8415, 0.0464, 0.0022, 0.9093, 0.0927, 0.0043, 0.1411, 0.1388, 0.0065, -0.7568, 0.1846, + 0.0086, -0.9589, 0.23, 0.0108, -0.2794, 0.2749, 0.0129, 0.657, 0.3192, 0.0151, 0.9894, 0.3629, 0.0172, + 0.4121, 0.4057, 0.0194, -0.544, 0.4477, 0.0215, -1.0, 0.4887, 0.0237, -0.5366, 0.5286, 0.0259, 0.4202, + 0.5675, 0.028, 0.9906, 0.605, 0.0302, 0.6503, 0.6413, 0.0323 + ], + "dims": [16, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.019, 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, + 0.1036, -0.3514, 0.2421, 0.6463, 0.873, -0.9276, 1.0311, -1.9557, -0.1482, 1.7376, 2.2039, -0.6589, + -0.8618, -0.0922, -0.9073, -0.7032, -0.5762, -0.2371, 0.6923, 1.1571, 0.7572, -1.1471, -0.5302, -0.4391, + 0.5516, 1.0461, -0.4812, -0.1443, -0.4862, -0.6423, 0.674, -0.4614, 0.5475, 1.1495, 0.2389, 0.8582, + -0.0259, -0.6099, -0.223, 1.0963, -1.5704, -0.4595, 0.9507, 0.6696, -0.7721, -1.7415, 1.2087, -0.6387, + -1.1052, -0.5243, -0.04, -0.4671, 0.4909, -0.1931, -0.1937, -0.0447, -0.3171, 2.6839, -0.0076, 1.5185, + 0.8465, 0.3737, 0.0242, -0.0703, 1.1279, 0.8862, 1.2275, -0.1786, -0.8767, -1.8072, -0.263, 0.9387, + -0.8021, 0.7813, 0.5001, -1.4202, -0.385, 0.9263, -0.0443, -0.2323, 0.548, 1.5696, 0.6193, -1.1346, + 1.7878, -0.516, 0.1192, -2.1572, 0.046, 1.1202, -1.4812, -0.9082, 0.1728, -1.5132, -0.4489, 0.337, + -0.1541, -0.9266, 0.2416, 0.927, -1.1146, 1.8758, -0.4312, 1.3714, 1.2106, -0.4272, -0.8529, 1.0328, + 1.8441, 1.7698, -0.762, 0.2168, 0.1322, -0.2802, 0.146, 2.1002, 0.8437, -0.1534, 0.4321, 0.836, 0.5955, + -1.5452, -0.0491, -0.8794, 0.2418, -1.4203, 0.3635, 0.2362, 0.3672, -0.1128, -0.8664, -0.6354, -1.4409, + -0.3413, -0.2409, -0.3188, 1.1054, 0.4265, 0.5867, -1.3279, 0.3201, 0.0125, 1.8157, 1.0745, 0.7372, + -0.2429, 0.71, -0.4299, -0.2304, 0.1645, 0.9489, -0.1816, -0.5968, 1.0394, 0.0204, 1.1786, -0.3315, + -0.3997, -0.9304, -1.4268, -1.1526, -0.1132, 0.149, 1.3967, -1.4634, -0.1412, -0.6339, -1.5995, -0.1366, + 0.7604, 0.1514, 0.0824, -1.183, -1.6572, 2.0099, -0.9108, -0.2256, 0.4527, -1.8254, 0.6475, 0.8964, + 0.5717, -0.239, 0.6983, -1.3416, 0.2715, -0.2852, 0.6051, 0.2167, -0.2181, -1.6306, 1.4788, 0.2754, + -0.0261, -0.4618, -0.5646, -1.0389, 0.5819, 1.3697, 0.0002, 1.5333, -1.0556, -0.1254, 0.1527, -1.4979, + -1.1358, 1.632, 0.2493, 0.8266, 0.3424, -0.4992, 0.2964, 0.7298, 1.8544, 0.3516, 0.0454, 1.5415, -0.2822, + -2.0774, 1.2323, 0.3963, -1.1503, -0.4775, -1.9287, -1.6164, 0.3998, 0.902, -0.0764, -1.8059, -0.5762, + -1.4362, -0.2706, -1.0183, -0.462, 2.0891, 0.1782, 1.1591, -0.8151, 1.3, -1.2464, -0.5099, 0.5098, + -3.3525, 0.4326, 0.7414, -0.7775, -0.4271, -0.3807, 1.3245, 2.4936, 0.3139, 1.0095, 0.2323, 0.845, + -1.2244, -0.4511, 0.6266, 0.9095, -1.7981, 1.5241, -0.4121, 0.2341, -0.4737, -1.3333, -1.615, 0.4164, + 0.71, -0.2429, -0.5656, 0.0863, 0.0352, -0.7227, -1.3613, -0.0988, -1.9114, -0.3009, 0.1435, 0.7029, + -0.3467, 0.5092, -0.0828, 0.6253, 0.7113, -1.2138, 1.5964, -0.8346, -1.1515, -0.7923, -0.8254, -3.0038, + 2.4033, -0.3398, 0.0922, 1.7053, 1.1114, 0.7462, 2.366, -0.8409, -0.6654, -0.653, -0.7899, -1.0957, + -0.7149, -0.1072, -0.1967, -2.3416, -1.2609, -1.6375, -0.3576, 0.9413, -0.5694, 0.3954, 0.1383, -0.7477, + -0.8689, 1.8286, 0.851, -1.4793, -0.1597, 0.8541, 0.238, 1.4392, -0.5644, 0.3158, -1.0686, -0.1313, + -0.0181, 0.2438, -0.8801, 0.1413, -0.3587, 0.8002, -0.5982, -1.4301, -0.662, 0.7324, -0.725, 0.061, + 0.9293, -0.6902, -0.0125, -0.2089, -0.1664, 0.5428, 0.4245, -0.7901, 0.5665, 0.9044, 0.1948, -0.1723, + 1.2705, 1.0303, 1.2202, 1.3762, -0.2959, 0.7237, -1.2077, 0.7937, -0.6705, 0.9287, 1.0583, 0.0496, + -1.3118, 0.5556, 0.0459, -0.1324, -0.5513, -0.7409, -1.8002, 0.9892, 0.3619, -1.4522 + ], + "dims": [2, 8, 24], + "type": "float32" + } + ] + }, + { + "name": "T[2,4,8,6] T[1] T[16,3] T[16,3]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.0574, -0.1188, -0.9078, 0.3452, -0.5713, -0.2351, + 1.0076, -0.7529, -0.225, -0.4327, -1.5071, -0.4586, -0.848, 0.5266, 0.0299, -0.0498, 1.0651, 0.886, 0.464, + -0.4986, 0.1289, 2.7631, 0.1405, 1.1191, 0.3152, 1.7528, -0.765, 1.8299, -0.2784, -0.2719, -1.2944, + -0.0243, -0.2354, -0.7087, 1.1566, 0.4296, -1.1874, -0.7468, -0.932, -0.8579, -0.9647, -0.0991, -1.019, + 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, -0.5912, 1.1312, 0.7562, -1.2023, -0.5833, -0.4407, -1.9791, + 0.7787, -0.7749, -0.1398, 1.1414, -0.6354, -1.4702, -0.2134, -0.8707, 1.6159, -0.2356, 0.9444, 2.1134, + -0.9754, 0.1757, -0.1319, -0.2735, 0.3355, 0.1885, 2.1432, 0.8527, 0.0965, -0.0625, 0.8269, 0.5599, + -0.7776, 0.3339, 0.1759, 2.1108, 1.0702, 0.0195, 1.1213, -1.4873, -0.2043, -1.0466, -1.5772, 0.1036, + -0.3514, 0.2421, 0.6463, 0.873, -0.9276, 0.1766, 1.0224, -0.4826, -0.5421, -0.5342, -0.6413, 0.0352, + -0.4765, -0.0409, 1.1993, 0.5374, -0.193, 0.5937, 0.7203, 0.5061, 1.5192, -0.4897, 0.9231, -0.6008, + -1.1164, 0.2577, -0.7226, -0.9244, 1.8737, 1.0122, -1.4482, -0.0644, 0.3215, 0.5908, -1.4197, 0.8279, + -0.2969, 0.712, -0.2068, -0.1548, 0.1553, -0.0489, 0.343, 0.1264, 0.1519, -1.3639, -1.6593, 1.0311, + -1.9557, -0.1482, 1.7376, 2.2039, -0.6589, 1.3314, -0.4498, 0.5493, 0.0539, 0.2601, 0.857, 2.5211, + -0.0452, -0.3105, -0.9407, -0.0034, 1.5199, 0.2654, -0.1441, 0.5407, -1.5476, 0.6455, -1.1382, 0.6052, + 1.1904, 1.2195, -0.047, -1.0914, 1.0223, 0.2113, 0.0306, 0.3604, 0.3166, -0.8975, -0.6393, 0.6207, -0.169, + -0.5816, 1.2632, 0.0695, 1.1862, 1.8127, -1.4459, -0.2158, -0.9792, -1.4392, 0.6508, 0.8964, 0.5717, + -0.239, 0.6983, -1.3416, 0.2715, -0.5996, -1.0962, 1.6327, 1.3951, 0.8784, 0.3389, 0.5054, -0.6681, + -1.4382, 1.7547, -0.9605, -0.4558, -0.2937, 0.9238, -1.2185, 0.4138, 0.5033, 0.9174, -0.4792, 0.6756, + -0.3413, -0.2242, -0.2111, 0.6282, -0.1213, -1.1116, -0.7401, -0.7879, 0.0606, -2.3337, -1.0941, -0.3682, + -0.0163, -0.0645, -0.8101, 0.1415, 0.8238, 0.2262, 1.2912, 0.6488, 1.2114, 1.3569, -0.2852, 0.6051, + 0.2167, -0.2181, -1.6306, 1.4788, 1.2907, 0.3124, 0.7299, 1.422, 0.3375, 0.0438, -1.6105, 0.2979, 1.1537, + -1.5604, 1.2779, -1.2514, 1.8131, 1.4436, -0.4207, 0.022, -0.6807, -1.3306, 0.1667, -1.4055, 1.5895, + 1.0838, -0.9077, -0.806, -1.2603, -1.7245, -0.3533, -0.9421, -0.1776, 0.3992, 0.0551, 0.5873, -0.5887, + -1.4733, -0.8565, 0.74, 0.2983, 0.4718, -1.1936, 0.7928, -0.8665, 0.9468, 0.2754, -0.0261, -0.4618, + -0.5646, -1.0389, 0.5819, 1.8698, -0.2635, -2.0799, -0.6313, 0.409, -1.1458, 0.6056, 0.5763, -3.3558, + 0.2836, 0.6909, -0.7631, 1.5646, 0.3338, 0.7105, 0.4683, -0.6179, 0.0818, 0.7967, -2.9351, 2.4179, + -0.4026, 0.6451, 1.6845, -1.7142, -0.5319, -0.8848, 0.6513, 1.0002, -1.4699, -0.5033, 0.0553, 0.9265, + -0.8652, -0.0288, -0.2209, 1.1629, 0.0616, -1.3136, -0.2764, 0.0277, -0.1126, 1.3697, 0.0002, 1.5333, + -1.0556, -0.1254, 0.1527, 0.0784, -1.8848, -1.6165, 0.6179, 0.9905, -0.0729, 2.4451, -0.35, 1.3289, + -0.6494, 0.3478, 1.0038, -0.0488, -0.981, -1.3632, 0.0929, -1.7926, -0.2921, -0.0901, 0.6106, 2.3603, + 1.3908, -0.7917, -0.6734, -1.4254, 0.7013, 0.2414, 0.2551, -0.7457, 0.3133, 0.061, 0.6776, 0.4361, + -0.8052, 0.3955, 0.8988, 0.2342, -0.5866, -1.8219, 1.1079, 0.5795, -1.4249 + ], + "dims": [2, 4, 8, 6], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "int64" + }, + { + "data": [ + 1.0, 1.0, 1.0, 0.5403, 0.9989, 1.0, -0.4161, 0.9957, 1.0, -0.99, 0.9903, 1.0, -0.6536, 0.9828, 1.0, + 0.2837, 0.9732, 0.9999, 0.9602, 0.9615, 0.9999, 0.7539, 0.9477, 0.9999, -0.1455, 0.9318, 0.9999, -0.9111, + 0.914, 0.9998, -0.8391, 0.8942, 0.9998, 0.0044, 0.8725, 0.9997, 0.8439, 0.8488, 0.9997, 0.9074, 0.8234, + 0.9996, 0.1367, 0.7962, 0.9995, -0.7597, 0.7673, 0.9995 + ], + "dims": [16, 3], + "type": "float32" + }, + { + "data": [ + 0.0, 0.0, 0.0, 0.8415, 0.0464, 0.0022, 0.9093, 0.0927, 0.0043, 0.1411, 0.1388, 0.0065, -0.7568, 0.1846, + 0.0086, -0.9589, 0.23, 0.0108, -0.2794, 0.2749, 0.0129, 0.657, 0.3192, 0.0151, 0.9894, 0.3629, 0.0172, + 0.4121, 0.4057, 0.0194, -0.544, 0.4477, 0.0215, -1.0, 0.4887, 0.0237, -0.5366, 0.5286, 0.0259, 0.4202, + 0.5675, 0.028, 0.9906, 0.605, 0.0302, 0.6503, 0.6413, 0.0323 + ], + "dims": [16, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -0.8618, -0.0922, -0.9073, -0.7032, -0.5762, -0.2371, + -0.0259, -0.6099, -0.223, 1.0963, -1.5704, -0.4595, 0.8465, 0.3737, 0.0242, -0.0703, 1.1279, 0.8862, + 1.7878, -0.516, 0.1192, -2.1572, 0.046, 1.1202, 1.8441, 1.7698, -0.762, 0.2168, 0.1322, -0.2802, -1.4409, + -0.3413, -0.2409, -0.3188, 1.1054, 0.4265, -0.3315, -0.3997, -0.9304, -1.4268, -1.1526, -0.1132, -1.019, + 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, 0.6923, 1.1571, 0.7572, -1.1471, -0.5302, -0.4391, 0.9507, + 0.6696, -0.7721, -1.7415, 1.2087, -0.6387, 1.2275, -0.1786, -0.8767, -1.8072, -0.263, 0.9387, -1.4812, + -0.9082, 0.1728, -1.5132, -0.4489, 0.337, 0.146, 2.1002, 0.8437, -0.1534, 0.4321, 0.836, 0.5867, -1.3279, + 0.3201, 0.0125, 1.8157, 1.0745, 0.149, 1.3967, -1.4634, -0.1412, -0.6339, -1.5995, 0.1036, -0.3514, + 0.2421, 0.6463, 0.873, -0.9276, 0.5516, 1.0461, -0.4812, -0.1443, -0.4862, -0.6423, -1.1052, -0.5243, + -0.04, -0.4671, 0.4909, -0.1931, -0.8021, 0.7813, 0.5001, -1.4202, -0.385, 0.9263, -0.1541, -0.9266, + 0.2416, 0.927, -1.1146, 1.8758, 0.5955, -1.5452, -0.0491, -0.8794, 0.2418, -1.4203, 0.7372, -0.2429, 0.71, + -0.4299, -0.2304, 0.1645, -0.1366, 0.7604, 0.1514, 0.0824, -1.183, -1.6572, 1.0311, -1.9557, -0.1482, + 1.7376, 2.2039, -0.6589, 0.674, -0.4614, 0.5475, 1.1495, 0.2389, 0.8582, -0.1937, -0.0447, -0.3171, + 2.6839, -0.0076, 1.5185, -0.0443, -0.2323, 0.548, 1.5696, 0.6193, -1.1346, -0.4312, 1.3714, 1.2106, + -0.4272, -0.8529, 1.0328, 0.3635, 0.2362, 0.3672, -0.1128, -0.8664, -0.6354, 0.9489, -0.1816, -0.5968, + 1.0394, 0.0204, 1.1786, 2.0099, -0.9108, -0.2256, 0.4527, -1.8254, 0.6475, 0.8964, 0.5717, -0.239, 0.6983, + -1.3416, 0.2715, -1.4979, -1.1358, 1.632, 0.2493, 0.8266, 0.3424, -1.8059, -0.5762, -1.4362, -0.2706, + -1.0183, -0.462, 0.2323, 0.845, -1.2244, -0.4511, 0.6266, 0.9095, 0.1435, 0.7029, -0.3467, 0.5092, + -0.0828, 0.6253, -0.7899, -1.0957, -0.7149, -0.1072, -0.1967, -2.3416, -1.0686, -0.1313, -0.0181, 0.2438, + -0.8801, 0.1413, 0.1948, -0.1723, 1.2705, 1.0303, 1.2202, 1.3762, -0.2852, 0.6051, 0.2167, -0.2181, + -1.6306, 1.4788, -0.4992, 0.2964, 0.7298, 1.8544, 0.3516, 0.0454, 2.0891, 0.1782, 1.1591, -0.8151, 1.3, + -1.2464, -1.7981, 1.5241, -0.4121, 0.2341, -0.4737, -1.3333, 0.7113, -1.2138, 1.5964, -0.8346, -1.1515, + -0.7923, -1.2609, -1.6375, -0.3576, 0.9413, -0.5694, 0.3954, -0.3587, 0.8002, -0.5982, -1.4301, -0.662, + 0.7324, -0.2959, 0.7237, -1.2077, 0.7937, -0.6705, 0.9287, 0.2754, -0.0261, -0.4618, -0.5646, -1.0389, + 0.5819, 1.5415, -0.2822, -2.0774, 1.2323, 0.3963, -1.1503, -0.5099, 0.5098, -3.3525, 0.4326, 0.7414, + -0.7775, -1.615, 0.4164, 0.71, -0.2429, -0.5656, 0.0863, -0.8254, -3.0038, 2.4033, -0.3398, 0.0922, + 1.7053, 0.1383, -0.7477, -0.8689, 1.8286, 0.851, -1.4793, -0.725, 0.061, 0.9293, -0.6902, -0.0125, + -0.2089, 1.0583, 0.0496, -1.3118, 0.5556, 0.0459, -0.1324, 1.3697, 0.0002, 1.5333, -1.0556, -0.1254, + 0.1527, -0.4775, -1.9287, -1.6164, 0.3998, 0.902, -0.0764, -0.4271, -0.3807, 1.3245, 2.4936, 0.3139, + 1.0095, 0.0352, -0.7227, -1.3613, -0.0988, -1.9114, -0.3009, 1.1114, 0.7462, 2.366, -0.8409, -0.6654, + -0.653, -0.1597, 0.8541, 0.238, 1.4392, -0.5644, 0.3158, -0.1664, 0.5428, 0.4245, -0.7901, 0.5665, 0.9044, + -0.5513, -0.7409, -1.8002, 0.9892, 0.3619, -1.4522 + ], + "dims": [2, 4, 8, 6], + "type": "float32" + } + ] + }, + { + "name": "T[1,2,18] T[1,2] T[4,3] T[4,3]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, 1.0076, -0.7529, -0.225, -0.4327, -1.5071, -0.4586, + -0.8663, -0.2656, 0.1665, 0.7911, -0.932, -0.8579, -1.0574, -0.1188, -0.9078, 0.3452, -0.5713, -0.2351, + -0.848, 0.5266, -1.2944, -0.0243, -0.2354, -0.7087, -0.9647, -0.0991, -0.2994, -0.065, -1.572, -1.3211 + ], + "dims": [1, 2, 18], + "type": "float32" + }, + { + "data": [0, 1], + "dims": [1, 2], + "type": "int64" + }, + { + "data": [1.0, 1.0, 1.0, 0.5403, 0.9989, 1.0, -0.4161, 0.9957, 1.0, -0.99, 0.9903, 1.0], + "dims": [4, 3], + "type": "float32" + }, + { + "data": [0.0, 0.0, 0.0, 0.8415, 0.0464, 0.0022, 0.9093, 0.0927, 0.0043, 0.1411, 0.1388, 0.0065], + "dims": [4, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, 1.0076, -0.7529, -0.225, -0.4327, -1.5071, -0.4586, + -0.8663, -0.2656, 0.1665, 0.7911, -0.932, -0.8579, -0.8618, -0.0922, -0.9073, -0.7032, -0.5762, -0.2371, + -0.4377, 0.537, -1.2929, -0.7267, -0.2107, -0.7115, -0.4666, -0.0261, -0.2965, -0.8469, -1.5749, -1.3217 + ], + "dims": [1, 2, 18], + "type": "float32" + } + ] + }, + { + "name": "T[1,3,2,6] T[1,2] T[4,3] T[4,3]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.0574, -0.1188, -0.9078, 0.3452, -0.5713, -0.2351, + 1.0076, -0.7529, -0.225, -0.4327, -1.5071, -0.4586, -0.848, 0.5266, -1.2944, -0.0243, -0.2354, -0.7087, + -0.8663, -0.2656, 0.1665, 0.7911, -0.932, -0.8579, -0.9647, -0.0991, -0.2994, -0.065, -1.572, -1.3211 + ], + "dims": [1, 3, 2, 6], + "type": "float32" + }, + { + "data": [0, 1], + "dims": [1, 2], + "type": "int64" + }, + { + "data": [1.0, 1.0, 1.0, 0.5403, 0.9989, 1.0, -0.4161, 0.9957, 1.0, -0.99, 0.9903, 1.0], + "dims": [4, 3], + "type": "float32" + }, + { + "data": [0.0, 0.0, 0.0, 0.8415, 0.0464, 0.0022, 0.9093, 0.0927, 0.0043, 0.1411, 0.1388, 0.0065], + "dims": [4, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -0.8618, -0.0922, -0.9073, -0.7032, -0.5762, -0.2371, + 1.0076, -0.7529, -0.225, -0.4327, -1.5071, -0.4586, -0.4377, 0.537, -1.2929, -0.7267, -0.2107, -0.7115, + -0.8663, -0.2656, 0.1665, 0.7911, -0.932, -0.8579, -0.4666, -0.0261, -0.2965, -0.8469, -1.5749, -1.3217 + ], + "dims": [1, 3, 2, 6], + "type": "float32" + } + ] + } + ] + }, + { + "name": "RotaryEmbedding with interleaved pattern", + "operator": "RotaryEmbedding", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "interleaved", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[1,3,8] T[1] T[8,2] T[8,2]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -0.132, -0.2751, -0.235, 0.0937, -1.2188, 1.1676, -1.0574, -0.1188, + -0.7396, -1.2425, -0.1752, 0.699, -0.811, 0.6737, -1.1233, -0.0919, -0.6861, 0.7202, 0.1963, 0.6142 + ], + "dims": [1, 3, 8], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "int64" + }, + { + "data": [ + 1.0, 1.0, 0.5403, 0.9999, -0.4161, 0.9998, -0.99, 0.9996, -0.6536, 0.9992, 0.2837, 0.9988, 0.9602, 0.9982, + 0.7539, 0.9976 + ], + "dims": [8, 2], + "type": "float32" + }, + { + "data": [ + 0.0, 0.0, 0.8415, 0.01, 0.9093, 0.02, 0.1411, 0.03, -0.7568, 0.04, -0.9589, 0.05, -0.2794, 0.06, 0.657, + 0.0699 + ], + "dims": [8, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -0.132, -0.2751, -0.235, 0.0937, -1.6411, -0.3948, -1.0561, -0.1294, + 0.646, -1.2937, -0.1822, 0.6972, -0.2751, -1.0178, -1.1212, -0.1143, -0.3694, -0.9235, 0.184, 0.618 + ], + "dims": [1, 3, 8], + "type": "float32" + } + ] + }, + { + "name": "T[1,3,8] Scalar T[8,2] T[8,2]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -0.132, -0.2751, -0.235, 0.0937, -1.2188, 1.1676, -1.0574, -0.1188, + -0.7396, -1.2425, -0.1752, 0.699, -0.811, 0.6737, -1.1233, -0.0919, -0.6861, 0.7202, 0.1963, 0.6142 + ], + "dims": [1, 3, 8], + "type": "float32" + }, + { + "data": [0], + "dims": [], + "type": "int64" + }, + { + "data": [ + 1.0, 1.0, 0.5403, 0.9999, -0.4161, 0.9998, -0.99, 0.9996, -0.6536, 0.9992, 0.2837, 0.9988, 0.9602, 0.9982, + 0.7539, 0.9976 + ], + "dims": [8, 2], + "type": "float32" + }, + { + "data": [ + 0.0, 0.0, 0.8415, 0.01, 0.9093, 0.02, 0.1411, 0.03, -0.7568, 0.04, -0.9589, 0.05, -0.2794, 0.06, 0.657, + 0.0699 + ], + "dims": [8, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -0.132, -0.2751, -0.235, 0.0937, -1.6411, -0.3948, -1.0561, -0.1294, + 0.646, -1.2937, -0.1822, 0.6972, -0.2751, -1.0178, -1.1212, -0.1143, -0.3694, -0.9235, 0.184, 0.618 + ], + "dims": [1, 3, 8], + "type": "float32" + } + ] + }, + { + "name": "T[1,2,3,4] T[1] T[8,2] T[8,2]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.0574, -0.1188, -0.811, 0.6737, -1.1233, -0.0919, + -0.132, -0.2751, -0.235, 0.0937, -0.7396, -1.2425, -0.1752, 0.699, -0.6861, 0.7202, 0.1963, 0.6142 + ], + "dims": [1, 2, 3, 4], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "int64" + }, + { + "data": [ + 1.0, 1.0, 0.5403, 0.9999, -0.4161, 0.9998, -0.99, 0.9996, -0.6536, 0.9992, 0.2837, 0.9988, 0.9602, 0.9982, + 0.7539, 0.9976 + ], + "dims": [8, 2], + "type": "float32" + }, + { + "data": [ + 0.0, 0.0, 0.8415, 0.01, 0.9093, 0.02, 0.1411, 0.03, -0.7568, 0.04, -0.9589, 0.05, -0.2794, 0.06, 0.657, + 0.0699 + ], + "dims": [8, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.6411, -0.3948, -1.0561, -0.1294, -0.2751, -1.0178, -1.1212, -0.1143, + -0.132, -0.2751, -0.235, 0.0937, 0.646, -1.2937, -0.1822, 0.6972, -0.3694, -0.9235, 0.184, 0.618 + ], + "dims": [1, 2, 3, 4], + "type": "float32" + } + ] + }, + { + "name": "T[2,8,24] T[1] T[16,3] T[16,3]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.019, 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, + 0.1036, -0.3514, 0.2421, 0.6463, 0.873, -0.9276, 1.0311, -1.9557, -0.1482, 1.7376, 2.2039, -0.6589, + -1.0574, -0.1188, -0.9078, 0.3452, -0.5713, -0.2351, -0.5912, 1.1312, 0.7562, -1.2023, -0.5833, -0.4407, + 0.1766, 1.0224, -0.4826, -0.5421, -0.5342, -0.6413, 1.3314, -0.4498, 0.5493, 0.0539, 0.2601, 0.857, + 1.0076, -0.7529, -0.225, -0.4327, -1.5071, -0.4586, -1.9791, 0.7787, -0.7749, -0.1398, 1.1414, -0.6354, + 0.0352, -0.4765, -0.0409, 1.1993, 0.5374, -0.193, 2.5211, -0.0452, -0.3105, -0.9407, -0.0034, 1.5199, + -0.848, 0.5266, 0.0299, -0.0498, 1.0651, 0.886, -1.4702, -0.2134, -0.8707, 1.6159, -0.2356, 0.9444, + 0.5937, 0.7203, 0.5061, 1.5192, -0.4897, 0.9231, 0.2654, -0.1441, 0.5407, -1.5476, 0.6455, -1.1382, 0.464, + -0.4986, 0.1289, 2.7631, 0.1405, 1.1191, 2.1134, -0.9754, 0.1757, -0.1319, -0.2735, 0.3355, -0.6008, + -1.1164, 0.2577, -0.7226, -0.9244, 1.8737, 0.6052, 1.1904, 1.2195, -0.047, -1.0914, 1.0223, 0.3152, + 1.7528, -0.765, 1.8299, -0.2784, -0.2719, 0.1885, 2.1432, 0.8527, 0.0965, -0.0625, 0.8269, 1.0122, + -1.4482, -0.0644, 0.3215, 0.5908, -1.4197, 0.2113, 0.0306, 0.3604, 0.3166, -0.8975, -0.6393, -1.2944, + -0.0243, -0.2354, -0.7087, 1.1566, 0.4296, 0.5599, -0.7776, 0.3339, 0.1759, 2.1108, 1.0702, 0.8279, + -0.2969, 0.712, -0.2068, -0.1548, 0.1553, 0.6207, -0.169, -0.5816, 1.2632, 0.0695, 1.1862, -1.1874, + -0.7468, -0.932, -0.8579, -0.9647, -0.0991, 0.0195, 1.1213, -1.4873, -0.2043, -1.0466, -1.5772, -0.0489, + 0.343, 0.1264, 0.1519, -1.3639, -1.6593, 1.8127, -1.4459, -0.2158, -0.9792, -1.4392, 0.6508, 0.8964, + 0.5717, -0.239, 0.6983, -1.3416, 0.2715, -0.2852, 0.6051, 0.2167, -0.2181, -1.6306, 1.4788, 0.2754, + -0.0261, -0.4618, -0.5646, -1.0389, 0.5819, 1.3697, 0.0002, 1.5333, -1.0556, -0.1254, 0.1527, -0.5996, + -1.0962, 1.6327, 1.3951, 0.8784, 0.3389, 1.2907, 0.3124, 0.7299, 1.422, 0.3375, 0.0438, 1.8698, -0.2635, + -2.0799, -0.6313, 0.409, -1.1458, 0.0784, -1.8848, -1.6165, 0.6179, 0.9905, -0.0729, 0.5054, -0.6681, + -1.4382, 1.7547, -0.9605, -0.4558, -1.6105, 0.2979, 1.1537, -1.5604, 1.2779, -1.2514, 0.6056, 0.5763, + -3.3558, 0.2836, 0.6909, -0.7631, 2.4451, -0.35, 1.3289, -0.6494, 0.3478, 1.0038, -0.2937, 0.9238, + -1.2185, 0.4138, 0.5033, 0.9174, 1.8131, 1.4436, -0.4207, 0.022, -0.6807, -1.3306, 1.5646, 0.3338, 0.7105, + 0.4683, -0.6179, 0.0818, -0.0488, -0.981, -1.3632, 0.0929, -1.7926, -0.2921, -0.4792, 0.6756, -0.3413, + -0.2242, -0.2111, 0.6282, 0.1667, -1.4055, 1.5895, 1.0838, -0.9077, -0.806, 0.7967, -2.9351, 2.4179, + -0.4026, 0.6451, 1.6845, -0.0901, 0.6106, 2.3603, 1.3908, -0.7917, -0.6734, -0.1213, -1.1116, -0.7401, + -0.7879, 0.0606, -2.3337, -1.2603, -1.7245, -0.3533, -0.9421, -0.1776, 0.3992, -1.7142, -0.5319, -0.8848, + 0.6513, 1.0002, -1.4699, -1.4254, 0.7013, 0.2414, 0.2551, -0.7457, 0.3133, -1.0941, -0.3682, -0.0163, + -0.0645, -0.8101, 0.1415, 0.0551, 0.5873, -0.5887, -1.4733, -0.8565, 0.74, -0.5033, 0.0553, 0.9265, + -0.8652, -0.0288, -0.2209, 0.061, 0.6776, 0.4361, -0.8052, 0.3955, 0.8988, 0.8238, 0.2262, 1.2912, 0.6488, + 1.2114, 1.3569, 0.2983, 0.4718, -1.1936, 0.7928, -0.8665, 0.9468, 1.1629, 0.0616, -1.3136, -0.2764, + 0.0277, -0.1126, 0.2342, -0.5866, -1.8219, 1.1079, 0.5795, -1.4249 + ], + "dims": [2, 8, 24], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "int64" + }, + { + "data": [ + 1.0, 1.0, 1.0, 0.5403, 0.9989, 1.0, -0.4161, 0.9957, 1.0, -0.99, 0.9903, 1.0, -0.6536, 0.9828, 1.0, + 0.2837, 0.9732, 0.9999, 0.9602, 0.9615, 0.9999, 0.7539, 0.9477, 0.9999, -0.1455, 0.9318, 0.9999, -0.9111, + 0.914, 0.9998, -0.8391, 0.8942, 0.9998, 0.0044, 0.8725, 0.9997, 0.8439, 0.8488, 0.9997, 0.9074, 0.8234, + 0.9996, 0.1367, 0.7962, 0.9995, -0.7597, 0.7673, 0.9995 + ], + "dims": [16, 3], + "type": "float32" + }, + { + "data": [ + 0.0, 0.0, 0.0, 0.8415, 0.0464, 0.0022, 0.9093, 0.0927, 0.0043, 0.1411, 0.1388, 0.0065, -0.7568, 0.1846, + 0.0086, -0.9589, 0.23, 0.0108, -0.2794, 0.2749, 0.0129, 0.657, 0.3192, 0.0151, 0.9894, 0.3629, 0.0172, + 0.4121, 0.4057, 0.0194, -0.544, 0.4477, 0.0215, -1.0, 0.4887, 0.0237, -0.5366, 0.5286, 0.0259, 0.4202, + 0.5675, 0.028, 0.9906, 0.605, 0.0302, 0.6503, 0.6413, 0.0323 + ], + "dims": [16, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.019, 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, + 0.1036, -0.3514, 0.2421, 0.6463, 0.873, -0.9276, 1.0311, -1.9557, -0.1482, 1.7376, 2.2039, -0.6589, + -0.4713, -0.954, -0.9229, 0.3027, -0.5708, -0.2363, -1.2713, 0.1137, 0.8112, -1.1659, -0.5824, -0.4419, + -0.7649, 0.7011, -0.4569, -0.5639, -0.5328, -0.6424, 1.0979, 0.8773, 0.5462, 0.0793, 0.2582, 0.8576, + 0.2653, 1.2295, -0.1839, -0.4517, -1.5052, -0.4651, 0.1155, -2.1237, -0.7586, -0.211, 1.1441, -0.6304, + 0.4186, 0.2303, -0.1519, 1.1903, 0.5382, -0.1906, -1.008, 2.3112, -0.222, -0.9655, -0.0099, 1.5198, + 0.7652, -0.641, 0.0365, -0.0452, 1.0593, 0.8929, 1.4856, 0.0038, -1.0865, 1.4794, -0.2417, 0.9428, + -0.6894, -0.6293, 0.2904, 1.5747, -0.4956, 0.9199, -0.2424, 0.1801, 0.7503, -1.4576, 0.6529, -1.134, + -0.6807, -0.0252, -0.3834, 2.7394, 0.1308, 1.1203, -2.1196, -0.9618, 0.197, -0.0972, -0.2764, 0.3332, + -0.4522, 1.1844, 0.3867, -0.6626, -0.9405, 1.8656, 0.5053, -1.2361, 1.2072, 0.1789, -1.1002, 1.0129, + 1.7702, 0.1949, -1.1653, 1.6049, -0.2755, -0.2749, 2.1087, 0.4272, 0.8076, 0.29, -0.0714, 0.8261, -1.1016, + -1.3814, -0.1366, 0.2981, 0.606, -1.4132, 0.0893, -0.1939, 0.2779, 0.391, -0.8906, -0.6489, -1.2496, + 0.3383, -0.0315, -0.7461, 1.151, 0.4445, 0.3203, -0.9031, 0.2727, 0.2609, 2.0968, 1.0974, 0.712, -0.5164, + 0.7415, -0.0031, -0.1568, 0.1533, 0.5487, -0.3357, -0.9064, 1.0546, 0.0542, 1.187, -0.4045, -1.3431, + -0.6094, -1.1105, -0.9631, -0.1137, -0.7219, 0.8582, -1.3443, -0.6684, -1.0227, -1.5929, -0.2622, 0.2264, + 0.0713, 0.1843, -1.3387, -1.6797, 2.3165, 0.1009, 0.1081, -0.9969, -1.4488, 0.6291, 0.8964, 0.5717, + -0.239, 0.6983, -1.3416, 0.2715, -0.2852, 0.6051, 0.2167, -0.2181, -1.6306, 1.4788, 0.2754, -0.0261, + -0.4618, -0.5646, -1.0389, 0.5819, 1.3697, 0.0002, 1.5333, -1.0556, -0.1254, 0.1527, 0.5985, -1.0968, + 1.5662, 1.4693, 0.8776, 0.3408, 0.4345, 1.2549, 0.6631, 1.4543, 0.3374, 0.0445, 1.232, 1.4311, -2.0483, + -0.7272, 0.4114, -1.1449, 1.6283, -0.9524, -1.6435, 0.5422, 0.9907, -0.0708, 0.3972, 0.7376, -1.5947, + 1.6138, -0.9586, -0.46, 0.3993, -1.5884, 1.2934, -1.4467, 1.2833, -1.2459, -0.776, 0.3108, -3.3677, + -0.0287, 0.6942, -0.7601, -0.6993, 2.369, 1.3834, -0.5234, 0.3435, 1.0053, 0.1604, -0.956, -1.2641, + 0.2406, 0.4973, 0.9206, -1.9987, -1.1733, -0.4197, -0.0366, -0.672, -1.335, -1.596, -0.1097, 0.6386, + 0.5624, -0.6184, 0.0778, 0.1867, 0.9643, -1.3629, -0.0972, -1.7907, -0.3037, 0.8245, -0.0789, -0.294, + -0.2833, -0.2165, 0.6264, -1.1726, 0.7926, 1.3621, 1.3586, -0.9007, -0.8138, -2.7421, 1.3155, 2.4507, + 0.0507, 0.6305, 1.69, 0.521, -0.3309, 2.063, 1.8026, -0.7859, -0.6802, -1.1003, -0.199, -0.5391, -0.937, + 0.0857, -2.333, -2.0112, 0.7193, -0.1272, -0.9981, -0.1818, 0.3973, -0.9963, 1.4929, -1.0109, 0.4304, + 1.016, -1.459, 0.2682, 1.5658, 0.1762, 0.3038, -0.7491, 0.3052, -1.1534, -0.0478, 0.0021, -0.0665, + -0.8118, 0.131, 0.2171, 0.5485, -0.161, -1.5784, -0.866, 0.7289, -0.4678, 0.1937, 1.1287, -0.5772, + -0.0259, -0.2212, 0.2479, 0.6336, 0.6407, -0.6543, 0.3838, 0.9039, 0.4724, 0.7117, 1.0165, 1.027, 1.1908, + 1.375, -0.085, 0.5517, -1.3842, 0.3703, -0.8806, 0.9336, 0.8362, 0.8105, -1.1566, -0.6813, 0.0294, + -0.1122, 0.562, -0.2884, -2.0803, 0.4684, 0.6009, -1.416 + ], + "dims": [2, 8, 24], + "type": "float32" + } + ] + }, + { + "name": "T[2,8,24] Scalar T[16,3] T[16,3]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.019, 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, + 0.1036, -0.3514, 0.2421, 0.6463, 0.873, -0.9276, 1.0311, -1.9557, -0.1482, 1.7376, 2.2039, -0.6589, + -1.0574, -0.1188, -0.9078, 0.3452, -0.5713, -0.2351, -0.5912, 1.1312, 0.7562, -1.2023, -0.5833, -0.4407, + 0.1766, 1.0224, -0.4826, -0.5421, -0.5342, -0.6413, 1.3314, -0.4498, 0.5493, 0.0539, 0.2601, 0.857, + 1.0076, -0.7529, -0.225, -0.4327, -1.5071, -0.4586, -1.9791, 0.7787, -0.7749, -0.1398, 1.1414, -0.6354, + 0.0352, -0.4765, -0.0409, 1.1993, 0.5374, -0.193, 2.5211, -0.0452, -0.3105, -0.9407, -0.0034, 1.5199, + -0.848, 0.5266, 0.0299, -0.0498, 1.0651, 0.886, -1.4702, -0.2134, -0.8707, 1.6159, -0.2356, 0.9444, + 0.5937, 0.7203, 0.5061, 1.5192, -0.4897, 0.9231, 0.2654, -0.1441, 0.5407, -1.5476, 0.6455, -1.1382, 0.464, + -0.4986, 0.1289, 2.7631, 0.1405, 1.1191, 2.1134, -0.9754, 0.1757, -0.1319, -0.2735, 0.3355, -0.6008, + -1.1164, 0.2577, -0.7226, -0.9244, 1.8737, 0.6052, 1.1904, 1.2195, -0.047, -1.0914, 1.0223, 0.3152, + 1.7528, -0.765, 1.8299, -0.2784, -0.2719, 0.1885, 2.1432, 0.8527, 0.0965, -0.0625, 0.8269, 1.0122, + -1.4482, -0.0644, 0.3215, 0.5908, -1.4197, 0.2113, 0.0306, 0.3604, 0.3166, -0.8975, -0.6393, -1.2944, + -0.0243, -0.2354, -0.7087, 1.1566, 0.4296, 0.5599, -0.7776, 0.3339, 0.1759, 2.1108, 1.0702, 0.8279, + -0.2969, 0.712, -0.2068, -0.1548, 0.1553, 0.6207, -0.169, -0.5816, 1.2632, 0.0695, 1.1862, -1.1874, + -0.7468, -0.932, -0.8579, -0.9647, -0.0991, 0.0195, 1.1213, -1.4873, -0.2043, -1.0466, -1.5772, -0.0489, + 0.343, 0.1264, 0.1519, -1.3639, -1.6593, 1.8127, -1.4459, -0.2158, -0.9792, -1.4392, 0.6508, 0.8964, + 0.5717, -0.239, 0.6983, -1.3416, 0.2715, -0.2852, 0.6051, 0.2167, -0.2181, -1.6306, 1.4788, 0.2754, + -0.0261, -0.4618, -0.5646, -1.0389, 0.5819, 1.3697, 0.0002, 1.5333, -1.0556, -0.1254, 0.1527, -0.5996, + -1.0962, 1.6327, 1.3951, 0.8784, 0.3389, 1.2907, 0.3124, 0.7299, 1.422, 0.3375, 0.0438, 1.8698, -0.2635, + -2.0799, -0.6313, 0.409, -1.1458, 0.0784, -1.8848, -1.6165, 0.6179, 0.9905, -0.0729, 0.5054, -0.6681, + -1.4382, 1.7547, -0.9605, -0.4558, -1.6105, 0.2979, 1.1537, -1.5604, 1.2779, -1.2514, 0.6056, 0.5763, + -3.3558, 0.2836, 0.6909, -0.7631, 2.4451, -0.35, 1.3289, -0.6494, 0.3478, 1.0038, -0.2937, 0.9238, + -1.2185, 0.4138, 0.5033, 0.9174, 1.8131, 1.4436, -0.4207, 0.022, -0.6807, -1.3306, 1.5646, 0.3338, 0.7105, + 0.4683, -0.6179, 0.0818, -0.0488, -0.981, -1.3632, 0.0929, -1.7926, -0.2921, -0.4792, 0.6756, -0.3413, + -0.2242, -0.2111, 0.6282, 0.1667, -1.4055, 1.5895, 1.0838, -0.9077, -0.806, 0.7967, -2.9351, 2.4179, + -0.4026, 0.6451, 1.6845, -0.0901, 0.6106, 2.3603, 1.3908, -0.7917, -0.6734, -0.1213, -1.1116, -0.7401, + -0.7879, 0.0606, -2.3337, -1.2603, -1.7245, -0.3533, -0.9421, -0.1776, 0.3992, -1.7142, -0.5319, -0.8848, + 0.6513, 1.0002, -1.4699, -1.4254, 0.7013, 0.2414, 0.2551, -0.7457, 0.3133, -1.0941, -0.3682, -0.0163, + -0.0645, -0.8101, 0.1415, 0.0551, 0.5873, -0.5887, -1.4733, -0.8565, 0.74, -0.5033, 0.0553, 0.9265, + -0.8652, -0.0288, -0.2209, 0.061, 0.6776, 0.4361, -0.8052, 0.3955, 0.8988, 0.8238, 0.2262, 1.2912, 0.6488, + 1.2114, 1.3569, 0.2983, 0.4718, -1.1936, 0.7928, -0.8665, 0.9468, 1.1629, 0.0616, -1.3136, -0.2764, + 0.0277, -0.1126, 0.2342, -0.5866, -1.8219, 1.1079, 0.5795, -1.4249 + ], + "dims": [2, 8, 24], + "type": "float32" + }, + { + "data": [0], + "dims": [], + "type": "int64" + }, + { + "data": [ + 1.0, 1.0, 1.0, 0.5403, 0.9989, 1.0, -0.4161, 0.9957, 1.0, -0.99, 0.9903, 1.0, -0.6536, 0.9828, 1.0, + 0.2837, 0.9732, 0.9999, 0.9602, 0.9615, 0.9999, 0.7539, 0.9477, 0.9999, -0.1455, 0.9318, 0.9999, -0.9111, + 0.914, 0.9998, -0.8391, 0.8942, 0.9998, 0.0044, 0.8725, 0.9997, 0.8439, 0.8488, 0.9997, 0.9074, 0.8234, + 0.9996, 0.1367, 0.7962, 0.9995, -0.7597, 0.7673, 0.9995 + ], + "dims": [16, 3], + "type": "float32" + }, + { + "data": [ + 0.0, 0.0, 0.0, 0.8415, 0.0464, 0.0022, 0.9093, 0.0927, 0.0043, 0.1411, 0.1388, 0.0065, -0.7568, 0.1846, + 0.0086, -0.9589, 0.23, 0.0108, -0.2794, 0.2749, 0.0129, 0.657, 0.3192, 0.0151, 0.9894, 0.3629, 0.0172, + 0.4121, 0.4057, 0.0194, -0.544, 0.4477, 0.0215, -1.0, 0.4887, 0.0237, -0.5366, 0.5286, 0.0259, 0.4202, + 0.5675, 0.028, 0.9906, 0.605, 0.0302, 0.6503, 0.6413, 0.0323 + ], + "dims": [16, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.019, 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, + 0.1036, -0.3514, 0.2421, 0.6463, 0.873, -0.9276, 1.0311, -1.9557, -0.1482, 1.7376, 2.2039, -0.6589, + -0.4713, -0.954, -0.9229, 0.3027, -0.5708, -0.2363, -1.2713, 0.1137, 0.8112, -1.1659, -0.5824, -0.4419, + -0.7649, 0.7011, -0.4569, -0.5639, -0.5328, -0.6424, 1.0979, 0.8773, 0.5462, 0.0793, 0.2582, 0.8576, + 0.2653, 1.2295, -0.1839, -0.4517, -1.5052, -0.4651, 0.1155, -2.1237, -0.7586, -0.211, 1.1441, -0.6304, + 0.4186, 0.2303, -0.1519, 1.1903, 0.5382, -0.1906, -1.008, 2.3112, -0.222, -0.9655, -0.0099, 1.5198, + 0.7652, -0.641, 0.0365, -0.0452, 1.0593, 0.8929, 1.4856, 0.0038, -1.0865, 1.4794, -0.2417, 0.9428, + -0.6894, -0.6293, 0.2904, 1.5747, -0.4956, 0.9199, -0.2424, 0.1801, 0.7503, -1.4576, 0.6529, -1.134, + -0.6807, -0.0252, -0.3834, 2.7394, 0.1308, 1.1203, -2.1196, -0.9618, 0.197, -0.0972, -0.2764, 0.3332, + -0.4522, 1.1844, 0.3867, -0.6626, -0.9405, 1.8656, 0.5053, -1.2361, 1.2072, 0.1789, -1.1002, 1.0129, + 1.7702, 0.1949, -1.1653, 1.6049, -0.2755, -0.2749, 2.1087, 0.4272, 0.8076, 0.29, -0.0714, 0.8261, -1.1016, + -1.3814, -0.1366, 0.2981, 0.606, -1.4132, 0.0893, -0.1939, 0.2779, 0.391, -0.8906, -0.6489, -1.2496, + 0.3383, -0.0315, -0.7461, 1.151, 0.4445, 0.3203, -0.9031, 0.2727, 0.2609, 2.0968, 1.0974, 0.712, -0.5164, + 0.7415, -0.0031, -0.1568, 0.1533, 0.5487, -0.3357, -0.9064, 1.0546, 0.0542, 1.187, -0.4045, -1.3431, + -0.6094, -1.1105, -0.9631, -0.1137, -0.7219, 0.8582, -1.3443, -0.6684, -1.0227, -1.5929, -0.2622, 0.2264, + 0.0713, 0.1843, -1.3387, -1.6797, 2.3165, 0.1009, 0.1081, -0.9969, -1.4488, 0.6291, 0.8964, 0.5717, + -0.239, 0.6983, -1.3416, 0.2715, -0.2852, 0.6051, 0.2167, -0.2181, -1.6306, 1.4788, 0.2754, -0.0261, + -0.4618, -0.5646, -1.0389, 0.5819, 1.3697, 0.0002, 1.5333, -1.0556, -0.1254, 0.1527, 0.5985, -1.0968, + 1.5662, 1.4693, 0.8776, 0.3408, 0.4345, 1.2549, 0.6631, 1.4543, 0.3374, 0.0445, 1.232, 1.4311, -2.0483, + -0.7272, 0.4114, -1.1449, 1.6283, -0.9524, -1.6435, 0.5422, 0.9907, -0.0708, 0.3972, 0.7376, -1.5947, + 1.6138, -0.9586, -0.46, 0.3993, -1.5884, 1.2934, -1.4467, 1.2833, -1.2459, -0.776, 0.3108, -3.3677, + -0.0287, 0.6942, -0.7601, -0.6993, 2.369, 1.3834, -0.5234, 0.3435, 1.0053, 0.1604, -0.956, -1.2641, + 0.2406, 0.4973, 0.9206, -1.9987, -1.1733, -0.4197, -0.0366, -0.672, -1.335, -1.596, -0.1097, 0.6386, + 0.5624, -0.6184, 0.0778, 0.1867, 0.9643, -1.3629, -0.0972, -1.7907, -0.3037, 0.8245, -0.0789, -0.294, + -0.2833, -0.2165, 0.6264, -1.1726, 0.7926, 1.3621, 1.3586, -0.9007, -0.8138, -2.7421, 1.3155, 2.4507, + 0.0507, 0.6305, 1.69, 0.521, -0.3309, 2.063, 1.8026, -0.7859, -0.6802, -1.1003, -0.199, -0.5391, -0.937, + 0.0857, -2.333, -2.0112, 0.7193, -0.1272, -0.9981, -0.1818, 0.3973, -0.9963, 1.4929, -1.0109, 0.4304, + 1.016, -1.459, 0.2682, 1.5658, 0.1762, 0.3038, -0.7491, 0.3052, -1.1534, -0.0478, 0.0021, -0.0665, + -0.8118, 0.131, 0.2171, 0.5485, -0.161, -1.5784, -0.866, 0.7289, -0.4678, 0.1937, 1.1287, -0.5772, + -0.0259, -0.2212, 0.2479, 0.6336, 0.6407, -0.6543, 0.3838, 0.9039, 0.4724, 0.7117, 1.0165, 1.027, 1.1908, + 1.375, -0.085, 0.5517, -1.3842, 0.3703, -0.8806, 0.9336, 0.8362, 0.8105, -1.1566, -0.6813, 0.0294, + -0.1122, 0.562, -0.2884, -2.0803, 0.4684, 0.6009, -1.416 + ], + "dims": [2, 8, 24], + "type": "float32" + } + ] + }, + { + "name": "T[2,4,8,6] T[1] T[16,3] T[16,3]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -1.0574, -0.1188, -0.9078, 0.3452, -0.5713, -0.2351, + 1.0076, -0.7529, -0.225, -0.4327, -1.5071, -0.4586, -0.848, 0.5266, 0.0299, -0.0498, 1.0651, 0.886, 0.464, + -0.4986, 0.1289, 2.7631, 0.1405, 1.1191, 0.3152, 1.7528, -0.765, 1.8299, -0.2784, -0.2719, -1.2944, + -0.0243, -0.2354, -0.7087, 1.1566, 0.4296, -1.1874, -0.7468, -0.932, -0.8579, -0.9647, -0.0991, -1.019, + 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, -0.5912, 1.1312, 0.7562, -1.2023, -0.5833, -0.4407, -1.9791, + 0.7787, -0.7749, -0.1398, 1.1414, -0.6354, -1.4702, -0.2134, -0.8707, 1.6159, -0.2356, 0.9444, 2.1134, + -0.9754, 0.1757, -0.1319, -0.2735, 0.3355, 0.1885, 2.1432, 0.8527, 0.0965, -0.0625, 0.8269, 0.5599, + -0.7776, 0.3339, 0.1759, 2.1108, 1.0702, 0.0195, 1.1213, -1.4873, -0.2043, -1.0466, -1.5772, 0.1036, + -0.3514, 0.2421, 0.6463, 0.873, -0.9276, 0.1766, 1.0224, -0.4826, -0.5421, -0.5342, -0.6413, 0.0352, + -0.4765, -0.0409, 1.1993, 0.5374, -0.193, 0.5937, 0.7203, 0.5061, 1.5192, -0.4897, 0.9231, -0.6008, + -1.1164, 0.2577, -0.7226, -0.9244, 1.8737, 1.0122, -1.4482, -0.0644, 0.3215, 0.5908, -1.4197, 0.8279, + -0.2969, 0.712, -0.2068, -0.1548, 0.1553, -0.0489, 0.343, 0.1264, 0.1519, -1.3639, -1.6593, 1.0311, + -1.9557, -0.1482, 1.7376, 2.2039, -0.6589, 1.3314, -0.4498, 0.5493, 0.0539, 0.2601, 0.857, 2.5211, + -0.0452, -0.3105, -0.9407, -0.0034, 1.5199, 0.2654, -0.1441, 0.5407, -1.5476, 0.6455, -1.1382, 0.6052, + 1.1904, 1.2195, -0.047, -1.0914, 1.0223, 0.2113, 0.0306, 0.3604, 0.3166, -0.8975, -0.6393, 0.6207, -0.169, + -0.5816, 1.2632, 0.0695, 1.1862, 1.8127, -1.4459, -0.2158, -0.9792, -1.4392, 0.6508, 0.8964, 0.5717, + -0.239, 0.6983, -1.3416, 0.2715, -0.5996, -1.0962, 1.6327, 1.3951, 0.8784, 0.3389, 0.5054, -0.6681, + -1.4382, 1.7547, -0.9605, -0.4558, -0.2937, 0.9238, -1.2185, 0.4138, 0.5033, 0.9174, -0.4792, 0.6756, + -0.3413, -0.2242, -0.2111, 0.6282, -0.1213, -1.1116, -0.7401, -0.7879, 0.0606, -2.3337, -1.0941, -0.3682, + -0.0163, -0.0645, -0.8101, 0.1415, 0.8238, 0.2262, 1.2912, 0.6488, 1.2114, 1.3569, -0.2852, 0.6051, + 0.2167, -0.2181, -1.6306, 1.4788, 1.2907, 0.3124, 0.7299, 1.422, 0.3375, 0.0438, -1.6105, 0.2979, 1.1537, + -1.5604, 1.2779, -1.2514, 1.8131, 1.4436, -0.4207, 0.022, -0.6807, -1.3306, 0.1667, -1.4055, 1.5895, + 1.0838, -0.9077, -0.806, -1.2603, -1.7245, -0.3533, -0.9421, -0.1776, 0.3992, 0.0551, 0.5873, -0.5887, + -1.4733, -0.8565, 0.74, 0.2983, 0.4718, -1.1936, 0.7928, -0.8665, 0.9468, 0.2754, -0.0261, -0.4618, + -0.5646, -1.0389, 0.5819, 1.8698, -0.2635, -2.0799, -0.6313, 0.409, -1.1458, 0.6056, 0.5763, -3.3558, + 0.2836, 0.6909, -0.7631, 1.5646, 0.3338, 0.7105, 0.4683, -0.6179, 0.0818, 0.7967, -2.9351, 2.4179, + -0.4026, 0.6451, 1.6845, -1.7142, -0.5319, -0.8848, 0.6513, 1.0002, -1.4699, -0.5033, 0.0553, 0.9265, + -0.8652, -0.0288, -0.2209, 1.1629, 0.0616, -1.3136, -0.2764, 0.0277, -0.1126, 1.3697, 0.0002, 1.5333, + -1.0556, -0.1254, 0.1527, 0.0784, -1.8848, -1.6165, 0.6179, 0.9905, -0.0729, 2.4451, -0.35, 1.3289, + -0.6494, 0.3478, 1.0038, -0.0488, -0.981, -1.3632, 0.0929, -1.7926, -0.2921, -0.0901, 0.6106, 2.3603, + 1.3908, -0.7917, -0.6734, -1.4254, 0.7013, 0.2414, 0.2551, -0.7457, 0.3133, 0.061, 0.6776, 0.4361, + -0.8052, 0.3955, 0.8988, 0.2342, -0.5866, -1.8219, 1.1079, 0.5795, -1.4249 + ], + "dims": [2, 4, 8, 6], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "int64" + }, + { + "data": [ + 1.0, 1.0, 1.0, 0.5403, 0.9989, 1.0, -0.4161, 0.9957, 1.0, -0.99, 0.9903, 1.0, -0.6536, 0.9828, 1.0, + 0.2837, 0.9732, 0.9999, 0.9602, 0.9615, 0.9999, 0.7539, 0.9477, 0.9999, -0.1455, 0.9318, 0.9999, -0.9111, + 0.914, 0.9998, -0.8391, 0.8942, 0.9998, 0.0044, 0.8725, 0.9997, 0.8439, 0.8488, 0.9997, 0.9074, 0.8234, + 0.9996, 0.1367, 0.7962, 0.9995, -0.7597, 0.7673, 0.9995 + ], + "dims": [16, 3], + "type": "float32" + }, + { + "data": [ + 0.0, 0.0, 0.0, 0.8415, 0.0464, 0.0022, 0.9093, 0.0927, 0.0043, 0.1411, 0.1388, 0.0065, -0.7568, 0.1846, + 0.0086, -0.9589, 0.23, 0.0108, -0.2794, 0.2749, 0.0129, 0.657, 0.3192, 0.0151, 0.9894, 0.3629, 0.0172, + 0.4121, 0.4057, 0.0194, -0.544, 0.4477, 0.0215, -1.0, 0.4887, 0.0237, -0.5366, 0.5286, 0.0259, 0.4202, + 0.5675, 0.028, 0.9906, 0.605, 0.0302, 0.6503, 0.6413, 0.0323 + ], + "dims": [16, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, -0.4713, -0.954, -0.9229, 0.3027, -0.5708, -0.2363, + 0.2653, 1.2295, -0.1839, -0.4517, -1.5052, -0.4651, 0.7652, -0.641, 0.0365, -0.0452, 1.0593, 0.8929, + -0.6807, -0.0252, -0.3834, 2.7394, 0.1308, 1.1203, 1.7702, 0.1949, -1.1653, 1.6049, -0.2755, -0.2749, + -1.2496, 0.3383, -0.0315, -0.7461, 1.151, 0.4445, -0.4045, -1.3431, -0.6094, -1.1105, -0.9631, -0.1137, + -1.019, 0.3157, -1.6036, 1.8493, 0.0447, 1.5853, -1.2713, 0.1137, 0.8112, -1.1659, -0.5824, -0.4419, + 0.1155, -2.1237, -0.7586, -0.211, 1.1441, -0.6304, 1.4856, 0.0038, -1.0865, 1.4794, -0.2417, 0.9428, + -2.1196, -0.9618, 0.197, -0.0972, -0.2764, 0.3332, 2.1087, 0.4272, 0.8076, 0.29, -0.0714, 0.8261, 0.3203, + -0.9031, 0.2727, 0.2609, 2.0968, 1.0974, -0.7219, 0.8582, -1.3443, -0.6684, -1.0227, -1.5929, 0.1036, + -0.3514, 0.2421, 0.6463, 0.873, -0.9276, -0.7649, 0.7011, -0.4569, -0.5639, -0.5328, -0.6424, 0.4186, + 0.2303, -0.1519, 1.1903, 0.5382, -0.1906, -0.6894, -0.6293, 0.2904, 1.5747, -0.4956, 0.9199, -0.4522, + 1.1844, 0.3867, -0.6626, -0.9405, 1.8656, -1.1016, -1.3814, -0.1366, 0.2981, 0.606, -1.4132, 0.712, + -0.5164, 0.7415, -0.0031, -0.1568, 0.1533, -0.2622, 0.2264, 0.0713, 0.1843, -1.3387, -1.6797, 1.0311, + -1.9557, -0.1482, 1.7376, 2.2039, -0.6589, 1.0979, 0.8773, 0.5462, 0.0793, 0.2582, 0.8576, -1.008, 2.3112, + -0.222, -0.9655, -0.0099, 1.5198, -0.2424, 0.1801, 0.7503, -1.4576, 0.6529, -1.134, 0.5053, -1.2361, + 1.2072, 0.1789, -1.1002, 1.0129, 0.0893, -0.1939, 0.2779, 0.391, -0.8906, -0.6489, 0.5487, -0.3357, + -0.9064, 1.0546, 0.0542, 1.187, 2.3165, 0.1009, 0.1081, -0.9969, -1.4488, 0.6291, 0.8964, 0.5717, -0.239, + 0.6983, -1.3416, 0.2715, 0.5985, -1.0968, 1.5662, 1.4693, 0.8776, 0.3408, 0.3972, 0.7376, -1.5947, 1.6138, + -0.9586, -0.46, 0.1604, -0.956, -1.2641, 0.2406, 0.4973, 0.9206, 0.8245, -0.0789, -0.294, -0.2833, + -0.2165, 0.6264, -1.1003, -0.199, -0.5391, -0.937, 0.0857, -2.333, -1.1534, -0.0478, 0.0021, -0.0665, + -0.8118, 0.131, 0.4724, 0.7117, 1.0165, 1.027, 1.1908, 1.375, -0.2852, 0.6051, 0.2167, -0.2181, -1.6306, + 1.4788, 0.4345, 1.2549, 0.6631, 1.4543, 0.3374, 0.0445, 0.3993, -1.5884, 1.2934, -1.4467, 1.2833, -1.2459, + -1.9987, -1.1733, -0.4197, -0.0366, -0.672, -1.335, -1.1726, 0.7926, 1.3621, 1.3586, -0.9007, -0.8138, + -2.0112, 0.7193, -0.1272, -0.9981, -0.1818, 0.3973, 0.2171, 0.5485, -0.161, -1.5784, -0.866, 0.7289, + -0.085, 0.5517, -1.3842, 0.3703, -0.8806, 0.9336, 0.2754, -0.0261, -0.4618, -0.5646, -1.0389, 0.5819, + 1.232, 1.4311, -2.0483, -0.7272, 0.4114, -1.1449, -0.776, 0.3108, -3.3677, -0.0287, 0.6942, -0.7601, + -1.596, -0.1097, 0.6386, 0.5624, -0.6184, 0.0778, -2.7421, 1.3155, 2.4507, 0.0507, 0.6305, 1.69, -0.9963, + 1.4929, -1.0109, 0.4304, 1.016, -1.459, -0.4678, 0.1937, 1.1287, -0.5772, -0.0259, -0.2212, 0.8362, + 0.8105, -1.1566, -0.6813, 0.0294, -0.1122, 1.3697, 0.0002, 1.5333, -1.0556, -0.1254, 0.1527, 1.6283, + -0.9524, -1.6435, 0.5422, 0.9907, -0.0708, -0.6993, 2.369, 1.3834, -0.5234, 0.3435, 1.0053, 0.1867, + 0.9643, -1.3629, -0.0972, -1.7907, -0.3037, 0.521, -0.3309, 2.063, 1.8026, -0.7859, -0.6802, 0.2682, + 1.5658, 0.1762, 0.3038, -0.7491, 0.3052, 0.2479, 0.6336, 0.6407, -0.6543, 0.3838, 0.9039, 0.562, -0.2884, + -2.0803, 0.4684, 0.6009, -1.416 + ], + "dims": [2, 4, 8, 6], + "type": "float32" + } + ] + } + ] + }, + { + "name": "RotaryEmbedding with custom rotary dim", + "operator": "RotaryEmbedding", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "rotary_embedding_dim", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "T[1,2,6] T[1,2] T[2,2] T[2,2]", + "inputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, 1.0076, -0.7529, -0.225, -0.4327, -1.5071, -0.4586 + ], + "dims": [1, 2, 6], + "type": "float32" + }, + { + "data": [0, 1], + "dims": [1, 2], + "type": "int64" + }, + { + "data": [1.0, 1.0, 1.0, 0.5403], + "dims": [2, 2], + "type": "float32" + }, + { + "data": [0.0, 0.0, 0.0, 0.8415], + "dims": [2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.0408, 0.9166, -1.3042, -1.1097, -1.2188, 1.1676, 1.0076, -0.0427, -0.225, -0.8673, -1.5071, -0.4586 + ], + "dims": [1, 2, 6], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index e96a0aa045..3a4eac7890 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1385,6 +1385,7 @@ "pow_int32.jsonc", "pow-big-number.jsonc", "reshape.jsonc", + "rotary-embedding.jsonc", "skip-layer-norm.jsonc", "slice.jsonc", //"softmax.jsonc", diff --git a/onnxruntime/contrib_ops/js/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/js/bert/rotary_embedding.cc new file mode 100644 index 0000000000..7ee168e27f --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/rotary_embedding.cc @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "rotary_embedding.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX(RotaryEmbedding, kMSDomain, 1, kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()) + .TypeConstraint("M", DataTypeImpl::GetTensorType()), + RotaryEmbedding); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/rotary_embedding.h b/onnxruntime/contrib_ops/js/bert/rotary_embedding.h new file mode 100644 index 0000000000..376b4e7082 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/rotary_embedding.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; + +class RotaryEmbedding final : public JsKernel { + public: + explicit RotaryEmbedding(const OpKernelInfo& info) : JsKernel(info) { + int64_t interleaved = info.GetAttrOrDefault("interleaved", 0); + int64_t num_heads = info.GetAttrOrDefault("num_heads", 0); + int64_t rotary_embedding_dim = info.GetAttrOrDefault("rotary_embedding_dim", 0); + float scale = info.GetAttrOrDefault("scale", 1.0); + + JSEP_INIT_KERNEL_ATTRIBUTE(RotaryEmbedding, ({ + "interleaved" : !!$1, + "numHeads" : $2, + "rotaryEmbeddingDim" : $3, + "scale" : $4, + }), + static_cast(interleaved), static_cast(num_heads), + static_cast(rotary_embedding_dim), scale); + } +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index 25e7567a2e..4536f662bf 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -15,6 +15,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedC class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, RotaryEmbedding); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); template <> @@ -33,6 +34,7 @@ Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo};