[DML EP] Support partial rotary embedding (#22417)

### Description
This adds support for partial RotaryEmbedding to DML. Essentially,
partial RotaryEmbedding simply consists of doing the rotary embedding
calculation on a subregion of the input tensor of as if its head size
was `rotary_embedding_dim`, while leaving the second part of the tensor
(i.e. `head_size - rotary_embedding_dim`) alone.

To achieve this, all we need to do is follow the following steps:

1. Split the tensor into 2 parts
2. Run the rotary embedding algorithm on the first part, just like we
were doing before on the entire tensor
3. Join the 2 parts back together

Since we're leaving the middle part intact, the RotaryEmbedding fusion
will still be done within DML. Also, the concat at the end is
essentially free because DML optimizes it out and directly allocate the
result of RotaryEmbedding at the right place. The only overhead here is
the splitting of the tensor at the beginning, which we should eventually
make part of the RotaryEmbedding fusion within DML.



### Motivation and Context
This fix allows us to correctly run models that have a
`partial_rotary_factor` setting in huggingface, including Nvidia's
Nemotron: https://huggingface.co/nvidia/Nemotron-Mini-4B-Instruct
This commit is contained in:
Patrice Vignola 2024-10-16 13:28:44 -07:00 committed by GitHub
parent a164228c10
commit f610605a48
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 232 additions and 78 deletions

View file

@ -25,6 +25,41 @@
// The interleaved version is very similar but instead of swapping 2 halves, we swap every pair of adjacent elements and we swap
// the sign of every adjacent element.
// Here's a representation of what the graph looks like in DML, before getting fused together:
/*
Input CosCache PositionIds SinCache
| | | |
| | +--------+-----------+ |
Split | | | |
| | Gather Gather
+-------+ | | |
| | | |
| Identity----------+ | |
| | | | |
| | | | |
| --Split-- | | |
| \ / | +-----------------+ |
| \ / | | |
| \ / Mul |
| \ / | |
| X | |
| / \ | |
| / \ | |
| Join | |
| | | |
| | +---------------------------------------------------------+
| | | |
| Mul |
| | |
| +-----+ +------+
| | |
| Add
| |
+-------------+ |
| |
Join
*/
namespace Dml
{
class DmlOperatorRotaryEmbedding : public DmlOperator
@ -56,25 +91,45 @@ public:
ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[positionIdsIndex].GetDimensionCount() == 4);
ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[cosCacheIndex].GetDimensionCount() == 4);
ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[sinCacheIndex].GetDimensionCount() == 4);
ML_CHECK_VALID_ARGUMENT(m_outputTensorDescs[0].GetDimensionCount() == 4);
ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[cosCacheIndex].GetSizes() == m_inputTensorDescs[sinCacheIndex].GetSizes());
const uint32_t headSize = m_inputTensorDescs[cosCacheIndex].GetSizes().back() * 2;
// The last dimension of the data is the hidden size, so it must be divisible by the head size
ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[inputDataIndex].GetSizes().back() % headSize == 0);
uint32_t numHeads = gsl::narrow_cast<uint32_t>(kernelInfo.GetOptionalAttribute<int64_t>(AttrName::NumHeads, 0));
uint32_t rotaryEmbeddingDim = gsl::narrow_cast<uint32_t>(kernelInfo.GetOptionalAttribute<int64_t>(AttrName::RotaryEmbeddingDim, 0));
// We resize the data to be of shape [batchSize, sequenceLength, numHeads, headSize]
const auto inputDataSizes = m_inputTensorDescs[inputDataIndex].GetSizes();
const uint32_t hiddenSize = inputIs4D ? inputDataSizes[1] * inputDataSizes[3] : inputDataSizes.back();
const uint32_t headSize = numHeads == 0
? m_inputTensorDescs[cosCacheIndex].GetSizes().back() * 2
: hiddenSize / numHeads;
if (rotaryEmbeddingDim > 0)
{
ORT_ENFORCE(numHeads > 0, "num_heads must be provided if rotary_embedding_dim is specified");
}
else
{
rotaryEmbeddingDim = headSize;
}
if (numHeads == 0)
{
numHeads = hiddenSize / headSize;
}
else if (inputIs4D)
{
ORT_ENFORCE(numHeads == inputDataSizes[1], "When the input has 4 dimensions, num_heads must be 0 or have the same value as the second dimension of the input");
}
const uint32_t batchSize = inputIs4D ? inputDataSizes[0] : inputDataSizes[1];
const uint32_t sequenceLength = inputDataSizes[2];
const uint32_t numHeads = inputIs4D ? inputDataSizes[1] : inputDataSizes[3] / headSize;
const auto cosCacheSizes = m_inputTensorDescs[cosCacheIndex].GetSizes();
const uint32_t maxSequenceLength = cosCacheSizes[cosCacheSizes.size() - 2];
if (sequenceLength > maxSequenceLength)
const bool isPackedBatching = gsl::narrow_cast<uint32_t>(kernelInfo.GetOptionalAttribute<int64_t>(AttrName::IsPackedBatching, 0)) == 1;
if (!isPackedBatching && sequenceLength > maxSequenceLength)
{
ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported");
}
@ -84,64 +139,103 @@ public:
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType;
const std::array<uint32_t, 4> inputOutputShape = {batchSize, sequenceLength, numHeads, headSize};
// We resize the data to be of shape [batchSize, sequenceLength, numHeads, headSize]
const std::array<uint32_t, 4> inputOutputShape = inputIs4D
? std::array<uint32_t, 4>({batchSize, numHeads, sequenceLength, headSize})
: std::array<uint32_t, 4>({batchSize, sequenceLength, numHeads, headSize});
const std::array<uint32_t, 4> splitInputOutputShape1 = inputIs4D
? std::array<uint32_t, 4>({batchSize, numHeads, sequenceLength, rotaryEmbeddingDim})
: std::array<uint32_t, 4>({batchSize, sequenceLength, numHeads, rotaryEmbeddingDim});
const std::array<uint32_t, 4> splitInputOutputShape2 = inputIs4D
? std::array<uint32_t, 4>({batchSize, numHeads, sequenceLength, headSize - rotaryEmbeddingDim})
: std::array<uint32_t, 4>({batchSize, sequenceLength, numHeads, headSize - rotaryEmbeddingDim});
TensorDesc inputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape);
TensorDesc stridedInputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape);
TensorDesc splitInputOutputTensorDesc1 = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputOutputShape1);
TensorDesc splitInputOutputTensorDesc2 = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputOutputShape2);
// Split the input to perform the rotary embedding only on a subregion of the tensor if needed. The split inputs
// will be joined back together at the end.
const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc();
std::array<DML_TENSOR_DESC, 2> splitTensorDescs = {
splitInputOutputTensorDesc1.GetDmlDesc(),
splitInputOutputTensorDesc2.GetDmlDesc(),
};
DML_SPLIT_OPERATOR_DESC splitInputOperatorDesc{};
DML_OPERATOR_DESC splitInputDmlOperatorDesc{};
if (headSize != rotaryEmbeddingDim)
{
splitInputOperatorDesc.InputTensor = &inputOutputDmlTensorDesc;
splitInputOperatorDesc.OutputCount = gsl::narrow_cast<uint32_t>(splitTensorDescs.size());
splitInputOperatorDesc.OutputTensors = splitTensorDescs.data();
splitInputOperatorDesc.Axis = gsl::narrow_cast<uint32_t>(inputOutputShape.size()) - 1;
splitInputDmlOperatorDesc.Type = DML_OPERATOR_SPLIT;
splitInputDmlOperatorDesc.Desc = &splitInputOperatorDesc;
}
// Copy the partial input to preserve its real input shape in the graph without reshaping it. This will disappear during DML's graph compilation phase.
DML_SCALE_BIAS scaleBias = {1.0f, 0.0f};
const std::array<uint32_t, 4> partialInputOutputShape = {batchSize, sequenceLength, numHeads, rotaryEmbeddingDim};
TensorDesc partialStridedInputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, partialInputOutputShape);
TensorDesc partialInputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, partialInputOutputShape);
if (inputIs4D)
{
const std::array<uint32_t, 4> inputOutputStrides = {headSize * numHeads * sequenceLength, headSize, sequenceLength * headSize, 1};
stridedInputOutputTensorDesc.SetStrides(inputOutputStrides);
const std::array<uint32_t, 4> partialInputOutputStrides = {rotaryEmbeddingDim * numHeads * sequenceLength, rotaryEmbeddingDim, sequenceLength * rotaryEmbeddingDim, 1};
partialStridedInputOutputTensorDesc.SetStrides(partialInputOutputStrides);
}
const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc();
const DML_TENSOR_DESC stridedInputOutputDmlTensorDesc = stridedInputOutputTensorDesc.GetDmlDesc();
// Copy the input to preserve its real input shape in the graph without reshaping it. This will disappear during DML's graph compilation phase.
DML_SCALE_BIAS scaleBias = {1.0f, 0.0f};
const DML_TENSOR_DESC partialStridedInputOutputDmlTensorDesc = partialStridedInputOutputTensorDesc.GetDmlDesc();
const DML_TENSOR_DESC partialInputOutputDmlTensorDesc = partialInputOutputTensorDesc.GetDmlDesc();
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC copyInputDesc{};
copyInputDesc.InputTensor = &stridedInputOutputDmlTensorDesc;
copyInputDesc.OutputTensor = &inputOutputDmlTensorDesc;
copyInputDesc.InputTensor = &partialStridedInputOutputDmlTensorDesc;
copyInputDesc.OutputTensor = &partialInputOutputDmlTensorDesc;
copyInputDesc.ScaleBias = &scaleBias;
const DML_OPERATOR_DESC copyInputDmlDesc = {DML_OPERATOR_ELEMENT_WISE_IDENTITY, &copyInputDesc};
const uint32_t halfRoraryEmbeddingDim = rotaryEmbeddingDim / 2;
// Split the input data into 2 equal parts
const std::vector<uint32_t> inputDataTensorShape = interleaved
? std::vector<uint32_t>({batchSize, sequenceLength, numHeads, headSize / 2, 2})
: std::vector<uint32_t>({batchSize, sequenceLength, numHeads, 2, headSize / 2});
const std::vector<uint32_t> partialInputDataTensorShape = interleaved
? std::vector<uint32_t>({batchSize, sequenceLength, numHeads, rotaryEmbeddingDim / 2, 2})
: std::vector<uint32_t>({batchSize, sequenceLength, numHeads, 2, rotaryEmbeddingDim / 2});
const std::vector<uint32_t> splitInputDataTensorShape = interleaved
? std::vector<uint32_t>({batchSize, sequenceLength, numHeads, headSize / 2, 1})
: std::vector<uint32_t>({batchSize, sequenceLength, numHeads, 1, headSize / 2});
? std::vector<uint32_t>({batchSize, sequenceLength, numHeads, rotaryEmbeddingDim / 2, 1})
: std::vector<uint32_t>({batchSize, sequenceLength, numHeads, 1, rotaryEmbeddingDim / 2});
TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape);
TensorDesc partialInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, partialInputDataTensorShape);
const DML_TENSOR_DESC partialInputDataDmlTensorDesc = partialInputDataTensorDesc.GetDmlDesc();
const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc();
TensorDesc joinedDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape);
TensorDesc joinedDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, partialInputDataTensorShape);
const DML_TENSOR_DESC joinedDataDmlTensorDesc = joinedDataTensorDesc.GetDmlDesc();
TensorDesc splitInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputDataTensorShape);
const std::array<DML_TENSOR_DESC, 2> splitInputDataDmlTensorDescs = {splitInputDataTensorDesc.GetDmlDesc(), splitInputDataTensorDesc.GetDmlDesc()};
DML_SPLIT_OPERATOR_DESC splitInputDesc{};
splitInputDesc.InputTensor = &inputDataDmlTensorDesc;
splitInputDesc.OutputTensors = splitInputDataDmlTensorDescs.data();
splitInputDesc.OutputCount = gsl::narrow_cast<uint32_t>(splitInputDataDmlTensorDescs.size());
splitInputDesc.Axis = interleaved
DML_SPLIT_OPERATOR_DESC splitPartialInputDesc{};
splitPartialInputDesc.InputTensor = &partialInputDataDmlTensorDesc;
splitPartialInputDesc.OutputTensors = splitInputDataDmlTensorDescs.data();
splitPartialInputDesc.OutputCount = gsl::narrow_cast<uint32_t>(splitInputDataDmlTensorDescs.size());
splitPartialInputDesc.Axis = interleaved
? gsl::narrow_cast<uint32_t>(splitInputDataTensorShape.size()) - 1
: gsl::narrow_cast<uint32_t>(splitInputDataTensorShape.size()) - 2;
const DML_OPERATOR_DESC splitInputDmlDesc = {DML_OPERATOR_SPLIT, &splitInputDesc};
const DML_OPERATOR_DESC splitPartialInputDmlDesc = {DML_OPERATOR_SPLIT, &splitPartialInputDesc};
// Swap the 2 halves and join them together
DML_JOIN_OPERATOR_DESC joinInputDesc{};
joinInputDesc.InputTensors = splitInputDataDmlTensorDescs.data();
joinInputDesc.OutputTensor = &joinedDataDmlTensorDesc;
joinInputDesc.Axis = splitInputDesc.Axis;
joinInputDesc.InputCount = gsl::narrow_cast<uint32_t>(splitInputDataDmlTensorDescs.size());
const DML_OPERATOR_DESC joinInputDmlDesc = {DML_OPERATOR_JOIN, &joinInputDesc};
DML_JOIN_OPERATOR_DESC joinPartialInputDesc{};
joinPartialInputDesc.InputTensors = splitInputDataDmlTensorDescs.data();
joinPartialInputDesc.OutputTensor = &joinedDataDmlTensorDesc;
joinPartialInputDesc.Axis = splitPartialInputDesc.Axis;
joinPartialInputDesc.InputCount = gsl::narrow_cast<uint32_t>(splitInputDataDmlTensorDescs.size());
const DML_OPERATOR_DESC joinPartialInputDmlDesc = {DML_OPERATOR_JOIN, &joinPartialInputDesc};
// We generate a sequence from 0 to sequenceLength and add the offset to it
const std::array<uint32_t, 4> positionIdsRangeShape = {1, 1, 1, sequenceLength};
@ -177,7 +271,7 @@ public:
const DML_OPERATOR_DESC positionIdsAddOffsetDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &positionIdsAddOffset};
// Gather the cos/sin values based on the position ids
const std::array<uint32_t, 4> gatheredCosSinShape = {1, batchSize, sequenceLength, headSize / 2};
const std::array<uint32_t, 4> gatheredCosSinShape = {1, batchSize, sequenceLength, rotaryEmbeddingDim / 2};
TensorDesc gatheredCosSinTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, gatheredCosSinShape);
const DML_TENSOR_DESC gatheredCosSinDmlTensorDesc = gatheredCosSinTensorDesc.GetDmlDesc();
@ -191,9 +285,9 @@ public:
// After gathering cos/sin, reshape and broadcast them to match the number of heads of the input data
const std::vector<uint32_t> reshapedCosSinShape = interleaved
? std::vector<uint32_t>({batchSize, sequenceLength, 1, headSize / 2, 1})
: std::vector<uint32_t>({batchSize, sequenceLength, 1, 1, headSize / 2});
TensorDesc broadcastedCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, inputDataTensorShape, reshapedCosSinShape);
? std::vector<uint32_t>({batchSize, sequenceLength, 1, rotaryEmbeddingDim / 2, 1})
: std::vector<uint32_t>({batchSize, sequenceLength, 1, 1, rotaryEmbeddingDim / 2});
TensorDesc broadcastedCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, partialInputDataTensorShape, reshapedCosSinShape);
const DML_TENSOR_DESC broadcastedCosSinDmlTensorDesc = broadcastedCosSinTensorDesc.GetDmlDesc();
// Create a vector that contains the sign values {-1, 1}
@ -224,7 +318,7 @@ public:
const std::vector<uint32_t> reshapedSignShape = interleaved
? std::vector<uint32_t>({1, 1, 1, 1, 2})
: std::vector<uint32_t>({1, 1, 1, 2, 1});
TensorDesc broadcastedSignCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, inputDataTensorShape, reshapedSignShape);
TensorDesc broadcastedSignCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, partialInputDataTensorShape, reshapedSignShape);
const DML_TENSOR_DESC broadcastedSignDmlTensorDesc = broadcastedSignCosSinTensorDesc.GetDmlDesc();
DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulSignDesc{};
@ -242,11 +336,23 @@ public:
// Add the multiplied cos and sin values together
DML_ELEMENT_WISE_ADD_OPERATOR_DESC addDesc{};
addDesc.ATensor = &inputOutputDmlTensorDesc;
addDesc.BTensor = &inputOutputDmlTensorDesc;
addDesc.OutputTensor = &stridedInputOutputDmlTensorDesc;
addDesc.ATensor = &partialInputOutputDmlTensorDesc;
addDesc.BTensor = &partialInputOutputDmlTensorDesc;
addDesc.OutputTensor = &partialStridedInputOutputDmlTensorDesc;
const DML_OPERATOR_DESC addDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &addDesc};
DML_JOIN_OPERATOR_DESC joinOutputOperatorDesc{};
DML_OPERATOR_DESC joinOutputDmlOperatorDesc{};
if (headSize != rotaryEmbeddingDim)
{
joinOutputOperatorDesc.InputCount = gsl::narrow_cast<uint32_t>(splitTensorDescs.size());
joinOutputOperatorDesc.InputTensors = splitTensorDescs.data();
joinOutputOperatorDesc.OutputTensor = &inputOutputDmlTensorDesc;
joinOutputOperatorDesc.Axis = gsl::narrow_cast<uint32_t>(inputOutputShape.size()) - 1;
joinOutputDmlOperatorDesc.Type = DML_OPERATOR_JOIN;
joinOutputDmlOperatorDesc.Desc = &joinOutputOperatorDesc;
}
// Construct the graph
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
@ -254,12 +360,12 @@ public:
std::vector<const DML_OPERATOR_DESC*> opDescs = {
&copyInputDmlDesc, // Copy the input data to preseve the real input shape
&splitInputDmlDesc, // Split the input data
&splitPartialInputDmlDesc, // Split the input data
&gatherCosSinDmlDesc, // Gather cos
&gatherCosSinDmlDesc, // Gather sin
&signRangeDmlDesc, // Generate the signs
&joinInputDmlDesc, // Join the split data
&joinPartialInputDmlDesc, // Join the split data
&mulCosSinDmlDesc, // Multiply cos with the non-rotated data
&mulCosSinDmlDesc, // Multiply sin with the rotated data
&mulSignDmlDesc, // Multiply the sign with the rotated data
@ -269,12 +375,12 @@ public:
enum NodeIndex : uint32_t
{
copyInputOpIndex,
splitInputOpIndex,
splitPartialInputOpIndex,
gatherCosOpIndex,
gatherSinOpIndex,
signRangeOpIndex,
joinInputOpIndex,
joinPartialInputOpIndex,
mulCosOpIndex,
mulSinOpIndex,
mulSignOpIndex,
@ -285,6 +391,9 @@ public:
positionIdsAddOffsetOpIndex,
};
uint32_t splitInputOpIndex = positionIdsIsOffset ? positionIdsAddOffsetOpIndex + 1 : addOpIndex + 1;
uint32_t joinOutputOpIndex = splitInputOpIndex + 1;
if (positionIdsIsOffset)
{
opDescs.push_back(&positionIdsRangeDmlDesc);
@ -332,11 +441,32 @@ public:
inputEdges.push_back(positionIdsToGatherSinEdge);
}
DML_INPUT_GRAPH_EDGE_DESC inputToCopyInputEdge = {};
inputToCopyInputEdge.GraphInputIndex = inputDataIndex;
inputToCopyInputEdge.ToNodeIndex = copyInputOpIndex;
inputToCopyInputEdge.ToNodeInputIndex = 0;
inputEdges.push_back(inputToCopyInputEdge);
if (splitInputDmlOperatorDesc.Desc)
{
opDescs.push_back(&splitInputDmlOperatorDesc);
opDescs.push_back(&joinOutputDmlOperatorDesc);
DML_INPUT_GRAPH_EDGE_DESC inputToSplitInputEdge = {};
inputToSplitInputEdge.GraphInputIndex = inputDataIndex;
inputToSplitInputEdge.ToNodeIndex = splitInputOpIndex;
inputToSplitInputEdge.ToNodeInputIndex = 0;
inputEdges.push_back(inputToSplitInputEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC partialInputToCopyInputEdge = {};
partialInputToCopyInputEdge.FromNodeIndex = splitInputOpIndex;
partialInputToCopyInputEdge.FromNodeOutputIndex = 0;
partialInputToCopyInputEdge.ToNodeIndex = copyInputOpIndex;
partialInputToCopyInputEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(partialInputToCopyInputEdge);
}
else
{
DML_INPUT_GRAPH_EDGE_DESC inputToCopyInputEdge = {};
inputToCopyInputEdge.GraphInputIndex = inputDataIndex;
inputToCopyInputEdge.ToNodeIndex = copyInputOpIndex;
inputToCopyInputEdge.ToNodeInputIndex = 0;
inputEdges.push_back(inputToCopyInputEdge);
}
DML_INPUT_GRAPH_EDGE_DESC cosToGatherEdge = {};
cosToGatherEdge.GraphInputIndex = cosCacheIndex;
@ -353,7 +483,7 @@ public:
DML_INTERMEDIATE_GRAPH_EDGE_DESC inputToSplitEdge = {};
inputToSplitEdge.FromNodeIndex = copyInputOpIndex;
inputToSplitEdge.FromNodeOutputIndex = 0;
inputToSplitEdge.ToNodeIndex = splitInputOpIndex;
inputToSplitEdge.ToNodeIndex = splitPartialInputOpIndex;
inputToSplitEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(inputToSplitEdge);
@ -365,16 +495,16 @@ public:
intermediateEdges.push_back(nonRotatedDataToMulEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC secondHalfDataToJoinEdge = {};
secondHalfDataToJoinEdge.FromNodeIndex = splitInputOpIndex;
secondHalfDataToJoinEdge.FromNodeIndex = splitPartialInputOpIndex;
secondHalfDataToJoinEdge.FromNodeOutputIndex = 1;
secondHalfDataToJoinEdge.ToNodeIndex = joinInputOpIndex;
secondHalfDataToJoinEdge.ToNodeIndex = joinPartialInputOpIndex;
secondHalfDataToJoinEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(secondHalfDataToJoinEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC firstHalfDataToJoinEdge = {};
firstHalfDataToJoinEdge.FromNodeIndex = splitInputOpIndex;
firstHalfDataToJoinEdge.FromNodeIndex = splitPartialInputOpIndex;
firstHalfDataToJoinEdge.FromNodeOutputIndex = 0;
firstHalfDataToJoinEdge.ToNodeIndex = joinInputOpIndex;
firstHalfDataToJoinEdge.ToNodeIndex = joinPartialInputOpIndex;
firstHalfDataToJoinEdge.ToNodeInputIndex = 1;
intermediateEdges.push_back(firstHalfDataToJoinEdge);
@ -386,7 +516,7 @@ public:
intermediateEdges.push_back(cosToMulEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC rotatedDataToMulEdge = {};
rotatedDataToMulEdge.FromNodeIndex = joinInputOpIndex;
rotatedDataToMulEdge.FromNodeIndex = joinPartialInputOpIndex;
rotatedDataToMulEdge.FromNodeOutputIndex = 0;
rotatedDataToMulEdge.ToNodeIndex = mulSinOpIndex;
rotatedDataToMulEdge.ToNodeInputIndex = 0;
@ -427,11 +557,36 @@ public:
rotatedSinToAddEdge.ToNodeInputIndex = 1;
intermediateEdges.push_back(rotatedSinToAddEdge);
DML_OUTPUT_GRAPH_EDGE_DESC addToOutputEdge = {};
addToOutputEdge.FromNodeIndex = addOpIndex;
addToOutputEdge.FromNodeOutputIndex = 0;
addToOutputEdge.GraphOutputIndex = 0;
outputEdges.push_back(addToOutputEdge);
if (splitInputDmlOperatorDesc.Desc)
{
DML_INTERMEDIATE_GRAPH_EDGE_DESC addToJoinOutputEdge = {};
addToJoinOutputEdge.FromNodeIndex = addOpIndex;
addToJoinOutputEdge.FromNodeOutputIndex = 0;
addToJoinOutputEdge.ToNodeIndex = joinOutputOpIndex;
addToJoinOutputEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(addToJoinOutputEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC remainingInputToJoinOutputEdge = {};
remainingInputToJoinOutputEdge.FromNodeIndex = splitInputOpIndex;
remainingInputToJoinOutputEdge.FromNodeOutputIndex = 1;
remainingInputToJoinOutputEdge.ToNodeIndex = joinOutputOpIndex;
remainingInputToJoinOutputEdge.ToNodeInputIndex = 1;
intermediateEdges.push_back(remainingInputToJoinOutputEdge);
DML_OUTPUT_GRAPH_EDGE_DESC joinOutputToOutputEdge = {};
joinOutputToOutputEdge.FromNodeIndex = joinOutputOpIndex;
joinOutputToOutputEdge.FromNodeOutputIndex = 0;
joinOutputToOutputEdge.GraphOutputIndex = 0;
outputEdges.push_back(joinOutputToOutputEdge);
}
else
{
DML_OUTPUT_GRAPH_EDGE_DESC addToOutputEdge = {};
addToOutputEdge.FromNodeIndex = addOpIndex;
addToOutputEdge.FromNodeOutputIndex = 0;
addToOutputEdge.GraphOutputIndex = 0;
outputEdges.push_back(addToOutputEdge);
}
MLOperatorGraphDesc operatorGraphDesc = {};
operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(inputEdges.size());

View file

@ -130,6 +130,8 @@ namespace AttrName
static constexpr const char* UppercaseN = "N";
static constexpr const char* UppercaseK = "K";
static constexpr const char* MatMulNBitsBlockSize = "block_size";
static constexpr const char* RotaryEmbeddingDim = "rotary_embedding_dim";
static constexpr const char* IsPackedBatching = "is_packed_batching";
} // namespace AttrName

View file

@ -135,8 +135,7 @@ static void RunTests(const std::vector<float>& input_data,
int max_sequence_length = 0,
int64_t interleaved = 0,
int64_t is_packed_batching = 0,
bool use_float16 = true,
bool disable_dml = false) {
bool use_float16 = true) {
// FP32 test for CPU
RunTest(input_data,
position_ids,
@ -173,7 +172,7 @@ static void RunTests(const std::vector<float>& input_data,
TensorType::kFloat,
false, /* disable_cpu */
false, /* disable_cuda */
disable_dml || false /* disable_dml */);
false /* disable_dml */);
// FP16 test for CUDA and DML
if (use_float16) {
@ -193,7 +192,7 @@ static void RunTests(const std::vector<float>& input_data,
TensorType::kFloat16,
true, /* disable_cpu */
false, /* disable_cuda*/
disable_dml || false /* disable_dml */);
false /* disable_dml */);
// RunTest(input_data,
// position_ids,
@ -743,9 +742,8 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi) {
num_heads,
max_sequence_length,
interleaved,
0, // is_packed_batching
true, /*use_fp16*/
true /*disable_dml*/);
0, // is_packed_batching
true /*use_fp16*/);
}
TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi_Packed_Batching) {
@ -785,9 +783,8 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi_Packed_B
num_heads,
max_sequence_length,
interleaved,
1, // is_packed_batching
true, /*use_fp16*/
true /*disable_dml*/);
1, // is_packed_batching
true /*use_fp16*/);
}
} // namespace test