mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-27 22:45:57 +00:00
[DML EP] Add RotaryEmbedding (#18158)
This is a graph implementation of RotaryEmbedding since there's no time to add it to DML before 1.16.2, but it eventually should move into DirectML since we're bandwidth-bound.
This commit is contained in:
parent
9868a71373
commit
800ae7742c
7 changed files with 461 additions and 8 deletions
|
|
@ -1247,6 +1247,7 @@ Do not modify directly.*
|
|||
|QLinearSigmoid|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* X_zero_point:**T**<br> *in* Y_scale:**tensor(float)**<br> *in* Y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|
||||
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float), tensor(float16), tensor(int32)<br/> **T2** = tensor(int8), tensor(uint8)|
|
||||
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
|
||||
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
| |
|
||||
| |
|
||||
|
|
|
|||
|
|
@ -0,0 +1,436 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "precomp.h"
|
||||
|
||||
// This operator is easier to understand by looking at a python implementation of the non-interleaved version:
|
||||
//
|
||||
// def rotate_half(x):
|
||||
// """Rotates half the hidden dims of the input."""
|
||||
// half_dim = x.shape[-1] // 2
|
||||
// x1 = x[..., :half_dim]
|
||||
// x2 = x[..., half_dim:]
|
||||
// return np.concatenate((-x2, x1), dim=-1)
|
||||
//
|
||||
//
|
||||
// def apply_rope(x, cos, sin, position_ids):
|
||||
// cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
// sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
// x_embed = (x * cos) + (rotate_half(x) * sin)
|
||||
// return x_embed
|
||||
//
|
||||
// For the non-interleaved version, we multiply the cos cache by the non-rotated input tensor while we multiply the sin cache
|
||||
// by the rotated input tensor. Rotating the tensor means slicing it in half on the head dimension and swapping the 2 halves.
|
||||
//
|
||||
// The interleaved version is very similar but instead of swapping 2 halves, we swap every pair of adjacent elements and we swap
|
||||
// the sign of every adjacent element.
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
class DmlOperatorRotaryEmbedding : public DmlOperator
|
||||
{
|
||||
public:
|
||||
DmlOperatorRotaryEmbedding(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo)
|
||||
{
|
||||
enum InputIndex : uint32_t
|
||||
{
|
||||
inputDataIndex,
|
||||
positionIdsIndex,
|
||||
cosCacheIndex,
|
||||
sinCacheIndex,
|
||||
};
|
||||
|
||||
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 4);
|
||||
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
|
||||
|
||||
// When positionIds is a scalar, it represents the start offset for each sequence
|
||||
const bool positionIdsIsOffset = kernelInfo.GetInputTensorDimensionCount(positionIdsIndex) == 1;
|
||||
|
||||
Initialize(kernelInfo);
|
||||
|
||||
ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[inputDataIndex].GetDimensionCount() == 4);
|
||||
ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[positionIdsIndex].GetDimensionCount() == 4);
|
||||
ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[cosCacheIndex].GetDimensionCount() == 4);
|
||||
ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[sinCacheIndex].GetDimensionCount() == 4);
|
||||
|
||||
ML_CHECK_VALID_ARGUMENT(m_outputTensorDescs[0].GetDimensionCount() == 4);
|
||||
|
||||
ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[cosCacheIndex].GetSizes() == m_inputTensorDescs[sinCacheIndex].GetSizes());
|
||||
const uint32_t headSize = m_inputTensorDescs[cosCacheIndex].GetSizes().back() * 2;
|
||||
|
||||
// The last dimension of the data is the hidden size, so it must be divisible by the head size
|
||||
ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[inputDataIndex].GetSizes().back() % headSize == 0);
|
||||
|
||||
// We resize the data to be of shape [batchSize, sequenceLength, numHeads, headSize]
|
||||
const auto inputDataSizes = m_inputTensorDescs[inputDataIndex].GetSizes();
|
||||
const uint32_t batchSize = inputDataSizes[1];
|
||||
const uint32_t sequenceLength = inputDataSizes[2];
|
||||
const uint32_t numHeads = inputDataSizes[3] / headSize;
|
||||
|
||||
const auto cosCacheSizes = m_inputTensorDescs[cosCacheIndex].GetSizes();
|
||||
const uint32_t maxSequenceLength = cosCacheSizes[cosCacheSizes.size() - 2];
|
||||
|
||||
if (sequenceLength > maxSequenceLength)
|
||||
{
|
||||
ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported");
|
||||
}
|
||||
|
||||
const bool interleaved = gsl::narrow_cast<bool>(kernelInfo.GetOptionalAttribute<int64_t>(AttrName::Interleaved, 0));
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType;
|
||||
|
||||
// Splitting the hiddenSize into numHeads and headSize dimensions makes it easier for DML to handle
|
||||
const std::array<uint32_t, 4> inputOutputShape = {batchSize, sequenceLength, numHeads, headSize};
|
||||
TensorDesc inputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape);
|
||||
const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc();
|
||||
|
||||
// Copy the input to preserve its real input shape in the graph without reshaping it. This will disappear during DML's graph compilation phase.
|
||||
DML_SCALE_BIAS scaleBias = {1.0f, 0.0f};
|
||||
|
||||
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC copyInputDesc{};
|
||||
copyInputDesc.InputTensor = &inputOutputDmlTensorDesc;
|
||||
copyInputDesc.OutputTensor = &inputOutputDmlTensorDesc;
|
||||
copyInputDesc.ScaleBias = &scaleBias;
|
||||
const DML_OPERATOR_DESC copyInputDmlDesc = {DML_OPERATOR_ELEMENT_WISE_IDENTITY, ©InputDesc};
|
||||
|
||||
// Split the input data into 2 equal parts
|
||||
const std::vector<uint32_t> inputDataTensorShape = interleaved
|
||||
? std::vector<uint32_t>({batchSize, sequenceLength, numHeads, headSize / 2, 2})
|
||||
: std::vector<uint32_t>({batchSize, sequenceLength, numHeads, 2, headSize / 2});
|
||||
|
||||
const std::vector<uint32_t> splitInputDataTensorShape = interleaved
|
||||
? std::vector<uint32_t>({batchSize, sequenceLength, numHeads, headSize / 2, 1})
|
||||
: std::vector<uint32_t>({batchSize, sequenceLength, numHeads, 1, headSize / 2});
|
||||
|
||||
TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape);
|
||||
const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc();
|
||||
|
||||
TensorDesc splitInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputDataTensorShape);
|
||||
const std::array<DML_TENSOR_DESC, 2> splitInputDataDmlTensorDescs = {splitInputDataTensorDesc.GetDmlDesc(), splitInputDataTensorDesc.GetDmlDesc()};
|
||||
|
||||
DML_SPLIT_OPERATOR_DESC splitInputDesc{};
|
||||
splitInputDesc.InputTensor = &inputDataDmlTensorDesc;
|
||||
splitInputDesc.OutputTensors = splitInputDataDmlTensorDescs.data();
|
||||
splitInputDesc.OutputCount = gsl::narrow_cast<uint32_t>(splitInputDataDmlTensorDescs.size());
|
||||
splitInputDesc.Axis = interleaved
|
||||
? gsl::narrow_cast<uint32_t>(splitInputDataTensorShape.size()) - 1
|
||||
: gsl::narrow_cast<uint32_t>(splitInputDataTensorShape.size()) - 2;
|
||||
|
||||
const DML_OPERATOR_DESC splitInputDmlDesc = {DML_OPERATOR_SPLIT, &splitInputDesc};
|
||||
|
||||
// Swap the 2 halves and join them together
|
||||
DML_JOIN_OPERATOR_DESC joinInputDesc{};
|
||||
joinInputDesc.InputTensors = splitInputDataDmlTensorDescs.data();
|
||||
joinInputDesc.OutputTensor = &inputDataDmlTensorDesc;
|
||||
joinInputDesc.Axis = splitInputDesc.Axis;
|
||||
joinInputDesc.InputCount = gsl::narrow_cast<uint32_t>(splitInputDataDmlTensorDescs.size());
|
||||
const DML_OPERATOR_DESC joinInputDmlDesc = {DML_OPERATOR_JOIN, &joinInputDesc};
|
||||
|
||||
// We generate a sequence from 0 to sequenceLength and add the offset to it
|
||||
const std::array<uint32_t, 4> positionIdsRangeShape = {1, 1, 1, sequenceLength};
|
||||
auto positionIdsDataType = kernelInfo.GetInputEdgeDescription(positionIdsIndex).tensorDataType;
|
||||
TensorDesc positionIdsRangeTensorDesc = TensorDesc::ConstructDefaultTensorDesc(positionIdsDataType, positionIdsRangeShape);
|
||||
const DML_TENSOR_DESC positionIdsRangeDmlTensorDesc = positionIdsRangeTensorDesc.GetDmlDesc();
|
||||
|
||||
const std::array<uint32_t, 4> broadcastedPositionIdsRangeShape = {1, 1, batchSize, sequenceLength};
|
||||
TensorDesc broadcastedPositionIdsRangeTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(positionIdsDataType, broadcastedPositionIdsRangeShape, positionIdsRangeShape);
|
||||
const DML_TENSOR_DESC broadcastedPositionIdsRangeDmlTensorDesc = broadcastedPositionIdsRangeTensorDesc.GetDmlDesc();
|
||||
|
||||
const std::array<uint32_t, 4> broadcastedOffsetShape = {1, 1, batchSize, sequenceLength};
|
||||
TensorDesc broadcastedOffsetTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(positionIdsDataType, broadcastedOffsetShape, m_inputTensorDescs[positionIdsIndex].GetSizes());
|
||||
const DML_TENSOR_DESC broadcastedOffsetDmlTensorDesc = broadcastedOffsetTensorDesc.GetDmlDesc();
|
||||
|
||||
TensorDesc offsetPositionIdsTensorDesc = TensorDesc::ConstructDefaultTensorDesc(positionIdsDataType, broadcastedOffsetShape);
|
||||
const DML_TENSOR_DESC offsetPositionIdsRangeDmlTensorDesc = offsetPositionIdsTensorDesc.GetDmlDesc();
|
||||
|
||||
DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC positionIdsRange{};
|
||||
DML_ELEMENT_WISE_ADD_OPERATOR_DESC positionIdsAddOffset{};
|
||||
if (positionIdsIsOffset)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(positionIdsDataType == MLOperatorTensorDataType::Int64);
|
||||
positionIdsRange.ValueDataType = DML_TENSOR_DATA_TYPE_INT64;
|
||||
positionIdsRange.ValueDelta.Int64 = 1;
|
||||
positionIdsRange.OutputTensor = &positionIdsRangeDmlTensorDesc;
|
||||
|
||||
positionIdsAddOffset.ATensor = &broadcastedPositionIdsRangeDmlTensorDesc;
|
||||
positionIdsAddOffset.BTensor = &broadcastedOffsetDmlTensorDesc;
|
||||
positionIdsAddOffset.OutputTensor = &offsetPositionIdsRangeDmlTensorDesc;
|
||||
}
|
||||
const DML_OPERATOR_DESC positionIdsRangeDmlDesc = {DML_OPERATOR_FILL_VALUE_SEQUENCE, &positionIdsRange};
|
||||
const DML_OPERATOR_DESC positionIdsAddOffsetDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &positionIdsAddOffset};
|
||||
|
||||
// Gather the cos/sin values based on the position ids
|
||||
const std::array<uint32_t, 4> gatheredCosSinShape = {1, batchSize, sequenceLength, headSize / 2};
|
||||
TensorDesc gatheredCosSinTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, gatheredCosSinShape);
|
||||
const DML_TENSOR_DESC gatheredCosSinDmlTensorDesc = gatheredCosSinTensorDesc.GetDmlDesc();
|
||||
|
||||
DML_GATHER_OPERATOR_DESC gatherCosSinDesc{};
|
||||
gatherCosSinDesc.InputTensor = &inputDescs[cosCacheIndex];
|
||||
gatherCosSinDesc.IndicesTensor = positionIdsIsOffset ? &offsetPositionIdsRangeDmlTensorDesc : &inputDescs[positionIdsIndex];
|
||||
gatherCosSinDesc.OutputTensor = &gatheredCosSinDmlTensorDesc;
|
||||
gatherCosSinDesc.Axis = 2;
|
||||
gatherCosSinDesc.IndexDimensions = 2;
|
||||
const DML_OPERATOR_DESC gatherCosSinDmlDesc {DML_OPERATOR_GATHER, &gatherCosSinDesc};
|
||||
|
||||
// After gathering cos/sin, reshape and broadcast them to match the number of heads of the input data
|
||||
const std::vector<uint32_t> reshapedCosSinShape = interleaved
|
||||
? std::vector<uint32_t>({batchSize, sequenceLength, 1, headSize / 2, 1})
|
||||
: std::vector<uint32_t>({batchSize, sequenceLength, 1, 1, headSize / 2});
|
||||
TensorDesc broadcastedCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, inputDataTensorShape, reshapedCosSinShape);
|
||||
const DML_TENSOR_DESC broadcastedCosSinDmlTensorDesc = broadcastedCosSinTensorDesc.GetDmlDesc();
|
||||
|
||||
// Create a vector that contains the sign values {-1, 1}
|
||||
const std::array<uint32_t, 1> signTensorShape = {2};
|
||||
TensorDesc signTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, signTensorShape);
|
||||
const DML_TENSOR_DESC signDmlTensorDesc = signTensorDesc.GetDmlDesc();
|
||||
|
||||
DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC signRange{};
|
||||
signRange.OutputTensor = &signDmlTensorDesc;
|
||||
if (dataType == MLOperatorTensorDataType::Float16)
|
||||
{
|
||||
const auto valueStart = static_cast<MLFloat16>(-1.0f);
|
||||
const auto valueDelta = static_cast<MLFloat16>(2.0f);
|
||||
memcpy(signRange.ValueStart.Bytes, reinterpret_cast<const BYTE*>(&valueStart), sizeof(valueStart));
|
||||
memcpy(signRange.ValueDelta.Bytes, reinterpret_cast<const BYTE*>(&valueDelta), sizeof(valueDelta));
|
||||
signRange.ValueDataType = DML_TENSOR_DATA_TYPE_FLOAT16;
|
||||
}
|
||||
else
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(dataType == MLOperatorTensorDataType::Float);
|
||||
signRange.ValueStart.Float32 = -1.0f;
|
||||
signRange.ValueDelta.Float32 = 2.0f;
|
||||
signRange.ValueDataType = DML_TENSOR_DATA_TYPE_FLOAT32;
|
||||
}
|
||||
const DML_OPERATOR_DESC signRangeDmlDesc = {DML_OPERATOR_FILL_VALUE_SEQUENCE, &signRange};
|
||||
|
||||
// Multiply the broadcasted sign values with the rotated input
|
||||
const std::vector<uint32_t> reshapedSignShape = interleaved
|
||||
? std::vector<uint32_t>({1, 1, 1, 1, 2})
|
||||
: std::vector<uint32_t>({1, 1, 1, 2, 1});
|
||||
TensorDesc broadcastedSignCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, inputDataTensorShape, reshapedSignShape);
|
||||
const DML_TENSOR_DESC broadcastedSignDmlTensorDesc = broadcastedSignCosSinTensorDesc.GetDmlDesc();
|
||||
|
||||
DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulSignDesc{};
|
||||
mulSignDesc.ATensor = &inputDataDmlTensorDesc;
|
||||
mulSignDesc.BTensor = &broadcastedSignDmlTensorDesc;
|
||||
mulSignDesc.OutputTensor = &inputDataDmlTensorDesc;
|
||||
const DML_OPERATOR_DESC mulSignDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulSignDesc};
|
||||
|
||||
// Multiply the non-rotated data with the cos and the rotated data with the sin
|
||||
DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulCosSinDesc{};
|
||||
mulCosSinDesc.ATensor = &inputDataDmlTensorDesc;
|
||||
mulCosSinDesc.BTensor = &broadcastedCosSinDmlTensorDesc;
|
||||
mulCosSinDesc.OutputTensor = &inputDataDmlTensorDesc;
|
||||
const DML_OPERATOR_DESC mulCosSinDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulCosSinDesc};
|
||||
|
||||
// Add the multiplied cos and sin values together
|
||||
DML_ELEMENT_WISE_ADD_OPERATOR_DESC addDesc{};
|
||||
addDesc.ATensor = &inputOutputDmlTensorDesc;
|
||||
addDesc.BTensor = &inputOutputDmlTensorDesc;
|
||||
addDesc.OutputTensor = &inputOutputDmlTensorDesc;
|
||||
const DML_OPERATOR_DESC addDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &addDesc};
|
||||
|
||||
// Construct the graph
|
||||
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
|
||||
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
|
||||
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
|
||||
|
||||
std::vector<const DML_OPERATOR_DESC*> opDescs = {
|
||||
©InputDmlDesc, // Copy the input data to preseve the real input shape
|
||||
&splitInputDmlDesc, // Split the input data
|
||||
&gatherCosSinDmlDesc, // Gather cos
|
||||
&gatherCosSinDmlDesc, // Gather sin
|
||||
&signRangeDmlDesc, // Generate the signs
|
||||
|
||||
&joinInputDmlDesc, // Join the split data
|
||||
&mulCosSinDmlDesc, // Multiply cos with the non-rotated data
|
||||
&mulCosSinDmlDesc, // Multiply sin with the rotated data
|
||||
&mulSignDmlDesc, // Multiply the sign with the rotated data
|
||||
&addDmlDesc, // Add the rotated cos and non-rotated sin parts together
|
||||
};
|
||||
|
||||
enum NodeIndex : uint32_t
|
||||
{
|
||||
copyInputOpIndex,
|
||||
splitInputOpIndex,
|
||||
gatherCosOpIndex,
|
||||
gatherSinOpIndex,
|
||||
signRangeOpIndex,
|
||||
|
||||
joinInputOpIndex,
|
||||
mulCosOpIndex,
|
||||
mulSinOpIndex,
|
||||
mulSignOpIndex,
|
||||
addOpIndex,
|
||||
|
||||
// The following indices are optional
|
||||
positionIdsRangeOpIndex,
|
||||
positionIdsAddOffsetOpIndex,
|
||||
};
|
||||
|
||||
if (positionIdsIsOffset)
|
||||
{
|
||||
opDescs.push_back(&positionIdsRangeDmlDesc);
|
||||
opDescs.push_back(&positionIdsAddOffsetDmlDesc);
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC positionIdsToAddOffsetEdge = {};
|
||||
positionIdsToAddOffsetEdge.GraphInputIndex = positionIdsIndex;
|
||||
positionIdsToAddOffsetEdge.ToNodeIndex = positionIdsAddOffsetOpIndex;
|
||||
positionIdsToAddOffsetEdge.ToNodeInputIndex = 1;
|
||||
inputEdges.push_back(positionIdsToAddOffsetEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsOffsetToAddOffsetEdge = {};
|
||||
positionIdsOffsetToAddOffsetEdge.FromNodeIndex = positionIdsRangeOpIndex;
|
||||
positionIdsOffsetToAddOffsetEdge.FromNodeOutputIndex = 0;
|
||||
positionIdsOffsetToAddOffsetEdge.ToNodeIndex = positionIdsAddOffsetOpIndex;
|
||||
positionIdsOffsetToAddOffsetEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(positionIdsOffsetToAddOffsetEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsAddOffsetToGatherCosEdge = {};
|
||||
positionIdsAddOffsetToGatherCosEdge.FromNodeIndex = positionIdsAddOffsetOpIndex;
|
||||
positionIdsAddOffsetToGatherCosEdge.FromNodeOutputIndex = 0;
|
||||
positionIdsAddOffsetToGatherCosEdge.ToNodeIndex = gatherCosOpIndex;
|
||||
positionIdsAddOffsetToGatherCosEdge.ToNodeInputIndex = 1;
|
||||
intermediateEdges.push_back(positionIdsAddOffsetToGatherCosEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsAddOffsetToGatherSinEdge = {};
|
||||
positionIdsAddOffsetToGatherSinEdge.FromNodeIndex = positionIdsAddOffsetOpIndex;
|
||||
positionIdsAddOffsetToGatherSinEdge.FromNodeOutputIndex = 0;
|
||||
positionIdsAddOffsetToGatherSinEdge.ToNodeIndex = gatherSinOpIndex;
|
||||
positionIdsAddOffsetToGatherSinEdge.ToNodeInputIndex = 1;
|
||||
intermediateEdges.push_back(positionIdsAddOffsetToGatherSinEdge);
|
||||
}
|
||||
else
|
||||
{
|
||||
DML_INPUT_GRAPH_EDGE_DESC positionIdsToGatherCosEdge = {};
|
||||
positionIdsToGatherCosEdge.GraphInputIndex = positionIdsIndex;
|
||||
positionIdsToGatherCosEdge.ToNodeIndex = gatherCosOpIndex;
|
||||
positionIdsToGatherCosEdge.ToNodeInputIndex = 1;
|
||||
inputEdges.push_back(positionIdsToGatherCosEdge);
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC positionIdsToGatherSinEdge = {};
|
||||
positionIdsToGatherSinEdge.GraphInputIndex = positionIdsIndex;
|
||||
positionIdsToGatherSinEdge.ToNodeIndex = gatherSinOpIndex;
|
||||
positionIdsToGatherSinEdge.ToNodeInputIndex = 1;
|
||||
inputEdges.push_back(positionIdsToGatherSinEdge);
|
||||
}
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC inputToCopyInputEdge = {};
|
||||
inputToCopyInputEdge.GraphInputIndex = inputDataIndex;
|
||||
inputToCopyInputEdge.ToNodeIndex = copyInputOpIndex;
|
||||
inputToCopyInputEdge.ToNodeInputIndex = 0;
|
||||
inputEdges.push_back(inputToCopyInputEdge);
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC cosToGatherEdge = {};
|
||||
cosToGatherEdge.GraphInputIndex = cosCacheIndex;
|
||||
cosToGatherEdge.ToNodeIndex = gatherCosOpIndex;
|
||||
cosToGatherEdge.ToNodeInputIndex = 0;
|
||||
inputEdges.push_back(cosToGatherEdge);
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC sinToGatherEdge = {};
|
||||
sinToGatherEdge.GraphInputIndex = sinCacheIndex;
|
||||
sinToGatherEdge.ToNodeIndex = gatherSinOpIndex;
|
||||
sinToGatherEdge.ToNodeInputIndex = 0;
|
||||
inputEdges.push_back(sinToGatherEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC inputToSplitEdge = {};
|
||||
inputToSplitEdge.FromNodeIndex = copyInputOpIndex;
|
||||
inputToSplitEdge.FromNodeOutputIndex = 0;
|
||||
inputToSplitEdge.ToNodeIndex = splitInputOpIndex;
|
||||
inputToSplitEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(inputToSplitEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC nonRotatedDataToMulEdge = {};
|
||||
nonRotatedDataToMulEdge.FromNodeIndex = copyInputOpIndex;
|
||||
nonRotatedDataToMulEdge.FromNodeOutputIndex = 0;
|
||||
nonRotatedDataToMulEdge.ToNodeIndex = mulCosOpIndex;
|
||||
nonRotatedDataToMulEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(nonRotatedDataToMulEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC secondHalfDataToJoinEdge = {};
|
||||
secondHalfDataToJoinEdge.FromNodeIndex = splitInputOpIndex;
|
||||
secondHalfDataToJoinEdge.FromNodeOutputIndex = 1;
|
||||
secondHalfDataToJoinEdge.ToNodeIndex = joinInputOpIndex;
|
||||
secondHalfDataToJoinEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(secondHalfDataToJoinEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC firstHalfDataToJoinEdge = {};
|
||||
firstHalfDataToJoinEdge.FromNodeIndex = splitInputOpIndex;
|
||||
firstHalfDataToJoinEdge.FromNodeOutputIndex = 0;
|
||||
firstHalfDataToJoinEdge.ToNodeIndex = joinInputOpIndex;
|
||||
firstHalfDataToJoinEdge.ToNodeInputIndex = 1;
|
||||
intermediateEdges.push_back(firstHalfDataToJoinEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC cosToMulEdge = {};
|
||||
cosToMulEdge.FromNodeIndex = gatherCosOpIndex;
|
||||
cosToMulEdge.FromNodeOutputIndex = 0;
|
||||
cosToMulEdge.ToNodeIndex = mulCosOpIndex;
|
||||
cosToMulEdge.ToNodeInputIndex = 1;
|
||||
intermediateEdges.push_back(cosToMulEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC rotatedDataToMulEdge = {};
|
||||
rotatedDataToMulEdge.FromNodeIndex = joinInputOpIndex;
|
||||
rotatedDataToMulEdge.FromNodeOutputIndex = 0;
|
||||
rotatedDataToMulEdge.ToNodeIndex = mulSinOpIndex;
|
||||
rotatedDataToMulEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(rotatedDataToMulEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC sinToMulEdge = {};
|
||||
sinToMulEdge.FromNodeIndex = gatherSinOpIndex;
|
||||
sinToMulEdge.FromNodeOutputIndex = 0;
|
||||
sinToMulEdge.ToNodeIndex = mulSinOpIndex;
|
||||
sinToMulEdge.ToNodeInputIndex = 1;
|
||||
intermediateEdges.push_back(sinToMulEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC rotatedSinToMulEdge = {};
|
||||
rotatedSinToMulEdge.FromNodeIndex = mulSinOpIndex;
|
||||
rotatedSinToMulEdge.FromNodeOutputIndex = 0;
|
||||
rotatedSinToMulEdge.ToNodeIndex = mulSignOpIndex;
|
||||
rotatedSinToMulEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(rotatedSinToMulEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC signToMulEdge = {};
|
||||
signToMulEdge.FromNodeIndex = signRangeOpIndex;
|
||||
signToMulEdge.FromNodeOutputIndex = 0;
|
||||
signToMulEdge.ToNodeIndex = mulSignOpIndex;
|
||||
signToMulEdge.ToNodeInputIndex = 1;
|
||||
intermediateEdges.push_back(signToMulEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC nonRotatedCosToAddEdge = {};
|
||||
nonRotatedCosToAddEdge.FromNodeIndex = mulCosOpIndex;
|
||||
nonRotatedCosToAddEdge.FromNodeOutputIndex = 0;
|
||||
nonRotatedCosToAddEdge.ToNodeIndex = addOpIndex;
|
||||
nonRotatedCosToAddEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(nonRotatedCosToAddEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC rotatedSinToAddEdge = {};
|
||||
rotatedSinToAddEdge.FromNodeIndex = mulSignOpIndex;
|
||||
rotatedSinToAddEdge.FromNodeOutputIndex = 0;
|
||||
rotatedSinToAddEdge.ToNodeIndex = addOpIndex;
|
||||
rotatedSinToAddEdge.ToNodeInputIndex = 1;
|
||||
intermediateEdges.push_back(rotatedSinToAddEdge);
|
||||
|
||||
DML_OUTPUT_GRAPH_EDGE_DESC addToOutputEdge = {};
|
||||
addToOutputEdge.FromNodeIndex = addOpIndex;
|
||||
addToOutputEdge.FromNodeOutputIndex = 0;
|
||||
addToOutputEdge.GraphOutputIndex = 0;
|
||||
outputEdges.push_back(addToOutputEdge);
|
||||
|
||||
MLOperatorGraphDesc operatorGraphDesc = {};
|
||||
operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(inputEdges.size());
|
||||
operatorGraphDesc.inputEdges = inputEdges.data();
|
||||
operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast<uint32_t>(intermediateEdges.size());
|
||||
operatorGraphDesc.intermediateEdges = intermediateEdges.data();
|
||||
operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
|
||||
operatorGraphDesc.outputEdges = outputEdges.data();
|
||||
operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
|
||||
operatorGraphDesc.nodesAsOpDesc = opDescs.data();
|
||||
|
||||
SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo);
|
||||
}
|
||||
};
|
||||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(RotaryEmbedding, DmlOperatorRotaryEmbedding);
|
||||
|
||||
} // namespace Dml
|
||||
|
|
@ -510,6 +510,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(BitwiseAnd);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(BitwiseOr);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(BitwiseXor);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(BitwiseNot);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(RotaryEmbedding);
|
||||
|
||||
DML_OP_EXTERN_QUERY_FUNCTION(MaxPool);
|
||||
DML_OP_EXTERN_QUERY_FUNCTION(Slice);
|
||||
|
|
@ -527,6 +528,7 @@ DML_OP_EXTERN_QUERY_FUNCTION(Attention);
|
|||
constexpr static std::array<const char*, 1> typeNameListDefault = {"T"};
|
||||
constexpr static std::array<const char*, 1> typeNameListDefaultV = {"V"};
|
||||
constexpr static std::array<const char*, 2> typeNameListAttention = {"T", "M"};
|
||||
constexpr static std::array<const char*, 2> typeNameListRotaryEmbedding = {"T", "M"};
|
||||
constexpr static std::array<const char*, 2> typeNameListTwo = { "T1", "T2" };
|
||||
constexpr static std::array<const char*, 2> typeNameListLayerNorm = { "T", "U" };
|
||||
constexpr static std::array<const char*, 2> typeNameListLayerNormContrib = { "T", "V" };
|
||||
|
|
@ -597,6 +599,7 @@ constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListShape
|
|||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListSize = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListQLinearSigmoid = {SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListAttention = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListRotaryEmbedding = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int64};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListGroupNorm = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListNonZero = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit | SupportedTensorDataTypes::Bool};
|
||||
|
||||
|
|
@ -1006,6 +1009,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO_MS( 1, QLinearSigmoid, typeNameListDefault, supportedTypeListQLinearSigmoid, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQLinearSigmoid)},
|
||||
{REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)},
|
||||
{REG_INFO_MS( 1, MultiHeadAttention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, RotaryEmbedding, typeNameListRotaryEmbedding, supportedTypeListRotaryEmbedding, DmlGraphSupport::Supported)},
|
||||
|
||||
{REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
|
||||
|
|
|
|||
|
|
@ -122,6 +122,7 @@ namespace AttrName
|
|||
|
||||
static constexpr const char* GraphFusedActivation = "activation";
|
||||
static constexpr const char* GraphFusedAxis = "activation_axis";
|
||||
static constexpr const char* Interleaved = "interleaved";
|
||||
|
||||
} // namespace AttrName
|
||||
|
||||
|
|
|
|||
|
|
@ -1584,6 +1584,7 @@ using ShapeInferenceHelper_DequantizeLinear = GetOutputShapeAsInputShapeHelper;
|
|||
using ShapeInferenceHelper_QLinearSigmoid = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Attention = AttentionHelper;
|
||||
using ShapeInferenceHelper_MultiHeadAttention = MultiHeadAttentionHelper;
|
||||
using ShapeInferenceHelper_RotaryEmbedding = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Sign = GetBroadcastedOutputShapeHelper;
|
||||
using ShapeInferenceHelper_IsNaN = GetBroadcastedOutputShapeHelper;
|
||||
using ShapeInferenceHelper_Erf = GetBroadcastedOutputShapeHelper;
|
||||
|
|
|
|||
|
|
@ -437,6 +437,7 @@ namespace OperatorHelper
|
|||
static const int sc_sinceVer_BiasAdd = 1;
|
||||
static const int sc_sinceVer_QuickGelu = 1;
|
||||
static const int sc_sinceVer_GroupNorm = 1;
|
||||
static const int sc_sinceVer_RotaryEmbedding = 1;
|
||||
} // namespace MsftOperatorSet1
|
||||
|
||||
} // namespace OperatorHelper
|
||||
|
|
|
|||
|
|
@ -25,7 +25,8 @@ static void RunTest(
|
|||
int64_t interleaved,
|
||||
bool use_float16,
|
||||
bool disable_cpu,
|
||||
bool disable_cuda) {
|
||||
bool disable_cuda,
|
||||
bool disable_dml) {
|
||||
// input : (batch_size, sequence_length, hidden_size)
|
||||
// position ids : (1) or (batch_size, sequence_length)
|
||||
// cos cache : (max_sequence_length, head_size / 2)
|
||||
|
|
@ -50,9 +51,14 @@ static void RunTest(
|
|||
|
||||
int min_cuda_architecture = use_float16 ? 530 : 0;
|
||||
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
|
||||
bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml;
|
||||
|
||||
if (enable_cuda && !disable_cuda) {
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
}
|
||||
if (enable_dml && !disable_dml) {
|
||||
execution_providers.push_back(DefaultDmlExecutionProvider());
|
||||
}
|
||||
if (!use_float16 && !disable_cpu) {
|
||||
execution_providers.push_back(DefaultCpuExecutionProvider());
|
||||
}
|
||||
|
|
@ -107,9 +113,10 @@ static void RunTests(const std::vector<float>& input_data,
|
|||
interleaved,
|
||||
false, /* use_fp16 */
|
||||
false, /* disable_cpu */
|
||||
true /* disable_cuda */);
|
||||
true, /* disable_cuda */
|
||||
true /* disable_dml */);
|
||||
|
||||
// FP32 test for CUDA
|
||||
// FP32 test for CUDA and DML
|
||||
RunTest(input_data,
|
||||
position_ids,
|
||||
cos_cache,
|
||||
|
|
@ -123,9 +130,10 @@ static void RunTests(const std::vector<float>& input_data,
|
|||
interleaved,
|
||||
false, /* use_fp16 */
|
||||
false, /* disable_cpu */
|
||||
false /* disable_cuda */);
|
||||
false, /* disable_cuda */
|
||||
false /* disable_dml */);
|
||||
|
||||
// FP16 test for CUDA
|
||||
// FP16 test for CUDA and DML
|
||||
if (use_float16) {
|
||||
RunTest(input_data,
|
||||
position_ids,
|
||||
|
|
@ -138,9 +146,10 @@ static void RunTests(const std::vector<float>& input_data,
|
|||
num_heads,
|
||||
max_sequence_length,
|
||||
interleaved,
|
||||
true, /* use_fp16 */
|
||||
true, /* disable_cpu */
|
||||
false /* disable_cuda*/);
|
||||
true, /* use_fp16 */
|
||||
true, /* disable_cpu */
|
||||
false, /* disable_cuda*/
|
||||
false /* disable_dml */);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue