[DML EP] Add RotaryEmbedding (#18158)

This is a graph implementation of RotaryEmbedding since there's no time
to add it to DML before 1.16.2, but it eventually should move into
DirectML since we're bandwidth-bound.
This commit is contained in:
Patrice Vignola 2023-11-07 08:26:11 -08:00 committed by GitHub
parent 9868a71373
commit 800ae7742c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 461 additions and 8 deletions

View file

@ -1247,6 +1247,7 @@ Do not modify directly.*
|QLinearSigmoid|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* X_zero_point:**T**<br> *in* Y_scale:**tensor(float)**<br> *in* Y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float), tensor(float16), tensor(int32)<br/> **T2** = tensor(int8), tensor(uint8)|
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
| |
| |

View file

@ -0,0 +1,436 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
// This operator is easier to understand by looking at a python implementation of the non-interleaved version:
//
// def rotate_half(x):
// """Rotates half the hidden dims of the input."""
// half_dim = x.shape[-1] // 2
// x1 = x[..., :half_dim]
// x2 = x[..., half_dim:]
// return np.concatenate((-x2, x1), dim=-1)
//
//
// def apply_rope(x, cos, sin, position_ids):
// cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
// sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
// x_embed = (x * cos) + (rotate_half(x) * sin)
// return x_embed
//
// For the non-interleaved version, we multiply the cos cache by the non-rotated input tensor while we multiply the sin cache
// by the rotated input tensor. Rotating the tensor means slicing it in half on the head dimension and swapping the 2 halves.
//
// The interleaved version is very similar but instead of swapping 2 halves, we swap every pair of adjacent elements and we swap
// the sign of every adjacent element.
namespace Dml
{
class DmlOperatorRotaryEmbedding : public DmlOperator
{
public:
DmlOperatorRotaryEmbedding(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo)
{
enum InputIndex : uint32_t
{
inputDataIndex,
positionIdsIndex,
cosCacheIndex,
sinCacheIndex,
};
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 4);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
// When positionIds is a scalar, it represents the start offset for each sequence
const bool positionIdsIsOffset = kernelInfo.GetInputTensorDimensionCount(positionIdsIndex) == 1;
Initialize(kernelInfo);
ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[inputDataIndex].GetDimensionCount() == 4);
ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[positionIdsIndex].GetDimensionCount() == 4);
ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[cosCacheIndex].GetDimensionCount() == 4);
ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[sinCacheIndex].GetDimensionCount() == 4);
ML_CHECK_VALID_ARGUMENT(m_outputTensorDescs[0].GetDimensionCount() == 4);
ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[cosCacheIndex].GetSizes() == m_inputTensorDescs[sinCacheIndex].GetSizes());
const uint32_t headSize = m_inputTensorDescs[cosCacheIndex].GetSizes().back() * 2;
// The last dimension of the data is the hidden size, so it must be divisible by the head size
ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[inputDataIndex].GetSizes().back() % headSize == 0);
// We resize the data to be of shape [batchSize, sequenceLength, numHeads, headSize]
const auto inputDataSizes = m_inputTensorDescs[inputDataIndex].GetSizes();
const uint32_t batchSize = inputDataSizes[1];
const uint32_t sequenceLength = inputDataSizes[2];
const uint32_t numHeads = inputDataSizes[3] / headSize;
const auto cosCacheSizes = m_inputTensorDescs[cosCacheIndex].GetSizes();
const uint32_t maxSequenceLength = cosCacheSizes[cosCacheSizes.size() - 2];
if (sequenceLength > maxSequenceLength)
{
ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported");
}
const bool interleaved = gsl::narrow_cast<bool>(kernelInfo.GetOptionalAttribute<int64_t>(AttrName::Interleaved, 0));
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType;
// Splitting the hiddenSize into numHeads and headSize dimensions makes it easier for DML to handle
const std::array<uint32_t, 4> inputOutputShape = {batchSize, sequenceLength, numHeads, headSize};
TensorDesc inputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape);
const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc();
// Copy the input to preserve its real input shape in the graph without reshaping it. This will disappear during DML's graph compilation phase.
DML_SCALE_BIAS scaleBias = {1.0f, 0.0f};
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC copyInputDesc{};
copyInputDesc.InputTensor = &inputOutputDmlTensorDesc;
copyInputDesc.OutputTensor = &inputOutputDmlTensorDesc;
copyInputDesc.ScaleBias = &scaleBias;
const DML_OPERATOR_DESC copyInputDmlDesc = {DML_OPERATOR_ELEMENT_WISE_IDENTITY, &copyInputDesc};
// Split the input data into 2 equal parts
const std::vector<uint32_t> inputDataTensorShape = interleaved
? std::vector<uint32_t>({batchSize, sequenceLength, numHeads, headSize / 2, 2})
: std::vector<uint32_t>({batchSize, sequenceLength, numHeads, 2, headSize / 2});
const std::vector<uint32_t> splitInputDataTensorShape = interleaved
? std::vector<uint32_t>({batchSize, sequenceLength, numHeads, headSize / 2, 1})
: std::vector<uint32_t>({batchSize, sequenceLength, numHeads, 1, headSize / 2});
TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape);
const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc();
TensorDesc splitInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputDataTensorShape);
const std::array<DML_TENSOR_DESC, 2> splitInputDataDmlTensorDescs = {splitInputDataTensorDesc.GetDmlDesc(), splitInputDataTensorDesc.GetDmlDesc()};
DML_SPLIT_OPERATOR_DESC splitInputDesc{};
splitInputDesc.InputTensor = &inputDataDmlTensorDesc;
splitInputDesc.OutputTensors = splitInputDataDmlTensorDescs.data();
splitInputDesc.OutputCount = gsl::narrow_cast<uint32_t>(splitInputDataDmlTensorDescs.size());
splitInputDesc.Axis = interleaved
? gsl::narrow_cast<uint32_t>(splitInputDataTensorShape.size()) - 1
: gsl::narrow_cast<uint32_t>(splitInputDataTensorShape.size()) - 2;
const DML_OPERATOR_DESC splitInputDmlDesc = {DML_OPERATOR_SPLIT, &splitInputDesc};
// Swap the 2 halves and join them together
DML_JOIN_OPERATOR_DESC joinInputDesc{};
joinInputDesc.InputTensors = splitInputDataDmlTensorDescs.data();
joinInputDesc.OutputTensor = &inputDataDmlTensorDesc;
joinInputDesc.Axis = splitInputDesc.Axis;
joinInputDesc.InputCount = gsl::narrow_cast<uint32_t>(splitInputDataDmlTensorDescs.size());
const DML_OPERATOR_DESC joinInputDmlDesc = {DML_OPERATOR_JOIN, &joinInputDesc};
// We generate a sequence from 0 to sequenceLength and add the offset to it
const std::array<uint32_t, 4> positionIdsRangeShape = {1, 1, 1, sequenceLength};
auto positionIdsDataType = kernelInfo.GetInputEdgeDescription(positionIdsIndex).tensorDataType;
TensorDesc positionIdsRangeTensorDesc = TensorDesc::ConstructDefaultTensorDesc(positionIdsDataType, positionIdsRangeShape);
const DML_TENSOR_DESC positionIdsRangeDmlTensorDesc = positionIdsRangeTensorDesc.GetDmlDesc();
const std::array<uint32_t, 4> broadcastedPositionIdsRangeShape = {1, 1, batchSize, sequenceLength};
TensorDesc broadcastedPositionIdsRangeTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(positionIdsDataType, broadcastedPositionIdsRangeShape, positionIdsRangeShape);
const DML_TENSOR_DESC broadcastedPositionIdsRangeDmlTensorDesc = broadcastedPositionIdsRangeTensorDesc.GetDmlDesc();
const std::array<uint32_t, 4> broadcastedOffsetShape = {1, 1, batchSize, sequenceLength};
TensorDesc broadcastedOffsetTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(positionIdsDataType, broadcastedOffsetShape, m_inputTensorDescs[positionIdsIndex].GetSizes());
const DML_TENSOR_DESC broadcastedOffsetDmlTensorDesc = broadcastedOffsetTensorDesc.GetDmlDesc();
TensorDesc offsetPositionIdsTensorDesc = TensorDesc::ConstructDefaultTensorDesc(positionIdsDataType, broadcastedOffsetShape);
const DML_TENSOR_DESC offsetPositionIdsRangeDmlTensorDesc = offsetPositionIdsTensorDesc.GetDmlDesc();
DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC positionIdsRange{};
DML_ELEMENT_WISE_ADD_OPERATOR_DESC positionIdsAddOffset{};
if (positionIdsIsOffset)
{
ML_CHECK_VALID_ARGUMENT(positionIdsDataType == MLOperatorTensorDataType::Int64);
positionIdsRange.ValueDataType = DML_TENSOR_DATA_TYPE_INT64;
positionIdsRange.ValueDelta.Int64 = 1;
positionIdsRange.OutputTensor = &positionIdsRangeDmlTensorDesc;
positionIdsAddOffset.ATensor = &broadcastedPositionIdsRangeDmlTensorDesc;
positionIdsAddOffset.BTensor = &broadcastedOffsetDmlTensorDesc;
positionIdsAddOffset.OutputTensor = &offsetPositionIdsRangeDmlTensorDesc;
}
const DML_OPERATOR_DESC positionIdsRangeDmlDesc = {DML_OPERATOR_FILL_VALUE_SEQUENCE, &positionIdsRange};
const DML_OPERATOR_DESC positionIdsAddOffsetDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &positionIdsAddOffset};
// Gather the cos/sin values based on the position ids
const std::array<uint32_t, 4> gatheredCosSinShape = {1, batchSize, sequenceLength, headSize / 2};
TensorDesc gatheredCosSinTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, gatheredCosSinShape);
const DML_TENSOR_DESC gatheredCosSinDmlTensorDesc = gatheredCosSinTensorDesc.GetDmlDesc();
DML_GATHER_OPERATOR_DESC gatherCosSinDesc{};
gatherCosSinDesc.InputTensor = &inputDescs[cosCacheIndex];
gatherCosSinDesc.IndicesTensor = positionIdsIsOffset ? &offsetPositionIdsRangeDmlTensorDesc : &inputDescs[positionIdsIndex];
gatherCosSinDesc.OutputTensor = &gatheredCosSinDmlTensorDesc;
gatherCosSinDesc.Axis = 2;
gatherCosSinDesc.IndexDimensions = 2;
const DML_OPERATOR_DESC gatherCosSinDmlDesc {DML_OPERATOR_GATHER, &gatherCosSinDesc};
// After gathering cos/sin, reshape and broadcast them to match the number of heads of the input data
const std::vector<uint32_t> reshapedCosSinShape = interleaved
? std::vector<uint32_t>({batchSize, sequenceLength, 1, headSize / 2, 1})
: std::vector<uint32_t>({batchSize, sequenceLength, 1, 1, headSize / 2});
TensorDesc broadcastedCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, inputDataTensorShape, reshapedCosSinShape);
const DML_TENSOR_DESC broadcastedCosSinDmlTensorDesc = broadcastedCosSinTensorDesc.GetDmlDesc();
// Create a vector that contains the sign values {-1, 1}
const std::array<uint32_t, 1> signTensorShape = {2};
TensorDesc signTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, signTensorShape);
const DML_TENSOR_DESC signDmlTensorDesc = signTensorDesc.GetDmlDesc();
DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC signRange{};
signRange.OutputTensor = &signDmlTensorDesc;
if (dataType == MLOperatorTensorDataType::Float16)
{
const auto valueStart = static_cast<MLFloat16>(-1.0f);
const auto valueDelta = static_cast<MLFloat16>(2.0f);
memcpy(signRange.ValueStart.Bytes, reinterpret_cast<const BYTE*>(&valueStart), sizeof(valueStart));
memcpy(signRange.ValueDelta.Bytes, reinterpret_cast<const BYTE*>(&valueDelta), sizeof(valueDelta));
signRange.ValueDataType = DML_TENSOR_DATA_TYPE_FLOAT16;
}
else
{
ML_CHECK_VALID_ARGUMENT(dataType == MLOperatorTensorDataType::Float);
signRange.ValueStart.Float32 = -1.0f;
signRange.ValueDelta.Float32 = 2.0f;
signRange.ValueDataType = DML_TENSOR_DATA_TYPE_FLOAT32;
}
const DML_OPERATOR_DESC signRangeDmlDesc = {DML_OPERATOR_FILL_VALUE_SEQUENCE, &signRange};
// Multiply the broadcasted sign values with the rotated input
const std::vector<uint32_t> reshapedSignShape = interleaved
? std::vector<uint32_t>({1, 1, 1, 1, 2})
: std::vector<uint32_t>({1, 1, 1, 2, 1});
TensorDesc broadcastedSignCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, inputDataTensorShape, reshapedSignShape);
const DML_TENSOR_DESC broadcastedSignDmlTensorDesc = broadcastedSignCosSinTensorDesc.GetDmlDesc();
DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulSignDesc{};
mulSignDesc.ATensor = &inputDataDmlTensorDesc;
mulSignDesc.BTensor = &broadcastedSignDmlTensorDesc;
mulSignDesc.OutputTensor = &inputDataDmlTensorDesc;
const DML_OPERATOR_DESC mulSignDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulSignDesc};
// Multiply the non-rotated data with the cos and the rotated data with the sin
DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulCosSinDesc{};
mulCosSinDesc.ATensor = &inputDataDmlTensorDesc;
mulCosSinDesc.BTensor = &broadcastedCosSinDmlTensorDesc;
mulCosSinDesc.OutputTensor = &inputDataDmlTensorDesc;
const DML_OPERATOR_DESC mulCosSinDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulCosSinDesc};
// Add the multiplied cos and sin values together
DML_ELEMENT_WISE_ADD_OPERATOR_DESC addDesc{};
addDesc.ATensor = &inputOutputDmlTensorDesc;
addDesc.BTensor = &inputOutputDmlTensorDesc;
addDesc.OutputTensor = &inputOutputDmlTensorDesc;
const DML_OPERATOR_DESC addDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &addDesc};
// Construct the graph
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
std::vector<const DML_OPERATOR_DESC*> opDescs = {
&copyInputDmlDesc, // Copy the input data to preseve the real input shape
&splitInputDmlDesc, // Split the input data
&gatherCosSinDmlDesc, // Gather cos
&gatherCosSinDmlDesc, // Gather sin
&signRangeDmlDesc, // Generate the signs
&joinInputDmlDesc, // Join the split data
&mulCosSinDmlDesc, // Multiply cos with the non-rotated data
&mulCosSinDmlDesc, // Multiply sin with the rotated data
&mulSignDmlDesc, // Multiply the sign with the rotated data
&addDmlDesc, // Add the rotated cos and non-rotated sin parts together
};
enum NodeIndex : uint32_t
{
copyInputOpIndex,
splitInputOpIndex,
gatherCosOpIndex,
gatherSinOpIndex,
signRangeOpIndex,
joinInputOpIndex,
mulCosOpIndex,
mulSinOpIndex,
mulSignOpIndex,
addOpIndex,
// The following indices are optional
positionIdsRangeOpIndex,
positionIdsAddOffsetOpIndex,
};
if (positionIdsIsOffset)
{
opDescs.push_back(&positionIdsRangeDmlDesc);
opDescs.push_back(&positionIdsAddOffsetDmlDesc);
DML_INPUT_GRAPH_EDGE_DESC positionIdsToAddOffsetEdge = {};
positionIdsToAddOffsetEdge.GraphInputIndex = positionIdsIndex;
positionIdsToAddOffsetEdge.ToNodeIndex = positionIdsAddOffsetOpIndex;
positionIdsToAddOffsetEdge.ToNodeInputIndex = 1;
inputEdges.push_back(positionIdsToAddOffsetEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsOffsetToAddOffsetEdge = {};
positionIdsOffsetToAddOffsetEdge.FromNodeIndex = positionIdsRangeOpIndex;
positionIdsOffsetToAddOffsetEdge.FromNodeOutputIndex = 0;
positionIdsOffsetToAddOffsetEdge.ToNodeIndex = positionIdsAddOffsetOpIndex;
positionIdsOffsetToAddOffsetEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(positionIdsOffsetToAddOffsetEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsAddOffsetToGatherCosEdge = {};
positionIdsAddOffsetToGatherCosEdge.FromNodeIndex = positionIdsAddOffsetOpIndex;
positionIdsAddOffsetToGatherCosEdge.FromNodeOutputIndex = 0;
positionIdsAddOffsetToGatherCosEdge.ToNodeIndex = gatherCosOpIndex;
positionIdsAddOffsetToGatherCosEdge.ToNodeInputIndex = 1;
intermediateEdges.push_back(positionIdsAddOffsetToGatherCosEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsAddOffsetToGatherSinEdge = {};
positionIdsAddOffsetToGatherSinEdge.FromNodeIndex = positionIdsAddOffsetOpIndex;
positionIdsAddOffsetToGatherSinEdge.FromNodeOutputIndex = 0;
positionIdsAddOffsetToGatherSinEdge.ToNodeIndex = gatherSinOpIndex;
positionIdsAddOffsetToGatherSinEdge.ToNodeInputIndex = 1;
intermediateEdges.push_back(positionIdsAddOffsetToGatherSinEdge);
}
else
{
DML_INPUT_GRAPH_EDGE_DESC positionIdsToGatherCosEdge = {};
positionIdsToGatherCosEdge.GraphInputIndex = positionIdsIndex;
positionIdsToGatherCosEdge.ToNodeIndex = gatherCosOpIndex;
positionIdsToGatherCosEdge.ToNodeInputIndex = 1;
inputEdges.push_back(positionIdsToGatherCosEdge);
DML_INPUT_GRAPH_EDGE_DESC positionIdsToGatherSinEdge = {};
positionIdsToGatherSinEdge.GraphInputIndex = positionIdsIndex;
positionIdsToGatherSinEdge.ToNodeIndex = gatherSinOpIndex;
positionIdsToGatherSinEdge.ToNodeInputIndex = 1;
inputEdges.push_back(positionIdsToGatherSinEdge);
}
DML_INPUT_GRAPH_EDGE_DESC inputToCopyInputEdge = {};
inputToCopyInputEdge.GraphInputIndex = inputDataIndex;
inputToCopyInputEdge.ToNodeIndex = copyInputOpIndex;
inputToCopyInputEdge.ToNodeInputIndex = 0;
inputEdges.push_back(inputToCopyInputEdge);
DML_INPUT_GRAPH_EDGE_DESC cosToGatherEdge = {};
cosToGatherEdge.GraphInputIndex = cosCacheIndex;
cosToGatherEdge.ToNodeIndex = gatherCosOpIndex;
cosToGatherEdge.ToNodeInputIndex = 0;
inputEdges.push_back(cosToGatherEdge);
DML_INPUT_GRAPH_EDGE_DESC sinToGatherEdge = {};
sinToGatherEdge.GraphInputIndex = sinCacheIndex;
sinToGatherEdge.ToNodeIndex = gatherSinOpIndex;
sinToGatherEdge.ToNodeInputIndex = 0;
inputEdges.push_back(sinToGatherEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC inputToSplitEdge = {};
inputToSplitEdge.FromNodeIndex = copyInputOpIndex;
inputToSplitEdge.FromNodeOutputIndex = 0;
inputToSplitEdge.ToNodeIndex = splitInputOpIndex;
inputToSplitEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(inputToSplitEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC nonRotatedDataToMulEdge = {};
nonRotatedDataToMulEdge.FromNodeIndex = copyInputOpIndex;
nonRotatedDataToMulEdge.FromNodeOutputIndex = 0;
nonRotatedDataToMulEdge.ToNodeIndex = mulCosOpIndex;
nonRotatedDataToMulEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(nonRotatedDataToMulEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC secondHalfDataToJoinEdge = {};
secondHalfDataToJoinEdge.FromNodeIndex = splitInputOpIndex;
secondHalfDataToJoinEdge.FromNodeOutputIndex = 1;
secondHalfDataToJoinEdge.ToNodeIndex = joinInputOpIndex;
secondHalfDataToJoinEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(secondHalfDataToJoinEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC firstHalfDataToJoinEdge = {};
firstHalfDataToJoinEdge.FromNodeIndex = splitInputOpIndex;
firstHalfDataToJoinEdge.FromNodeOutputIndex = 0;
firstHalfDataToJoinEdge.ToNodeIndex = joinInputOpIndex;
firstHalfDataToJoinEdge.ToNodeInputIndex = 1;
intermediateEdges.push_back(firstHalfDataToJoinEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC cosToMulEdge = {};
cosToMulEdge.FromNodeIndex = gatherCosOpIndex;
cosToMulEdge.FromNodeOutputIndex = 0;
cosToMulEdge.ToNodeIndex = mulCosOpIndex;
cosToMulEdge.ToNodeInputIndex = 1;
intermediateEdges.push_back(cosToMulEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC rotatedDataToMulEdge = {};
rotatedDataToMulEdge.FromNodeIndex = joinInputOpIndex;
rotatedDataToMulEdge.FromNodeOutputIndex = 0;
rotatedDataToMulEdge.ToNodeIndex = mulSinOpIndex;
rotatedDataToMulEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(rotatedDataToMulEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC sinToMulEdge = {};
sinToMulEdge.FromNodeIndex = gatherSinOpIndex;
sinToMulEdge.FromNodeOutputIndex = 0;
sinToMulEdge.ToNodeIndex = mulSinOpIndex;
sinToMulEdge.ToNodeInputIndex = 1;
intermediateEdges.push_back(sinToMulEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC rotatedSinToMulEdge = {};
rotatedSinToMulEdge.FromNodeIndex = mulSinOpIndex;
rotatedSinToMulEdge.FromNodeOutputIndex = 0;
rotatedSinToMulEdge.ToNodeIndex = mulSignOpIndex;
rotatedSinToMulEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(rotatedSinToMulEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC signToMulEdge = {};
signToMulEdge.FromNodeIndex = signRangeOpIndex;
signToMulEdge.FromNodeOutputIndex = 0;
signToMulEdge.ToNodeIndex = mulSignOpIndex;
signToMulEdge.ToNodeInputIndex = 1;
intermediateEdges.push_back(signToMulEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC nonRotatedCosToAddEdge = {};
nonRotatedCosToAddEdge.FromNodeIndex = mulCosOpIndex;
nonRotatedCosToAddEdge.FromNodeOutputIndex = 0;
nonRotatedCosToAddEdge.ToNodeIndex = addOpIndex;
nonRotatedCosToAddEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(nonRotatedCosToAddEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC rotatedSinToAddEdge = {};
rotatedSinToAddEdge.FromNodeIndex = mulSignOpIndex;
rotatedSinToAddEdge.FromNodeOutputIndex = 0;
rotatedSinToAddEdge.ToNodeIndex = addOpIndex;
rotatedSinToAddEdge.ToNodeInputIndex = 1;
intermediateEdges.push_back(rotatedSinToAddEdge);
DML_OUTPUT_GRAPH_EDGE_DESC addToOutputEdge = {};
addToOutputEdge.FromNodeIndex = addOpIndex;
addToOutputEdge.FromNodeOutputIndex = 0;
addToOutputEdge.GraphOutputIndex = 0;
outputEdges.push_back(addToOutputEdge);
MLOperatorGraphDesc operatorGraphDesc = {};
operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(inputEdges.size());
operatorGraphDesc.inputEdges = inputEdges.data();
operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast<uint32_t>(intermediateEdges.size());
operatorGraphDesc.intermediateEdges = intermediateEdges.data();
operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
operatorGraphDesc.outputEdges = outputEdges.data();
operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
operatorGraphDesc.nodesAsOpDesc = opDescs.data();
SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(RotaryEmbedding, DmlOperatorRotaryEmbedding);
} // namespace Dml

View file

@ -510,6 +510,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(BitwiseAnd);
DML_OP_EXTERN_CREATION_FUNCTION(BitwiseOr);
DML_OP_EXTERN_CREATION_FUNCTION(BitwiseXor);
DML_OP_EXTERN_CREATION_FUNCTION(BitwiseNot);
DML_OP_EXTERN_CREATION_FUNCTION(RotaryEmbedding);
DML_OP_EXTERN_QUERY_FUNCTION(MaxPool);
DML_OP_EXTERN_QUERY_FUNCTION(Slice);
@ -527,6 +528,7 @@ DML_OP_EXTERN_QUERY_FUNCTION(Attention);
constexpr static std::array<const char*, 1> typeNameListDefault = {"T"};
constexpr static std::array<const char*, 1> typeNameListDefaultV = {"V"};
constexpr static std::array<const char*, 2> typeNameListAttention = {"T", "M"};
constexpr static std::array<const char*, 2> typeNameListRotaryEmbedding = {"T", "M"};
constexpr static std::array<const char*, 2> typeNameListTwo = { "T1", "T2" };
constexpr static std::array<const char*, 2> typeNameListLayerNorm = { "T", "U" };
constexpr static std::array<const char*, 2> typeNameListLayerNormContrib = { "T", "V" };
@ -597,6 +599,7 @@ constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListShape
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListSize = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListQLinearSigmoid = {SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListAttention = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListRotaryEmbedding = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListGroupNorm = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListNonZero = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit | SupportedTensorDataTypes::Bool};
@ -1006,6 +1009,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_MS( 1, QLinearSigmoid, typeNameListDefault, supportedTypeListQLinearSigmoid, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQLinearSigmoid)},
{REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)},
{REG_INFO_MS( 1, MultiHeadAttention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, RotaryEmbedding, typeNameListRotaryEmbedding, supportedTypeListRotaryEmbedding, DmlGraphSupport::Supported)},
{REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)},
{REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},

View file

@ -122,6 +122,7 @@ namespace AttrName
static constexpr const char* GraphFusedActivation = "activation";
static constexpr const char* GraphFusedAxis = "activation_axis";
static constexpr const char* Interleaved = "interleaved";
} // namespace AttrName

View file

@ -1584,6 +1584,7 @@ using ShapeInferenceHelper_DequantizeLinear = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QLinearSigmoid = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Attention = AttentionHelper;
using ShapeInferenceHelper_MultiHeadAttention = MultiHeadAttentionHelper;
using ShapeInferenceHelper_RotaryEmbedding = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Sign = GetBroadcastedOutputShapeHelper;
using ShapeInferenceHelper_IsNaN = GetBroadcastedOutputShapeHelper;
using ShapeInferenceHelper_Erf = GetBroadcastedOutputShapeHelper;

View file

@ -437,6 +437,7 @@ namespace OperatorHelper
static const int sc_sinceVer_BiasAdd = 1;
static const int sc_sinceVer_QuickGelu = 1;
static const int sc_sinceVer_GroupNorm = 1;
static const int sc_sinceVer_RotaryEmbedding = 1;
} // namespace MsftOperatorSet1
} // namespace OperatorHelper

View file

@ -25,7 +25,8 @@ static void RunTest(
int64_t interleaved,
bool use_float16,
bool disable_cpu,
bool disable_cuda) {
bool disable_cuda,
bool disable_dml) {
// input : (batch_size, sequence_length, hidden_size)
// position ids : (1) or (batch_size, sequence_length)
// cos cache : (max_sequence_length, head_size / 2)
@ -50,9 +51,14 @@ static void RunTest(
int min_cuda_architecture = use_float16 ? 530 : 0;
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml;
if (enable_cuda && !disable_cuda) {
execution_providers.push_back(DefaultCudaExecutionProvider());
}
if (enable_dml && !disable_dml) {
execution_providers.push_back(DefaultDmlExecutionProvider());
}
if (!use_float16 && !disable_cpu) {
execution_providers.push_back(DefaultCpuExecutionProvider());
}
@ -107,9 +113,10 @@ static void RunTests(const std::vector<float>& input_data,
interleaved,
false, /* use_fp16 */
false, /* disable_cpu */
true /* disable_cuda */);
true, /* disable_cuda */
true /* disable_dml */);
// FP32 test for CUDA
// FP32 test for CUDA and DML
RunTest(input_data,
position_ids,
cos_cache,
@ -123,9 +130,10 @@ static void RunTests(const std::vector<float>& input_data,
interleaved,
false, /* use_fp16 */
false, /* disable_cpu */
false /* disable_cuda */);
false, /* disable_cuda */
false /* disable_dml */);
// FP16 test for CUDA
// FP16 test for CUDA and DML
if (use_float16) {
RunTest(input_data,
position_ids,
@ -138,9 +146,10 @@ static void RunTests(const std::vector<float>& input_data,
num_heads,
max_sequence_length,
interleaved,
true, /* use_fp16 */
true, /* disable_cpu */
false /* disable_cuda*/);
true, /* use_fp16 */
true, /* disable_cpu */
false, /* disable_cuda*/
false /* disable_dml */);
}
}