From 13dabd97b6da9b79c1ddb69ec864ec9139133b23 Mon Sep 17 00:00:00 2001 From: Dwayne Robinson Date: Wed, 25 Mar 2020 21:48:14 -0700 Subject: [PATCH 1/6] Slice --- .../dml/OperatorAuthorHelper/OperatorHelper.cpp | 17 +++++++++++++++++ .../dml/OperatorAuthorHelper/OperatorHelper.h | 4 +++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 9f3a0ed7e6..2d4526d71c 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -32,6 +32,23 @@ void HandleNegativeAxes(gsl::span onnxAxes, uint32_t dimCount) } } +void FillWithLeadingValues(/*inout*/ std::vector& values, uint32_t minimumElementCount, uint32_t fillValue) +{ + // e.g. + // input = [6,7] + // elementCount = 4 + // fillValue = 1 + // output = [1,1,6,7] + + const size_t oldElementCount = values.size(); + const size_t newElementCount = std::max(size_t(minimumElementCount), oldElementCount); + const size_t fillCount = newElementCount - oldElementCount; + + values.resize(newElementCount); + std::copy_backward(values.data(), values.data() + oldElementCount, values.data() + fillCount); + std::fill_n(values.data(), fillCount, fillValue); +} + int64_t ReadAsInt64(MLOperatorTensorDataType tensorDataType, const void* p) { switch (tensorDataType) diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index aa8486117f..6940ed3a85 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -75,6 +75,8 @@ void RemoveValuesByIndex(gsl::span indices, bool keepOneValue, / values.resize(newValuesCount); } +void FillWithLeadingValues(/*inout*/ std::vector& values, uint32_t minimumElementCount, uint32_t fillValue); + int64_t ReadAsInt64(MLOperatorTensorDataType tensorDataType, const void* p); class EdgeShapes { @@ -641,7 +643,7 @@ public: std::vector m_outputDimensions; std::vector m_offsets; std::vector m_sizes; - std::vector m_strides; + std::vector m_strides; }; class SliceHelper : public SliceHelperBase From ccb840ac99ff3cea6c9e9da03d10e3118a913021 Mon Sep 17 00:00:00 2001 From: Dwayne Robinson Date: Fri, 27 Mar 2020 00:48:06 -0700 Subject: [PATCH 2/6] Fix slice. --- .../src/Operators/DmlOperatorSlice.cpp | 48 +++----- .../src/Operators/OperatorRegistration.cpp | 7 +- .../MLOperatorAuthorHelper.h | 14 +++ .../OperatorAuthorHelper/OperatorHelper.cpp | 64 +++++++---- .../dml/OperatorAuthorHelper/OperatorHelper.h | 108 +++++++----------- 5 files changed, 124 insertions(+), 117 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp index 0e9d0feb5a..fb82f7708c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp @@ -13,49 +13,36 @@ public: : DmlOperator(kernelInfo), SliceHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), opsetVersion) { - uint32_t minInputCount = (opsetVersion < 10) ? 1 : 3; - ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= minInputCount); + const uint32_t inputCount = kernelInfo.GetInputCount(); + ML_CHECK_VALID_ARGUMENT((opsetVersion < 10 && inputCount == 1) + || (opsetVersion == 10 && inputCount >= 3 && inputCount <= 5)); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); - // TODO (23108599): Slice V10 introduces an optional "Steps" input which the kernel does not yet support. - THROW_HR_IF(E_NOTIMPL, kernelInfo.GetInputCount() > 4); - - std::vector> kernelInputIndices = { 0 }; + std::vector> kernelInputIndices = { 0 }; // Only bind GPU to first 'data' tensor. DmlOperator::Initialize(kernelInfo, kernelInputIndices); - assert(m_inputTensorDescs[0].GetDimensionCount() >= gsl::narrow_cast(m_offsets.size())); - assert(m_inputTensorDescs[0].GetDimensionCount() >= gsl::narrow_cast(m_sizes.size())); - assert(m_inputTensorDescs[0].GetDimensionCount() >= gsl::narrow_cast(m_strides.size())); + const uint32_t inputTensorRank = m_inputTensorDescs[0].GetDimensionCount(); + assert(inputTensorRank >= gsl::narrow_cast(m_offsets.size())); + assert(inputTensorRank >= gsl::narrow_cast(m_sizes.size())); + assert(inputTensorRank >= gsl::narrow_cast(m_strides.size())); // Pad the parameters to respect DML's requirements - m_offsets.insert( - m_offsets.begin(), - m_inputTensorDescs[0].GetDimensionCount() - gsl::narrow_cast(m_offsets.size()), - 0); - - m_sizes.insert( - m_sizes.begin(), - m_inputTensorDescs[0].GetDimensionCount() - gsl::narrow_cast(m_sizes.size()), - 1); - - m_strides.insert( - m_strides.begin(), - m_inputTensorDescs[0].GetDimensionCount() - gsl::narrow_cast(m_strides.size()), - 1); + FillWithLeadingValues(/*inout*/ m_offsets, m_inputTensorDescs[0].GetDimensionCount(), 0u); + FillWithLeadingValues(/*inout*/ m_sizes, m_inputTensorDescs[0].GetDimensionCount(), 1u); + FillWithLeadingValues(/*inout*/ m_strides, m_inputTensorDescs[0].GetDimensionCount(), 1); std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); - DML_SLICE_OPERATOR_DESC sliceDesc = {}; + DML_SLICE1_OPERATOR_DESC sliceDesc = {}; sliceDesc.InputTensor = inputDescs.data(); sliceDesc.OutputTensor = outputDescs.data(); sliceDesc.DimensionCount = gsl::narrow_cast(m_offsets.size()); - sliceDesc.Offsets = m_offsets.data(); - sliceDesc.Sizes = m_sizes.data(); - sliceDesc.Strides = m_strides.data(); + sliceDesc.InputWindowOffsets = m_offsets.data(); + sliceDesc.InputWindowSizes = m_sizes.data(); + sliceDesc.InputWindowStrides = m_strides.data(); - DML_OPERATOR_DESC opDesc = { DML_OPERATOR_SLICE, &sliceDesc }; - + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_SLICE1, &sliceDesc }; SetDmlOperatorDesc(opDesc, kernelInfo); } }; @@ -73,9 +60,10 @@ public: void CALLBACK QuerySlice(IMLOperatorSupportQueryContextPrivate* context, bool *isSupported) { - *isSupported = (context->GetInputCount() <= 4); + *isSupported = (context->GetInputCount() <= 5); } DML_OP_DEFINE_CREATION_FUNCTION(Slice7, DmlOperatorSliceTemplate<7>); DML_OP_DEFINE_CREATION_FUNCTION(Slice10, DmlOperatorSliceTemplate<10>); +DML_OP_DEFINE_CREATION_FUNCTION(Slice11, DmlOperatorSliceTemplate<10>); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index f19bf8d6a9..8be3774de1 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -101,6 +101,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Tile); DML_OP_EXTERN_CREATION_FUNCTION(Concat); DML_OP_EXTERN_CREATION_FUNCTION(Slice7); DML_OP_EXTERN_CREATION_FUNCTION(Slice10); +DML_OP_EXTERN_CREATION_FUNCTION(Slice11); DML_OP_EXTERN_CREATION_FUNCTION(Pad); DML_OP_EXTERN_CREATION_FUNCTION(SpaceToDepth); DML_OP_EXTERN_CREATION_FUNCTION(DepthToSpace); @@ -214,6 +215,7 @@ const static char* const typeNameListCast[2] = { "T1", "T2" }; const static char* const typeNameListIsNan[2] = { "T1", "T2" }; const static char* const typeNameListConstantOfShape[2] = { "T1", "T2" }; const static char* const typeNameListScatterGather[2] = { "T", "Tind" }; +const static char* const typeNameListSlice10[2] = { "T", "Tind" }; const static char* const typeNameListQuantize[2] = { "T1", "T2" }; const static char* const typeNameListWhere[2] = { "B", "T" }; const static char* const typeNameListOneHot[3] = { "T1", "T2", "T3" }; @@ -228,6 +230,7 @@ const static SupportedTensorDataTypes supportedTypeListTopK[2] = {SupportedTenso const static SupportedTensorDataTypes supportedTypeListIndices[1] = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 }; const static SupportedTensorDataTypes supportedTypeListCast[2] = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 }; const static SupportedTensorDataTypes supportedTypeListScatterGather[2] = { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; +const static SupportedTensorDataTypes supportedTypeListSlice10[2] = { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; const static SupportedTensorDataTypes supportedTypeListQuantizeLinear[2] = { SupportedTensorDataTypes::Float32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 }; const static SupportedTensorDataTypes supportedTypeListDequantizeLinear[2] = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32 }; const static SupportedTensorDataTypes supportedTypeListQuantize[2] = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 }; @@ -297,7 +300,8 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl {REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, {REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, - {REG_INFO_VER( 10, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported, {1, 2, 3}, std::nullopt, QuerySlice)}, + {REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, {1, 2, 3, 4}, std::nullopt, QuerySlice)}, + {REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, {1, 2, 3, 4}, std::nullopt, QuerySlice)}, {REG_INFO( 7, Pad, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, @@ -473,7 +477,6 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl {REG_INFO( 11, Scan, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 11, ScatterElements, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 11, ScatterND, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, Slice, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 11, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 11, Split, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 11, Squeeze, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h index 7fee23b8d0..a6615bd815 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h @@ -563,6 +563,20 @@ public: return m_impl->GetOutputCount(); } + // Returns true if an input to the operator is valid. + // This returns false for optional omitted inputs and invalid indices. + bool IsInputValid(uint32_t inputIndex) const noexcept + { + return m_impl->IsInputValid(inputIndex); + } + + // Returns true if an output to the operator is valid. + // This returns false for optional omitted inputs and invalid indices. + bool IsOutputValid(uint32_t inputIndex) const noexcept + { + return m_impl->IsOutputValid(inputIndex); + } + MLOperatorEdgeDescription GetInputEdgeDescription(uint32_t inputIndex) const { MLOperatorEdgeDescription ret; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 2d4526d71c..924df87537 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -32,21 +32,51 @@ void HandleNegativeAxes(gsl::span onnxAxes, uint32_t dimCount) } } -void FillWithLeadingValues(/*inout*/ std::vector& values, uint32_t minimumElementCount, uint32_t fillValue) +void ReadCpuLocalTensorIntoInt32( + const MLOperatorTensor& tensor, + std::vector& result + ) { - // e.g. - // input = [6,7] - // elementCount = 4 - // fillValue = 1 - // output = [1,1,6,7] + result.clear(); + ML_CHECK_VALID_ARGUMENT(tensor.IsCpuData(), "Tensor must be CPU Tensor."); - const size_t oldElementCount = values.size(); - const size_t newElementCount = std::max(size_t(minimumElementCount), oldElementCount); - const size_t fillCount = newElementCount - oldElementCount; + const std::vector& tensorDimensions = tensor.GetShape(); + const uint32_t elementCount = ComputeElementCountFromDimensions(tensorDimensions); - values.resize(newElementCount); - std::copy_backward(values.data(), values.data() + oldElementCount, values.data() + fillCount); - std::fill_n(values.data(), fillCount, fillValue); + switch (tensor.GetTensorDataType()) + { + case MLOperatorTensorDataType::Int32: + { + const int32_t* data = tensor.GetData(); + result.assign(data, data + elementCount); + } + break; + + case MLOperatorTensorDataType::Int64: + { + const int64_t* data = tensor.GetData(); + for (auto d : gsl::make_span(data, data + elementCount)) + { + result.push_back(gsl::narrow_cast(d)); + } + } + break; + + default: + ML_INVALID_ARGUMENT("Expecting CPU local tensor of type int32 or int64."); + break; + } +} + +void DowncastDimensions(gsl::span inputDimensions, std::vector& outputDimensions) +{ + outputDimensions.reserve(inputDimensions.size()); + outputDimensions.clear(); + + for (int64_t dim : inputDimensions) + { + outputDimensions.push_back(gsl::narrow_cast(std::clamp(dim, INT32_MIN, INT32_MAX))); + } } int64_t ReadAsInt64(MLOperatorTensorDataType tensorDataType, const void* p) @@ -1070,10 +1100,7 @@ int64_t ReadAsInt64(MLOperatorTensorDataType tensorDataType, const void* p) // First element of shape tensor is how many dims to expand to. std::vector desiredTensorShape; - for (int64_t dim : gsl::make_span(shapeData, dimCount)) - { - desiredTensorShape.push_back(gsl::narrow_cast(dim)); - } + DowncastDimensions(gsl::make_span(shapeData, dimCount), /*out*/ desiredTensorShape); // Determine the broadcasted input shape. outputDimensions = OperatorHelper::BroadcastTensorShape(actualInputTensorShape, desiredTensorShape); @@ -1098,10 +1125,7 @@ int64_t ReadAsInt64(MLOperatorTensorDataType tensorDataType, const void* p) // First element of shape tensor is how many dims to expand to. std::vector desiredTensorShape; - for (int64_t dim : gsl::make_span(shapeData, dimCount)) - { - desiredTensorShape.push_back(gsl::narrow_cast(dim)); - } + DowncastDimensions(gsl::make_span(shapeData, dimCount), /*out*/ desiredTensorShape); return { std::move(EdgeShapes(desiredTensorShape)) }; } diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 6940ed3a85..e593c39cb2 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -75,10 +75,28 @@ void RemoveValuesByIndex(gsl::span indices, bool keepOneValue, / values.resize(newValuesCount); } -void FillWithLeadingValues(/*inout*/ std::vector& values, uint32_t minimumElementCount, uint32_t fillValue); +template +void FillWithLeadingValues(/*inout*/ std::vector& values, uint32_t minimumElementCount, T fillValue) +{ + // e.g. + // input = [6,7] + // elementCount = 4 + // fillValue = 1 + // output = [1,1,6,7] + + const size_t oldElementCount = values.size(); + const size_t newElementCount = std::max(size_t(minimumElementCount), oldElementCount); + const size_t fillCount = newElementCount - oldElementCount; + + values.resize(newElementCount); + std::copy_backward(values.data(), values.data() + oldElementCount, values.data() + fillCount); + std::fill_n(values.data(), fillCount, fillValue); +} int64_t ReadAsInt64(MLOperatorTensorDataType tensorDataType, const void* p); +void ReadCpuLocalTensorIntoInt32(const MLOperatorTensor& tensor, std::vector& result); + class EdgeShapes { public: EdgeShapes() = default; @@ -503,56 +521,25 @@ class SplitHelper { class SliceHelperBase { public: - template + template void ReadIndexTensors( const Info_t& operatorInfo, - std::vector& starts, - std::vector& ends, - std::vector& axes, - std::vector& steps - ) + /*out*/ std::vector& starts, + /*out*/ std::vector& ends, + /*out*/ std::vector& axes, + /*out*/ std::vector& steps + ) { - // Get starts, ends, optional axes and optional steps from constant inputs. - MLOperatorTensor startsTensor = operatorInfo.GetConstantInputTensor(1); - const std::vector& startsTensorDimensions = startsTensor.GetShape(); - size_t dimCount = startsTensorDimensions[0]; - const Index_t* startsData = startsTensor.GetData(); - for (size_t i = 0; i < dimCount; ++i) + // Get starts, ends, optional axes, and optional steps from constant inputs. + ReadCpuLocalTensorIntoInt32(operatorInfo.GetConstantInputTensor(1), /*out*/ starts); + ReadCpuLocalTensorIntoInt32(operatorInfo.GetConstantInputTensor(2), /*out*/ ends); + if (operatorInfo.IsInputValid(3)) { - starts.push_back(gsl::narrow_cast(startsData[i])); + ReadCpuLocalTensorIntoInt32(operatorInfo.GetConstantInputTensor(3), /*out*/ axes); } - - MLOperatorTensor endsTensor = operatorInfo.GetConstantInputTensor(2); - const std::vector& endsTensorDimensions = endsTensor.GetShape(); - dimCount = endsTensorDimensions[0]; - const Index_t* endsData = endsTensor.GetData(); - for (size_t i = 0; i < dimCount; ++i) + if (operatorInfo.IsInputValid(4)) { - ends.push_back(gsl::narrow_cast(endsData[i])); - } - uint32_t inputCount = operatorInfo.GetInputCount(); - if (inputCount > 3) - { - MLOperatorTensor axesTensor = operatorInfo.GetConstantInputTensor(3); - const std::vector& axesTensorDimensions = axesTensor.GetShape(); - dimCount = axesTensorDimensions[0]; - const Index_t* axesData = axesTensor.GetData(); - for (size_t i = 0; i < dimCount; ++i) - { - axes.push_back(gsl::narrow_cast(axesData[i])); - } - } - - if (inputCount > 4) - { - MLOperatorTensor stepsTensor = operatorInfo.GetConstantInputTensor(4); - const std::vector& stepsTensorDimensions = stepsTensor.GetShape(); - dimCount = stepsTensorDimensions[0]; - const Index_t* stepsData = stepsTensor.GetData(); - for (size_t i = 0; i < dimCount; ++i) - { - steps.push_back(gsl::narrow_cast(stepsData[i])); - } + ReadCpuLocalTensorIntoInt32(operatorInfo.GetConstantInputTensor(4), /*out*/ steps); } } @@ -567,29 +554,23 @@ public: std::vector ends; std::vector axes; std::vector steps; + if (opsetVersion == 7) { - // Get starts, ends and axes from attributes + // Read starts, ends, and axes from attributes. starts = operatorInfo.GetOptionalAttributeVectorInt32(AttrName::Starts); ends = operatorInfo.GetOptionalAttributeVectorInt32(AttrName::Ends); axes = operatorInfo.GetOptionalAttributeVectorInt32(AttrName::Axes); } else if (opsetVersion == 10) { - if (operatorInfo.GetConstantInputTensor(1).GetTensorDataType() == MLOperatorTensorDataType::Int32) - { - ReadIndexTensors(operatorInfo, starts, ends, axes, steps); - } - else - { - THROW_HR_IF(E_INVALIDARG, operatorInfo.GetConstantInputTensor(1).GetTensorDataType() != MLOperatorTensorDataType::Int64); - ReadIndexTensors(operatorInfo, starts, ends, axes, steps); - } + // Read starts, ends, and axes from tensors. + ReadIndexTensors(operatorInfo, /*out*/ starts, /*out*/ ends, /*out*/ axes, /*out*/ steps); } - const uint32_t dimCount = gsl::narrow_cast(inputDimensions.size()); - HandleNegativeAxes(/*inout*/ axes, dimCount); - + const uint32_t inputDimensionCount = gsl::narrow_cast(inputDimensions.size()); + HandleNegativeAxes(/*inout*/ axes, inputDimensionCount); + ML_CHECK_VALID_ARGUMENT(starts.size() == ends.size(), "'starts' must equal 'ends' in size."); ML_CHECK_VALID_ARGUMENT(axes.empty() || starts.size() == axes.size(), "'axes' must equal 'starts' in size, or 'axes' must be empty."); @@ -604,18 +585,14 @@ public: // Clamp selected dimensions to given 'starts' and 'ends'. for (int i = 0, ci = gsl::narrow_cast(starts.size()); i < ci; ++i) { - int dimIndex = i; - if (!axes.empty()) - { - dimIndex = axes[i]; - } + int dimIndex = axes.empty() ? i : axes[i]; ML_CHECK_VALID_ARGUMENT(dimIndex < inputDimensions.size(), "'axes' must be valid with within actual input dimensions."); // Positive values are offsets from 0. // Negative values are offsets from the dimension's size. int dim = gsl::narrow_cast(inputDimensions[dimIndex]); - int start = (starts[i] < 0) ? (starts[i] + dim) : starts[i]; - int end = (ends[i] < 0) ? (ends[i] + dim) : ends[i]; + int start = (starts[i] < 0 && starts[i] > INT_MIN) ? (starts[i] + dim) : starts[i]; + int end = (ends[i] < 0 && ends[i] < INT_MAX) ? (ends[i] + dim) : ends[i]; // Clamp the dimensions to the slice extents. // Clamp negative numbers to 0, per case test_slice_start_out_of_bounds. @@ -1195,6 +1172,7 @@ using ShapeInferenceHelper_Transpose = TransposeHelper; using ShapeInferenceHelper_Concat = ConcatHelper; using ShapeInferenceHelper_Slice7 = SliceHelper; using ShapeInferenceHelper_Slice10 = Slice10Helper; +using ShapeInferenceHelper_Slice11 = Slice10Helper; // No functional change from 10. using ShapeInferenceHelper_Pad = PaddingHelper; using ShapeInferenceHelper_SpaceToDepth = SpaceToDepthHelper; using ShapeInferenceHelper_DepthToSpace = DepthToSpaceHelper; From 89df6abac7ca4e80e396bd7ef4ed7768c3af6b1e Mon Sep 17 00:00:00 2001 From: Dwayne Robinson Date: Fri, 27 Mar 2020 02:25:58 -0700 Subject: [PATCH 3/6] Fix slice. --- .../src/External/DirectMLHelpers/ApiTraits.h | 358 ++- .../External/DirectMLHelpers/ApiTraits.h.bak | 1936 +++++++++++++++++ .../External/DirectMLHelpers/DirectMLSchema.h | 360 +++ .../DirectMLHelpers/DirectMLSchema.h.bak | 1514 +++++++++++++ .../DirectMLHelpers/GeneratedSchemaHelpers.h | 336 +++ .../GeneratedSchemaHelpers.h.bak | 1388 ++++++++++++ .../DirectMLHelpers/GeneratedSchemaTypes.h | 22 +- .../GeneratedSchemaTypes.h.bak | 105 + .../External/DirectMLHelpers/SchemaHelpers.h | 46 + .../DirectMLHelpers/SchemaHelpers.h.bak | 345 +++ .../src/MLOperatorAuthorImpl.cpp | 3 +- .../src/Operators/DmlOperatorSlice.cpp | 6 +- .../DmlExecutionProvider/src/TensorDesc.cpp | 8 +- .../OperatorAuthorHelper/OperatorHelper.cpp | 2 +- .../dml/OperatorAuthorHelper/OperatorHelper.h | 24 +- 15 files changed, 6438 insertions(+), 15 deletions(-) create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h.bak create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h.bak create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h.bak create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h.bak create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h.bak diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index 2870946480..703d027412 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -13,7 +13,7 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 9; + static constexpr auto ValueCount = 12; }; template <> @@ -25,7 +25,7 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 97; + static constexpr auto ValueCount = 119; static constexpr size_t ActivationFunctionCount = 19; }; @@ -90,6 +90,30 @@ struct EnumTraits static constexpr auto ValueCount = 2; }; +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 3; +}; + +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 2; +}; + +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 2; +}; + +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 3; +}; + template constexpr auto EnumValueCount = EnumTraits::ValueCount; @@ -610,6 +634,138 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_RESAMPLE; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ROUND; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_IS_INFINITY; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_FILL_VALUE_CONSTANT; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_FILL_VALUE_SEQUENCE; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_CUMULATIVE_SUMMATION; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_REVERSE_SUBSEQUENCES; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GATHER_ELEMENTS; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GATHER_ND; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SCATTER_ND; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MAX_POOLING2; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SLICE1; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_TOP_K1; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DEPTH_TO_SPACE1; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SPACE_TO_DEPTH1; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR; +}; + template <> struct OperatorDescTraits { @@ -1192,6 +1348,138 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_RESAMPLE> using DescType = DML_RESAMPLE_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT> +{ + using DescType = DML_ELEMENT_WISE_BIT_SHIFT_LEFT_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT> +{ + using DescType = DML_ELEMENT_WISE_BIT_SHIFT_RIGHT_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ROUND> +{ + using DescType = DML_ELEMENT_WISE_ROUND_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_IS_INFINITY> +{ + using DescType = DML_ELEMENT_WISE_IS_INFINITY_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE> +{ + using DescType = DML_ELEMENT_WISE_MODULUS_TRUNCATE_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR> +{ + using DescType = DML_ELEMENT_WISE_MODULUS_FLOOR_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_FILL_VALUE_CONSTANT> +{ + using DescType = DML_FILL_VALUE_CONSTANT_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_FILL_VALUE_SEQUENCE> +{ + using DescType = DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_CUMULATIVE_SUMMATION> +{ + using DescType = DML_CUMULATIVE_SUMMATION_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_REVERSE_SUBSEQUENCES> +{ + using DescType = DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GATHER_ELEMENTS> +{ + using DescType = DML_GATHER_ELEMENTS_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GATHER_ND> +{ + using DescType = DML_GATHER_ND_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SCATTER_ND> +{ + using DescType = DML_SCATTER_ND_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MAX_POOLING2> +{ + using DescType = DML_MAX_POOLING2_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE1> +{ + using DescType = DML_SLICE1_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_TOP_K1> +{ + using DescType = DML_TOP_K1_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_DEPTH_TO_SPACE1> +{ + using DescType = DML_DEPTH_TO_SPACE1_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SPACE_TO_DEPTH1> +{ + using DescType = DML_SPACE_TO_DEPTH1_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1> +{ + using DescType = DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY> +{ + using DescType = DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION> +{ + using DescType = DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR> +{ + using DescType = DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_ELU> { @@ -1474,6 +1762,50 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_ONE_HOT_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_RESAMPLE: return std::invoke(std::forward(visitor), DML_RESAMPLE_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_BIT_SHIFT_LEFT_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_BIT_SHIFT_RIGHT_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_ROUND: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ROUND_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_IS_INFINITY: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_IS_INFINITY_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_MODULUS_TRUNCATE_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_MODULUS_FLOOR_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_FILL_VALUE_CONSTANT: + return std::invoke(std::forward(visitor), DML_FILL_VALUE_CONSTANT_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_FILL_VALUE_SEQUENCE: + return std::invoke(std::forward(visitor), DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_CUMULATIVE_SUMMATION: + return std::invoke(std::forward(visitor), DML_CUMULATIVE_SUMMATION_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_REVERSE_SUBSEQUENCES: + return std::invoke(std::forward(visitor), DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_GATHER_ELEMENTS: + return std::invoke(std::forward(visitor), DML_GATHER_ELEMENTS_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_GATHER_ND: + return std::invoke(std::forward(visitor), DML_GATHER_ND_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_SCATTER_ND: + return std::invoke(std::forward(visitor), DML_SCATTER_ND_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_MAX_POOLING2: + return std::invoke(std::forward(visitor), DML_MAX_POOLING2_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_SLICE1: + return std::invoke(std::forward(visitor), DML_SLICE1_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_TOP_K1: + return std::invoke(std::forward(visitor), DML_TOP_K1_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_DEPTH_TO_SPACE1: + return std::invoke(std::forward(visitor), DML_DEPTH_TO_SPACE1_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_SPACE_TO_DEPTH1: + return std::invoke(std::forward(visitor), DML_SPACE_TO_DEPTH1_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1: + return std::invoke(std::forward(visitor), DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: + return std::invoke(std::forward(visitor), DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: + return std::invoke(std::forward(visitor), DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: + return std::invoke(std::forward(visitor), DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_ELU: return std::invoke(std::forward(visitor), DML_ACTIVATION_ELU_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_HARDMAX: @@ -1601,6 +1933,28 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_SCATTER: return "DML_OPERATOR_SCATTER"; case DML_OPERATOR_ONE_HOT: return "DML_OPERATOR_ONE_HOT"; case DML_OPERATOR_RESAMPLE: return "DML_OPERATOR_RESAMPLE"; + case DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT: return "DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT"; + case DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT: return "DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT"; + case DML_OPERATOR_ELEMENT_WISE_ROUND: return "DML_OPERATOR_ELEMENT_WISE_ROUND"; + case DML_OPERATOR_ELEMENT_WISE_IS_INFINITY: return "DML_OPERATOR_ELEMENT_WISE_IS_INFINITY"; + case DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE: return "DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE"; + case DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR: return "DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR"; + case DML_OPERATOR_FILL_VALUE_CONSTANT: return "DML_OPERATOR_FILL_VALUE_CONSTANT"; + case DML_OPERATOR_FILL_VALUE_SEQUENCE: return "DML_OPERATOR_FILL_VALUE_SEQUENCE"; + case DML_OPERATOR_CUMULATIVE_SUMMATION: return "DML_OPERATOR_CUMULATIVE_SUMMATION"; + case DML_OPERATOR_REVERSE_SUBSEQUENCES: return "DML_OPERATOR_REVERSE_SUBSEQUENCES"; + case DML_OPERATOR_GATHER_ELEMENTS: return "DML_OPERATOR_GATHER_ELEMENTS"; + case DML_OPERATOR_GATHER_ND: return "DML_OPERATOR_GATHER_ND"; + case DML_OPERATOR_SCATTER_ND: return "DML_OPERATOR_SCATTER_ND"; + case DML_OPERATOR_MAX_POOLING2: return "DML_OPERATOR_MAX_POOLING2"; + case DML_OPERATOR_SLICE1: return "DML_OPERATOR_SLICE1"; + case DML_OPERATOR_TOP_K1: return "DML_OPERATOR_TOP_K1"; + case DML_OPERATOR_DEPTH_TO_SPACE1: return "DML_OPERATOR_DEPTH_TO_SPACE1"; + case DML_OPERATOR_SPACE_TO_DEPTH1: return "DML_OPERATOR_SPACE_TO_DEPTH1"; + case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1: return "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1"; + case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: return "DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY"; + case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: return "DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION"; + case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return "DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR"; default: assert(false); return ""; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h.bak b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h.bak new file mode 100644 index 0000000000..d75afeb4b6 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h.bak @@ -0,0 +1,1936 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace ApiTraits +{ +template +struct EnumTraits +{ +}; + +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 12; +}; + +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 2; +}; + +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 119; + static constexpr size_t ActivationFunctionCount = 19; +}; + +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 3; +}; + +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 12; + static constexpr DML_REDUCE_FUNCTION Invalid = static_cast(ValueCount); +}; + +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 2; +}; + +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 2; +}; + +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 2; +}; + +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 3; +}; + +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 2; +}; + +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 3; +}; + +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 2; +}; + +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 2; +}; + +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 3; +}; + +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 2; +}; + +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 2; +}; + +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 3; +}; + +template +constexpr auto EnumValueCount = EnumTraits::ValueCount; + +template +constexpr bool IsValidEnumValue(T value) +{ + return (std::make_unsigned_t(value) < std::make_unsigned_t(EnumValueCount)); +} + +template +struct FlagTraits +{ +}; + +template <> +struct FlagTraits +{ + static constexpr auto ValidMask = DML_TENSOR_FLAG_OWNED_BY_DML; +}; + +template <> +struct FlagTraits +{ + static constexpr auto ValidMask = DML_EXECUTION_FLAG_ALLOW_HALF_PRECISION_COMPUTATION | DML_EXECUTION_FLAG_DISABLE_META_COMMANDS | DML_EXECUTION_FLAG_DESCRIPTORS_VOLATILE; +}; + +template <> +struct FlagTraits +{ + static constexpr auto ValidMask = DML_CREATE_DEVICE_FLAG_DEBUG; +}; + +template +constexpr auto FlagsValidMask = FlagTraits::ValidMask; + +template +constexpr bool IsValidFlags(T value) +{ + return (value & ~FlagsValidMask) == 0; +} + +template +struct TensorDescTraits +{ +}; + +template <> +struct TensorDescTraits +{ + static constexpr DML_TENSOR_TYPE Type = DML_TENSOR_TYPE_BUFFER; +}; + + +template +struct OperatorDescTraits +{ +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_IDENTITY; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ABS; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ACOS; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ADD; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ASIN; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ATAN; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_CEIL; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_CLIP; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_COS; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_DIVIDE; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_EXP; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_FLOOR; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOG; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_MAX; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_MEAN; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_MIN; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_MULTIPLY; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_POW; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_RECIP; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_SIN; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_SQRT; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_SUBTRACT; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_TAN; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_THRESHOLD; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_CONVOLUTION; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GEMM; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_REDUCE; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_AVERAGE_POOLING; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LP_POOLING; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MAX_POOLING; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MAX_POOLING1; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_POOLING; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SLICE; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_CAST; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SPLIT; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_JOIN; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_PADDING; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_VALUE_SCALE_2D; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_UPSAMPLE_2D; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GATHER; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SPACE_TO_DEPTH; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DEPTH_TO_SPACE; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_TILE; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_TOP_K; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_BATCH_NORMALIZATION; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LP_NORMALIZATION; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_RNN; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LSTM; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GRU; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_SIGN; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_IS_NAN; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ERF; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_SINH; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_COSH; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_TANH; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ASINH; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ACOSH; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ATANH; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_IF; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ADD1; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MAX_UNPOOLING; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DIAGONAL_MATRIX; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SCATTER; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ONE_HOT; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_RESAMPLE; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ROUND; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_IS_INFINITY; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_FILL_VALUE_CONSTANT; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_FILL_VALUE_SEQUENCE; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_CUMULATIVE_SUMMATION; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_REVERSE_SUBSEQUENCES; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GATHER_ELEMENTS; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GATHER_ND; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SCATTER_ND; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MAX_POOLING2; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SLICE1; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_TOP_K1; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DEPTH_TO_SPACE1; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SPACE_TO_DEPTH1; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_ELU; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_HARDMAX; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_HARD_SIGMOID; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_IDENTITY; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_LEAKY_RELU; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_LINEAR; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_LOG_SOFTMAX; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_RELU; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SCALED_ELU; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SCALED_TANH; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SIGMOID; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SOFTMAX; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SOFTPLUS; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SOFTSIGN; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_TANH; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SHRINK; +}; + + +template +struct OperatorTypeTraits +{ +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_IDENTITY> +{ + using DescType = DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ABS> +{ + using DescType = DML_ELEMENT_WISE_ABS_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ACOS> +{ + using DescType = DML_ELEMENT_WISE_ACOS_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ADD> +{ + using DescType = DML_ELEMENT_WISE_ADD_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ASIN> +{ + using DescType = DML_ELEMENT_WISE_ASIN_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ATAN> +{ + using DescType = DML_ELEMENT_WISE_ATAN_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_CEIL> +{ + using DescType = DML_ELEMENT_WISE_CEIL_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_CLIP> +{ + using DescType = DML_ELEMENT_WISE_CLIP_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_COS> +{ + using DescType = DML_ELEMENT_WISE_COS_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_DIVIDE> +{ + using DescType = DML_ELEMENT_WISE_DIVIDE_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_EXP> +{ + using DescType = DML_ELEMENT_WISE_EXP_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_FLOOR> +{ + using DescType = DML_ELEMENT_WISE_FLOOR_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOG> +{ + using DescType = DML_ELEMENT_WISE_LOG_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND> +{ + using DescType = DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS> +{ + using DescType = DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN> +{ + using DescType = DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN> +{ + using DescType = DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT> +{ + using DescType = DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR> +{ + using DescType = DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR> +{ + using DescType = DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_MAX> +{ + using DescType = DML_ELEMENT_WISE_MAX_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_MEAN> +{ + using DescType = DML_ELEMENT_WISE_MEAN_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_MIN> +{ + using DescType = DML_ELEMENT_WISE_MIN_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_MULTIPLY> +{ + using DescType = DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_POW> +{ + using DescType = DML_ELEMENT_WISE_POW_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW> +{ + using DescType = DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_RECIP> +{ + using DescType = DML_ELEMENT_WISE_RECIP_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_SIN> +{ + using DescType = DML_ELEMENT_WISE_SIN_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_SQRT> +{ + using DescType = DML_ELEMENT_WISE_SQRT_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_SUBTRACT> +{ + using DescType = DML_ELEMENT_WISE_SUBTRACT_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_TAN> +{ + using DescType = DML_ELEMENT_WISE_TAN_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_THRESHOLD> +{ + using DescType = DML_ELEMENT_WISE_THRESHOLD_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR> +{ + using DescType = DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR> +{ + using DescType = DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_CONVOLUTION> +{ + using DescType = DML_CONVOLUTION_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GEMM> +{ + using DescType = DML_GEMM_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_REDUCE> +{ + using DescType = DML_REDUCE_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_AVERAGE_POOLING> +{ + using DescType = DML_AVERAGE_POOLING_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LP_POOLING> +{ + using DescType = DML_LP_POOLING_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MAX_POOLING> +{ + using DescType = DML_MAX_POOLING_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MAX_POOLING1> +{ + using DescType = DML_MAX_POOLING1_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_POOLING> +{ + using DescType = DML_ROI_POOLING_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE> +{ + using DescType = DML_SLICE_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_CAST> +{ + using DescType = DML_CAST_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SPLIT> +{ + using DescType = DML_SPLIT_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_JOIN> +{ + using DescType = DML_JOIN_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_PADDING> +{ + using DescType = DML_PADDING_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_VALUE_SCALE_2D> +{ + using DescType = DML_VALUE_SCALE_2D_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_UPSAMPLE_2D> +{ + using DescType = DML_UPSAMPLE_2D_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GATHER> +{ + using DescType = DML_GATHER_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SPACE_TO_DEPTH> +{ + using DescType = DML_SPACE_TO_DEPTH_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_DEPTH_TO_SPACE> +{ + using DescType = DML_DEPTH_TO_SPACE_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_TILE> +{ + using DescType = DML_TILE_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_TOP_K> +{ + using DescType = DML_TOP_K_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_BATCH_NORMALIZATION> +{ + using DescType = DML_BATCH_NORMALIZATION_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION> +{ + using DescType = DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION> +{ + using DescType = DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LP_NORMALIZATION> +{ + using DescType = DML_LP_NORMALIZATION_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_RNN> +{ + using DescType = DML_RNN_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LSTM> +{ + using DescType = DML_LSTM_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GRU> +{ + using DescType = DML_GRU_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_SIGN> +{ + using DescType = DML_ELEMENT_WISE_SIGN_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_IS_NAN> +{ + using DescType = DML_ELEMENT_WISE_IS_NAN_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ERF> +{ + using DescType = DML_ELEMENT_WISE_ERF_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_SINH> +{ + using DescType = DML_ELEMENT_WISE_SINH_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_COSH> +{ + using DescType = DML_ELEMENT_WISE_COSH_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_TANH> +{ + using DescType = DML_ELEMENT_WISE_TANH_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ASINH> +{ + using DescType = DML_ELEMENT_WISE_ASINH_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ACOSH> +{ + using DescType = DML_ELEMENT_WISE_ACOSH_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ATANH> +{ + using DescType = DML_ELEMENT_WISE_ATANH_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_IF> +{ + using DescType = DML_ELEMENT_WISE_IF_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ADD1> +{ + using DescType = DML_ELEMENT_WISE_ADD1_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MAX_UNPOOLING> +{ + using DescType = DML_MAX_UNPOOLING_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_DIAGONAL_MATRIX> +{ + using DescType = DML_DIAGONAL_MATRIX_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SCATTER> +{ + using DescType = DML_SCATTER_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ONE_HOT> +{ + using DescType = DML_ONE_HOT_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_RESAMPLE> +{ + using DescType = DML_RESAMPLE_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT> +{ + using DescType = DML_ELEMENT_WISE_BIT_SHIFT_LEFT_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT> +{ + using DescType = DML_ELEMENT_WISE_BIT_SHIFT_RIGHT_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ROUND> +{ + using DescType = DML_ELEMENT_WISE_ROUND_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_IS_INFINITY> +{ + using DescType = DML_ELEMENT_WISE_IS_INFINITY_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE> +{ + using DescType = DML_ELEMENT_WISE_MODULUS_TRUNCATE_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR> +{ + using DescType = DML_ELEMENT_WISE_MODULUS_FLOOR_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_FILL_VALUE_CONSTANT> +{ + using DescType = DML_FILL_VALUE_CONSTANT_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_FILL_VALUE_SEQUENCE> +{ + using DescType = DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_CUMULATIVE_SUMMATION> +{ + using DescType = DML_CUMULATIVE_SUMMATION_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_REVERSE_SUBSEQUENCES> +{ + using DescType = DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GATHER_ELEMENTS> +{ + using DescType = DML_GATHER_ELEMENTS_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GATHER_ND> +{ + using DescType = DML_GATHER_ND_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SCATTER_ND> +{ + using DescType = DML_SCATTER_ND_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MAX_POOLING2> +{ + using DescType = DML_MAX_POOLING2_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE1> +{ + using DescType = DML_SLICE1_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_TOP_K1> +{ + using DescType = DML_TOP_K1_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_DEPTH_TO_SPACE1> +{ + using DescType = DML_DEPTH_TO_SPACE1_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SPACE_TO_DEPTH1> +{ + using DescType = DML_SPACE_TO_DEPTH1_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1> +{ + using DescType = DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_ELU> +{ + using DescType = DML_ACTIVATION_ELU_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_HARDMAX> +{ + using DescType = DML_ACTIVATION_HARDMAX_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_HARD_SIGMOID> +{ + using DescType = DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_IDENTITY> +{ + using DescType = DML_ACTIVATION_IDENTITY_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_LEAKY_RELU> +{ + using DescType = DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_LINEAR> +{ + using DescType = DML_ACTIVATION_LINEAR_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_LOG_SOFTMAX> +{ + using DescType = DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU> +{ + using DescType = DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS> +{ + using DescType = DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_RELU> +{ + using DescType = DML_ACTIVATION_RELU_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SCALED_ELU> +{ + using DescType = DML_ACTIVATION_SCALED_ELU_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SCALED_TANH> +{ + using DescType = DML_ACTIVATION_SCALED_TANH_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SIGMOID> +{ + using DescType = DML_ACTIVATION_SIGMOID_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SOFTMAX> +{ + using DescType = DML_ACTIVATION_SOFTMAX_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SOFTPLUS> +{ + using DescType = DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SOFTSIGN> +{ + using DescType = DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_TANH> +{ + using DescType = DML_ACTIVATION_TANH_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU> +{ + using DescType = DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SHRINK> +{ + using DescType = DML_ACTIVATION_SHRINK_OPERATOR_DESC; +}; + + +// Calls a visitor functor, supplying an empty operator desc corresponding to the given DML_OPERATOR_TYPE as +// the first argument. +// +// For example: +// Visit(DML_OPERATOR_ELEMENT_WISE_IDENTITY, [](auto tag) { +// using T = decltype(tag); // T is one of the DML_*_OPERATOR_DESC structs +// }); +// +template +auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args) +{ + switch (static_cast(type)) + { + case DML_OPERATOR_ELEMENT_WISE_IDENTITY: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_ABS: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ABS_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_ACOS: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ACOS_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_ADD: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ADD_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_ASIN: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ASIN_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_ATAN: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ATAN_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_CEIL: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_CEIL_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_CLIP: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_CLIP_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_COS: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_COS_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_DIVIDE: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_DIVIDE_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_EXP: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_EXP_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_FLOOR: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_FLOOR_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_LOG: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_LOG_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_MAX: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_MAX_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_MEAN: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_MEAN_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_MIN: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_MIN_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_MULTIPLY: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_POW: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_POW_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_RECIP: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_RECIP_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_SIN: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_SIN_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_SQRT: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_SQRT_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_SUBTRACT: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_SUBTRACT_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_TAN: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_TAN_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_THRESHOLD: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_THRESHOLD_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_CONVOLUTION: + return std::invoke(std::forward(visitor), DML_CONVOLUTION_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_GEMM: + return std::invoke(std::forward(visitor), DML_GEMM_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_REDUCE: + return std::invoke(std::forward(visitor), DML_REDUCE_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_AVERAGE_POOLING: + return std::invoke(std::forward(visitor), DML_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_LP_POOLING: + return std::invoke(std::forward(visitor), DML_LP_POOLING_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_MAX_POOLING: + return std::invoke(std::forward(visitor), DML_MAX_POOLING_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_MAX_POOLING1: + return std::invoke(std::forward(visitor), DML_MAX_POOLING1_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ROI_POOLING: + return std::invoke(std::forward(visitor), DML_ROI_POOLING_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_SLICE: + return std::invoke(std::forward(visitor), DML_SLICE_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_CAST: + return std::invoke(std::forward(visitor), DML_CAST_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_SPLIT: + return std::invoke(std::forward(visitor), DML_SPLIT_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_JOIN: + return std::invoke(std::forward(visitor), DML_JOIN_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_PADDING: + return std::invoke(std::forward(visitor), DML_PADDING_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_VALUE_SCALE_2D: + return std::invoke(std::forward(visitor), DML_VALUE_SCALE_2D_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_UPSAMPLE_2D: + return std::invoke(std::forward(visitor), DML_UPSAMPLE_2D_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_GATHER: + return std::invoke(std::forward(visitor), DML_GATHER_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_SPACE_TO_DEPTH: + return std::invoke(std::forward(visitor), DML_SPACE_TO_DEPTH_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_DEPTH_TO_SPACE: + return std::invoke(std::forward(visitor), DML_DEPTH_TO_SPACE_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_TILE: + return std::invoke(std::forward(visitor), DML_TILE_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_TOP_K: + return std::invoke(std::forward(visitor), DML_TOP_K_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_BATCH_NORMALIZATION: + return std::invoke(std::forward(visitor), DML_BATCH_NORMALIZATION_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION: + return std::invoke(std::forward(visitor), DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION: + return std::invoke(std::forward(visitor), DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_LP_NORMALIZATION: + return std::invoke(std::forward(visitor), DML_LP_NORMALIZATION_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_RNN: + return std::invoke(std::forward(visitor), DML_RNN_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_LSTM: + return std::invoke(std::forward(visitor), DML_LSTM_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_GRU: + return std::invoke(std::forward(visitor), DML_GRU_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_SIGN: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_SIGN_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_IS_NAN: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_IS_NAN_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_ERF: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ERF_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_SINH: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_SINH_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_COSH: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_COSH_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_TANH: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_TANH_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_ASINH: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ASINH_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_ACOSH: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ACOSH_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_ATANH: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ATANH_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_IF: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_IF_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_ADD1: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ADD1_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_MAX_UNPOOLING: + return std::invoke(std::forward(visitor), DML_MAX_UNPOOLING_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_DIAGONAL_MATRIX: + return std::invoke(std::forward(visitor), DML_DIAGONAL_MATRIX_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_SCATTER: + return std::invoke(std::forward(visitor), DML_SCATTER_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ONE_HOT: + return std::invoke(std::forward(visitor), DML_ONE_HOT_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_RESAMPLE: + return std::invoke(std::forward(visitor), DML_RESAMPLE_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_BIT_SHIFT_LEFT_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_BIT_SHIFT_RIGHT_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_ROUND: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ROUND_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_IS_INFINITY: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_IS_INFINITY_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_MODULUS_TRUNCATE_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_MODULUS_FLOOR_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_FILL_VALUE_CONSTANT: + return std::invoke(std::forward(visitor), DML_FILL_VALUE_CONSTANT_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_FILL_VALUE_SEQUENCE: + return std::invoke(std::forward(visitor), DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_CUMULATIVE_SUMMATION: + return std::invoke(std::forward(visitor), DML_CUMULATIVE_SUMMATION_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_REVERSE_SUBSEQUENCES: + return std::invoke(std::forward(visitor), DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_GATHER_ELEMENTS: + return std::invoke(std::forward(visitor), DML_GATHER_ELEMENTS_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_GATHER_ND: + return std::invoke(std::forward(visitor), DML_GATHER_ND_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_SCATTER_ND: + return std::invoke(std::forward(visitor), DML_SCATTER_ND_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_MAX_POOLING2: + return std::invoke(std::forward(visitor), DML_MAX_POOLING2_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_SLICE1: + return std::invoke(std::forward(visitor), DML_SLICE1_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_TOP_K1: + return std::invoke(std::forward(visitor), DML_TOP_K1_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_DEPTH_TO_SPACE1: + return std::invoke(std::forward(visitor), DML_DEPTH_TO_SPACE1_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_SPACE_TO_DEPTH1: + return std::invoke(std::forward(visitor), DML_SPACE_TO_DEPTH1_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1: + return std::invoke(std::forward(visitor), DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_ELU: + return std::invoke(std::forward(visitor), DML_ACTIVATION_ELU_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_HARDMAX: + return std::invoke(std::forward(visitor), DML_ACTIVATION_HARDMAX_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: + return std::invoke(std::forward(visitor), DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_IDENTITY: + return std::invoke(std::forward(visitor), DML_ACTIVATION_IDENTITY_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_LEAKY_RELU: + return std::invoke(std::forward(visitor), DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_LINEAR: + return std::invoke(std::forward(visitor), DML_ACTIVATION_LINEAR_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX: + return std::invoke(std::forward(visitor), DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU: + return std::invoke(std::forward(visitor), DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS: + return std::invoke(std::forward(visitor), DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_RELU: + return std::invoke(std::forward(visitor), DML_ACTIVATION_RELU_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_SCALED_ELU: + return std::invoke(std::forward(visitor), DML_ACTIVATION_SCALED_ELU_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_SCALED_TANH: + return std::invoke(std::forward(visitor), DML_ACTIVATION_SCALED_TANH_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_SIGMOID: + return std::invoke(std::forward(visitor), DML_ACTIVATION_SIGMOID_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_SOFTMAX: + return std::invoke(std::forward(visitor), DML_ACTIVATION_SOFTMAX_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_SOFTPLUS: + return std::invoke(std::forward(visitor), DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_SOFTSIGN: + return std::invoke(std::forward(visitor), DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_TANH: + return std::invoke(std::forward(visitor), DML_ACTIVATION_TANH_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU: + return std::invoke(std::forward(visitor), DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_SHRINK: + return std::invoke(std::forward(visitor), DML_ACTIVATION_SHRINK_OPERATOR_DESC{}, std::forward(args)...); + + default: + THROW_HR(E_INVALIDARG); + } +} + + +inline gsl::czstring ToString(DML_OPERATOR_TYPE value) +{ + switch (value) + { + case DML_OPERATOR_INVALID: return "DML_OPERATOR_INVALID"; + case DML_OPERATOR_ELEMENT_WISE_IDENTITY: return "DML_OPERATOR_ELEMENT_WISE_IDENTITY"; + case DML_OPERATOR_ELEMENT_WISE_ABS: return "DML_OPERATOR_ELEMENT_WISE_ABS"; + case DML_OPERATOR_ELEMENT_WISE_ACOS: return "DML_OPERATOR_ELEMENT_WISE_ACOS"; + case DML_OPERATOR_ELEMENT_WISE_ADD: return "DML_OPERATOR_ELEMENT_WISE_ADD"; + case DML_OPERATOR_ELEMENT_WISE_ASIN: return "DML_OPERATOR_ELEMENT_WISE_ASIN"; + case DML_OPERATOR_ELEMENT_WISE_ATAN: return "DML_OPERATOR_ELEMENT_WISE_ATAN"; + case DML_OPERATOR_ELEMENT_WISE_CEIL: return "DML_OPERATOR_ELEMENT_WISE_CEIL"; + case DML_OPERATOR_ELEMENT_WISE_CLIP: return "DML_OPERATOR_ELEMENT_WISE_CLIP"; + case DML_OPERATOR_ELEMENT_WISE_COS: return "DML_OPERATOR_ELEMENT_WISE_COS"; + case DML_OPERATOR_ELEMENT_WISE_DIVIDE: return "DML_OPERATOR_ELEMENT_WISE_DIVIDE"; + case DML_OPERATOR_ELEMENT_WISE_EXP: return "DML_OPERATOR_ELEMENT_WISE_EXP"; + case DML_OPERATOR_ELEMENT_WISE_FLOOR: return "DML_OPERATOR_ELEMENT_WISE_FLOOR"; + case DML_OPERATOR_ELEMENT_WISE_LOG: return "DML_OPERATOR_ELEMENT_WISE_LOG"; + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND"; + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS"; + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN"; + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN"; + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT"; + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR"; + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR"; + case DML_OPERATOR_ELEMENT_WISE_MAX: return "DML_OPERATOR_ELEMENT_WISE_MAX"; + case DML_OPERATOR_ELEMENT_WISE_MEAN: return "DML_OPERATOR_ELEMENT_WISE_MEAN"; + case DML_OPERATOR_ELEMENT_WISE_MIN: return "DML_OPERATOR_ELEMENT_WISE_MIN"; + case DML_OPERATOR_ELEMENT_WISE_MULTIPLY: return "DML_OPERATOR_ELEMENT_WISE_MULTIPLY"; + case DML_OPERATOR_ELEMENT_WISE_POW: return "DML_OPERATOR_ELEMENT_WISE_POW"; + case DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW: return "DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW"; + case DML_OPERATOR_ELEMENT_WISE_RECIP: return "DML_OPERATOR_ELEMENT_WISE_RECIP"; + case DML_OPERATOR_ELEMENT_WISE_SIN: return "DML_OPERATOR_ELEMENT_WISE_SIN"; + case DML_OPERATOR_ELEMENT_WISE_SQRT: return "DML_OPERATOR_ELEMENT_WISE_SQRT"; + case DML_OPERATOR_ELEMENT_WISE_SUBTRACT: return "DML_OPERATOR_ELEMENT_WISE_SUBTRACT"; + case DML_OPERATOR_ELEMENT_WISE_TAN: return "DML_OPERATOR_ELEMENT_WISE_TAN"; + case DML_OPERATOR_ELEMENT_WISE_THRESHOLD: return "DML_OPERATOR_ELEMENT_WISE_THRESHOLD"; + case DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR: return "DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR"; + case DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR: return "DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR"; + case DML_OPERATOR_CONVOLUTION: return "DML_OPERATOR_CONVOLUTION"; + case DML_OPERATOR_GEMM: return "DML_OPERATOR_GEMM"; + case DML_OPERATOR_REDUCE: return "DML_OPERATOR_REDUCE"; + case DML_OPERATOR_AVERAGE_POOLING: return "DML_OPERATOR_AVERAGE_POOLING"; + case DML_OPERATOR_LP_POOLING: return "DML_OPERATOR_LP_POOLING"; + case DML_OPERATOR_MAX_POOLING: return "DML_OPERATOR_MAX_POOLING"; + case DML_OPERATOR_MAX_POOLING1: return "DML_OPERATOR_MAX_POOLING1"; + case DML_OPERATOR_ROI_POOLING: return "DML_OPERATOR_ROI_POOLING"; + case DML_OPERATOR_SLICE: return "DML_OPERATOR_SLICE"; + case DML_OPERATOR_CAST: return "DML_OPERATOR_CAST"; + case DML_OPERATOR_SPLIT: return "DML_OPERATOR_SPLIT"; + case DML_OPERATOR_JOIN: return "DML_OPERATOR_JOIN"; + case DML_OPERATOR_PADDING: return "DML_OPERATOR_PADDING"; + case DML_OPERATOR_VALUE_SCALE_2D: return "DML_OPERATOR_VALUE_SCALE_2D"; + case DML_OPERATOR_UPSAMPLE_2D: return "DML_OPERATOR_UPSAMPLE_2D"; + case DML_OPERATOR_GATHER: return "DML_OPERATOR_GATHER"; + case DML_OPERATOR_SPACE_TO_DEPTH: return "DML_OPERATOR_SPACE_TO_DEPTH"; + case DML_OPERATOR_DEPTH_TO_SPACE: return "DML_OPERATOR_DEPTH_TO_SPACE"; + case DML_OPERATOR_TILE: return "DML_OPERATOR_TILE"; + case DML_OPERATOR_TOP_K: return "DML_OPERATOR_TOP_K"; + case DML_OPERATOR_BATCH_NORMALIZATION: return "DML_OPERATOR_BATCH_NORMALIZATION"; + case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION: return "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION"; + case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION: return "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION"; + case DML_OPERATOR_LP_NORMALIZATION: return "DML_OPERATOR_LP_NORMALIZATION"; + case DML_OPERATOR_RNN: return "DML_OPERATOR_RNN"; + case DML_OPERATOR_LSTM: return "DML_OPERATOR_LSTM"; + case DML_OPERATOR_GRU: return "DML_OPERATOR_GRU"; + case DML_OPERATOR_ELEMENT_WISE_SIGN: return "DML_OPERATOR_ELEMENT_WISE_SIGN"; + case DML_OPERATOR_ELEMENT_WISE_IS_NAN: return "DML_OPERATOR_ELEMENT_WISE_IS_NAN"; + case DML_OPERATOR_ELEMENT_WISE_ERF: return "DML_OPERATOR_ELEMENT_WISE_ERF"; + case DML_OPERATOR_ELEMENT_WISE_SINH: return "DML_OPERATOR_ELEMENT_WISE_SINH"; + case DML_OPERATOR_ELEMENT_WISE_COSH: return "DML_OPERATOR_ELEMENT_WISE_COSH"; + case DML_OPERATOR_ELEMENT_WISE_TANH: return "DML_OPERATOR_ELEMENT_WISE_TANH"; + case DML_OPERATOR_ELEMENT_WISE_ASINH: return "DML_OPERATOR_ELEMENT_WISE_ASINH"; + case DML_OPERATOR_ELEMENT_WISE_ACOSH: return "DML_OPERATOR_ELEMENT_WISE_ACOSH"; + case DML_OPERATOR_ELEMENT_WISE_ATANH: return "DML_OPERATOR_ELEMENT_WISE_ATANH"; + case DML_OPERATOR_ELEMENT_WISE_IF: return "DML_OPERATOR_ELEMENT_WISE_IF"; + case DML_OPERATOR_ELEMENT_WISE_ADD1: return "DML_OPERATOR_ELEMENT_WISE_ADD1"; + case DML_OPERATOR_MAX_UNPOOLING: return "DML_OPERATOR_MAX_UNPOOLING"; + case DML_OPERATOR_DIAGONAL_MATRIX: return "DML_OPERATOR_DIAGONAL_MATRIX"; + case DML_OPERATOR_SCATTER: return "DML_OPERATOR_SCATTER"; + case DML_OPERATOR_ONE_HOT: return "DML_OPERATOR_ONE_HOT"; + case DML_OPERATOR_RESAMPLE: return "DML_OPERATOR_RESAMPLE"; + case DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT: return "DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT"; + case DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT: return "DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT"; + case DML_OPERATOR_ELEMENT_WISE_ROUND: return "DML_OPERATOR_ELEMENT_WISE_ROUND"; + case DML_OPERATOR_ELEMENT_WISE_IS_INFINITY: return "DML_OPERATOR_ELEMENT_WISE_IS_INFINITY"; + case DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE: return "DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE"; + case DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR: return "DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR"; + case DML_OPERATOR_FILL_VALUE_CONSTANT: return "DML_OPERATOR_FILL_VALUE_CONSTANT"; + case DML_OPERATOR_FILL_VALUE_SEQUENCE: return "DML_OPERATOR_FILL_VALUE_SEQUENCE"; + case DML_OPERATOR_CUMULATIVE_SUMMATION: return "DML_OPERATOR_CUMULATIVE_SUMMATION"; + case DML_OPERATOR_REVERSE_SUBSEQUENCES: return "DML_OPERATOR_REVERSE_SUBSEQUENCES"; + case DML_OPERATOR_GATHER_ELEMENTS: return "DML_OPERATOR_GATHER_ELEMENTS"; + case DML_OPERATOR_GATHER_ND: return "DML_OPERATOR_GATHER_ND"; + case DML_OPERATOR_SCATTER_ND: return "DML_OPERATOR_SCATTER_ND"; + case DML_OPERATOR_MAX_POOLING2: return "DML_OPERATOR_MAX_POOLING2"; + case DML_OPERATOR_SLICE1: return "DML_OPERATOR_SLICE1"; + case DML_OPERATOR_TOP_K1: return "DML_OPERATOR_TOP_K1"; + case DML_OPERATOR_DEPTH_TO_SPACE1: return "DML_OPERATOR_DEPTH_TO_SPACE1"; + case DML_OPERATOR_SPACE_TO_DEPTH1: return "DML_OPERATOR_SPACE_TO_DEPTH1"; + case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1: return "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1"; + default: + assert(false); + return ""; + } +} +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index 7c46a8a6a2..07592cc2ee 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -19,12 +19,15 @@ enum DML_SCHEMA_FIELD_TYPE DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC, DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY, DML_SCHEMA_FIELD_TYPE_UINT, + DML_SCHEMA_FIELD_TYPE_UINT64, DML_SCHEMA_FIELD_TYPE_INT, DML_SCHEMA_FIELD_TYPE_FLOAT, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, + DML_SCHEMA_FIELD_TYPE_INT_ARRAY, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, DML_SCHEMA_FIELD_TYPE_SIZE_2D, + DML_SCHEMA_FIELD_TYPE_SCALAR_UNION, }; enum DML_SCHEMA_OPERATOR_SUPPORT_FLAGS @@ -1246,6 +1249,363 @@ constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE_OPERATOR_SCHEMA { DML_RESAMPLE_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_BIT_SHIFT_LEFT_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_BIT_SHIFT_LEFT_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT", + DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_BIT_SHIFT_LEFT_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_BIT_SHIFT_RIGHT_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_BIT_SHIFT_RIGHT_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT", + DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_BIT_SHIFT_RIGHT_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ROUND_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "RoundingMode", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ROUND_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_ROUND", + DML_OPERATOR_ELEMENT_WISE_ROUND, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_ROUND_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_IS_INFINITY_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InfinityMode", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_IS_INFINITY_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_IS_INFINITY", + DML_OPERATOR_ELEMENT_WISE_IS_INFINITY, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_IS_INFINITY_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_MODULUS_TRUNCATE_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_MODULUS_TRUNCATE_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE", + DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_MODULUS_TRUNCATE_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_MODULUS_FLOOR_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_MODULUS_FLOOR_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR", + DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_MODULUS_FLOOR_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_FILL_VALUE_CONSTANT_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ValueDataType", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALAR_UNION, "Value", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_FILL_VALUE_CONSTANT_OPERATOR_SCHEMA { + "DML_OPERATOR_FILL_VALUE_CONSTANT", + DML_OPERATOR_FILL_VALUE_CONSTANT, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_FILL_VALUE_CONSTANT_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_FILL_VALUE_SEQUENCE_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ValueDataType", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALAR_UNION, "ValueStart", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALAR_UNION, "ValueDelta", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_FILL_VALUE_SEQUENCE_OPERATOR_SCHEMA { + "DML_OPERATOR_FILL_VALUE_SEQUENCE", + DML_OPERATOR_FILL_VALUE_SEQUENCE, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 4, + DML_FILL_VALUE_SEQUENCE_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA_FIELDS[5] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "HasExclusiveSum", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisDirection", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA { + "DML_OPERATOR_CUMULATIVE_SUMMATION", + DML_OPERATOR_CUMULATIVE_SUMMATION, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 5, + DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "SequenceLengthsTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA { + "DML_OPERATOR_REVERSE_SUBSEQUENCES", + DML_OPERATOR_REVERSE_SUBSEQUENCES, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 4, + DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_GATHER_ELEMENTS_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_GATHER_ELEMENTS_OPERATOR_SCHEMA { + "DML_OPERATOR_GATHER_ELEMENTS", + DML_OPERATOR_GATHER_ELEMENTS, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_GATHER_ELEMENTS_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_GATHER_ND_OPERATOR_SCHEMA_FIELDS[5] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InputDimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IndicesDimensionCount", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_GATHER_ND_OPERATOR_SCHEMA { + "DML_OPERATOR_GATHER_ND", + DML_OPERATOR_GATHER_ND, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 5, + DML_GATHER_ND_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_SCATTER_ND_OPERATOR_SCHEMA_FIELDS[6] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "UpdatesTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InputDimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IndicesDimensionCount", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_SCATTER_ND_OPERATOR_SCHEMA { + "DML_OPERATOR_SCATTER_ND", + DML_OPERATOR_SCATTER_ND, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 6, + DML_SCATTER_ND_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_MAX_POOLING2_OPERATOR_SCHEMA_FIELDS[9] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputIndicesTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_MAX_POOLING2_OPERATOR_SCHEMA { + "DML_OPERATOR_MAX_POOLING2", + DML_OPERATOR_MAX_POOLING2, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 9, + DML_MAX_POOLING2_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_SLICE1_OPERATOR_SCHEMA_FIELDS[6] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "InputWindowOffsets", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "InputWindowSizes", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_INT_ARRAY, "InputWindowStrides", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_SLICE1_OPERATOR_SCHEMA { + "DML_OPERATOR_SLICE1", + DML_OPERATOR_SLICE1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 6, + DML_SLICE1_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_TOP_K1_OPERATOR_SCHEMA_FIELDS[6] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputValueTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputIndexTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "K", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisDirection", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_TOP_K1_OPERATOR_SCHEMA { + "DML_OPERATOR_TOP_K1", + DML_OPERATOR_TOP_K1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 6, + DML_TOP_K1_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "BlockSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Order", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA { + "DML_OPERATOR_DEPTH_TO_SPACE1", + DML_OPERATOR_DEPTH_TO_SPACE1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "BlockSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Order", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA { + "DML_OPERATOR_SPACE_TO_DEPTH1", + DML_OPERATOR_SPACE_TO_DEPTH1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA_FIELDS[9] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Axes", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "NormalizeVariance", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC, "FusedActivation", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA { + "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1", + DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 9, + DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA_FIELDS[9] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA { + "DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY", + DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 9, + DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA_FIELDS[16] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "FilterTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "FilterScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "FilterZeroPointTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "GroupCount", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA { + "DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION", + DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 16, + DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA { + "DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR", + DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS, +}; constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h.bak b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h.bak new file mode 100644 index 0000000000..7c46a8a6a2 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h.bak @@ -0,0 +1,1514 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +extern "C" { + +enum DML_SCHEMA_FIELD_KIND +{ + DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, + DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, + DML_SCHEMA_FIELD_KIND_ATTRIBUTE, +}; + +enum DML_SCHEMA_FIELD_TYPE +{ + DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, + DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY, + DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC, + DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY, + DML_SCHEMA_FIELD_TYPE_UINT, + DML_SCHEMA_FIELD_TYPE_INT, + DML_SCHEMA_FIELD_TYPE_FLOAT, + DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, + DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, + DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, + DML_SCHEMA_FIELD_TYPE_SIZE_2D, +}; + +enum DML_SCHEMA_OPERATOR_SUPPORT_FLAGS +{ + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE = 0, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION = (1 << 0), +}; + +DEFINE_ENUM_FLAG_OPERATORS(DML_SCHEMA_OPERATOR_SUPPORT_FLAGS); + +struct DML_SCHEMA_FIELD +{ + DML_SCHEMA_FIELD_KIND Kind; + DML_SCHEMA_FIELD_TYPE Type; + const CHAR* Name; + BOOL Optional; +}; + +struct DML_OPERATOR_SCHEMA +{ + const CHAR* OperatorName; + DML_OPERATOR_TYPE OperatorType; + DML_SCHEMA_OPERATOR_SUPPORT_FLAGS SupportFlags; + + UINT FieldCount; + const DML_SCHEMA_FIELD* Fields; +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_IDENTITY_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_IDENTITY_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_IDENTITY", + DML_OPERATOR_ELEMENT_WISE_IDENTITY, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_IDENTITY_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ABS_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ABS_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_ABS", + DML_OPERATOR_ELEMENT_WISE_ABS, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_ABS_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ACOS_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ACOS_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_ACOS", + DML_OPERATOR_ELEMENT_WISE_ACOS, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_ACOS_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ADD_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ADD_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_ADD", + DML_OPERATOR_ELEMENT_WISE_ADD, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_ADD_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ASIN_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ASIN_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_ASIN", + DML_OPERATOR_ELEMENT_WISE_ASIN, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_ASIN_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ATAN_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ATAN_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_ATAN", + DML_OPERATOR_ELEMENT_WISE_ATAN, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_ATAN_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_CEIL_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_CEIL_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_CEIL", + DML_OPERATOR_ELEMENT_WISE_CEIL, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_CEIL_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA_FIELDS[5] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Min", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Max", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_CLIP", + DML_OPERATOR_ELEMENT_WISE_CLIP, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 5, + DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_COS", + DML_OPERATOR_ELEMENT_WISE_COS, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_DIVIDE_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_DIVIDE_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_DIVIDE", + DML_OPERATOR_ELEMENT_WISE_DIVIDE, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_DIVIDE_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_EXP_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_EXP_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_EXP", + DML_OPERATOR_ELEMENT_WISE_EXP, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_EXP_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_FLOOR_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_FLOOR_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_FLOOR", + DML_OPERATOR_ELEMENT_WISE_FLOOR, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_FLOOR_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOG_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOG_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_LOG", + DML_OPERATOR_ELEMENT_WISE_LOG, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_LOG_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND", + DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS", + DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 3, + DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN", + DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 3, + DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN", + DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 3, + DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_SCHEMA_FIELDS[2] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT", + DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 2, + DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR", + DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR", + DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_MAX_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_MAX_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_MAX", + DML_OPERATOR_ELEMENT_WISE_MAX, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_MAX_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_MEAN_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_MEAN_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_MEAN", + DML_OPERATOR_ELEMENT_WISE_MEAN, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_MEAN_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_MIN_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_MIN_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_MIN", + DML_OPERATOR_ELEMENT_WISE_MIN, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_MIN_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_MULTIPLY_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_MULTIPLY_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_MULTIPLY", + DML_OPERATOR_ELEMENT_WISE_MULTIPLY, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_MULTIPLY_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_POW_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ExponentTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_POW_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_POW", + DML_OPERATOR_ELEMENT_WISE_POW, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 4, + DML_ELEMENT_WISE_POW_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Exponent", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW", + DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 4, + DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_RECIP_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_RECIP_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_RECIP", + DML_OPERATOR_ELEMENT_WISE_RECIP, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_RECIP_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_SIN_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_SIN_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_SIN", + DML_OPERATOR_ELEMENT_WISE_SIN, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_SIN_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_SQRT", + DML_OPERATOR_ELEMENT_WISE_SQRT, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_SUBTRACT", + DML_OPERATOR_ELEMENT_WISE_SUBTRACT, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_TAN_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_TAN_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_TAN", + DML_OPERATOR_ELEMENT_WISE_TAN, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_TAN_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Min", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_THRESHOLD", + DML_OPERATOR_ELEMENT_WISE_THRESHOLD, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 4, + DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR", + DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR", + DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_CONVOLUTION_OPERATOR_SCHEMA_FIELDS[14] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "FilterTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Mode", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Direction", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "OutputPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "GroupCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC, "FusedActivation", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_CONVOLUTION_OPERATOR_SCHEMA { + "DML_OPERATOR_CONVOLUTION", + DML_OPERATOR_CONVOLUTION, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 14, + DML_CONVOLUTION_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_GEMM_OPERATOR_SCHEMA_FIELDS[9] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "CTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "TransA", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "TransB", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC, "FusedActivation", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_GEMM_OPERATOR_SCHEMA { + "DML_OPERATOR_GEMM", + DML_OPERATOR_GEMM, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 9, + DML_GEMM_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_REDUCE_OPERATOR_SCHEMA_FIELDS[5] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Function", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Axes", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_REDUCE_OPERATOR_SCHEMA { + "DML_OPERATOR_REDUCE", + DML_OPERATOR_REDUCE, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 5, + DML_REDUCE_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[8] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_AVERAGE_POOLING_OPERATOR_SCHEMA { + "DML_OPERATOR_AVERAGE_POOLING", + DML_OPERATOR_AVERAGE_POOLING, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 8, + DML_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_LP_POOLING_OPERATOR_SCHEMA_FIELDS[8] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "P", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_LP_POOLING_OPERATOR_SCHEMA { + "DML_OPERATOR_LP_POOLING", + DML_OPERATOR_LP_POOLING, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 8, + DML_LP_POOLING_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_MAX_POOLING_OPERATOR_SCHEMA_FIELDS[7] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_MAX_POOLING_OPERATOR_SCHEMA { + "DML_OPERATOR_MAX_POOLING", + DML_OPERATOR_MAX_POOLING, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 7, + DML_MAX_POOLING_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_MAX_POOLING1_OPERATOR_SCHEMA_FIELDS[8] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputIndicesTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_MAX_POOLING1_OPERATOR_SCHEMA { + "DML_OPERATOR_MAX_POOLING1", + DML_OPERATOR_MAX_POOLING1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 8, + DML_MAX_POOLING1_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ROI_POOLING_OPERATOR_SCHEMA_FIELDS[5] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ROITensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SpatialScale", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SIZE_2D, "PooledSize", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ROI_POOLING_OPERATOR_SCHEMA { + "DML_OPERATOR_ROI_POOLING", + DML_OPERATOR_ROI_POOLING, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 5, + DML_ROI_POOLING_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_SLICE_OPERATOR_SCHEMA_FIELDS[6] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Offsets", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Sizes", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_SLICE_OPERATOR_SCHEMA { + "DML_OPERATOR_SLICE", + DML_OPERATOR_SLICE, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 6, + DML_SLICE_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_CAST_OPERATOR_SCHEMA_FIELDS[2] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_CAST_OPERATOR_SCHEMA { + "DML_OPERATOR_CAST", + DML_OPERATOR_CAST, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 2, + DML_CAST_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_SPLIT_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "OutputCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY, "OutputTensors", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_SPLIT_OPERATOR_SCHEMA { + "DML_OPERATOR_SPLIT", + DML_OPERATOR_SPLIT, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_SPLIT_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_JOIN_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InputCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY, "InputTensors", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_JOIN_OPERATOR_SCHEMA { + "DML_OPERATOR_JOIN", + DML_OPERATOR_JOIN, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_JOIN_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_PADDING_OPERATOR_SCHEMA_FIELDS[7] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "PaddingMode", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "PaddingValue", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_PADDING_OPERATOR_SCHEMA { + "DML_OPERATOR_PADDING", + DML_OPERATOR_PADDING, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 7, + DML_PADDING_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_VALUE_SCALE_2D_OPERATOR_SCHEMA_FIELDS[5] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Scale", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ChannelCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, "Bias", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_VALUE_SCALE_2D_OPERATOR_SCHEMA { + "DML_OPERATOR_VALUE_SCALE_2D", + DML_OPERATOR_VALUE_SCALE_2D, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 5, + DML_VALUE_SCALE_2D_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_UPSAMPLE_2D_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SIZE_2D, "ScaleSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InterpolationMode", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_UPSAMPLE_2D_OPERATOR_SCHEMA { + "DML_OPERATOR_UPSAMPLE_2D", + DML_OPERATOR_UPSAMPLE_2D, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_UPSAMPLE_2D_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_GATHER_OPERATOR_SCHEMA_FIELDS[5] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IndexDimensions", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_GATHER_OPERATOR_SCHEMA { + "DML_OPERATOR_GATHER", + DML_OPERATOR_GATHER, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 5, + DML_GATHER_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_SPACE_TO_DEPTH_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "BlockSize", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_SPACE_TO_DEPTH_OPERATOR_SCHEMA { + "DML_OPERATOR_SPACE_TO_DEPTH", + DML_OPERATOR_SPACE_TO_DEPTH, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 3, + DML_SPACE_TO_DEPTH_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_DEPTH_TO_SPACE_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "BlockSize", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_DEPTH_TO_SPACE_OPERATOR_SCHEMA { + "DML_OPERATOR_DEPTH_TO_SPACE", + DML_OPERATOR_DEPTH_TO_SPACE, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 3, + DML_DEPTH_TO_SPACE_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_TILE_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "RepeatsCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Repeats", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_TILE_OPERATOR_SCHEMA { + "DML_OPERATOR_TILE", + DML_OPERATOR_TILE, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_TILE_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_TOP_K_OPERATOR_SCHEMA_FIELDS[5] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputValueTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputIndexTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "K", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_TOP_K_OPERATOR_SCHEMA { + "DML_OPERATOR_TOP_K", + DML_OPERATOR_TOP_K, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 5, + DML_TOP_K_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA_FIELDS[9] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "MeanTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "VarianceTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Spatial", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC, "FusedActivation", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA { + "DML_OPERATOR_BATCH_NORMALIZATION", + DML_OPERATOR_BATCH_NORMALIZATION, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 9, + DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA_FIELDS[8] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "CrossChannel", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "NormalizeVariance", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC, "FusedActivation", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA { + "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION", + DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 8, + DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA_FIELDS[7] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "CrossChannel", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "LocalSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Bias", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA { + "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION", + DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 7, + DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_LP_NORMALIZATION_OPERATOR_SCHEMA_FIELDS[5] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "P", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_LP_NORMALIZATION_OPERATOR_SCHEMA { + "DML_OPERATOR_LP_NORMALIZATION", + DML_OPERATOR_LP_NORMALIZATION, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 5, + DML_LP_NORMALIZATION_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_RNN_OPERATOR_SCHEMA_FIELDS[11] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "WeightTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "RecurrenceTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "HiddenInitTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "SequenceLengthsTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputSequenceTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputSingleTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ActivationDescCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY, "ActivationDescs", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Direction", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_RNN_OPERATOR_SCHEMA { + "DML_OPERATOR_RNN", + DML_OPERATOR_RNN, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 11, + DML_RNN_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_LSTM_OPERATOR_SCHEMA_FIELDS[17] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "WeightTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "RecurrenceTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "HiddenInitTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "CellMemInitTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "SequenceLengthsTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "PeepholeTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputSequenceTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputSingleTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputCellSingleTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ActivationDescCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY, "ActivationDescs", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Direction", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "ClipThreshold", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "UseClipThreshold", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "CoupleInputForget", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_LSTM_OPERATOR_SCHEMA { + "DML_OPERATOR_LSTM", + DML_OPERATOR_LSTM, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 17, + DML_LSTM_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_GRU_OPERATOR_SCHEMA_FIELDS[12] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "WeightTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "RecurrenceTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "HiddenInitTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "SequenceLengthsTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputSequenceTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputSingleTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ActivationDescCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY, "ActivationDescs", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Direction", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "LinearBeforeReset", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_GRU_OPERATOR_SCHEMA { + "DML_OPERATOR_GRU", + DML_OPERATOR_GRU, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 12, + DML_GRU_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_SIGN_OPERATOR_SCHEMA_FIELDS[2] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_SIGN_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_SIGN", + DML_OPERATOR_ELEMENT_WISE_SIGN, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 2, + DML_ELEMENT_WISE_SIGN_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA_FIELDS[2] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_IS_NAN", + DML_OPERATOR_ELEMENT_WISE_IS_NAN, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 2, + DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_ERF", + DML_OPERATOR_ELEMENT_WISE_ERF, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_SINH_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_SINH_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_SINH", + DML_OPERATOR_ELEMENT_WISE_SINH, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_SINH_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_COSH_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_COSH_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_COSH", + DML_OPERATOR_ELEMENT_WISE_COSH, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_COSH_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_TANH_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_TANH_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_TANH", + DML_OPERATOR_ELEMENT_WISE_TANH, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_TANH_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ASINH_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ASINH_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_ASINH", + DML_OPERATOR_ELEMENT_WISE_ASINH, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_ASINH_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ACOSH_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ACOSH_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_ACOSH", + DML_OPERATOR_ELEMENT_WISE_ACOSH, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_ACOSH_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ATANH_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ATANH_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_ATANH", + DML_OPERATOR_ELEMENT_WISE_ATANH, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_ATANH_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_IF_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ConditionTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_IF_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_IF", + DML_OPERATOR_ELEMENT_WISE_IF, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_ELEMENT_WISE_IF_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ADD1_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC, "FusedActivation", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ADD1_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_ADD1", + DML_OPERATOR_ELEMENT_WISE_ADD1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 4, + DML_ELEMENT_WISE_ADD1_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_MAX_UNPOOLING_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_MAX_UNPOOLING_OPERATOR_SCHEMA { + "DML_OPERATOR_MAX_UNPOOLING", + DML_OPERATOR_MAX_UNPOOLING, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 3, + DML_MAX_UNPOOLING_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_DIAGONAL_MATRIX_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_INT, "Offset", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Value", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_DIAGONAL_MATRIX_OPERATOR_SCHEMA { + "DML_OPERATOR_DIAGONAL_MATRIX", + DML_OPERATOR_DIAGONAL_MATRIX, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 3, + DML_DIAGONAL_MATRIX_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_SCATTER_OPERATOR_SCHEMA_FIELDS[5] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "UpdatesTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_SCATTER_OPERATOR_SCHEMA { + "DML_OPERATOR_SCATTER", + DML_OPERATOR_SCATTER, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 5, + DML_SCATTER_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ONE_HOT_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ValuesTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ONE_HOT_OPERATOR_SCHEMA { + "DML_OPERATOR_ONE_HOT", + DML_OPERATOR_ONE_HOT, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_ONE_HOT_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_RESAMPLE_OPERATOR_SCHEMA_FIELDS[5] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InterpolationMode", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ScaleCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, "Scales", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE_OPERATOR_SCHEMA { + "DML_OPERATOR_RESAMPLE", + DML_OPERATOR_RESAMPLE, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 5, + DML_RESAMPLE_OPERATOR_SCHEMA_FIELDS, +}; + + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_ELU_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_ELU", + DML_OPERATOR_ACTIVATION_ELU, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA_FIELDS[2] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_HARDMAX", + DML_OPERATOR_ACTIVATION_HARDMAX, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 2, + DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_HARD_SIGMOID", + DML_OPERATOR_ACTIVATION_HARD_SIGMOID, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 4, + DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_IDENTITY_OPERATOR_SCHEMA_FIELDS[2] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_IDENTITY_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_IDENTITY", + DML_OPERATOR_ACTIVATION_IDENTITY, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 2, + DML_ACTIVATION_IDENTITY_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_LEAKY_RELU_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_LEAKY_RELU_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_LEAKY_RELU", + DML_OPERATOR_ACTIVATION_LEAKY_RELU, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ACTIVATION_LEAKY_RELU_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_LINEAR_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_LINEAR_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_LINEAR", + DML_OPERATOR_ACTIVATION_LINEAR, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 4, + DML_ACTIVATION_LINEAR_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA_FIELDS[2] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_LOG_SOFTMAX", + DML_OPERATOR_ACTIVATION_LOG_SOFTMAX, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 2, + DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "SlopeTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU", + DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS", + DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 4, + DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_RELU_OPERATOR_SCHEMA_FIELDS[2] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_RELU_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_RELU", + DML_OPERATOR_ACTIVATION_RELU, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 2, + DML_ACTIVATION_RELU_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SCALED_ELU_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Gamma", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SCALED_ELU_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_SCALED_ELU", + DML_OPERATOR_ACTIVATION_SCALED_ELU, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 4, + DML_ACTIVATION_SCALED_ELU_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SCALED_TANH_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SCALED_TANH_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_SCALED_TANH", + DML_OPERATOR_ACTIVATION_SCALED_TANH, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 4, + DML_ACTIVATION_SCALED_TANH_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SIGMOID_OPERATOR_SCHEMA_FIELDS[2] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SIGMOID_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_SIGMOID", + DML_OPERATOR_ACTIVATION_SIGMOID, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 2, + DML_ACTIVATION_SIGMOID_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA_FIELDS[2] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_SOFTMAX", + DML_OPERATOR_ACTIVATION_SOFTMAX, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 2, + DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Steepness", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_SOFTPLUS", + DML_OPERATOR_ACTIVATION_SOFTPLUS, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SOFTSIGN_OPERATOR_SCHEMA_FIELDS[2] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SOFTSIGN_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_SOFTSIGN", + DML_OPERATOR_ACTIVATION_SOFTSIGN, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 2, + DML_ACTIVATION_SOFTSIGN_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_TANH_OPERATOR_SCHEMA_FIELDS[2] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_TANH_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_TANH", + DML_OPERATOR_ACTIVATION_TANH, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 2, + DML_ACTIVATION_TANH_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU", + DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Bias", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Threshold", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_SHRINK", + DML_OPERATOR_ACTIVATION_SHRINK, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 4, + DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA_FIELDS, +}; + +} // extern "C" diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index b8285c77a7..2d0c35b3eb 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -738,6 +738,232 @@ inline std::vector GetFields(const DML_RESAMPLE_OPERATOR_DESC& de OperatorField(&DML_RESAMPLE_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Scales), desc.ScaleCount)), }; } +inline std::vector GetFields(const DML_ELEMENT_WISE_BIT_SHIFT_LEFT_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_BIT_SHIFT_LEFT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_BIT_SHIFT_LEFT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_BIT_SHIFT_LEFT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_BIT_SHIFT_RIGHT_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_BIT_SHIFT_RIGHT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_BIT_SHIFT_RIGHT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_BIT_SHIFT_RIGHT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_ROUND_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_ROUND_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_ROUND_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_ROUND_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.RoundingMode))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_IS_INFINITY_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_IS_INFINITY_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_IS_INFINITY_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_IS_INFINITY_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.InfinityMode))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_MODULUS_TRUNCATE_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_MODULUS_TRUNCATE_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_MODULUS_TRUNCATE_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_MODULUS_TRUNCATE_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_MODULUS_FLOOR_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_MODULUS_FLOOR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_MODULUS_FLOOR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_MODULUS_FLOOR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_FILL_VALUE_CONSTANT_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_FILL_VALUE_CONSTANT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_FILL_VALUE_CONSTANT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.ValueDataType))), + OperatorField(&DML_FILL_VALUE_CONSTANT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Value))), + }; +} +inline std::vector GetFields(const DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_FILL_VALUE_SEQUENCE_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_FILL_VALUE_SEQUENCE_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.ValueDataType))), + OperatorField(&DML_FILL_VALUE_SEQUENCE_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ValueStart))), + OperatorField(&DML_FILL_VALUE_SEQUENCE_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.ValueDelta))), + }; +} +inline std::vector GetFields(const DML_CUMULATIVE_SUMMATION_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Axis))), + OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.HasExclusiveSum))), + OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.AxisDirection))), + }; +} +inline std::vector GetFields(const DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.SequenceLengthsTensor))), + OperatorField(&DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Axis))), + }; +} +inline std::vector GetFields(const DML_GATHER_ELEMENTS_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_GATHER_ELEMENTS_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_GATHER_ELEMENTS_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.IndicesTensor))), + OperatorField(&DML_GATHER_ELEMENTS_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_GATHER_ELEMENTS_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Axis))), + }; +} +inline std::vector GetFields(const DML_GATHER_ND_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_GATHER_ND_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_GATHER_ND_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.IndicesTensor))), + OperatorField(&DML_GATHER_ND_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_GATHER_ND_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.InputDimensionCount))), + OperatorField(&DML_GATHER_ND_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.IndicesDimensionCount))), + }; +} +inline std::vector GetFields(const DML_SCATTER_ND_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.IndicesTensor))), + OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.UpdatesTensor))), + OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.InputDimensionCount))), + OperatorField(&DML_SCATTER_ND_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.IndicesDimensionCount))), + }; +} +inline std::vector GetFields(const DML_MAX_POOLING2_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputIndicesTensor))), + OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), + OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + OperatorField(&DML_MAX_POOLING2_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), + }; +} +inline std::vector GetFields(const DML_SLICE1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_SLICE1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_SLICE1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_SLICE1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_SLICE1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.InputWindowOffsets), desc.DimensionCount)), + OperatorField(&DML_SLICE1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.InputWindowSizes), desc.DimensionCount)), + OperatorField(&DML_SLICE1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.InputWindowStrides), desc.DimensionCount)), + }; +} +inline std::vector GetFields(const DML_TOP_K1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_TOP_K1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_TOP_K1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputValueTensor))), + OperatorField(&DML_TOP_K1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputIndexTensor))), + OperatorField(&DML_TOP_K1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Axis))), + OperatorField(&DML_TOP_K1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.K))), + OperatorField(&DML_TOP_K1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.AxisDirection))), + }; +} +inline std::vector GetFields(const DML_DEPTH_TO_SPACE1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.BlockSize))), + OperatorField(&DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Order))), + }; +} +inline std::vector GetFields(const DML_SPACE_TO_DEPTH1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.BlockSize))), + OperatorField(&DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Order))), + }; +} +inline std::vector GetFields(const DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.ScaleTensor))), + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.BiasTensor))), + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.AxisCount))), + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.Axes), desc.AxisCount)), + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.NormalizeVariance))), + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Epsilon))), + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.FusedActivation))), + }; +} +inline std::vector GetFields(const DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.AScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.AZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.BScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.BZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.InputZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.FilterTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.FilterScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.FilterZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.BiasTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[13], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[14], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA.Fields[15], ToOperatorFieldType(static_cast(desc.GroupCount))), + }; +} +inline std::vector GetFields(const DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))), + OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))), + }; +} inline std::vector GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc) { return { @@ -970,6 +1196,28 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_SCATTER: return DML_SCATTER_OPERATOR_SCHEMA; case DML_OPERATOR_ONE_HOT: return DML_ONE_HOT_OPERATOR_SCHEMA; case DML_OPERATOR_RESAMPLE: return DML_RESAMPLE_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT: return DML_ELEMENT_WISE_BIT_SHIFT_LEFT_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT: return DML_ELEMENT_WISE_BIT_SHIFT_RIGHT_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_ROUND: return DML_ELEMENT_WISE_ROUND_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_IS_INFINITY: return DML_ELEMENT_WISE_IS_INFINITY_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE: return DML_ELEMENT_WISE_MODULUS_TRUNCATE_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR: return DML_ELEMENT_WISE_MODULUS_FLOOR_OPERATOR_SCHEMA; + case DML_OPERATOR_FILL_VALUE_CONSTANT: return DML_FILL_VALUE_CONSTANT_OPERATOR_SCHEMA; + case DML_OPERATOR_FILL_VALUE_SEQUENCE: return DML_FILL_VALUE_SEQUENCE_OPERATOR_SCHEMA; + case DML_OPERATOR_CUMULATIVE_SUMMATION: return DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA; + case DML_OPERATOR_REVERSE_SUBSEQUENCES: return DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA; + case DML_OPERATOR_GATHER_ELEMENTS: return DML_GATHER_ELEMENTS_OPERATOR_SCHEMA; + case DML_OPERATOR_GATHER_ND: return DML_GATHER_ND_OPERATOR_SCHEMA; + case DML_OPERATOR_SCATTER_ND: return DML_SCATTER_ND_OPERATOR_SCHEMA; + case DML_OPERATOR_MAX_POOLING2: return DML_MAX_POOLING2_OPERATOR_SCHEMA; + case DML_OPERATOR_SLICE1: return DML_SLICE1_OPERATOR_SCHEMA; + case DML_OPERATOR_TOP_K1: return DML_TOP_K1_OPERATOR_SCHEMA; + case DML_OPERATOR_DEPTH_TO_SPACE1: return DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA; + case DML_OPERATOR_SPACE_TO_DEPTH1: return DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA; + case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1: return DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA; + case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: return DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA; + case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: return DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA; + case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA; @@ -1305,6 +1553,94 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_RESAMPLE_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_BIT_SHIFT_LEFT_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_BIT_SHIFT_RIGHT_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_ROUND: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_ROUND_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_IS_INFINITY: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_IS_INFINITY_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_MODULUS_TRUNCATE_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_MODULUS_FLOOR_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_FILL_VALUE_CONSTANT: + return AbstractOperatorDesc( + &DML_FILL_VALUE_CONSTANT_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_FILL_VALUE_SEQUENCE: + return AbstractOperatorDesc( + &DML_FILL_VALUE_SEQUENCE_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_CUMULATIVE_SUMMATION: + return AbstractOperatorDesc( + &DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_REVERSE_SUBSEQUENCES: + return AbstractOperatorDesc( + &DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_GATHER_ELEMENTS: + return AbstractOperatorDesc( + &DML_GATHER_ELEMENTS_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_GATHER_ND: + return AbstractOperatorDesc( + &DML_GATHER_ND_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_SCATTER_ND: + return AbstractOperatorDesc( + &DML_SCATTER_ND_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_MAX_POOLING2: + return AbstractOperatorDesc( + &DML_MAX_POOLING2_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_SLICE1: + return AbstractOperatorDesc( + &DML_SLICE1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_TOP_K1: + return AbstractOperatorDesc( + &DML_TOP_K1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_DEPTH_TO_SPACE1: + return AbstractOperatorDesc( + &DML_DEPTH_TO_SPACE1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_SPACE_TO_DEPTH1: + return AbstractOperatorDesc( + &DML_SPACE_TO_DEPTH1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1: + return AbstractOperatorDesc( + &DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: + return AbstractOperatorDesc( + &DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: + return AbstractOperatorDesc( + &DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: + return AbstractOperatorDesc( + &DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_ACTIVATION_ELU: return AbstractOperatorDesc( &DML_ACTIVATION_ELU_OPERATOR_SCHEMA, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h.bak b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h.bak new file mode 100644 index 0000000000..b8285c77a7 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h.bak @@ -0,0 +1,1388 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace SchemaHelpers +{ +AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc); + +inline std::vector GetFields(const DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_IDENTITY_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_IDENTITY_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_IDENTITY_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_ABS_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_ABS_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_ABS_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_ABS_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_ACOS_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_ACOS_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_ACOS_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_ACOS_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_ADD_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_ADD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_ADD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_ADD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_ASIN_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_ASIN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_ASIN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_ASIN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_ATAN_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_ATAN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_ATAN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_ATAN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_CEIL_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_CEIL_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_CEIL_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_CEIL_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_CLIP_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + OperatorField(&DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Min))), + OperatorField(&DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Max))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_COS_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_DIVIDE_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_DIVIDE_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_DIVIDE_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_DIVIDE_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_EXP_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_EXP_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_EXP_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_EXP_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_FLOOR_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_FLOOR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_FLOOR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_FLOOR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_LOG_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_LOG_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_LOG_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_LOG_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_MAX_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_MAX_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_MAX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_MAX_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_MEAN_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_MEAN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_MEAN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_MEAN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_MIN_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_MIN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_MIN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_MIN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_MULTIPLY_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_MULTIPLY_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_MULTIPLY_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_POW_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_POW_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_POW_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.ExponentTensor))), + OperatorField(&DML_ELEMENT_WISE_POW_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_POW_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.ScaleBias))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + OperatorField(&DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Exponent))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_RECIP_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_RECIP_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_RECIP_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_RECIP_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_SIN_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_SIN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_SIN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_SIN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_SQRT_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_SUBTRACT_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_TAN_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_TAN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_TAN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_TAN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_THRESHOLD_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + OperatorField(&DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Min))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.ScaleTensor))), + OperatorField(&DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ZeroPointTensor))), + OperatorField(&DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.ScaleTensor))), + OperatorField(&DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ZeroPointTensor))), + OperatorField(&DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_CONVOLUTION_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.FilterTensor))), + OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.BiasTensor))), + OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Mode))), + OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.Direction))), + OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), + OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.OutputPadding), desc.DimensionCount)), + OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.GroupCount))), + OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[13], ToOperatorFieldType(static_cast(desc.FusedActivation))), + }; +} +inline std::vector GetFields(const DML_GEMM_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_GEMM_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_GEMM_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_GEMM_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.CTensor))), + OperatorField(&DML_GEMM_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_GEMM_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.TransA))), + OperatorField(&DML_GEMM_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.TransB))), + OperatorField(&DML_GEMM_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.Alpha))), + OperatorField(&DML_GEMM_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Beta))), + OperatorField(&DML_GEMM_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.FusedActivation))), + }; +} +inline std::vector GetFields(const DML_REDUCE_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_REDUCE_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.Function))), + OperatorField(&DML_REDUCE_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_REDUCE_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_REDUCE_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.AxisCount))), + OperatorField(&DML_REDUCE_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Axes), desc.AxisCount)), + }; +} +inline std::vector GetFields(const DML_AVERAGE_POOLING_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), + OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.IncludePadding))), + }; +} +inline std::vector GetFields(const DML_LP_POOLING_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_LP_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_LP_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_LP_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_LP_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_LP_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), + OperatorField(&DML_LP_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_LP_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + OperatorField(&DML_LP_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.P))), + }; +} +inline std::vector GetFields(const DML_MAX_POOLING_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_MAX_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_MAX_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_MAX_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_MAX_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_MAX_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), + OperatorField(&DML_MAX_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_MAX_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + }; +} +inline std::vector GetFields(const DML_MAX_POOLING1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_MAX_POOLING1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_MAX_POOLING1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_MAX_POOLING1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputIndicesTensor))), + OperatorField(&DML_MAX_POOLING1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_MAX_POOLING1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_MAX_POOLING1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), + OperatorField(&DML_MAX_POOLING1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_MAX_POOLING1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + }; +} +inline std::vector GetFields(const DML_ROI_POOLING_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.ROITensor))), + OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.SpatialScale))), + OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.PooledSize))), + }; +} +inline std::vector GetFields(const DML_SLICE_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_SLICE_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_SLICE_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_SLICE_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_SLICE_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Offsets), desc.DimensionCount)), + OperatorField(&DML_SLICE_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Sizes), desc.DimensionCount)), + OperatorField(&DML_SLICE_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + }; +} +inline std::vector GetFields(const DML_CAST_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_CAST_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_CAST_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_SPLIT_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_SPLIT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_SPLIT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputCount))), + OperatorField(&DML_SPLIT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensors), desc.OutputCount)), + OperatorField(&DML_SPLIT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Axis))), + }; +} +inline std::vector GetFields(const DML_JOIN_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_JOIN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputCount))), + OperatorField(&DML_JOIN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputTensors), desc.InputCount)), + OperatorField(&DML_JOIN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_JOIN_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Axis))), + }; +} +inline std::vector GetFields(const DML_PADDING_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_PADDING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_PADDING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_PADDING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.PaddingMode))), + OperatorField(&DML_PADDING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.PaddingValue))), + OperatorField(&DML_PADDING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_PADDING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_PADDING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + }; +} +inline std::vector GetFields(const DML_VALUE_SCALE_2D_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_VALUE_SCALE_2D_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_VALUE_SCALE_2D_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_VALUE_SCALE_2D_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Scale))), + OperatorField(&DML_VALUE_SCALE_2D_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.ChannelCount))), + OperatorField(&DML_VALUE_SCALE_2D_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Bias), desc.ChannelCount)), + }; +} +inline std::vector GetFields(const DML_UPSAMPLE_2D_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_UPSAMPLE_2D_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_UPSAMPLE_2D_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_UPSAMPLE_2D_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleSize))), + OperatorField(&DML_UPSAMPLE_2D_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.InterpolationMode))), + }; +} +inline std::vector GetFields(const DML_GATHER_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_GATHER_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_GATHER_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.IndicesTensor))), + OperatorField(&DML_GATHER_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_GATHER_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Axis))), + OperatorField(&DML_GATHER_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.IndexDimensions))), + }; +} +inline std::vector GetFields(const DML_SPACE_TO_DEPTH_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_SPACE_TO_DEPTH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_SPACE_TO_DEPTH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_SPACE_TO_DEPTH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.BlockSize))), + }; +} +inline std::vector GetFields(const DML_DEPTH_TO_SPACE_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_DEPTH_TO_SPACE_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_DEPTH_TO_SPACE_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_DEPTH_TO_SPACE_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.BlockSize))), + }; +} +inline std::vector GetFields(const DML_TILE_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_TILE_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_TILE_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_TILE_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.RepeatsCount))), + OperatorField(&DML_TILE_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Repeats), desc.RepeatsCount)), + }; +} +inline std::vector GetFields(const DML_TOP_K_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_TOP_K_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_TOP_K_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputValueTensor))), + OperatorField(&DML_TOP_K_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputIndexTensor))), + OperatorField(&DML_TOP_K_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Axis))), + OperatorField(&DML_TOP_K_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.K))), + }; +} +inline std::vector GetFields(const DML_BATCH_NORMALIZATION_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.MeanTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.VarianceTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.ScaleTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.BiasTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.Spatial))), + OperatorField(&DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Epsilon))), + OperatorField(&DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.FusedActivation))), + }; +} +inline std::vector GetFields(const DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.ScaleTensor))), + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.BiasTensor))), + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.CrossChannel))), + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.NormalizeVariance))), + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.Epsilon))), + OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.FusedActivation))), + }; +} +inline std::vector GetFields(const DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.CrossChannel))), + OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.LocalSize))), + OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Alpha))), + OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.Beta))), + OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.Bias))), + }; +} +inline std::vector GetFields(const DML_LP_NORMALIZATION_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_LP_NORMALIZATION_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_LP_NORMALIZATION_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_LP_NORMALIZATION_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Axis))), + OperatorField(&DML_LP_NORMALIZATION_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Epsilon))), + OperatorField(&DML_LP_NORMALIZATION_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.P))), + }; +} +inline std::vector GetFields(const DML_RNN_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_RNN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_RNN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.WeightTensor))), + OperatorField(&DML_RNN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.RecurrenceTensor))), + OperatorField(&DML_RNN_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.BiasTensor))), + OperatorField(&DML_RNN_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.HiddenInitTensor))), + OperatorField(&DML_RNN_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.SequenceLengthsTensor))), + OperatorField(&DML_RNN_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.OutputSequenceTensor))), + OperatorField(&DML_RNN_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputSingleTensor))), + OperatorField(&DML_RNN_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.ActivationDescCount))), + OperatorField(&DML_RNN_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.ActivationDescs), desc.ActivationDescCount)), + OperatorField(&DML_RNN_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.Direction))), + }; +} +inline std::vector GetFields(const DML_LSTM_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.WeightTensor))), + OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.RecurrenceTensor))), + OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.BiasTensor))), + OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.HiddenInitTensor))), + OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.CellMemInitTensor))), + OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.SequenceLengthsTensor))), + OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.PeepholeTensor))), + OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.OutputSequenceTensor))), + OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.OutputSingleTensor))), + OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.OutputCellSingleTensor))), + OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.ActivationDescCount))), + OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.ActivationDescs), desc.ActivationDescCount)), + OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[13], ToOperatorFieldType(static_cast(desc.Direction))), + OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[14], ToOperatorFieldType(static_cast(desc.ClipThreshold))), + OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[15], ToOperatorFieldType(static_cast(desc.UseClipThreshold))), + OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[16], ToOperatorFieldType(static_cast(desc.CoupleInputForget))), + }; +} +inline std::vector GetFields(const DML_GRU_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.WeightTensor))), + OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.RecurrenceTensor))), + OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.BiasTensor))), + OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.HiddenInitTensor))), + OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.SequenceLengthsTensor))), + OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.OutputSequenceTensor))), + OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputSingleTensor))), + OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.ActivationDescCount))), + OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.ActivationDescs), desc.ActivationDescCount)), + OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.Direction))), + OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.LinearBeforeReset))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_SIGN_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_SIGN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_SIGN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_IS_NAN_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_ERF_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_SINH_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_SINH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_SINH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_SINH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_COSH_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_COSH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_COSH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_COSH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_TANH_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_TANH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_TANH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_TANH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_ASINH_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_ASINH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_ASINH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_ASINH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_ACOSH_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_ACOSH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_ACOSH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_ACOSH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_ATANH_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_ATANH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_ATANH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_ATANH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_IF_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_IF_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ConditionTensor))), + OperatorField(&DML_ELEMENT_WISE_IF_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_IF_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_IF_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_ADD1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_ADD1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_ADD1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_ADD1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_ADD1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.FusedActivation))), + }; +} +inline std::vector GetFields(const DML_MAX_UNPOOLING_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_MAX_UNPOOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_MAX_UNPOOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.IndicesTensor))), + OperatorField(&DML_MAX_UNPOOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_DIAGONAL_MATRIX_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_DIAGONAL_MATRIX_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_DIAGONAL_MATRIX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.Offset))), + OperatorField(&DML_DIAGONAL_MATRIX_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Value))), + }; +} +inline std::vector GetFields(const DML_SCATTER_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_SCATTER_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_SCATTER_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.IndicesTensor))), + OperatorField(&DML_SCATTER_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.UpdatesTensor))), + OperatorField(&DML_SCATTER_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_SCATTER_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Axis))), + }; +} +inline std::vector GetFields(const DML_ONE_HOT_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ONE_HOT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.IndicesTensor))), + OperatorField(&DML_ONE_HOT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.ValuesTensor))), + OperatorField(&DML_ONE_HOT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ONE_HOT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Axis))), + }; +} +inline std::vector GetFields(const DML_RESAMPLE_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_RESAMPLE_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_RESAMPLE_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_RESAMPLE_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.InterpolationMode))), + OperatorField(&DML_RESAMPLE_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.ScaleCount))), + OperatorField(&DML_RESAMPLE_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Scales), desc.ScaleCount)), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_ELU_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_ELU_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ACTIVATION_ELU_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Alpha))), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_HARDMAX_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Alpha))), + OperatorField(&DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Beta))), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_IDENTITY_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_IDENTITY_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_IDENTITY_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_LEAKY_RELU_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_LEAKY_RELU_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ACTIVATION_LEAKY_RELU_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Alpha))), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_LINEAR_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_LINEAR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_LINEAR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ACTIVATION_LINEAR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Alpha))), + OperatorField(&DML_ACTIVATION_LINEAR_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Beta))), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.SlopeTensor))), + OperatorField(&DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Alpha))), + OperatorField(&DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Beta))), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_RELU_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_RELU_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_RELU_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_SCALED_ELU_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_SCALED_ELU_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_SCALED_ELU_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ACTIVATION_SCALED_ELU_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Alpha))), + OperatorField(&DML_ACTIVATION_SCALED_ELU_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Gamma))), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_SCALED_TANH_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_SCALED_TANH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_SCALED_TANH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ACTIVATION_SCALED_TANH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Alpha))), + OperatorField(&DML_ACTIVATION_SCALED_TANH_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Beta))), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_SIGMOID_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_SIGMOID_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_SIGMOID_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_SOFTMAX_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Steepness))), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_SOFTSIGN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_SOFTSIGN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_TANH_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_TANH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_TANH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Alpha))), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_SHRINK_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Bias))), + OperatorField(&DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Threshold))), + }; +} + +inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) +{ + switch (operatorType) + { + case DML_OPERATOR_ELEMENT_WISE_IDENTITY: return DML_ELEMENT_WISE_IDENTITY_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_ABS: return DML_ELEMENT_WISE_ABS_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_ACOS: return DML_ELEMENT_WISE_ACOS_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_ADD: return DML_ELEMENT_WISE_ADD_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_ASIN: return DML_ELEMENT_WISE_ASIN_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_ATAN: return DML_ELEMENT_WISE_ATAN_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_CEIL: return DML_ELEMENT_WISE_CEIL_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_CLIP: return DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_COS: return DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_DIVIDE: return DML_ELEMENT_WISE_DIVIDE_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_EXP: return DML_ELEMENT_WISE_EXP_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_FLOOR: return DML_ELEMENT_WISE_FLOOR_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_LOG: return DML_ELEMENT_WISE_LOG_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND: return DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS: return DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN: return DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN: return DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT: return DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR: return DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR: return DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_MAX: return DML_ELEMENT_WISE_MAX_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_MEAN: return DML_ELEMENT_WISE_MEAN_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_MIN: return DML_ELEMENT_WISE_MIN_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_MULTIPLY: return DML_ELEMENT_WISE_MULTIPLY_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_POW: return DML_ELEMENT_WISE_POW_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW: return DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_RECIP: return DML_ELEMENT_WISE_RECIP_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_SIN: return DML_ELEMENT_WISE_SIN_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_SQRT: return DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_SUBTRACT: return DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_TAN: return DML_ELEMENT_WISE_TAN_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_THRESHOLD: return DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR: return DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR: return DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_SCHEMA; + case DML_OPERATOR_CONVOLUTION: return DML_CONVOLUTION_OPERATOR_SCHEMA; + case DML_OPERATOR_GEMM: return DML_GEMM_OPERATOR_SCHEMA; + case DML_OPERATOR_REDUCE: return DML_REDUCE_OPERATOR_SCHEMA; + case DML_OPERATOR_AVERAGE_POOLING: return DML_AVERAGE_POOLING_OPERATOR_SCHEMA; + case DML_OPERATOR_LP_POOLING: return DML_LP_POOLING_OPERATOR_SCHEMA; + case DML_OPERATOR_MAX_POOLING: return DML_MAX_POOLING_OPERATOR_SCHEMA; + case DML_OPERATOR_MAX_POOLING1: return DML_MAX_POOLING1_OPERATOR_SCHEMA; + case DML_OPERATOR_ROI_POOLING: return DML_ROI_POOLING_OPERATOR_SCHEMA; + case DML_OPERATOR_SLICE: return DML_SLICE_OPERATOR_SCHEMA; + case DML_OPERATOR_CAST: return DML_CAST_OPERATOR_SCHEMA; + case DML_OPERATOR_SPLIT: return DML_SPLIT_OPERATOR_SCHEMA; + case DML_OPERATOR_JOIN: return DML_JOIN_OPERATOR_SCHEMA; + case DML_OPERATOR_PADDING: return DML_PADDING_OPERATOR_SCHEMA; + case DML_OPERATOR_VALUE_SCALE_2D: return DML_VALUE_SCALE_2D_OPERATOR_SCHEMA; + case DML_OPERATOR_UPSAMPLE_2D: return DML_UPSAMPLE_2D_OPERATOR_SCHEMA; + case DML_OPERATOR_GATHER: return DML_GATHER_OPERATOR_SCHEMA; + case DML_OPERATOR_SPACE_TO_DEPTH: return DML_SPACE_TO_DEPTH_OPERATOR_SCHEMA; + case DML_OPERATOR_DEPTH_TO_SPACE: return DML_DEPTH_TO_SPACE_OPERATOR_SCHEMA; + case DML_OPERATOR_TILE: return DML_TILE_OPERATOR_SCHEMA; + case DML_OPERATOR_TOP_K: return DML_TOP_K_OPERATOR_SCHEMA; + case DML_OPERATOR_BATCH_NORMALIZATION: return DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA; + case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION: return DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA; + case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION: return DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA; + case DML_OPERATOR_LP_NORMALIZATION: return DML_LP_NORMALIZATION_OPERATOR_SCHEMA; + case DML_OPERATOR_RNN: return DML_RNN_OPERATOR_SCHEMA; + case DML_OPERATOR_LSTM: return DML_LSTM_OPERATOR_SCHEMA; + case DML_OPERATOR_GRU: return DML_GRU_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_SIGN: return DML_ELEMENT_WISE_SIGN_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_IS_NAN: return DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_ERF: return DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_SINH: return DML_ELEMENT_WISE_SINH_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_COSH: return DML_ELEMENT_WISE_COSH_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_TANH: return DML_ELEMENT_WISE_TANH_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_ASINH: return DML_ELEMENT_WISE_ASINH_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_ACOSH: return DML_ELEMENT_WISE_ACOSH_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_ATANH: return DML_ELEMENT_WISE_ATANH_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_IF: return DML_ELEMENT_WISE_IF_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_ADD1: return DML_ELEMENT_WISE_ADD1_OPERATOR_SCHEMA; + case DML_OPERATOR_MAX_UNPOOLING: return DML_MAX_UNPOOLING_OPERATOR_SCHEMA; + case DML_OPERATOR_DIAGONAL_MATRIX: return DML_DIAGONAL_MATRIX_OPERATOR_SCHEMA; + case DML_OPERATOR_SCATTER: return DML_SCATTER_OPERATOR_SCHEMA; + case DML_OPERATOR_ONE_HOT: return DML_ONE_HOT_OPERATOR_SCHEMA; + case DML_OPERATOR_RESAMPLE: return DML_RESAMPLE_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_IDENTITY: return DML_ACTIVATION_IDENTITY_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_LEAKY_RELU: return DML_ACTIVATION_LEAKY_RELU_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_LINEAR: return DML_ACTIVATION_LINEAR_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX: return DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU: return DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS: return DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_RELU: return DML_ACTIVATION_RELU_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_SCALED_ELU: return DML_ACTIVATION_SCALED_ELU_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_SCALED_TANH: return DML_ACTIVATION_SCALED_TANH_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_SIGMOID: return DML_ACTIVATION_SIGMOID_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_SOFTMAX: return DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_SOFTPLUS: return DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_SOFTSIGN: return DML_ACTIVATION_SOFTSIGN_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_TANH: return DML_ACTIVATION_TANH_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU: return DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_SHRINK: return DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA; + default: THROW_HR(E_INVALIDARG); + } +} + +inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) +{ + switch (static_cast(opDesc.Type)) + { + case DML_OPERATOR_ELEMENT_WISE_IDENTITY: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_IDENTITY_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_ABS: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_ABS_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_ACOS: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_ACOS_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_ADD: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_ADD_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_ASIN: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_ASIN_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_ATAN: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_ATAN_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_CEIL: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_CEIL_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_CLIP: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_COS: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_DIVIDE: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_DIVIDE_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_EXP: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_EXP_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_FLOOR: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_FLOOR_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_LOG: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_LOG_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_MAX: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_MAX_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_MEAN: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_MEAN_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_MIN: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_MIN_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_MULTIPLY: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_MULTIPLY_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_POW: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_POW_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_RECIP: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_RECIP_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_SIN: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_SIN_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_SQRT: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_SUBTRACT: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_TAN: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_TAN_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_THRESHOLD: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_CONVOLUTION: + return AbstractOperatorDesc( + &DML_CONVOLUTION_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_GEMM: + return AbstractOperatorDesc( + &DML_GEMM_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_REDUCE: + return AbstractOperatorDesc( + &DML_REDUCE_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_AVERAGE_POOLING: + return AbstractOperatorDesc( + &DML_AVERAGE_POOLING_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_LP_POOLING: + return AbstractOperatorDesc( + &DML_LP_POOLING_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_MAX_POOLING: + return AbstractOperatorDesc( + &DML_MAX_POOLING_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_MAX_POOLING1: + return AbstractOperatorDesc( + &DML_MAX_POOLING1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ROI_POOLING: + return AbstractOperatorDesc( + &DML_ROI_POOLING_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_SLICE: + return AbstractOperatorDesc( + &DML_SLICE_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_CAST: + return AbstractOperatorDesc( + &DML_CAST_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_SPLIT: + return AbstractOperatorDesc( + &DML_SPLIT_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_JOIN: + return AbstractOperatorDesc( + &DML_JOIN_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_PADDING: + return AbstractOperatorDesc( + &DML_PADDING_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_VALUE_SCALE_2D: + return AbstractOperatorDesc( + &DML_VALUE_SCALE_2D_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_UPSAMPLE_2D: + return AbstractOperatorDesc( + &DML_UPSAMPLE_2D_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_GATHER: + return AbstractOperatorDesc( + &DML_GATHER_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_SPACE_TO_DEPTH: + return AbstractOperatorDesc( + &DML_SPACE_TO_DEPTH_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_DEPTH_TO_SPACE: + return AbstractOperatorDesc( + &DML_DEPTH_TO_SPACE_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_TILE: + return AbstractOperatorDesc( + &DML_TILE_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_TOP_K: + return AbstractOperatorDesc( + &DML_TOP_K_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_BATCH_NORMALIZATION: + return AbstractOperatorDesc( + &DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION: + return AbstractOperatorDesc( + &DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION: + return AbstractOperatorDesc( + &DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_LP_NORMALIZATION: + return AbstractOperatorDesc( + &DML_LP_NORMALIZATION_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_RNN: + return AbstractOperatorDesc( + &DML_RNN_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_LSTM: + return AbstractOperatorDesc( + &DML_LSTM_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_GRU: + return AbstractOperatorDesc( + &DML_GRU_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_SIGN: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_SIGN_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_IS_NAN: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_ERF: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_SINH: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_SINH_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_COSH: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_COSH_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_TANH: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_TANH_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_ASINH: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_ASINH_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_ACOSH: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_ACOSH_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_ATANH: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_ATANH_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_IF: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_IF_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_ADD1: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_ADD1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_MAX_UNPOOLING: + return AbstractOperatorDesc( + &DML_MAX_UNPOOLING_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_DIAGONAL_MATRIX: + return AbstractOperatorDesc( + &DML_DIAGONAL_MATRIX_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_SCATTER: + return AbstractOperatorDesc( + &DML_SCATTER_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ONE_HOT: + return AbstractOperatorDesc( + &DML_ONE_HOT_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_RESAMPLE: + return AbstractOperatorDesc( + &DML_RESAMPLE_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_ELU: + return AbstractOperatorDesc( + &DML_ACTIVATION_ELU_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_HARDMAX: + return AbstractOperatorDesc( + &DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: + return AbstractOperatorDesc( + &DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_IDENTITY: + return AbstractOperatorDesc( + &DML_ACTIVATION_IDENTITY_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_LEAKY_RELU: + return AbstractOperatorDesc( + &DML_ACTIVATION_LEAKY_RELU_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_LINEAR: + return AbstractOperatorDesc( + &DML_ACTIVATION_LINEAR_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX: + return AbstractOperatorDesc( + &DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU: + return AbstractOperatorDesc( + &DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS: + return AbstractOperatorDesc( + &DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_RELU: + return AbstractOperatorDesc( + &DML_ACTIVATION_RELU_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_SCALED_ELU: + return AbstractOperatorDesc( + &DML_ACTIVATION_SCALED_ELU_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_SCALED_TANH: + return AbstractOperatorDesc( + &DML_ACTIVATION_SCALED_TANH_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_SIGMOID: + return AbstractOperatorDesc( + &DML_ACTIVATION_SIGMOID_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_SOFTMAX: + return AbstractOperatorDesc( + &DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_SOFTPLUS: + return AbstractOperatorDesc( + &DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_SOFTSIGN: + return AbstractOperatorDesc( + &DML_ACTIVATION_SOFTSIGN_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_TANH: + return AbstractOperatorDesc( + &DML_ACTIVATION_TANH_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU: + return AbstractOperatorDesc( + &DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_SHRINK: + return AbstractOperatorDesc( + &DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + default: THROW_HR(E_INVALIDARG); + } + +} +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h index 57c8ec8ce0..25f0dd26c6 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h @@ -7,12 +7,15 @@ using ApiAttributeVariant = std::variant< const DML_TENSOR_DESC*, const DML_OPERATOR_DESC*, UINT, + UINT64, INT, FLOAT, const UINT*, + const INT*, const FLOAT*, const DML_SCALE_BIAS*, - DML_SIZE_2D + DML_SIZE_2D, + DML_SCALAR_UNION >; namespace OperatorFieldTypes @@ -22,12 +25,15 @@ namespace OperatorFieldTypes using OperatorDesc = std::optional; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC using OperatorDescArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY using UInt = uint32_t; // DML_SCHEMA_FIELD_TYPE_UINT + using UInt64 = uint64_t; // DML_SCHEMA_FIELD_TYPE_UINT64 using Int = int32_t; // DML_SCHEMA_FIELD_TYPE_INT using Float = float; // DML_SCHEMA_FIELD_TYPE_FLOAT using UIntArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_UINT_ARRAY + using IntArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_INT_ARRAY using FloatArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY using ScaleBias = std::optional; // DML_SCHEMA_FIELD_TYPE_SCALE_BIAS using Size2D = DML_SIZE_2D; // DML_SCHEMA_FIELD_TYPE_SIZE_2D + using ScalarUnion = DML_SCALAR_UNION; // DML_SCHEMA_FIELD_TYPE_SCALAR_UNION } using OperatorFieldVariant = std::variant< @@ -36,12 +42,15 @@ using OperatorFieldVariant = std::variant< OperatorFieldTypes::OperatorDesc, OperatorFieldTypes::OperatorDescArray, OperatorFieldTypes::UInt, + OperatorFieldTypes::UInt64, OperatorFieldTypes::Int, OperatorFieldTypes::Float, OperatorFieldTypes::UIntArray, + OperatorFieldTypes::IntArray, OperatorFieldTypes::FloatArray, OperatorFieldTypes::ScaleBias, - OperatorFieldTypes::Size2D + OperatorFieldTypes::Size2D, + OperatorFieldTypes::ScalarUnion >; class OperatorField @@ -80,6 +89,9 @@ public: const OperatorFieldTypes::UInt& AsUInt() const { return std::get(m_data); } OperatorFieldTypes::UInt& AsUInt() { return std::get(m_data); } + const OperatorFieldTypes::UInt64& AsUInt64() const { return std::get(m_data); } + OperatorFieldTypes::UInt64& AsUInt64() { return std::get(m_data); } + const OperatorFieldTypes::Int& AsInt() const { return std::get(m_data); } OperatorFieldTypes::Int& AsInt() { return std::get(m_data); } @@ -89,6 +101,9 @@ public: const OperatorFieldTypes::UIntArray& AsUIntArray() const { return std::get(m_data); } OperatorFieldTypes::UIntArray& AsUIntArray() { return std::get(m_data); } + const OperatorFieldTypes::IntArray& AsIntArray() const { return std::get(m_data); } + OperatorFieldTypes::IntArray& AsIntArray() { return std::get(m_data); } + const OperatorFieldTypes::FloatArray& AsFloatArray() const { return std::get(m_data); } OperatorFieldTypes::FloatArray& AsFloatArray() { return std::get(m_data); } @@ -98,6 +113,9 @@ public: const OperatorFieldTypes::Size2D& AsSize2D() const { return std::get(m_data); } OperatorFieldTypes::Size2D& AsSize2D() { return std::get(m_data); } + const OperatorFieldTypes::ScalarUnion& AsScalarUnion() const { return std::get(m_data); } + OperatorFieldTypes::ScalarUnion& AsScalarUnion() { return std::get(m_data); } + private: const DML_SCHEMA_FIELD* m_schema; OperatorFieldVariant m_data; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h.bak b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h.bak new file mode 100644 index 0000000000..57c8ec8ce0 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h.bak @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +using ApiAttributeVariant = std::variant< + const DML_TENSOR_DESC*, + const DML_OPERATOR_DESC*, + UINT, + INT, + FLOAT, + const UINT*, + const FLOAT*, + const DML_SCALE_BIAS*, + DML_SIZE_2D + >; + +namespace OperatorFieldTypes +{ + using TensorDesc = std::optional; // DML_SCHEMA_FIELD_TYPE_TENSOR_DESC + using TensorDescArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY + using OperatorDesc = std::optional; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC + using OperatorDescArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY + using UInt = uint32_t; // DML_SCHEMA_FIELD_TYPE_UINT + using Int = int32_t; // DML_SCHEMA_FIELD_TYPE_INT + using Float = float; // DML_SCHEMA_FIELD_TYPE_FLOAT + using UIntArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_UINT_ARRAY + using FloatArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY + using ScaleBias = std::optional; // DML_SCHEMA_FIELD_TYPE_SCALE_BIAS + using Size2D = DML_SIZE_2D; // DML_SCHEMA_FIELD_TYPE_SIZE_2D +} + +using OperatorFieldVariant = std::variant< + OperatorFieldTypes::TensorDesc, + OperatorFieldTypes::TensorDescArray, + OperatorFieldTypes::OperatorDesc, + OperatorFieldTypes::OperatorDescArray, + OperatorFieldTypes::UInt, + OperatorFieldTypes::Int, + OperatorFieldTypes::Float, + OperatorFieldTypes::UIntArray, + OperatorFieldTypes::FloatArray, + OperatorFieldTypes::ScaleBias, + OperatorFieldTypes::Size2D + >; + +class OperatorField +{ +public: + OperatorField() = default; + explicit OperatorField(const DML_SCHEMA_FIELD* schema, OperatorFieldVariant&& data) + : m_schema(schema) + , m_data(std::move(data)) + { + assert(m_schema->Type == (DML_SCHEMA_FIELD_TYPE)m_data.index()); + } + + const DML_SCHEMA_FIELD* GetSchema() const + { + return m_schema; + } + + const OperatorFieldVariant& GetData() const + { + return m_data; + } + + const OperatorFieldTypes::TensorDesc& AsTensorDesc() const { return std::get(m_data); } + OperatorFieldTypes::TensorDesc& AsTensorDesc() { return std::get(m_data); } + + const OperatorFieldTypes::TensorDescArray& AsTensorDescArray() const { return std::get(m_data); } + OperatorFieldTypes::TensorDescArray& AsTensorDescArray() { return std::get(m_data); } + + const OperatorFieldTypes::OperatorDesc& AsOperatorDesc() const { return std::get(m_data); } + OperatorFieldTypes::OperatorDesc& AsOperatorDesc() { return std::get(m_data); } + + const OperatorFieldTypes::OperatorDescArray& AsOperatorDescArray() const { return std::get(m_data); } + OperatorFieldTypes::OperatorDescArray& AsOperatorDescArray() { return std::get(m_data); } + + const OperatorFieldTypes::UInt& AsUInt() const { return std::get(m_data); } + OperatorFieldTypes::UInt& AsUInt() { return std::get(m_data); } + + const OperatorFieldTypes::Int& AsInt() const { return std::get(m_data); } + OperatorFieldTypes::Int& AsInt() { return std::get(m_data); } + + const OperatorFieldTypes::Float& AsFloat() const { return std::get(m_data); } + OperatorFieldTypes::Float& AsFloat() { return std::get(m_data); } + + const OperatorFieldTypes::UIntArray& AsUIntArray() const { return std::get(m_data); } + OperatorFieldTypes::UIntArray& AsUIntArray() { return std::get(m_data); } + + const OperatorFieldTypes::FloatArray& AsFloatArray() const { return std::get(m_data); } + OperatorFieldTypes::FloatArray& AsFloatArray() { return std::get(m_data); } + + const OperatorFieldTypes::ScaleBias& AsScaleBias() const { return std::get(m_data); } + OperatorFieldTypes::ScaleBias& AsScaleBias() { return std::get(m_data); } + + const OperatorFieldTypes::Size2D& AsSize2D() const { return std::get(m_data); } + OperatorFieldTypes::Size2D& AsSize2D() { return std::get(m_data); } + +private: + const DML_SCHEMA_FIELD* m_schema; + OperatorFieldVariant m_data; +}; + diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h index fba6503dd6..09f1b1cdc4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h @@ -50,6 +50,11 @@ namespace SchemaHelpers return value; } + inline OperatorFieldTypes::UInt64 ToOperatorFieldType(uint64_t value) + { + return value; + } + inline OperatorFieldTypes::Int ToOperatorFieldType(int32_t value) { return value; @@ -71,6 +76,17 @@ namespace SchemaHelpers return field; } + inline OperatorFieldTypes::IntArray ToOperatorFieldType(const int32_t* values, uint32_t count) + { + OperatorFieldTypes::IntArray field; + if (values && count != 0) + { + field.emplace(count); + std::copy_n(values, count, field->begin()); + } + return field; + } + inline OperatorFieldTypes::FloatArray ToOperatorFieldType(const float* values, uint32_t count) { OperatorFieldTypes::FloatArray field; @@ -92,6 +108,10 @@ namespace SchemaHelpers return value; } + inline OperatorFieldTypes::ScalarUnion ToOperatorFieldType(DML_SCALAR_UNION value) + { + return value; + } class StructFieldWriter { @@ -250,6 +270,12 @@ namespace SchemaHelpers dst->Write(value); } break; + case DML_SCHEMA_FIELD_TYPE_UINT64: + { + uint64_t value = field.AsUInt64(); + dst->Write(value); + } break; + case DML_SCHEMA_FIELD_TYPE_INT: { int32_t value = field.AsInt(); @@ -276,6 +302,20 @@ namespace SchemaHelpers dst->Write(arrayPtr); } break; + case DML_SCHEMA_FIELD_TYPE_INT_ARRAY: + { + int32_t* arrayPtr = nullptr; + + const auto& values = field.AsIntArray(); + if (values) + { + arrayPtr = allocator->Allocate(values->size()); + std::copy(values->begin(), values->end(), arrayPtr); + } + + dst->Write(arrayPtr); + } break; + case DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY: { float* arrayPtr = nullptr; @@ -310,6 +350,12 @@ namespace SchemaHelpers dst->Write(value); } break; + case DML_SCHEMA_FIELD_TYPE_SCALAR_UNION: + { + uint64_t value = field.AsScalarUnion().UInt64; + dst->Write(value); + } break; + default: assert(false); THROW_HR(E_UNEXPECTED); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h.bak b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h.bak new file mode 100644 index 0000000000..fba6503dd6 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h.bak @@ -0,0 +1,345 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace SchemaHelpers +{ + inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc); + + inline OperatorFieldTypes::TensorDesc ToOperatorFieldType(const DML_TENSOR_DESC* value) + { + return value ? OperatorFieldTypes::TensorDesc(*value) : std::nullopt; + } + + inline OperatorFieldTypes::TensorDescArray ToOperatorFieldType(const DML_TENSOR_DESC* values, uint32_t count) + { + OperatorFieldTypes::TensorDescArray field; + if (values && count != 0) + { + field.emplace(count); + for (uint32_t i = 0; i < count; ++i) + { + (*field)[i] = values[i]; + } + } + return field; + } + + inline OperatorFieldTypes::OperatorDesc ToOperatorFieldType(const DML_OPERATOR_DESC* value) + { + return value ? OperatorFieldTypes::OperatorDesc(ConvertOperatorDesc(*value)) : std::nullopt; + } + + inline OperatorFieldTypes::OperatorDescArray ToOperatorFieldType(const DML_OPERATOR_DESC* values, uint32_t count) + { + OperatorFieldTypes::OperatorDescArray field; + if (values && count != 0) + { + field.emplace(count); + for (uint32_t i = 0; i < count; ++i) + { + (*field)[i] = ConvertOperatorDesc(values[i]); + } + } + return field; + } + + inline OperatorFieldTypes::UInt ToOperatorFieldType(uint32_t value) + { + return value; + } + + inline OperatorFieldTypes::Int ToOperatorFieldType(int32_t value) + { + return value; + } + + inline OperatorFieldTypes::Float ToOperatorFieldType(float value) + { + return value; + } + + inline OperatorFieldTypes::UIntArray ToOperatorFieldType(const uint32_t* values, uint32_t count) + { + OperatorFieldTypes::UIntArray field; + if (values && count != 0) + { + field.emplace(count); + std::copy_n(values, count, field->begin()); + } + return field; + } + + inline OperatorFieldTypes::FloatArray ToOperatorFieldType(const float* values, uint32_t count) + { + OperatorFieldTypes::FloatArray field; + if (values && count != 0) + { + field.emplace(count); + std::copy_n(values, count, field->begin()); + } + return field; + } + + inline OperatorFieldTypes::ScaleBias ToOperatorFieldType(const DML_SCALE_BIAS* value) + { + return value ? OperatorFieldTypes::ScaleBias(*value) : std::nullopt; + } + + inline OperatorFieldTypes::Size2D ToOperatorFieldType(DML_SIZE_2D value) + { + return value; + } + + + class StructFieldWriter + { + public: + explicit StructFieldWriter(gsl::span dst) + : m_dst(dst) + , m_bytesWritten(0) + {} + + template + void Write(const T& value) + { + static_assert(std::is_trivial_v, "Only trivial types are supported."); + + size_t dstOffset = RoundUpToMultiple(m_bytesWritten, alignof(T)); + size_t newBytesWritten = dstOffset + sizeof(value); + + assert(newBytesWritten <= gsl::narrow_cast(m_dst.size())); + memcpy(m_dst.data() + dstOffset, &value, sizeof(value)); + + m_bytesWritten = newBytesWritten; + } + + private: + template + T RoundUpToMultiple(T value, T multiple) + { + static_assert(std::is_integral_v); + + T remainder = value % multiple; + if (remainder != 0) + { + value += multiple - remainder; + } + + return value; + } + + gsl::span m_dst; + size_t m_bytesWritten; + }; + + template + DML_BUFFER_TENSOR_DESC MakeBufferTensorDesc(const DmlBufferTensorDesc& src, StackAllocator* allocator) + { + size_t dimensionCount = src.sizes.size(); + + auto* sizes = allocator->Allocate(dimensionCount); + std::copy_n(src.sizes.begin(), dimensionCount, sizes); + + UINT* strides = nullptr; + if (src.strides) + { + strides = allocator->Allocate(dimensionCount); + std::copy_n(src.strides->begin(), dimensionCount, strides); + } + + DML_BUFFER_TENSOR_DESC dst; + dst.DataType = src.dataType; + dst.Flags = src.flags; + dst.Sizes = sizes; + dst.Strides = strides; + dst.DimensionCount = static_cast(dimensionCount); + dst.TotalTensorSizeInBytes = src.totalTensorSizeInBytes; + dst.GuaranteedBaseOffsetAlignment = src.guaranteedBaseOffsetAlignment; + return dst; + } + + template + DML_TENSOR_DESC MakeTensorDesc(const DmlBufferTensorDesc& src, StackAllocator* allocator) + { + auto* desc = allocator->Allocate(); + *desc = MakeBufferTensorDesc(src, allocator); + + DML_TENSOR_DESC dst; + dst.Type = DML_TENSOR_TYPE_BUFFER; + dst.Desc = desc; + return dst; + } + + template + DML_OPERATOR_DESC ConvertOperatorDesc(const AbstractOperatorDesc& abstractDesc, StackAllocator* allocator); + + template + void WriteOperatorDescField(const OperatorField& field, StructFieldWriter* dst, StackAllocator* allocator) + { + const DML_SCHEMA_FIELD& schema = *field.GetSchema(); + + switch (schema.Type) + { + case DML_SCHEMA_FIELD_TYPE_TENSOR_DESC: + { + DML_TENSOR_DESC* desc = nullptr; + + const auto& value = field.AsTensorDesc(); + if (value) + { + desc = allocator->Allocate(); + *desc = MakeTensorDesc(*value, allocator); + } + + dst->Write(desc); + } break; + + case DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY: + { + DML_TENSOR_DESC* descs = nullptr; + + const auto& values = field.AsTensorDescArray(); + if (values) + { + descs = allocator->Allocate(values->size()); + for (size_t i = 0; i < values->size(); ++i) + { + descs[i] = MakeTensorDesc((*values)[i], allocator); + } + } + + dst->Write(descs); + } break; + + case DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC: + { + DML_OPERATOR_DESC* desc = nullptr; + + const auto& value = field.AsOperatorDesc(); + if (value) + { + desc = allocator->Allocate(); + *desc = ConvertOperatorDesc(*value, allocator); + } + + dst->Write(desc); + } break; + + case DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY: + { + DML_OPERATOR_DESC* descs = nullptr; + + const auto& values = field.AsOperatorDescArray(); + if (values) + { + descs = allocator->Allocate(values->size()); + for (size_t i = 0; i < values->size(); ++i) + { + descs[i] = ConvertOperatorDesc((*values)[i], allocator); + } + } + + dst->Write(descs); + } break; + + case DML_SCHEMA_FIELD_TYPE_UINT: + { + uint32_t value = field.AsUInt(); + dst->Write(value); + } break; + + case DML_SCHEMA_FIELD_TYPE_INT: + { + int32_t value = field.AsInt(); + dst->Write(value); + } break; + + case DML_SCHEMA_FIELD_TYPE_FLOAT: + { + float value = field.AsFloat(); + dst->Write(value); + } break; + + case DML_SCHEMA_FIELD_TYPE_UINT_ARRAY: + { + uint32_t* arrayPtr = nullptr; + + const auto& values = field.AsUIntArray(); + if (values) + { + arrayPtr = allocator->Allocate(values->size()); + std::copy(values->begin(), values->end(), arrayPtr); + } + + dst->Write(arrayPtr); + } break; + + case DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY: + { + float* arrayPtr = nullptr; + + const auto& values = field.AsFloatArray(); + if (values) + { + arrayPtr = allocator->Allocate(values->size()); + std::copy(values->begin(), values->end(), arrayPtr); + } + + dst->Write(arrayPtr); + } break; + + case DML_SCHEMA_FIELD_TYPE_SCALE_BIAS: + { + DML_SCALE_BIAS* scaleBias = nullptr; + + const auto& value = field.AsScaleBias(); + if (value) + { + scaleBias = allocator->Allocate(); + *scaleBias = *value; + } + + dst->Write(scaleBias); + } break; + + case DML_SCHEMA_FIELD_TYPE_SIZE_2D: + { + DML_SIZE_2D value = field.AsSize2D(); + dst->Write(value); + } break; + + default: + assert(false); + THROW_HR(E_UNEXPECTED); + } + } + + template + DML_OPERATOR_DESC ConvertOperatorDesc(const AbstractOperatorDesc& abstractDesc, StackAllocator* allocator) + { + const DML_OPERATOR_SCHEMA& schema = *abstractDesc.schema; + + // Retrieve the size of the ABI operator desc struct + size_t abiDescSizeInBytes = ApiTraits::OperatorTypeVisitor(schema.OperatorType, [](auto tag) { + using T = decltype(tag); // T is one of the DML_*_OPERATOR_DESC structs + return sizeof(T); + }); + + // Allocate a blob of bytes to hold the struct + byte* abiDesc = allocator->Allocate(abiDescSizeInBytes); + + // Use the schema to write data into the blob + + StructFieldWriter writer(gsl::make_span(abiDesc, abiDescSizeInBytes)); + + for (const OperatorField& field : abstractDesc.fields) + { + WriteOperatorDescField(field, &writer, allocator); + } + + return DML_OPERATOR_DESC{ schema.OperatorType, abiDesc }; + } + +} // namespace SchemaHelpers diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index 5df941e9ef..55e8e23870 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1726,7 +1726,8 @@ void InferAndVerifyOutputSizes( for (uint32_t output_dim = 0; output_dim < outputShapes.GetShape(outputIndex).size(); ++output_dim) { if (shape.dim(output_dim).has_dim_value()) { int64_t expected_size = shape.dim(output_dim).dim_value(); - ML_CHECK_BOOL(expected_size == outputShapes.GetShape(outputIndex)[output_dim]); + int64_t actual_size = outputShapes.GetShape(outputIndex)[output_dim]; + ML_CHECK_BOOL(expected_size == actual_size); } } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp index fb82f7708c..c26f223634 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp @@ -27,9 +27,9 @@ public: assert(inputTensorRank >= gsl::narrow_cast(m_strides.size())); // Pad the parameters to respect DML's requirements - FillWithLeadingValues(/*inout*/ m_offsets, m_inputTensorDescs[0].GetDimensionCount(), 0u); - FillWithLeadingValues(/*inout*/ m_sizes, m_inputTensorDescs[0].GetDimensionCount(), 1u); - FillWithLeadingValues(/*inout*/ m_strides, m_inputTensorDescs[0].GetDimensionCount(), 1); + FillWithLeadingValues(/*inout*/ m_offsets, inputTensorRank, 0u); + FillWithLeadingValues(/*inout*/ m_sizes, inputTensorRank, 1u); + FillWithLeadingValues(/*inout*/ m_strides, inputTensorRank, 1); std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp index 1afb7d83e6..b4a286e08e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp @@ -277,10 +277,14 @@ DML_TENSOR_DESC TensorDesc::GetDmlDesc() // requires coercion by the caller. void TensorDesc::ForceUnsignedDataType() { - static_assert(ApiTraits::EnumValueCount == 9, "New tensor data type. Update cases."); + static_assert(ApiTraits::EnumValueCount == 12, "New tensor data type. Update cases."); switch (m_bufferTensorDesc.DataType) { + case DML_TENSOR_DATA_TYPE_INT64: + m_bufferTensorDesc.DataType = DML_TENSOR_DATA_TYPE_UINT64; + break; + case DML_TENSOR_DATA_TYPE_INT32: m_bufferTensorDesc.DataType = DML_TENSOR_DATA_TYPE_UINT32; break; @@ -293,7 +297,7 @@ void TensorDesc::ForceUnsignedDataType() m_bufferTensorDesc.DataType = DML_TENSOR_DATA_TYPE_UINT8; break; - // Nothing to do if already unsigned + // Nothing to do if already unsigned case DML_TENSOR_DATA_TYPE_UINT32: case DML_TENSOR_DATA_TYPE_UINT16: case DML_TENSOR_DATA_TYPE_UINT8: diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 924df87537..185020ce24 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -477,7 +477,7 @@ int64_t ReadAsInt64(MLOperatorTensorDataType tensorDataType, const void* p) return edgeShapes; } - std::vector SliceHelperBase::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + std::vector SliceHelperBase::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const { return { m_outputDimensions }; } diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index e593c39cb2..4e62ba9bb8 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -89,7 +89,7 @@ void FillWithLeadingValues(/*inout*/ std::vector& values, uint32_t minimumEle const size_t fillCount = newElementCount - oldElementCount; values.resize(newElementCount); - std::copy_backward(values.data(), values.data() + oldElementCount, values.data() + fillCount); + std::copy_backward(values.begin(), values.begin() + oldElementCount, values.end()); std::fill_n(values.data(), fillCount, fillValue); } @@ -577,6 +577,7 @@ public: m_outputDimensions.assign(inputDimensions.begin(), inputDimensions.end()); m_offsets.resize(m_outputDimensions.size()); m_sizes.resize(m_outputDimensions.size()); + m_strides = std::move(steps); m_strides.resize(m_outputDimensions.size(), 1); // Only a stride of 1 element is supported by ONNX 1.2. // Set initial defaults lest 'starts' and 'ends' arrays are shorter than the dimension count. @@ -586,13 +587,27 @@ public: for (int i = 0, ci = gsl::narrow_cast(starts.size()); i < ci; ++i) { int dimIndex = axes.empty() ? i : axes[i]; + int stride = m_strides[i]; ML_CHECK_VALID_ARGUMENT(dimIndex < inputDimensions.size(), "'axes' must be valid with within actual input dimensions."); + ML_CHECK_VALID_ARGUMENT(stride != 0, "'steps' must not be 0."); // Positive values are offsets from 0. - // Negative values are offsets from the dimension's size. + // Negative values are offsets from back of the dimension's size. + // INT_MIN is a special value in ONNX which means to treat it as the smallest + // possible value, rather than the usual reversed from-the-back semantics. int dim = gsl::narrow_cast(inputDimensions[dimIndex]); int start = (starts[i] < 0 && starts[i] > INT_MIN) ? (starts[i] + dim) : starts[i]; - int end = (ends[i] < 0 && ends[i] < INT_MAX) ? (ends[i] + dim) : ends[i]; + int end = (ends[i] < 0 && starts[i] > INT_MIN) ? (ends[i] + dim) : ends[i]; + + // For negative strides, the ONNX start and end values are off-by-one. + // So fix them such that the start value remains the minimum extent + // of the slice window, and end remains the maximum exclusive extent. + if (stride < 0) + { + std::swap(start, end); + start += (start < INT_MAX) ? 1 : 0; // Avoid overflow wrap. + end += (end < INT_MAX) ? 1 : 0; + } // Clamp the dimensions to the slice extents. // Clamp negative numbers to 0, per case test_slice_start_out_of_bounds. @@ -600,7 +615,8 @@ public: end = std::min(end, dim); int size = std::max(end - start, 0); - m_outputDimensions[dimIndex] = size; + int absoluteStride = abs(stride); + m_outputDimensions[dimIndex] = (size / absoluteStride) + (size % absoluteStride != 0); m_offsets[dimIndex] = start; m_sizes[dimIndex] = gsl::narrow_cast(size); } From 5feb3c0f19ecf30ad492324cebc1a46a00c60e1f Mon Sep 17 00:00:00 2001 From: Dwayne Robinson Date: Fri, 27 Mar 2020 02:42:09 -0700 Subject: [PATCH 4/6] Delete litter backup files. --- .../External/DirectMLHelpers/ApiTraits.h.bak | 1936 ----------------- .../DirectMLHelpers/DirectMLSchema.h.bak | 1514 ------------- .../GeneratedSchemaHelpers.h.bak | 1388 ------------ .../GeneratedSchemaTypes.h.bak | 105 - 4 files changed, 4943 deletions(-) delete mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h.bak delete mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h.bak delete mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h.bak delete mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h.bak diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h.bak b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h.bak deleted file mode 100644 index d75afeb4b6..0000000000 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h.bak +++ /dev/null @@ -1,1936 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -namespace ApiTraits -{ -template -struct EnumTraits -{ -}; - -template <> -struct EnumTraits -{ - static constexpr auto ValueCount = 12; -}; - -template <> -struct EnumTraits -{ - static constexpr auto ValueCount = 2; -}; - -template <> -struct EnumTraits -{ - static constexpr auto ValueCount = 119; - static constexpr size_t ActivationFunctionCount = 19; -}; - -template <> -struct EnumTraits -{ - static constexpr auto ValueCount = 3; -}; - -template <> -struct EnumTraits -{ - static constexpr auto ValueCount = 12; - static constexpr DML_REDUCE_FUNCTION Invalid = static_cast(ValueCount); -}; - -template <> -struct EnumTraits -{ - static constexpr auto ValueCount = 2; -}; - -template <> -struct EnumTraits -{ - static constexpr auto ValueCount = 2; -}; - -template <> -struct EnumTraits -{ - static constexpr auto ValueCount = 2; -}; - -template <> -struct EnumTraits -{ - static constexpr auto ValueCount = 3; -}; - -template <> -struct EnumTraits -{ - static constexpr auto ValueCount = 2; -}; - -template <> -struct EnumTraits -{ - static constexpr auto ValueCount = 3; -}; - -template <> -struct EnumTraits -{ - static constexpr auto ValueCount = 2; -}; - -template <> -struct EnumTraits -{ - static constexpr auto ValueCount = 2; -}; - -template <> -struct EnumTraits -{ - static constexpr auto ValueCount = 3; -}; - -template <> -struct EnumTraits -{ - static constexpr auto ValueCount = 2; -}; - -template <> -struct EnumTraits -{ - static constexpr auto ValueCount = 2; -}; - -template <> -struct EnumTraits -{ - static constexpr auto ValueCount = 3; -}; - -template -constexpr auto EnumValueCount = EnumTraits::ValueCount; - -template -constexpr bool IsValidEnumValue(T value) -{ - return (std::make_unsigned_t(value) < std::make_unsigned_t(EnumValueCount)); -} - -template -struct FlagTraits -{ -}; - -template <> -struct FlagTraits -{ - static constexpr auto ValidMask = DML_TENSOR_FLAG_OWNED_BY_DML; -}; - -template <> -struct FlagTraits -{ - static constexpr auto ValidMask = DML_EXECUTION_FLAG_ALLOW_HALF_PRECISION_COMPUTATION | DML_EXECUTION_FLAG_DISABLE_META_COMMANDS | DML_EXECUTION_FLAG_DESCRIPTORS_VOLATILE; -}; - -template <> -struct FlagTraits -{ - static constexpr auto ValidMask = DML_CREATE_DEVICE_FLAG_DEBUG; -}; - -template -constexpr auto FlagsValidMask = FlagTraits::ValidMask; - -template -constexpr bool IsValidFlags(T value) -{ - return (value & ~FlagsValidMask) == 0; -} - -template -struct TensorDescTraits -{ -}; - -template <> -struct TensorDescTraits -{ - static constexpr DML_TENSOR_TYPE Type = DML_TENSOR_TYPE_BUFFER; -}; - - -template -struct OperatorDescTraits -{ -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_IDENTITY; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ABS; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ACOS; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ADD; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ASIN; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ATAN; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_CEIL; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_CLIP; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_COS; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_DIVIDE; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_EXP; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_FLOOR; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOG; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_MAX; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_MEAN; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_MIN; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_MULTIPLY; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_POW; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_RECIP; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_SIN; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_SQRT; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_SUBTRACT; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_TAN; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_THRESHOLD; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_CONVOLUTION; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GEMM; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_REDUCE; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_AVERAGE_POOLING; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LP_POOLING; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MAX_POOLING; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MAX_POOLING1; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_POOLING; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SLICE; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_CAST; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SPLIT; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_JOIN; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_PADDING; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_VALUE_SCALE_2D; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_UPSAMPLE_2D; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GATHER; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SPACE_TO_DEPTH; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DEPTH_TO_SPACE; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_TILE; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_TOP_K; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_BATCH_NORMALIZATION; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LP_NORMALIZATION; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_RNN; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LSTM; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GRU; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_SIGN; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_IS_NAN; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ERF; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_SINH; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_COSH; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_TANH; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ASINH; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ACOSH; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ATANH; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_IF; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ADD1; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MAX_UNPOOLING; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DIAGONAL_MATRIX; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SCATTER; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ONE_HOT; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_RESAMPLE; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ROUND; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_IS_INFINITY; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_FILL_VALUE_CONSTANT; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_FILL_VALUE_SEQUENCE; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_CUMULATIVE_SUMMATION; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_REVERSE_SUBSEQUENCES; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GATHER_ELEMENTS; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GATHER_ND; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SCATTER_ND; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MAX_POOLING2; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SLICE1; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_TOP_K1; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DEPTH_TO_SPACE1; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_SPACE_TO_DEPTH1; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_ELU; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_HARDMAX; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_HARD_SIGMOID; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_IDENTITY; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_LEAKY_RELU; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_LINEAR; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_LOG_SOFTMAX; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_RELU; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SCALED_ELU; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SCALED_TANH; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SIGMOID; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SOFTMAX; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SOFTPLUS; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SOFTSIGN; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_TANH; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU; -}; - -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SHRINK; -}; - - -template -struct OperatorTypeTraits -{ -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_IDENTITY> -{ - using DescType = DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ABS> -{ - using DescType = DML_ELEMENT_WISE_ABS_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ACOS> -{ - using DescType = DML_ELEMENT_WISE_ACOS_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ADD> -{ - using DescType = DML_ELEMENT_WISE_ADD_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ASIN> -{ - using DescType = DML_ELEMENT_WISE_ASIN_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ATAN> -{ - using DescType = DML_ELEMENT_WISE_ATAN_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_CEIL> -{ - using DescType = DML_ELEMENT_WISE_CEIL_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_CLIP> -{ - using DescType = DML_ELEMENT_WISE_CLIP_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_COS> -{ - using DescType = DML_ELEMENT_WISE_COS_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_DIVIDE> -{ - using DescType = DML_ELEMENT_WISE_DIVIDE_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_EXP> -{ - using DescType = DML_ELEMENT_WISE_EXP_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_FLOOR> -{ - using DescType = DML_ELEMENT_WISE_FLOOR_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOG> -{ - using DescType = DML_ELEMENT_WISE_LOG_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND> -{ - using DescType = DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS> -{ - using DescType = DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN> -{ - using DescType = DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN> -{ - using DescType = DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT> -{ - using DescType = DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR> -{ - using DescType = DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR> -{ - using DescType = DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_MAX> -{ - using DescType = DML_ELEMENT_WISE_MAX_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_MEAN> -{ - using DescType = DML_ELEMENT_WISE_MEAN_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_MIN> -{ - using DescType = DML_ELEMENT_WISE_MIN_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_MULTIPLY> -{ - using DescType = DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_POW> -{ - using DescType = DML_ELEMENT_WISE_POW_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW> -{ - using DescType = DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_RECIP> -{ - using DescType = DML_ELEMENT_WISE_RECIP_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_SIN> -{ - using DescType = DML_ELEMENT_WISE_SIN_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_SQRT> -{ - using DescType = DML_ELEMENT_WISE_SQRT_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_SUBTRACT> -{ - using DescType = DML_ELEMENT_WISE_SUBTRACT_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_TAN> -{ - using DescType = DML_ELEMENT_WISE_TAN_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_THRESHOLD> -{ - using DescType = DML_ELEMENT_WISE_THRESHOLD_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR> -{ - using DescType = DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR> -{ - using DescType = DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_CONVOLUTION> -{ - using DescType = DML_CONVOLUTION_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GEMM> -{ - using DescType = DML_GEMM_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_REDUCE> -{ - using DescType = DML_REDUCE_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_AVERAGE_POOLING> -{ - using DescType = DML_AVERAGE_POOLING_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LP_POOLING> -{ - using DescType = DML_LP_POOLING_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MAX_POOLING> -{ - using DescType = DML_MAX_POOLING_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MAX_POOLING1> -{ - using DescType = DML_MAX_POOLING1_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_POOLING> -{ - using DescType = DML_ROI_POOLING_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE> -{ - using DescType = DML_SLICE_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_CAST> -{ - using DescType = DML_CAST_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SPLIT> -{ - using DescType = DML_SPLIT_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_JOIN> -{ - using DescType = DML_JOIN_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_PADDING> -{ - using DescType = DML_PADDING_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_VALUE_SCALE_2D> -{ - using DescType = DML_VALUE_SCALE_2D_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_UPSAMPLE_2D> -{ - using DescType = DML_UPSAMPLE_2D_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GATHER> -{ - using DescType = DML_GATHER_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SPACE_TO_DEPTH> -{ - using DescType = DML_SPACE_TO_DEPTH_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_DEPTH_TO_SPACE> -{ - using DescType = DML_DEPTH_TO_SPACE_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_TILE> -{ - using DescType = DML_TILE_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_TOP_K> -{ - using DescType = DML_TOP_K_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_BATCH_NORMALIZATION> -{ - using DescType = DML_BATCH_NORMALIZATION_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION> -{ - using DescType = DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION> -{ - using DescType = DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LP_NORMALIZATION> -{ - using DescType = DML_LP_NORMALIZATION_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_RNN> -{ - using DescType = DML_RNN_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LSTM> -{ - using DescType = DML_LSTM_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GRU> -{ - using DescType = DML_GRU_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_SIGN> -{ - using DescType = DML_ELEMENT_WISE_SIGN_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_IS_NAN> -{ - using DescType = DML_ELEMENT_WISE_IS_NAN_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ERF> -{ - using DescType = DML_ELEMENT_WISE_ERF_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_SINH> -{ - using DescType = DML_ELEMENT_WISE_SINH_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_COSH> -{ - using DescType = DML_ELEMENT_WISE_COSH_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_TANH> -{ - using DescType = DML_ELEMENT_WISE_TANH_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ASINH> -{ - using DescType = DML_ELEMENT_WISE_ASINH_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ACOSH> -{ - using DescType = DML_ELEMENT_WISE_ACOSH_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ATANH> -{ - using DescType = DML_ELEMENT_WISE_ATANH_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_IF> -{ - using DescType = DML_ELEMENT_WISE_IF_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ADD1> -{ - using DescType = DML_ELEMENT_WISE_ADD1_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MAX_UNPOOLING> -{ - using DescType = DML_MAX_UNPOOLING_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_DIAGONAL_MATRIX> -{ - using DescType = DML_DIAGONAL_MATRIX_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SCATTER> -{ - using DescType = DML_SCATTER_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ONE_HOT> -{ - using DescType = DML_ONE_HOT_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_RESAMPLE> -{ - using DescType = DML_RESAMPLE_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT> -{ - using DescType = DML_ELEMENT_WISE_BIT_SHIFT_LEFT_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT> -{ - using DescType = DML_ELEMENT_WISE_BIT_SHIFT_RIGHT_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ROUND> -{ - using DescType = DML_ELEMENT_WISE_ROUND_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_IS_INFINITY> -{ - using DescType = DML_ELEMENT_WISE_IS_INFINITY_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE> -{ - using DescType = DML_ELEMENT_WISE_MODULUS_TRUNCATE_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR> -{ - using DescType = DML_ELEMENT_WISE_MODULUS_FLOOR_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_FILL_VALUE_CONSTANT> -{ - using DescType = DML_FILL_VALUE_CONSTANT_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_FILL_VALUE_SEQUENCE> -{ - using DescType = DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_CUMULATIVE_SUMMATION> -{ - using DescType = DML_CUMULATIVE_SUMMATION_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_REVERSE_SUBSEQUENCES> -{ - using DescType = DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GATHER_ELEMENTS> -{ - using DescType = DML_GATHER_ELEMENTS_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GATHER_ND> -{ - using DescType = DML_GATHER_ND_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SCATTER_ND> -{ - using DescType = DML_SCATTER_ND_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MAX_POOLING2> -{ - using DescType = DML_MAX_POOLING2_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE1> -{ - using DescType = DML_SLICE1_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_TOP_K1> -{ - using DescType = DML_TOP_K1_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_DEPTH_TO_SPACE1> -{ - using DescType = DML_DEPTH_TO_SPACE1_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SPACE_TO_DEPTH1> -{ - using DescType = DML_SPACE_TO_DEPTH1_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1> -{ - using DescType = DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_ELU> -{ - using DescType = DML_ACTIVATION_ELU_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_HARDMAX> -{ - using DescType = DML_ACTIVATION_HARDMAX_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_HARD_SIGMOID> -{ - using DescType = DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_IDENTITY> -{ - using DescType = DML_ACTIVATION_IDENTITY_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_LEAKY_RELU> -{ - using DescType = DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_LINEAR> -{ - using DescType = DML_ACTIVATION_LINEAR_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_LOG_SOFTMAX> -{ - using DescType = DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU> -{ - using DescType = DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS> -{ - using DescType = DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_RELU> -{ - using DescType = DML_ACTIVATION_RELU_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SCALED_ELU> -{ - using DescType = DML_ACTIVATION_SCALED_ELU_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SCALED_TANH> -{ - using DescType = DML_ACTIVATION_SCALED_TANH_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SIGMOID> -{ - using DescType = DML_ACTIVATION_SIGMOID_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SOFTMAX> -{ - using DescType = DML_ACTIVATION_SOFTMAX_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SOFTPLUS> -{ - using DescType = DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SOFTSIGN> -{ - using DescType = DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_TANH> -{ - using DescType = DML_ACTIVATION_TANH_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU> -{ - using DescType = DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_DESC; -}; - -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SHRINK> -{ - using DescType = DML_ACTIVATION_SHRINK_OPERATOR_DESC; -}; - - -// Calls a visitor functor, supplying an empty operator desc corresponding to the given DML_OPERATOR_TYPE as -// the first argument. -// -// For example: -// Visit(DML_OPERATOR_ELEMENT_WISE_IDENTITY, [](auto tag) { -// using T = decltype(tag); // T is one of the DML_*_OPERATOR_DESC structs -// }); -// -template -auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args) -{ - switch (static_cast(type)) - { - case DML_OPERATOR_ELEMENT_WISE_IDENTITY: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_ABS: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ABS_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_ACOS: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ACOS_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_ADD: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ADD_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_ASIN: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ASIN_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_ATAN: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ATAN_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_CEIL: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_CEIL_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_CLIP: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_CLIP_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_COS: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_COS_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_DIVIDE: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_DIVIDE_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_EXP: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_EXP_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_FLOOR: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_FLOOR_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_LOG: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_LOG_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_MAX: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_MAX_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_MEAN: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_MEAN_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_MIN: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_MIN_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_MULTIPLY: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_POW: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_POW_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_RECIP: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_RECIP_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_SIN: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_SIN_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_SQRT: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_SQRT_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_SUBTRACT: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_SUBTRACT_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_TAN: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_TAN_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_THRESHOLD: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_THRESHOLD_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_CONVOLUTION: - return std::invoke(std::forward(visitor), DML_CONVOLUTION_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_GEMM: - return std::invoke(std::forward(visitor), DML_GEMM_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_REDUCE: - return std::invoke(std::forward(visitor), DML_REDUCE_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_AVERAGE_POOLING: - return std::invoke(std::forward(visitor), DML_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_LP_POOLING: - return std::invoke(std::forward(visitor), DML_LP_POOLING_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_MAX_POOLING: - return std::invoke(std::forward(visitor), DML_MAX_POOLING_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_MAX_POOLING1: - return std::invoke(std::forward(visitor), DML_MAX_POOLING1_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ROI_POOLING: - return std::invoke(std::forward(visitor), DML_ROI_POOLING_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_SLICE: - return std::invoke(std::forward(visitor), DML_SLICE_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_CAST: - return std::invoke(std::forward(visitor), DML_CAST_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_SPLIT: - return std::invoke(std::forward(visitor), DML_SPLIT_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_JOIN: - return std::invoke(std::forward(visitor), DML_JOIN_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_PADDING: - return std::invoke(std::forward(visitor), DML_PADDING_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_VALUE_SCALE_2D: - return std::invoke(std::forward(visitor), DML_VALUE_SCALE_2D_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_UPSAMPLE_2D: - return std::invoke(std::forward(visitor), DML_UPSAMPLE_2D_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_GATHER: - return std::invoke(std::forward(visitor), DML_GATHER_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_SPACE_TO_DEPTH: - return std::invoke(std::forward(visitor), DML_SPACE_TO_DEPTH_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_DEPTH_TO_SPACE: - return std::invoke(std::forward(visitor), DML_DEPTH_TO_SPACE_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_TILE: - return std::invoke(std::forward(visitor), DML_TILE_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_TOP_K: - return std::invoke(std::forward(visitor), DML_TOP_K_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_BATCH_NORMALIZATION: - return std::invoke(std::forward(visitor), DML_BATCH_NORMALIZATION_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION: - return std::invoke(std::forward(visitor), DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION: - return std::invoke(std::forward(visitor), DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_LP_NORMALIZATION: - return std::invoke(std::forward(visitor), DML_LP_NORMALIZATION_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_RNN: - return std::invoke(std::forward(visitor), DML_RNN_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_LSTM: - return std::invoke(std::forward(visitor), DML_LSTM_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_GRU: - return std::invoke(std::forward(visitor), DML_GRU_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_SIGN: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_SIGN_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_IS_NAN: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_IS_NAN_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_ERF: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ERF_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_SINH: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_SINH_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_COSH: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_COSH_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_TANH: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_TANH_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_ASINH: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ASINH_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_ACOSH: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ACOSH_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_ATANH: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ATANH_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_IF: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_IF_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_ADD1: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ADD1_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_MAX_UNPOOLING: - return std::invoke(std::forward(visitor), DML_MAX_UNPOOLING_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_DIAGONAL_MATRIX: - return std::invoke(std::forward(visitor), DML_DIAGONAL_MATRIX_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_SCATTER: - return std::invoke(std::forward(visitor), DML_SCATTER_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ONE_HOT: - return std::invoke(std::forward(visitor), DML_ONE_HOT_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_RESAMPLE: - return std::invoke(std::forward(visitor), DML_RESAMPLE_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_BIT_SHIFT_LEFT_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_BIT_SHIFT_RIGHT_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_ROUND: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ROUND_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_IS_INFINITY: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_IS_INFINITY_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_MODULUS_TRUNCATE_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR: - return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_MODULUS_FLOOR_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_FILL_VALUE_CONSTANT: - return std::invoke(std::forward(visitor), DML_FILL_VALUE_CONSTANT_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_FILL_VALUE_SEQUENCE: - return std::invoke(std::forward(visitor), DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_CUMULATIVE_SUMMATION: - return std::invoke(std::forward(visitor), DML_CUMULATIVE_SUMMATION_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_REVERSE_SUBSEQUENCES: - return std::invoke(std::forward(visitor), DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_GATHER_ELEMENTS: - return std::invoke(std::forward(visitor), DML_GATHER_ELEMENTS_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_GATHER_ND: - return std::invoke(std::forward(visitor), DML_GATHER_ND_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_SCATTER_ND: - return std::invoke(std::forward(visitor), DML_SCATTER_ND_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_MAX_POOLING2: - return std::invoke(std::forward(visitor), DML_MAX_POOLING2_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_SLICE1: - return std::invoke(std::forward(visitor), DML_SLICE1_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_TOP_K1: - return std::invoke(std::forward(visitor), DML_TOP_K1_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_DEPTH_TO_SPACE1: - return std::invoke(std::forward(visitor), DML_DEPTH_TO_SPACE1_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_SPACE_TO_DEPTH1: - return std::invoke(std::forward(visitor), DML_SPACE_TO_DEPTH1_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1: - return std::invoke(std::forward(visitor), DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ACTIVATION_ELU: - return std::invoke(std::forward(visitor), DML_ACTIVATION_ELU_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ACTIVATION_HARDMAX: - return std::invoke(std::forward(visitor), DML_ACTIVATION_HARDMAX_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: - return std::invoke(std::forward(visitor), DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ACTIVATION_IDENTITY: - return std::invoke(std::forward(visitor), DML_ACTIVATION_IDENTITY_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ACTIVATION_LEAKY_RELU: - return std::invoke(std::forward(visitor), DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ACTIVATION_LINEAR: - return std::invoke(std::forward(visitor), DML_ACTIVATION_LINEAR_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX: - return std::invoke(std::forward(visitor), DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU: - return std::invoke(std::forward(visitor), DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS: - return std::invoke(std::forward(visitor), DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ACTIVATION_RELU: - return std::invoke(std::forward(visitor), DML_ACTIVATION_RELU_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ACTIVATION_SCALED_ELU: - return std::invoke(std::forward(visitor), DML_ACTIVATION_SCALED_ELU_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ACTIVATION_SCALED_TANH: - return std::invoke(std::forward(visitor), DML_ACTIVATION_SCALED_TANH_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ACTIVATION_SIGMOID: - return std::invoke(std::forward(visitor), DML_ACTIVATION_SIGMOID_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ACTIVATION_SOFTMAX: - return std::invoke(std::forward(visitor), DML_ACTIVATION_SOFTMAX_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ACTIVATION_SOFTPLUS: - return std::invoke(std::forward(visitor), DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ACTIVATION_SOFTSIGN: - return std::invoke(std::forward(visitor), DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ACTIVATION_TANH: - return std::invoke(std::forward(visitor), DML_ACTIVATION_TANH_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU: - return std::invoke(std::forward(visitor), DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_DESC{}, std::forward(args)...); - case DML_OPERATOR_ACTIVATION_SHRINK: - return std::invoke(std::forward(visitor), DML_ACTIVATION_SHRINK_OPERATOR_DESC{}, std::forward(args)...); - - default: - THROW_HR(E_INVALIDARG); - } -} - - -inline gsl::czstring ToString(DML_OPERATOR_TYPE value) -{ - switch (value) - { - case DML_OPERATOR_INVALID: return "DML_OPERATOR_INVALID"; - case DML_OPERATOR_ELEMENT_WISE_IDENTITY: return "DML_OPERATOR_ELEMENT_WISE_IDENTITY"; - case DML_OPERATOR_ELEMENT_WISE_ABS: return "DML_OPERATOR_ELEMENT_WISE_ABS"; - case DML_OPERATOR_ELEMENT_WISE_ACOS: return "DML_OPERATOR_ELEMENT_WISE_ACOS"; - case DML_OPERATOR_ELEMENT_WISE_ADD: return "DML_OPERATOR_ELEMENT_WISE_ADD"; - case DML_OPERATOR_ELEMENT_WISE_ASIN: return "DML_OPERATOR_ELEMENT_WISE_ASIN"; - case DML_OPERATOR_ELEMENT_WISE_ATAN: return "DML_OPERATOR_ELEMENT_WISE_ATAN"; - case DML_OPERATOR_ELEMENT_WISE_CEIL: return "DML_OPERATOR_ELEMENT_WISE_CEIL"; - case DML_OPERATOR_ELEMENT_WISE_CLIP: return "DML_OPERATOR_ELEMENT_WISE_CLIP"; - case DML_OPERATOR_ELEMENT_WISE_COS: return "DML_OPERATOR_ELEMENT_WISE_COS"; - case DML_OPERATOR_ELEMENT_WISE_DIVIDE: return "DML_OPERATOR_ELEMENT_WISE_DIVIDE"; - case DML_OPERATOR_ELEMENT_WISE_EXP: return "DML_OPERATOR_ELEMENT_WISE_EXP"; - case DML_OPERATOR_ELEMENT_WISE_FLOOR: return "DML_OPERATOR_ELEMENT_WISE_FLOOR"; - case DML_OPERATOR_ELEMENT_WISE_LOG: return "DML_OPERATOR_ELEMENT_WISE_LOG"; - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND"; - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS"; - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN"; - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN"; - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT"; - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR"; - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR: return "DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR"; - case DML_OPERATOR_ELEMENT_WISE_MAX: return "DML_OPERATOR_ELEMENT_WISE_MAX"; - case DML_OPERATOR_ELEMENT_WISE_MEAN: return "DML_OPERATOR_ELEMENT_WISE_MEAN"; - case DML_OPERATOR_ELEMENT_WISE_MIN: return "DML_OPERATOR_ELEMENT_WISE_MIN"; - case DML_OPERATOR_ELEMENT_WISE_MULTIPLY: return "DML_OPERATOR_ELEMENT_WISE_MULTIPLY"; - case DML_OPERATOR_ELEMENT_WISE_POW: return "DML_OPERATOR_ELEMENT_WISE_POW"; - case DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW: return "DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW"; - case DML_OPERATOR_ELEMENT_WISE_RECIP: return "DML_OPERATOR_ELEMENT_WISE_RECIP"; - case DML_OPERATOR_ELEMENT_WISE_SIN: return "DML_OPERATOR_ELEMENT_WISE_SIN"; - case DML_OPERATOR_ELEMENT_WISE_SQRT: return "DML_OPERATOR_ELEMENT_WISE_SQRT"; - case DML_OPERATOR_ELEMENT_WISE_SUBTRACT: return "DML_OPERATOR_ELEMENT_WISE_SUBTRACT"; - case DML_OPERATOR_ELEMENT_WISE_TAN: return "DML_OPERATOR_ELEMENT_WISE_TAN"; - case DML_OPERATOR_ELEMENT_WISE_THRESHOLD: return "DML_OPERATOR_ELEMENT_WISE_THRESHOLD"; - case DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR: return "DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR"; - case DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR: return "DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR"; - case DML_OPERATOR_CONVOLUTION: return "DML_OPERATOR_CONVOLUTION"; - case DML_OPERATOR_GEMM: return "DML_OPERATOR_GEMM"; - case DML_OPERATOR_REDUCE: return "DML_OPERATOR_REDUCE"; - case DML_OPERATOR_AVERAGE_POOLING: return "DML_OPERATOR_AVERAGE_POOLING"; - case DML_OPERATOR_LP_POOLING: return "DML_OPERATOR_LP_POOLING"; - case DML_OPERATOR_MAX_POOLING: return "DML_OPERATOR_MAX_POOLING"; - case DML_OPERATOR_MAX_POOLING1: return "DML_OPERATOR_MAX_POOLING1"; - case DML_OPERATOR_ROI_POOLING: return "DML_OPERATOR_ROI_POOLING"; - case DML_OPERATOR_SLICE: return "DML_OPERATOR_SLICE"; - case DML_OPERATOR_CAST: return "DML_OPERATOR_CAST"; - case DML_OPERATOR_SPLIT: return "DML_OPERATOR_SPLIT"; - case DML_OPERATOR_JOIN: return "DML_OPERATOR_JOIN"; - case DML_OPERATOR_PADDING: return "DML_OPERATOR_PADDING"; - case DML_OPERATOR_VALUE_SCALE_2D: return "DML_OPERATOR_VALUE_SCALE_2D"; - case DML_OPERATOR_UPSAMPLE_2D: return "DML_OPERATOR_UPSAMPLE_2D"; - case DML_OPERATOR_GATHER: return "DML_OPERATOR_GATHER"; - case DML_OPERATOR_SPACE_TO_DEPTH: return "DML_OPERATOR_SPACE_TO_DEPTH"; - case DML_OPERATOR_DEPTH_TO_SPACE: return "DML_OPERATOR_DEPTH_TO_SPACE"; - case DML_OPERATOR_TILE: return "DML_OPERATOR_TILE"; - case DML_OPERATOR_TOP_K: return "DML_OPERATOR_TOP_K"; - case DML_OPERATOR_BATCH_NORMALIZATION: return "DML_OPERATOR_BATCH_NORMALIZATION"; - case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION: return "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION"; - case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION: return "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION"; - case DML_OPERATOR_LP_NORMALIZATION: return "DML_OPERATOR_LP_NORMALIZATION"; - case DML_OPERATOR_RNN: return "DML_OPERATOR_RNN"; - case DML_OPERATOR_LSTM: return "DML_OPERATOR_LSTM"; - case DML_OPERATOR_GRU: return "DML_OPERATOR_GRU"; - case DML_OPERATOR_ELEMENT_WISE_SIGN: return "DML_OPERATOR_ELEMENT_WISE_SIGN"; - case DML_OPERATOR_ELEMENT_WISE_IS_NAN: return "DML_OPERATOR_ELEMENT_WISE_IS_NAN"; - case DML_OPERATOR_ELEMENT_WISE_ERF: return "DML_OPERATOR_ELEMENT_WISE_ERF"; - case DML_OPERATOR_ELEMENT_WISE_SINH: return "DML_OPERATOR_ELEMENT_WISE_SINH"; - case DML_OPERATOR_ELEMENT_WISE_COSH: return "DML_OPERATOR_ELEMENT_WISE_COSH"; - case DML_OPERATOR_ELEMENT_WISE_TANH: return "DML_OPERATOR_ELEMENT_WISE_TANH"; - case DML_OPERATOR_ELEMENT_WISE_ASINH: return "DML_OPERATOR_ELEMENT_WISE_ASINH"; - case DML_OPERATOR_ELEMENT_WISE_ACOSH: return "DML_OPERATOR_ELEMENT_WISE_ACOSH"; - case DML_OPERATOR_ELEMENT_WISE_ATANH: return "DML_OPERATOR_ELEMENT_WISE_ATANH"; - case DML_OPERATOR_ELEMENT_WISE_IF: return "DML_OPERATOR_ELEMENT_WISE_IF"; - case DML_OPERATOR_ELEMENT_WISE_ADD1: return "DML_OPERATOR_ELEMENT_WISE_ADD1"; - case DML_OPERATOR_MAX_UNPOOLING: return "DML_OPERATOR_MAX_UNPOOLING"; - case DML_OPERATOR_DIAGONAL_MATRIX: return "DML_OPERATOR_DIAGONAL_MATRIX"; - case DML_OPERATOR_SCATTER: return "DML_OPERATOR_SCATTER"; - case DML_OPERATOR_ONE_HOT: return "DML_OPERATOR_ONE_HOT"; - case DML_OPERATOR_RESAMPLE: return "DML_OPERATOR_RESAMPLE"; - case DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT: return "DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT"; - case DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT: return "DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT"; - case DML_OPERATOR_ELEMENT_WISE_ROUND: return "DML_OPERATOR_ELEMENT_WISE_ROUND"; - case DML_OPERATOR_ELEMENT_WISE_IS_INFINITY: return "DML_OPERATOR_ELEMENT_WISE_IS_INFINITY"; - case DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE: return "DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE"; - case DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR: return "DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR"; - case DML_OPERATOR_FILL_VALUE_CONSTANT: return "DML_OPERATOR_FILL_VALUE_CONSTANT"; - case DML_OPERATOR_FILL_VALUE_SEQUENCE: return "DML_OPERATOR_FILL_VALUE_SEQUENCE"; - case DML_OPERATOR_CUMULATIVE_SUMMATION: return "DML_OPERATOR_CUMULATIVE_SUMMATION"; - case DML_OPERATOR_REVERSE_SUBSEQUENCES: return "DML_OPERATOR_REVERSE_SUBSEQUENCES"; - case DML_OPERATOR_GATHER_ELEMENTS: return "DML_OPERATOR_GATHER_ELEMENTS"; - case DML_OPERATOR_GATHER_ND: return "DML_OPERATOR_GATHER_ND"; - case DML_OPERATOR_SCATTER_ND: return "DML_OPERATOR_SCATTER_ND"; - case DML_OPERATOR_MAX_POOLING2: return "DML_OPERATOR_MAX_POOLING2"; - case DML_OPERATOR_SLICE1: return "DML_OPERATOR_SLICE1"; - case DML_OPERATOR_TOP_K1: return "DML_OPERATOR_TOP_K1"; - case DML_OPERATOR_DEPTH_TO_SPACE1: return "DML_OPERATOR_DEPTH_TO_SPACE1"; - case DML_OPERATOR_SPACE_TO_DEPTH1: return "DML_OPERATOR_SPACE_TO_DEPTH1"; - case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1: return "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1"; - default: - assert(false); - return ""; - } -} -} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h.bak b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h.bak deleted file mode 100644 index 7c46a8a6a2..0000000000 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h.bak +++ /dev/null @@ -1,1514 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -extern "C" { - -enum DML_SCHEMA_FIELD_KIND -{ - DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, - DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, - DML_SCHEMA_FIELD_KIND_ATTRIBUTE, -}; - -enum DML_SCHEMA_FIELD_TYPE -{ - DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, - DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY, - DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC, - DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY, - DML_SCHEMA_FIELD_TYPE_UINT, - DML_SCHEMA_FIELD_TYPE_INT, - DML_SCHEMA_FIELD_TYPE_FLOAT, - DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, - DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, - DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, - DML_SCHEMA_FIELD_TYPE_SIZE_2D, -}; - -enum DML_SCHEMA_OPERATOR_SUPPORT_FLAGS -{ - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE = 0, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION = (1 << 0), -}; - -DEFINE_ENUM_FLAG_OPERATORS(DML_SCHEMA_OPERATOR_SUPPORT_FLAGS); - -struct DML_SCHEMA_FIELD -{ - DML_SCHEMA_FIELD_KIND Kind; - DML_SCHEMA_FIELD_TYPE Type; - const CHAR* Name; - BOOL Optional; -}; - -struct DML_OPERATOR_SCHEMA -{ - const CHAR* OperatorName; - DML_OPERATOR_TYPE OperatorType; - DML_SCHEMA_OPERATOR_SUPPORT_FLAGS SupportFlags; - - UINT FieldCount; - const DML_SCHEMA_FIELD* Fields; -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_IDENTITY_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_IDENTITY_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_IDENTITY", - DML_OPERATOR_ELEMENT_WISE_IDENTITY, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_IDENTITY_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ABS_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ABS_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_ABS", - DML_OPERATOR_ELEMENT_WISE_ABS, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_ABS_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ACOS_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ACOS_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_ACOS", - DML_OPERATOR_ELEMENT_WISE_ACOS, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_ACOS_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ADD_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ADD_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_ADD", - DML_OPERATOR_ELEMENT_WISE_ADD, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_ADD_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ASIN_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ASIN_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_ASIN", - DML_OPERATOR_ELEMENT_WISE_ASIN, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_ASIN_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ATAN_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ATAN_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_ATAN", - DML_OPERATOR_ELEMENT_WISE_ATAN, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_ATAN_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_CEIL_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_CEIL_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_CEIL", - DML_OPERATOR_ELEMENT_WISE_CEIL, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_CEIL_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA_FIELDS[5] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Min", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Max", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_CLIP", - DML_OPERATOR_ELEMENT_WISE_CLIP, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 5, - DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_COS", - DML_OPERATOR_ELEMENT_WISE_COS, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_DIVIDE_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_DIVIDE_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_DIVIDE", - DML_OPERATOR_ELEMENT_WISE_DIVIDE, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_DIVIDE_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_EXP_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_EXP_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_EXP", - DML_OPERATOR_ELEMENT_WISE_EXP, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_EXP_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_FLOOR_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_FLOOR_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_FLOOR", - DML_OPERATOR_ELEMENT_WISE_FLOOR, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_FLOOR_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOG_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOG_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_LOG", - DML_OPERATOR_ELEMENT_WISE_LOG, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_LOG_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND", - DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS", - DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 3, - DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN", - DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 3, - DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN", - DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 3, - DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_SCHEMA_FIELDS[2] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT", - DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 2, - DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR", - DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR", - DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_MAX_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_MAX_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_MAX", - DML_OPERATOR_ELEMENT_WISE_MAX, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_MAX_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_MEAN_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_MEAN_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_MEAN", - DML_OPERATOR_ELEMENT_WISE_MEAN, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_MEAN_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_MIN_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_MIN_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_MIN", - DML_OPERATOR_ELEMENT_WISE_MIN, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_MIN_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_MULTIPLY_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_MULTIPLY_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_MULTIPLY", - DML_OPERATOR_ELEMENT_WISE_MULTIPLY, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_MULTIPLY_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_POW_OPERATOR_SCHEMA_FIELDS[4] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ExponentTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_POW_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_POW", - DML_OPERATOR_ELEMENT_WISE_POW, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 4, - DML_ELEMENT_WISE_POW_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_SCHEMA_FIELDS[4] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Exponent", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW", - DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 4, - DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_RECIP_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_RECIP_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_RECIP", - DML_OPERATOR_ELEMENT_WISE_RECIP, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_RECIP_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_SIN_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_SIN_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_SIN", - DML_OPERATOR_ELEMENT_WISE_SIN, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_SIN_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_SQRT", - DML_OPERATOR_ELEMENT_WISE_SQRT, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_SUBTRACT", - DML_OPERATOR_ELEMENT_WISE_SUBTRACT, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_TAN_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_TAN_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_TAN", - DML_OPERATOR_ELEMENT_WISE_TAN, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_TAN_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA_FIELDS[4] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Min", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_THRESHOLD", - DML_OPERATOR_ELEMENT_WISE_THRESHOLD, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 4, - DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS[4] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR", - DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 4, - DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS[4] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR", - DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 4, - DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_CONVOLUTION_OPERATOR_SCHEMA_FIELDS[14] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "FilterTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Mode", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Direction", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "OutputPadding", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "GroupCount", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC, "FusedActivation", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_CONVOLUTION_OPERATOR_SCHEMA { - "DML_OPERATOR_CONVOLUTION", - DML_OPERATOR_CONVOLUTION, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 14, - DML_CONVOLUTION_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_GEMM_OPERATOR_SCHEMA_FIELDS[9] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "CTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "TransA", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "TransB", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC, "FusedActivation", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_GEMM_OPERATOR_SCHEMA { - "DML_OPERATOR_GEMM", - DML_OPERATOR_GEMM, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 9, - DML_GEMM_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_REDUCE_OPERATOR_SCHEMA_FIELDS[5] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Function", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisCount", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Axes", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_REDUCE_OPERATOR_SCHEMA { - "DML_OPERATOR_REDUCE", - DML_OPERATOR_REDUCE, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 5, - DML_REDUCE_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[8] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_AVERAGE_POOLING_OPERATOR_SCHEMA { - "DML_OPERATOR_AVERAGE_POOLING", - DML_OPERATOR_AVERAGE_POOLING, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 8, - DML_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_LP_POOLING_OPERATOR_SCHEMA_FIELDS[8] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "P", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_LP_POOLING_OPERATOR_SCHEMA { - "DML_OPERATOR_LP_POOLING", - DML_OPERATOR_LP_POOLING, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 8, - DML_LP_POOLING_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_MAX_POOLING_OPERATOR_SCHEMA_FIELDS[7] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_MAX_POOLING_OPERATOR_SCHEMA { - "DML_OPERATOR_MAX_POOLING", - DML_OPERATOR_MAX_POOLING, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 7, - DML_MAX_POOLING_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_MAX_POOLING1_OPERATOR_SCHEMA_FIELDS[8] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputIndicesTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_MAX_POOLING1_OPERATOR_SCHEMA { - "DML_OPERATOR_MAX_POOLING1", - DML_OPERATOR_MAX_POOLING1, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 8, - DML_MAX_POOLING1_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ROI_POOLING_OPERATOR_SCHEMA_FIELDS[5] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ROITensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SpatialScale", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SIZE_2D, "PooledSize", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ROI_POOLING_OPERATOR_SCHEMA { - "DML_OPERATOR_ROI_POOLING", - DML_OPERATOR_ROI_POOLING, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 5, - DML_ROI_POOLING_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_SLICE_OPERATOR_SCHEMA_FIELDS[6] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Offsets", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Sizes", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_SLICE_OPERATOR_SCHEMA { - "DML_OPERATOR_SLICE", - DML_OPERATOR_SLICE, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 6, - DML_SLICE_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_CAST_OPERATOR_SCHEMA_FIELDS[2] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_CAST_OPERATOR_SCHEMA { - "DML_OPERATOR_CAST", - DML_OPERATOR_CAST, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 2, - DML_CAST_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_SPLIT_OPERATOR_SCHEMA_FIELDS[4] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "OutputCount", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY, "OutputTensors", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_SPLIT_OPERATOR_SCHEMA { - "DML_OPERATOR_SPLIT", - DML_OPERATOR_SPLIT, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 4, - DML_SPLIT_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_JOIN_OPERATOR_SCHEMA_FIELDS[4] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InputCount", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY, "InputTensors", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_JOIN_OPERATOR_SCHEMA { - "DML_OPERATOR_JOIN", - DML_OPERATOR_JOIN, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 4, - DML_JOIN_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_PADDING_OPERATOR_SCHEMA_FIELDS[7] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "PaddingMode", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "PaddingValue", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_PADDING_OPERATOR_SCHEMA { - "DML_OPERATOR_PADDING", - DML_OPERATOR_PADDING, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 7, - DML_PADDING_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_VALUE_SCALE_2D_OPERATOR_SCHEMA_FIELDS[5] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Scale", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ChannelCount", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, "Bias", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_VALUE_SCALE_2D_OPERATOR_SCHEMA { - "DML_OPERATOR_VALUE_SCALE_2D", - DML_OPERATOR_VALUE_SCALE_2D, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 5, - DML_VALUE_SCALE_2D_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_UPSAMPLE_2D_OPERATOR_SCHEMA_FIELDS[4] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SIZE_2D, "ScaleSize", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InterpolationMode", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_UPSAMPLE_2D_OPERATOR_SCHEMA { - "DML_OPERATOR_UPSAMPLE_2D", - DML_OPERATOR_UPSAMPLE_2D, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 4, - DML_UPSAMPLE_2D_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_GATHER_OPERATOR_SCHEMA_FIELDS[5] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IndexDimensions", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_GATHER_OPERATOR_SCHEMA { - "DML_OPERATOR_GATHER", - DML_OPERATOR_GATHER, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 5, - DML_GATHER_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_SPACE_TO_DEPTH_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "BlockSize", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_SPACE_TO_DEPTH_OPERATOR_SCHEMA { - "DML_OPERATOR_SPACE_TO_DEPTH", - DML_OPERATOR_SPACE_TO_DEPTH, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 3, - DML_SPACE_TO_DEPTH_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_DEPTH_TO_SPACE_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "BlockSize", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_DEPTH_TO_SPACE_OPERATOR_SCHEMA { - "DML_OPERATOR_DEPTH_TO_SPACE", - DML_OPERATOR_DEPTH_TO_SPACE, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 3, - DML_DEPTH_TO_SPACE_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_TILE_OPERATOR_SCHEMA_FIELDS[4] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "RepeatsCount", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Repeats", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_TILE_OPERATOR_SCHEMA { - "DML_OPERATOR_TILE", - DML_OPERATOR_TILE, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 4, - DML_TILE_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_TOP_K_OPERATOR_SCHEMA_FIELDS[5] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputValueTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputIndexTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "K", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_TOP_K_OPERATOR_SCHEMA { - "DML_OPERATOR_TOP_K", - DML_OPERATOR_TOP_K, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 5, - DML_TOP_K_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA_FIELDS[9] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "MeanTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "VarianceTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Spatial", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC, "FusedActivation", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA { - "DML_OPERATOR_BATCH_NORMALIZATION", - DML_OPERATOR_BATCH_NORMALIZATION, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 9, - DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA_FIELDS[8] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "CrossChannel", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "NormalizeVariance", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC, "FusedActivation", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA { - "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION", - DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 8, - DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA_FIELDS[7] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "CrossChannel", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "LocalSize", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Bias", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA { - "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION", - DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 7, - DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_LP_NORMALIZATION_OPERATOR_SCHEMA_FIELDS[5] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "P", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_LP_NORMALIZATION_OPERATOR_SCHEMA { - "DML_OPERATOR_LP_NORMALIZATION", - DML_OPERATOR_LP_NORMALIZATION, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 5, - DML_LP_NORMALIZATION_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_RNN_OPERATOR_SCHEMA_FIELDS[11] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "WeightTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "RecurrenceTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "HiddenInitTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "SequenceLengthsTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputSequenceTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputSingleTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ActivationDescCount", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY, "ActivationDescs", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Direction", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_RNN_OPERATOR_SCHEMA { - "DML_OPERATOR_RNN", - DML_OPERATOR_RNN, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 11, - DML_RNN_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_LSTM_OPERATOR_SCHEMA_FIELDS[17] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "WeightTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "RecurrenceTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "HiddenInitTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "CellMemInitTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "SequenceLengthsTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "PeepholeTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputSequenceTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputSingleTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputCellSingleTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ActivationDescCount", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY, "ActivationDescs", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Direction", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "ClipThreshold", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "UseClipThreshold", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "CoupleInputForget", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_LSTM_OPERATOR_SCHEMA { - "DML_OPERATOR_LSTM", - DML_OPERATOR_LSTM, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 17, - DML_LSTM_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_GRU_OPERATOR_SCHEMA_FIELDS[12] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "WeightTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "RecurrenceTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "HiddenInitTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "SequenceLengthsTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputSequenceTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputSingleTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ActivationDescCount", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY, "ActivationDescs", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Direction", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "LinearBeforeReset", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_GRU_OPERATOR_SCHEMA { - "DML_OPERATOR_GRU", - DML_OPERATOR_GRU, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 12, - DML_GRU_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_SIGN_OPERATOR_SCHEMA_FIELDS[2] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_SIGN_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_SIGN", - DML_OPERATOR_ELEMENT_WISE_SIGN, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 2, - DML_ELEMENT_WISE_SIGN_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA_FIELDS[2] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_IS_NAN", - DML_OPERATOR_ELEMENT_WISE_IS_NAN, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 2, - DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_ERF", - DML_OPERATOR_ELEMENT_WISE_ERF, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_SINH_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_SINH_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_SINH", - DML_OPERATOR_ELEMENT_WISE_SINH, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_SINH_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_COSH_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_COSH_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_COSH", - DML_OPERATOR_ELEMENT_WISE_COSH, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_COSH_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_TANH_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_TANH_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_TANH", - DML_OPERATOR_ELEMENT_WISE_TANH, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_TANH_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ASINH_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ASINH_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_ASINH", - DML_OPERATOR_ELEMENT_WISE_ASINH, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_ASINH_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ACOSH_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ACOSH_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_ACOSH", - DML_OPERATOR_ELEMENT_WISE_ACOSH, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_ACOSH_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ATANH_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ATANH_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_ATANH", - DML_OPERATOR_ELEMENT_WISE_ATANH, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ELEMENT_WISE_ATANH_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_IF_OPERATOR_SCHEMA_FIELDS[4] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ConditionTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_IF_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_IF", - DML_OPERATOR_ELEMENT_WISE_IF, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 4, - DML_ELEMENT_WISE_IF_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ADD1_OPERATOR_SCHEMA_FIELDS[4] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC, "FusedActivation", true }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ADD1_OPERATOR_SCHEMA { - "DML_OPERATOR_ELEMENT_WISE_ADD1", - DML_OPERATOR_ELEMENT_WISE_ADD1, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 4, - DML_ELEMENT_WISE_ADD1_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_MAX_UNPOOLING_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_MAX_UNPOOLING_OPERATOR_SCHEMA { - "DML_OPERATOR_MAX_UNPOOLING", - DML_OPERATOR_MAX_UNPOOLING, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 3, - DML_MAX_UNPOOLING_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_DIAGONAL_MATRIX_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_INT, "Offset", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Value", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_DIAGONAL_MATRIX_OPERATOR_SCHEMA { - "DML_OPERATOR_DIAGONAL_MATRIX", - DML_OPERATOR_DIAGONAL_MATRIX, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 3, - DML_DIAGONAL_MATRIX_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_SCATTER_OPERATOR_SCHEMA_FIELDS[5] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "UpdatesTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_SCATTER_OPERATOR_SCHEMA { - "DML_OPERATOR_SCATTER", - DML_OPERATOR_SCATTER, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 5, - DML_SCATTER_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ONE_HOT_OPERATOR_SCHEMA_FIELDS[4] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ValuesTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ONE_HOT_OPERATOR_SCHEMA { - "DML_OPERATOR_ONE_HOT", - DML_OPERATOR_ONE_HOT, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 4, - DML_ONE_HOT_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_RESAMPLE_OPERATOR_SCHEMA_FIELDS[5] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InterpolationMode", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ScaleCount", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, "Scales", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE_OPERATOR_SCHEMA { - "DML_OPERATOR_RESAMPLE", - DML_OPERATOR_RESAMPLE, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 5, - DML_RESAMPLE_OPERATOR_SCHEMA_FIELDS, -}; - - -constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_ELU_OPERATOR_SCHEMA { - "DML_OPERATOR_ACTIVATION_ELU", - DML_OPERATOR_ACTIVATION_ELU, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA_FIELDS[2] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA { - "DML_OPERATOR_ACTIVATION_HARDMAX", - DML_OPERATOR_ACTIVATION_HARDMAX, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 2, - DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA_FIELDS[4] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA { - "DML_OPERATOR_ACTIVATION_HARD_SIGMOID", - DML_OPERATOR_ACTIVATION_HARD_SIGMOID, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 4, - DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ACTIVATION_IDENTITY_OPERATOR_SCHEMA_FIELDS[2] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_IDENTITY_OPERATOR_SCHEMA { - "DML_OPERATOR_ACTIVATION_IDENTITY", - DML_OPERATOR_ACTIVATION_IDENTITY, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 2, - DML_ACTIVATION_IDENTITY_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ACTIVATION_LEAKY_RELU_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_LEAKY_RELU_OPERATOR_SCHEMA { - "DML_OPERATOR_ACTIVATION_LEAKY_RELU", - DML_OPERATOR_ACTIVATION_LEAKY_RELU, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ACTIVATION_LEAKY_RELU_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ACTIVATION_LINEAR_OPERATOR_SCHEMA_FIELDS[4] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_LINEAR_OPERATOR_SCHEMA { - "DML_OPERATOR_ACTIVATION_LINEAR", - DML_OPERATOR_ACTIVATION_LINEAR, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 4, - DML_ACTIVATION_LINEAR_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA_FIELDS[2] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA { - "DML_OPERATOR_ACTIVATION_LOG_SOFTMAX", - DML_OPERATOR_ACTIVATION_LOG_SOFTMAX, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 2, - DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "SlopeTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA { - "DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU", - DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_SCHEMA_FIELDS[4] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_SCHEMA { - "DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS", - DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 4, - DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ACTIVATION_RELU_OPERATOR_SCHEMA_FIELDS[2] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_RELU_OPERATOR_SCHEMA { - "DML_OPERATOR_ACTIVATION_RELU", - DML_OPERATOR_ACTIVATION_RELU, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 2, - DML_ACTIVATION_RELU_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SCALED_ELU_OPERATOR_SCHEMA_FIELDS[4] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Gamma", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SCALED_ELU_OPERATOR_SCHEMA { - "DML_OPERATOR_ACTIVATION_SCALED_ELU", - DML_OPERATOR_ACTIVATION_SCALED_ELU, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 4, - DML_ACTIVATION_SCALED_ELU_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SCALED_TANH_OPERATOR_SCHEMA_FIELDS[4] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SCALED_TANH_OPERATOR_SCHEMA { - "DML_OPERATOR_ACTIVATION_SCALED_TANH", - DML_OPERATOR_ACTIVATION_SCALED_TANH, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 4, - DML_ACTIVATION_SCALED_TANH_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SIGMOID_OPERATOR_SCHEMA_FIELDS[2] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SIGMOID_OPERATOR_SCHEMA { - "DML_OPERATOR_ACTIVATION_SIGMOID", - DML_OPERATOR_ACTIVATION_SIGMOID, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 2, - DML_ACTIVATION_SIGMOID_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA_FIELDS[2] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA { - "DML_OPERATOR_ACTIVATION_SOFTMAX", - DML_OPERATOR_ACTIVATION_SOFTMAX, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 2, - DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Steepness", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA { - "DML_OPERATOR_ACTIVATION_SOFTPLUS", - DML_OPERATOR_ACTIVATION_SOFTPLUS, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SOFTSIGN_OPERATOR_SCHEMA_FIELDS[2] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SOFTSIGN_OPERATOR_SCHEMA { - "DML_OPERATOR_ACTIVATION_SOFTSIGN", - DML_OPERATOR_ACTIVATION_SOFTSIGN, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 2, - DML_ACTIVATION_SOFTSIGN_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ACTIVATION_TANH_OPERATOR_SCHEMA_FIELDS[2] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_TANH_OPERATOR_SCHEMA { - "DML_OPERATOR_ACTIVATION_TANH", - DML_OPERATOR_ACTIVATION_TANH, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 2, - DML_ACTIVATION_TANH_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_SCHEMA_FIELDS[3] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_SCHEMA { - "DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU", - DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 3, - DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_SCHEMA_FIELDS, -}; - -constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA_FIELDS[4] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Bias", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Threshold", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA { - "DML_OPERATOR_ACTIVATION_SHRINK", - DML_OPERATOR_ACTIVATION_SHRINK, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, - 4, - DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA_FIELDS, -}; - -} // extern "C" diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h.bak b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h.bak deleted file mode 100644 index b8285c77a7..0000000000 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h.bak +++ /dev/null @@ -1,1388 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -namespace SchemaHelpers -{ -AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc); - -inline std::vector GetFields(const DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_IDENTITY_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_IDENTITY_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_IDENTITY_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_ABS_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_ABS_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_ABS_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_ABS_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_ACOS_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_ACOS_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_ACOS_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_ACOS_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_ADD_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_ADD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), - OperatorField(&DML_ELEMENT_WISE_ADD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), - OperatorField(&DML_ELEMENT_WISE_ADD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_ASIN_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_ASIN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_ASIN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_ASIN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_ATAN_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_ATAN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_ATAN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_ATAN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_CEIL_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_CEIL_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_CEIL_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_CEIL_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_CLIP_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - OperatorField(&DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Min))), - OperatorField(&DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Max))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_COS_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_DIVIDE_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_DIVIDE_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), - OperatorField(&DML_ELEMENT_WISE_DIVIDE_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), - OperatorField(&DML_ELEMENT_WISE_DIVIDE_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_EXP_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_EXP_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_EXP_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_EXP_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_FLOOR_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_FLOOR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_FLOOR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_FLOOR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_LOG_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_LOG_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_LOG_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_LOG_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), - OperatorField(&DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), - OperatorField(&DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), - OperatorField(&DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), - OperatorField(&DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), - OperatorField(&DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), - OperatorField(&DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), - OperatorField(&DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), - OperatorField(&DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), - OperatorField(&DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), - OperatorField(&DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), - OperatorField(&DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), - OperatorField(&DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_MAX_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_MAX_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), - OperatorField(&DML_ELEMENT_WISE_MAX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), - OperatorField(&DML_ELEMENT_WISE_MAX_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_MEAN_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_MEAN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), - OperatorField(&DML_ELEMENT_WISE_MEAN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), - OperatorField(&DML_ELEMENT_WISE_MEAN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_MIN_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_MIN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), - OperatorField(&DML_ELEMENT_WISE_MIN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), - OperatorField(&DML_ELEMENT_WISE_MIN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_MULTIPLY_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), - OperatorField(&DML_ELEMENT_WISE_MULTIPLY_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), - OperatorField(&DML_ELEMENT_WISE_MULTIPLY_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_POW_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_POW_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_POW_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.ExponentTensor))), - OperatorField(&DML_ELEMENT_WISE_POW_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_POW_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.ScaleBias))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - OperatorField(&DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Exponent))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_RECIP_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_RECIP_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_RECIP_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_RECIP_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_SIN_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_SIN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_SIN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_SIN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_SQRT_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_SUBTRACT_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), - OperatorField(&DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), - OperatorField(&DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_TAN_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_TAN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_TAN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_TAN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_THRESHOLD_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - OperatorField(&DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Min))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.ScaleTensor))), - OperatorField(&DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ZeroPointTensor))), - OperatorField(&DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.ScaleTensor))), - OperatorField(&DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ZeroPointTensor))), - OperatorField(&DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_CONVOLUTION_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.FilterTensor))), - OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.BiasTensor))), - OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Mode))), - OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.Direction))), - OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.DimensionCount))), - OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), - OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), - OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), - OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), - OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.OutputPadding), desc.DimensionCount)), - OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.GroupCount))), - OperatorField(&DML_CONVOLUTION_OPERATOR_SCHEMA.Fields[13], ToOperatorFieldType(static_cast(desc.FusedActivation))), - }; -} -inline std::vector GetFields(const DML_GEMM_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_GEMM_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), - OperatorField(&DML_GEMM_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), - OperatorField(&DML_GEMM_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.CTensor))), - OperatorField(&DML_GEMM_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_GEMM_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.TransA))), - OperatorField(&DML_GEMM_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.TransB))), - OperatorField(&DML_GEMM_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.Alpha))), - OperatorField(&DML_GEMM_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Beta))), - OperatorField(&DML_GEMM_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.FusedActivation))), - }; -} -inline std::vector GetFields(const DML_REDUCE_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_REDUCE_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.Function))), - OperatorField(&DML_REDUCE_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_REDUCE_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_REDUCE_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.AxisCount))), - OperatorField(&DML_REDUCE_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Axes), desc.AxisCount)), - }; -} -inline std::vector GetFields(const DML_AVERAGE_POOLING_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.DimensionCount))), - OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), - OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), - OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), - OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), - OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.IncludePadding))), - }; -} -inline std::vector GetFields(const DML_LP_POOLING_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_LP_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_LP_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_LP_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.DimensionCount))), - OperatorField(&DML_LP_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), - OperatorField(&DML_LP_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), - OperatorField(&DML_LP_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), - OperatorField(&DML_LP_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), - OperatorField(&DML_LP_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.P))), - }; -} -inline std::vector GetFields(const DML_MAX_POOLING_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_MAX_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_MAX_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_MAX_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.DimensionCount))), - OperatorField(&DML_MAX_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), - OperatorField(&DML_MAX_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), - OperatorField(&DML_MAX_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), - OperatorField(&DML_MAX_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), - }; -} -inline std::vector GetFields(const DML_MAX_POOLING1_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_MAX_POOLING1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_MAX_POOLING1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_MAX_POOLING1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputIndicesTensor))), - OperatorField(&DML_MAX_POOLING1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.DimensionCount))), - OperatorField(&DML_MAX_POOLING1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), - OperatorField(&DML_MAX_POOLING1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), - OperatorField(&DML_MAX_POOLING1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), - OperatorField(&DML_MAX_POOLING1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), - }; -} -inline std::vector GetFields(const DML_ROI_POOLING_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.ROITensor))), - OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.SpatialScale))), - OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.PooledSize))), - }; -} -inline std::vector GetFields(const DML_SLICE_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_SLICE_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_SLICE_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_SLICE_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.DimensionCount))), - OperatorField(&DML_SLICE_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Offsets), desc.DimensionCount)), - OperatorField(&DML_SLICE_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Sizes), desc.DimensionCount)), - OperatorField(&DML_SLICE_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), - }; -} -inline std::vector GetFields(const DML_CAST_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_CAST_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_CAST_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_SPLIT_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_SPLIT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_SPLIT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputCount))), - OperatorField(&DML_SPLIT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensors), desc.OutputCount)), - OperatorField(&DML_SPLIT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Axis))), - }; -} -inline std::vector GetFields(const DML_JOIN_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_JOIN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputCount))), - OperatorField(&DML_JOIN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputTensors), desc.InputCount)), - OperatorField(&DML_JOIN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_JOIN_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Axis))), - }; -} -inline std::vector GetFields(const DML_PADDING_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_PADDING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_PADDING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_PADDING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.PaddingMode))), - OperatorField(&DML_PADDING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.PaddingValue))), - OperatorField(&DML_PADDING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.DimensionCount))), - OperatorField(&DML_PADDING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), - OperatorField(&DML_PADDING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), - }; -} -inline std::vector GetFields(const DML_VALUE_SCALE_2D_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_VALUE_SCALE_2D_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_VALUE_SCALE_2D_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_VALUE_SCALE_2D_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Scale))), - OperatorField(&DML_VALUE_SCALE_2D_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.ChannelCount))), - OperatorField(&DML_VALUE_SCALE_2D_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Bias), desc.ChannelCount)), - }; -} -inline std::vector GetFields(const DML_UPSAMPLE_2D_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_UPSAMPLE_2D_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_UPSAMPLE_2D_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_UPSAMPLE_2D_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleSize))), - OperatorField(&DML_UPSAMPLE_2D_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.InterpolationMode))), - }; -} -inline std::vector GetFields(const DML_GATHER_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_GATHER_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_GATHER_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.IndicesTensor))), - OperatorField(&DML_GATHER_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_GATHER_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Axis))), - OperatorField(&DML_GATHER_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.IndexDimensions))), - }; -} -inline std::vector GetFields(const DML_SPACE_TO_DEPTH_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_SPACE_TO_DEPTH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_SPACE_TO_DEPTH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_SPACE_TO_DEPTH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.BlockSize))), - }; -} -inline std::vector GetFields(const DML_DEPTH_TO_SPACE_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_DEPTH_TO_SPACE_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_DEPTH_TO_SPACE_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_DEPTH_TO_SPACE_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.BlockSize))), - }; -} -inline std::vector GetFields(const DML_TILE_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_TILE_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_TILE_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_TILE_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.RepeatsCount))), - OperatorField(&DML_TILE_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Repeats), desc.RepeatsCount)), - }; -} -inline std::vector GetFields(const DML_TOP_K_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_TOP_K_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_TOP_K_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputValueTensor))), - OperatorField(&DML_TOP_K_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputIndexTensor))), - OperatorField(&DML_TOP_K_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Axis))), - OperatorField(&DML_TOP_K_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.K))), - }; -} -inline std::vector GetFields(const DML_BATCH_NORMALIZATION_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.MeanTensor))), - OperatorField(&DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.VarianceTensor))), - OperatorField(&DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.ScaleTensor))), - OperatorField(&DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.BiasTensor))), - OperatorField(&DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.Spatial))), - OperatorField(&DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Epsilon))), - OperatorField(&DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.FusedActivation))), - }; -} -inline std::vector GetFields(const DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.ScaleTensor))), - OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.BiasTensor))), - OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.CrossChannel))), - OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.NormalizeVariance))), - OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.Epsilon))), - OperatorField(&DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.FusedActivation))), - }; -} -inline std::vector GetFields(const DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.CrossChannel))), - OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.LocalSize))), - OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Alpha))), - OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.Beta))), - OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.Bias))), - }; -} -inline std::vector GetFields(const DML_LP_NORMALIZATION_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_LP_NORMALIZATION_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_LP_NORMALIZATION_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_LP_NORMALIZATION_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Axis))), - OperatorField(&DML_LP_NORMALIZATION_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Epsilon))), - OperatorField(&DML_LP_NORMALIZATION_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.P))), - }; -} -inline std::vector GetFields(const DML_RNN_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_RNN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_RNN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.WeightTensor))), - OperatorField(&DML_RNN_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.RecurrenceTensor))), - OperatorField(&DML_RNN_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.BiasTensor))), - OperatorField(&DML_RNN_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.HiddenInitTensor))), - OperatorField(&DML_RNN_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.SequenceLengthsTensor))), - OperatorField(&DML_RNN_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.OutputSequenceTensor))), - OperatorField(&DML_RNN_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputSingleTensor))), - OperatorField(&DML_RNN_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.ActivationDescCount))), - OperatorField(&DML_RNN_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.ActivationDescs), desc.ActivationDescCount)), - OperatorField(&DML_RNN_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.Direction))), - }; -} -inline std::vector GetFields(const DML_LSTM_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.WeightTensor))), - OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.RecurrenceTensor))), - OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.BiasTensor))), - OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.HiddenInitTensor))), - OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.CellMemInitTensor))), - OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.SequenceLengthsTensor))), - OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.PeepholeTensor))), - OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.OutputSequenceTensor))), - OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.OutputSingleTensor))), - OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.OutputCellSingleTensor))), - OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.ActivationDescCount))), - OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.ActivationDescs), desc.ActivationDescCount)), - OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[13], ToOperatorFieldType(static_cast(desc.Direction))), - OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[14], ToOperatorFieldType(static_cast(desc.ClipThreshold))), - OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[15], ToOperatorFieldType(static_cast(desc.UseClipThreshold))), - OperatorField(&DML_LSTM_OPERATOR_SCHEMA.Fields[16], ToOperatorFieldType(static_cast(desc.CoupleInputForget))), - }; -} -inline std::vector GetFields(const DML_GRU_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.WeightTensor))), - OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.RecurrenceTensor))), - OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.BiasTensor))), - OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.HiddenInitTensor))), - OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.SequenceLengthsTensor))), - OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.OutputSequenceTensor))), - OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputSingleTensor))), - OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.ActivationDescCount))), - OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.ActivationDescs), desc.ActivationDescCount)), - OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.Direction))), - OperatorField(&DML_GRU_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.LinearBeforeReset))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_SIGN_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_SIGN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_SIGN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_IS_NAN_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_ERF_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_SINH_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_SINH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_SINH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_SINH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_COSH_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_COSH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_COSH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_COSH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_TANH_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_TANH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_TANH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_TANH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_ASINH_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_ASINH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_ASINH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_ASINH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_ACOSH_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_ACOSH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_ACOSH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_ACOSH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_ATANH_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_ATANH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ELEMENT_WISE_ATANH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_ATANH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_IF_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_IF_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ConditionTensor))), - OperatorField(&DML_ELEMENT_WISE_IF_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.ATensor))), - OperatorField(&DML_ELEMENT_WISE_IF_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.BTensor))), - OperatorField(&DML_ELEMENT_WISE_IF_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ELEMENT_WISE_ADD1_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ELEMENT_WISE_ADD1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), - OperatorField(&DML_ELEMENT_WISE_ADD1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), - OperatorField(&DML_ELEMENT_WISE_ADD1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ELEMENT_WISE_ADD1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.FusedActivation))), - }; -} -inline std::vector GetFields(const DML_MAX_UNPOOLING_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_MAX_UNPOOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_MAX_UNPOOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.IndicesTensor))), - OperatorField(&DML_MAX_UNPOOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_DIAGONAL_MATRIX_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_DIAGONAL_MATRIX_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_DIAGONAL_MATRIX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.Offset))), - OperatorField(&DML_DIAGONAL_MATRIX_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Value))), - }; -} -inline std::vector GetFields(const DML_SCATTER_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_SCATTER_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_SCATTER_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.IndicesTensor))), - OperatorField(&DML_SCATTER_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.UpdatesTensor))), - OperatorField(&DML_SCATTER_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_SCATTER_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Axis))), - }; -} -inline std::vector GetFields(const DML_ONE_HOT_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ONE_HOT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.IndicesTensor))), - OperatorField(&DML_ONE_HOT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.ValuesTensor))), - OperatorField(&DML_ONE_HOT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ONE_HOT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Axis))), - }; -} -inline std::vector GetFields(const DML_RESAMPLE_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_RESAMPLE_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_RESAMPLE_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_RESAMPLE_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.InterpolationMode))), - OperatorField(&DML_RESAMPLE_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.ScaleCount))), - OperatorField(&DML_RESAMPLE_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Scales), desc.ScaleCount)), - }; -} -inline std::vector GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ACTIVATION_ELU_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ACTIVATION_ELU_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ACTIVATION_ELU_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Alpha))), - }; -} -inline std::vector GetFields(const DML_ACTIVATION_HARDMAX_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Alpha))), - OperatorField(&DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Beta))), - }; -} -inline std::vector GetFields(const DML_ACTIVATION_IDENTITY_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ACTIVATION_IDENTITY_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ACTIVATION_IDENTITY_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ACTIVATION_LEAKY_RELU_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ACTIVATION_LEAKY_RELU_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ACTIVATION_LEAKY_RELU_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Alpha))), - }; -} -inline std::vector GetFields(const DML_ACTIVATION_LINEAR_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ACTIVATION_LINEAR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ACTIVATION_LINEAR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ACTIVATION_LINEAR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Alpha))), - OperatorField(&DML_ACTIVATION_LINEAR_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Beta))), - }; -} -inline std::vector GetFields(const DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.SlopeTensor))), - OperatorField(&DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Alpha))), - OperatorField(&DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Beta))), - }; -} -inline std::vector GetFields(const DML_ACTIVATION_RELU_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ACTIVATION_RELU_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ACTIVATION_RELU_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ACTIVATION_SCALED_ELU_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ACTIVATION_SCALED_ELU_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ACTIVATION_SCALED_ELU_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ACTIVATION_SCALED_ELU_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Alpha))), - OperatorField(&DML_ACTIVATION_SCALED_ELU_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Gamma))), - }; -} -inline std::vector GetFields(const DML_ACTIVATION_SCALED_TANH_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ACTIVATION_SCALED_TANH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ACTIVATION_SCALED_TANH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ACTIVATION_SCALED_TANH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Alpha))), - OperatorField(&DML_ACTIVATION_SCALED_TANH_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Beta))), - }; -} -inline std::vector GetFields(const DML_ACTIVATION_SIGMOID_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ACTIVATION_SIGMOID_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ACTIVATION_SIGMOID_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ACTIVATION_SOFTMAX_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Steepness))), - }; -} -inline std::vector GetFields(const DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ACTIVATION_SOFTSIGN_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ACTIVATION_SOFTSIGN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ACTIVATION_TANH_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ACTIVATION_TANH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ACTIVATION_TANH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} -inline std::vector GetFields(const DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Alpha))), - }; -} -inline std::vector GetFields(const DML_ACTIVATION_SHRINK_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Bias))), - OperatorField(&DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Threshold))), - }; -} - -inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) -{ - switch (operatorType) - { - case DML_OPERATOR_ELEMENT_WISE_IDENTITY: return DML_ELEMENT_WISE_IDENTITY_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_ABS: return DML_ELEMENT_WISE_ABS_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_ACOS: return DML_ELEMENT_WISE_ACOS_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_ADD: return DML_ELEMENT_WISE_ADD_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_ASIN: return DML_ELEMENT_WISE_ASIN_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_ATAN: return DML_ELEMENT_WISE_ATAN_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_CEIL: return DML_ELEMENT_WISE_CEIL_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_CLIP: return DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_COS: return DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_DIVIDE: return DML_ELEMENT_WISE_DIVIDE_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_EXP: return DML_ELEMENT_WISE_EXP_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_FLOOR: return DML_ELEMENT_WISE_FLOOR_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_LOG: return DML_ELEMENT_WISE_LOG_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND: return DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS: return DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN: return DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN: return DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT: return DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR: return DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR: return DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_MAX: return DML_ELEMENT_WISE_MAX_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_MEAN: return DML_ELEMENT_WISE_MEAN_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_MIN: return DML_ELEMENT_WISE_MIN_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_MULTIPLY: return DML_ELEMENT_WISE_MULTIPLY_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_POW: return DML_ELEMENT_WISE_POW_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW: return DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_RECIP: return DML_ELEMENT_WISE_RECIP_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_SIN: return DML_ELEMENT_WISE_SIN_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_SQRT: return DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_SUBTRACT: return DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_TAN: return DML_ELEMENT_WISE_TAN_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_THRESHOLD: return DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR: return DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR: return DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_SCHEMA; - case DML_OPERATOR_CONVOLUTION: return DML_CONVOLUTION_OPERATOR_SCHEMA; - case DML_OPERATOR_GEMM: return DML_GEMM_OPERATOR_SCHEMA; - case DML_OPERATOR_REDUCE: return DML_REDUCE_OPERATOR_SCHEMA; - case DML_OPERATOR_AVERAGE_POOLING: return DML_AVERAGE_POOLING_OPERATOR_SCHEMA; - case DML_OPERATOR_LP_POOLING: return DML_LP_POOLING_OPERATOR_SCHEMA; - case DML_OPERATOR_MAX_POOLING: return DML_MAX_POOLING_OPERATOR_SCHEMA; - case DML_OPERATOR_MAX_POOLING1: return DML_MAX_POOLING1_OPERATOR_SCHEMA; - case DML_OPERATOR_ROI_POOLING: return DML_ROI_POOLING_OPERATOR_SCHEMA; - case DML_OPERATOR_SLICE: return DML_SLICE_OPERATOR_SCHEMA; - case DML_OPERATOR_CAST: return DML_CAST_OPERATOR_SCHEMA; - case DML_OPERATOR_SPLIT: return DML_SPLIT_OPERATOR_SCHEMA; - case DML_OPERATOR_JOIN: return DML_JOIN_OPERATOR_SCHEMA; - case DML_OPERATOR_PADDING: return DML_PADDING_OPERATOR_SCHEMA; - case DML_OPERATOR_VALUE_SCALE_2D: return DML_VALUE_SCALE_2D_OPERATOR_SCHEMA; - case DML_OPERATOR_UPSAMPLE_2D: return DML_UPSAMPLE_2D_OPERATOR_SCHEMA; - case DML_OPERATOR_GATHER: return DML_GATHER_OPERATOR_SCHEMA; - case DML_OPERATOR_SPACE_TO_DEPTH: return DML_SPACE_TO_DEPTH_OPERATOR_SCHEMA; - case DML_OPERATOR_DEPTH_TO_SPACE: return DML_DEPTH_TO_SPACE_OPERATOR_SCHEMA; - case DML_OPERATOR_TILE: return DML_TILE_OPERATOR_SCHEMA; - case DML_OPERATOR_TOP_K: return DML_TOP_K_OPERATOR_SCHEMA; - case DML_OPERATOR_BATCH_NORMALIZATION: return DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA; - case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION: return DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA; - case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION: return DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA; - case DML_OPERATOR_LP_NORMALIZATION: return DML_LP_NORMALIZATION_OPERATOR_SCHEMA; - case DML_OPERATOR_RNN: return DML_RNN_OPERATOR_SCHEMA; - case DML_OPERATOR_LSTM: return DML_LSTM_OPERATOR_SCHEMA; - case DML_OPERATOR_GRU: return DML_GRU_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_SIGN: return DML_ELEMENT_WISE_SIGN_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_IS_NAN: return DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_ERF: return DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_SINH: return DML_ELEMENT_WISE_SINH_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_COSH: return DML_ELEMENT_WISE_COSH_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_TANH: return DML_ELEMENT_WISE_TANH_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_ASINH: return DML_ELEMENT_WISE_ASINH_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_ACOSH: return DML_ELEMENT_WISE_ACOSH_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_ATANH: return DML_ELEMENT_WISE_ATANH_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_IF: return DML_ELEMENT_WISE_IF_OPERATOR_SCHEMA; - case DML_OPERATOR_ELEMENT_WISE_ADD1: return DML_ELEMENT_WISE_ADD1_OPERATOR_SCHEMA; - case DML_OPERATOR_MAX_UNPOOLING: return DML_MAX_UNPOOLING_OPERATOR_SCHEMA; - case DML_OPERATOR_DIAGONAL_MATRIX: return DML_DIAGONAL_MATRIX_OPERATOR_SCHEMA; - case DML_OPERATOR_SCATTER: return DML_SCATTER_OPERATOR_SCHEMA; - case DML_OPERATOR_ONE_HOT: return DML_ONE_HOT_OPERATOR_SCHEMA; - case DML_OPERATOR_RESAMPLE: return DML_RESAMPLE_OPERATOR_SCHEMA; - case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA; - case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA; - case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA; - case DML_OPERATOR_ACTIVATION_IDENTITY: return DML_ACTIVATION_IDENTITY_OPERATOR_SCHEMA; - case DML_OPERATOR_ACTIVATION_LEAKY_RELU: return DML_ACTIVATION_LEAKY_RELU_OPERATOR_SCHEMA; - case DML_OPERATOR_ACTIVATION_LINEAR: return DML_ACTIVATION_LINEAR_OPERATOR_SCHEMA; - case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX: return DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA; - case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU: return DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA; - case DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS: return DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_SCHEMA; - case DML_OPERATOR_ACTIVATION_RELU: return DML_ACTIVATION_RELU_OPERATOR_SCHEMA; - case DML_OPERATOR_ACTIVATION_SCALED_ELU: return DML_ACTIVATION_SCALED_ELU_OPERATOR_SCHEMA; - case DML_OPERATOR_ACTIVATION_SCALED_TANH: return DML_ACTIVATION_SCALED_TANH_OPERATOR_SCHEMA; - case DML_OPERATOR_ACTIVATION_SIGMOID: return DML_ACTIVATION_SIGMOID_OPERATOR_SCHEMA; - case DML_OPERATOR_ACTIVATION_SOFTMAX: return DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA; - case DML_OPERATOR_ACTIVATION_SOFTPLUS: return DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA; - case DML_OPERATOR_ACTIVATION_SOFTSIGN: return DML_ACTIVATION_SOFTSIGN_OPERATOR_SCHEMA; - case DML_OPERATOR_ACTIVATION_TANH: return DML_ACTIVATION_TANH_OPERATOR_SCHEMA; - case DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU: return DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_SCHEMA; - case DML_OPERATOR_ACTIVATION_SHRINK: return DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA; - default: THROW_HR(E_INVALIDARG); - } -} - -inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) -{ - switch (static_cast(opDesc.Type)) - { - case DML_OPERATOR_ELEMENT_WISE_IDENTITY: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_IDENTITY_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_ABS: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_ABS_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_ACOS: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_ACOS_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_ADD: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_ADD_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_ASIN: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_ASIN_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_ATAN: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_ATAN_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_CEIL: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_CEIL_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_CLIP: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_COS: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_DIVIDE: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_DIVIDE_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_EXP: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_EXP_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_FLOOR: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_FLOOR_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_LOG: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_LOG_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_MAX: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_MAX_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_MEAN: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_MEAN_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_MIN: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_MIN_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_MULTIPLY: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_MULTIPLY_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_POW: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_POW_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_RECIP: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_RECIP_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_SIN: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_SIN_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_SQRT: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_SUBTRACT: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_TAN: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_TAN_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_THRESHOLD: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_CONVOLUTION: - return AbstractOperatorDesc( - &DML_CONVOLUTION_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_GEMM: - return AbstractOperatorDesc( - &DML_GEMM_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_REDUCE: - return AbstractOperatorDesc( - &DML_REDUCE_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_AVERAGE_POOLING: - return AbstractOperatorDesc( - &DML_AVERAGE_POOLING_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_LP_POOLING: - return AbstractOperatorDesc( - &DML_LP_POOLING_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_MAX_POOLING: - return AbstractOperatorDesc( - &DML_MAX_POOLING_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_MAX_POOLING1: - return AbstractOperatorDesc( - &DML_MAX_POOLING1_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ROI_POOLING: - return AbstractOperatorDesc( - &DML_ROI_POOLING_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_SLICE: - return AbstractOperatorDesc( - &DML_SLICE_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_CAST: - return AbstractOperatorDesc( - &DML_CAST_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_SPLIT: - return AbstractOperatorDesc( - &DML_SPLIT_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_JOIN: - return AbstractOperatorDesc( - &DML_JOIN_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_PADDING: - return AbstractOperatorDesc( - &DML_PADDING_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_VALUE_SCALE_2D: - return AbstractOperatorDesc( - &DML_VALUE_SCALE_2D_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_UPSAMPLE_2D: - return AbstractOperatorDesc( - &DML_UPSAMPLE_2D_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_GATHER: - return AbstractOperatorDesc( - &DML_GATHER_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_SPACE_TO_DEPTH: - return AbstractOperatorDesc( - &DML_SPACE_TO_DEPTH_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_DEPTH_TO_SPACE: - return AbstractOperatorDesc( - &DML_DEPTH_TO_SPACE_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_TILE: - return AbstractOperatorDesc( - &DML_TILE_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_TOP_K: - return AbstractOperatorDesc( - &DML_TOP_K_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_BATCH_NORMALIZATION: - return AbstractOperatorDesc( - &DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION: - return AbstractOperatorDesc( - &DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION: - return AbstractOperatorDesc( - &DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_LP_NORMALIZATION: - return AbstractOperatorDesc( - &DML_LP_NORMALIZATION_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_RNN: - return AbstractOperatorDesc( - &DML_RNN_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_LSTM: - return AbstractOperatorDesc( - &DML_LSTM_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_GRU: - return AbstractOperatorDesc( - &DML_GRU_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_SIGN: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_SIGN_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_IS_NAN: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_ERF: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_SINH: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_SINH_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_COSH: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_COSH_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_TANH: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_TANH_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_ASINH: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_ASINH_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_ACOSH: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_ACOSH_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_ATANH: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_ATANH_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_IF: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_IF_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ELEMENT_WISE_ADD1: - return AbstractOperatorDesc( - &DML_ELEMENT_WISE_ADD1_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_MAX_UNPOOLING: - return AbstractOperatorDesc( - &DML_MAX_UNPOOLING_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_DIAGONAL_MATRIX: - return AbstractOperatorDesc( - &DML_DIAGONAL_MATRIX_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_SCATTER: - return AbstractOperatorDesc( - &DML_SCATTER_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ONE_HOT: - return AbstractOperatorDesc( - &DML_ONE_HOT_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_RESAMPLE: - return AbstractOperatorDesc( - &DML_RESAMPLE_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ACTIVATION_ELU: - return AbstractOperatorDesc( - &DML_ACTIVATION_ELU_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ACTIVATION_HARDMAX: - return AbstractOperatorDesc( - &DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: - return AbstractOperatorDesc( - &DML_ACTIVATION_HARD_SIGMOID_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ACTIVATION_IDENTITY: - return AbstractOperatorDesc( - &DML_ACTIVATION_IDENTITY_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ACTIVATION_LEAKY_RELU: - return AbstractOperatorDesc( - &DML_ACTIVATION_LEAKY_RELU_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ACTIVATION_LINEAR: - return AbstractOperatorDesc( - &DML_ACTIVATION_LINEAR_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX: - return AbstractOperatorDesc( - &DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU: - return AbstractOperatorDesc( - &DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS: - return AbstractOperatorDesc( - &DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ACTIVATION_RELU: - return AbstractOperatorDesc( - &DML_ACTIVATION_RELU_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ACTIVATION_SCALED_ELU: - return AbstractOperatorDesc( - &DML_ACTIVATION_SCALED_ELU_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ACTIVATION_SCALED_TANH: - return AbstractOperatorDesc( - &DML_ACTIVATION_SCALED_TANH_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ACTIVATION_SIGMOID: - return AbstractOperatorDesc( - &DML_ACTIVATION_SIGMOID_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ACTIVATION_SOFTMAX: - return AbstractOperatorDesc( - &DML_ACTIVATION_SOFTMAX_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ACTIVATION_SOFTPLUS: - return AbstractOperatorDesc( - &DML_ACTIVATION_SOFTPLUS_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ACTIVATION_SOFTSIGN: - return AbstractOperatorDesc( - &DML_ACTIVATION_SOFTSIGN_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ACTIVATION_TANH: - return AbstractOperatorDesc( - &DML_ACTIVATION_TANH_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU: - return AbstractOperatorDesc( - &DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_ACTIVATION_SHRINK: - return AbstractOperatorDesc( - &DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); - default: THROW_HR(E_INVALIDARG); - } - -} -} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h.bak b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h.bak deleted file mode 100644 index 57c8ec8ce0..0000000000 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h.bak +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -using ApiAttributeVariant = std::variant< - const DML_TENSOR_DESC*, - const DML_OPERATOR_DESC*, - UINT, - INT, - FLOAT, - const UINT*, - const FLOAT*, - const DML_SCALE_BIAS*, - DML_SIZE_2D - >; - -namespace OperatorFieldTypes -{ - using TensorDesc = std::optional; // DML_SCHEMA_FIELD_TYPE_TENSOR_DESC - using TensorDescArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY - using OperatorDesc = std::optional; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC - using OperatorDescArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY - using UInt = uint32_t; // DML_SCHEMA_FIELD_TYPE_UINT - using Int = int32_t; // DML_SCHEMA_FIELD_TYPE_INT - using Float = float; // DML_SCHEMA_FIELD_TYPE_FLOAT - using UIntArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_UINT_ARRAY - using FloatArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY - using ScaleBias = std::optional; // DML_SCHEMA_FIELD_TYPE_SCALE_BIAS - using Size2D = DML_SIZE_2D; // DML_SCHEMA_FIELD_TYPE_SIZE_2D -} - -using OperatorFieldVariant = std::variant< - OperatorFieldTypes::TensorDesc, - OperatorFieldTypes::TensorDescArray, - OperatorFieldTypes::OperatorDesc, - OperatorFieldTypes::OperatorDescArray, - OperatorFieldTypes::UInt, - OperatorFieldTypes::Int, - OperatorFieldTypes::Float, - OperatorFieldTypes::UIntArray, - OperatorFieldTypes::FloatArray, - OperatorFieldTypes::ScaleBias, - OperatorFieldTypes::Size2D - >; - -class OperatorField -{ -public: - OperatorField() = default; - explicit OperatorField(const DML_SCHEMA_FIELD* schema, OperatorFieldVariant&& data) - : m_schema(schema) - , m_data(std::move(data)) - { - assert(m_schema->Type == (DML_SCHEMA_FIELD_TYPE)m_data.index()); - } - - const DML_SCHEMA_FIELD* GetSchema() const - { - return m_schema; - } - - const OperatorFieldVariant& GetData() const - { - return m_data; - } - - const OperatorFieldTypes::TensorDesc& AsTensorDesc() const { return std::get(m_data); } - OperatorFieldTypes::TensorDesc& AsTensorDesc() { return std::get(m_data); } - - const OperatorFieldTypes::TensorDescArray& AsTensorDescArray() const { return std::get(m_data); } - OperatorFieldTypes::TensorDescArray& AsTensorDescArray() { return std::get(m_data); } - - const OperatorFieldTypes::OperatorDesc& AsOperatorDesc() const { return std::get(m_data); } - OperatorFieldTypes::OperatorDesc& AsOperatorDesc() { return std::get(m_data); } - - const OperatorFieldTypes::OperatorDescArray& AsOperatorDescArray() const { return std::get(m_data); } - OperatorFieldTypes::OperatorDescArray& AsOperatorDescArray() { return std::get(m_data); } - - const OperatorFieldTypes::UInt& AsUInt() const { return std::get(m_data); } - OperatorFieldTypes::UInt& AsUInt() { return std::get(m_data); } - - const OperatorFieldTypes::Int& AsInt() const { return std::get(m_data); } - OperatorFieldTypes::Int& AsInt() { return std::get(m_data); } - - const OperatorFieldTypes::Float& AsFloat() const { return std::get(m_data); } - OperatorFieldTypes::Float& AsFloat() { return std::get(m_data); } - - const OperatorFieldTypes::UIntArray& AsUIntArray() const { return std::get(m_data); } - OperatorFieldTypes::UIntArray& AsUIntArray() { return std::get(m_data); } - - const OperatorFieldTypes::FloatArray& AsFloatArray() const { return std::get(m_data); } - OperatorFieldTypes::FloatArray& AsFloatArray() { return std::get(m_data); } - - const OperatorFieldTypes::ScaleBias& AsScaleBias() const { return std::get(m_data); } - OperatorFieldTypes::ScaleBias& AsScaleBias() { return std::get(m_data); } - - const OperatorFieldTypes::Size2D& AsSize2D() const { return std::get(m_data); } - OperatorFieldTypes::Size2D& AsSize2D() { return std::get(m_data); } - -private: - const DML_SCHEMA_FIELD* m_schema; - OperatorFieldVariant m_data; -}; - From 031647635b4672bf67a2df277434faf7fa016005 Mon Sep 17 00:00:00 2001 From: Dwayne Robinson Date: Fri, 27 Mar 2020 02:43:11 -0700 Subject: [PATCH 5/6] Delete another litter file. --- .../DirectMLHelpers/SchemaHelpers.h.bak | 345 ------------------ 1 file changed, 345 deletions(-) delete mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h.bak diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h.bak b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h.bak deleted file mode 100644 index fba6503dd6..0000000000 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h.bak +++ /dev/null @@ -1,345 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -namespace SchemaHelpers -{ - inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc); - - inline OperatorFieldTypes::TensorDesc ToOperatorFieldType(const DML_TENSOR_DESC* value) - { - return value ? OperatorFieldTypes::TensorDesc(*value) : std::nullopt; - } - - inline OperatorFieldTypes::TensorDescArray ToOperatorFieldType(const DML_TENSOR_DESC* values, uint32_t count) - { - OperatorFieldTypes::TensorDescArray field; - if (values && count != 0) - { - field.emplace(count); - for (uint32_t i = 0; i < count; ++i) - { - (*field)[i] = values[i]; - } - } - return field; - } - - inline OperatorFieldTypes::OperatorDesc ToOperatorFieldType(const DML_OPERATOR_DESC* value) - { - return value ? OperatorFieldTypes::OperatorDesc(ConvertOperatorDesc(*value)) : std::nullopt; - } - - inline OperatorFieldTypes::OperatorDescArray ToOperatorFieldType(const DML_OPERATOR_DESC* values, uint32_t count) - { - OperatorFieldTypes::OperatorDescArray field; - if (values && count != 0) - { - field.emplace(count); - for (uint32_t i = 0; i < count; ++i) - { - (*field)[i] = ConvertOperatorDesc(values[i]); - } - } - return field; - } - - inline OperatorFieldTypes::UInt ToOperatorFieldType(uint32_t value) - { - return value; - } - - inline OperatorFieldTypes::Int ToOperatorFieldType(int32_t value) - { - return value; - } - - inline OperatorFieldTypes::Float ToOperatorFieldType(float value) - { - return value; - } - - inline OperatorFieldTypes::UIntArray ToOperatorFieldType(const uint32_t* values, uint32_t count) - { - OperatorFieldTypes::UIntArray field; - if (values && count != 0) - { - field.emplace(count); - std::copy_n(values, count, field->begin()); - } - return field; - } - - inline OperatorFieldTypes::FloatArray ToOperatorFieldType(const float* values, uint32_t count) - { - OperatorFieldTypes::FloatArray field; - if (values && count != 0) - { - field.emplace(count); - std::copy_n(values, count, field->begin()); - } - return field; - } - - inline OperatorFieldTypes::ScaleBias ToOperatorFieldType(const DML_SCALE_BIAS* value) - { - return value ? OperatorFieldTypes::ScaleBias(*value) : std::nullopt; - } - - inline OperatorFieldTypes::Size2D ToOperatorFieldType(DML_SIZE_2D value) - { - return value; - } - - - class StructFieldWriter - { - public: - explicit StructFieldWriter(gsl::span dst) - : m_dst(dst) - , m_bytesWritten(0) - {} - - template - void Write(const T& value) - { - static_assert(std::is_trivial_v, "Only trivial types are supported."); - - size_t dstOffset = RoundUpToMultiple(m_bytesWritten, alignof(T)); - size_t newBytesWritten = dstOffset + sizeof(value); - - assert(newBytesWritten <= gsl::narrow_cast(m_dst.size())); - memcpy(m_dst.data() + dstOffset, &value, sizeof(value)); - - m_bytesWritten = newBytesWritten; - } - - private: - template - T RoundUpToMultiple(T value, T multiple) - { - static_assert(std::is_integral_v); - - T remainder = value % multiple; - if (remainder != 0) - { - value += multiple - remainder; - } - - return value; - } - - gsl::span m_dst; - size_t m_bytesWritten; - }; - - template - DML_BUFFER_TENSOR_DESC MakeBufferTensorDesc(const DmlBufferTensorDesc& src, StackAllocator* allocator) - { - size_t dimensionCount = src.sizes.size(); - - auto* sizes = allocator->Allocate(dimensionCount); - std::copy_n(src.sizes.begin(), dimensionCount, sizes); - - UINT* strides = nullptr; - if (src.strides) - { - strides = allocator->Allocate(dimensionCount); - std::copy_n(src.strides->begin(), dimensionCount, strides); - } - - DML_BUFFER_TENSOR_DESC dst; - dst.DataType = src.dataType; - dst.Flags = src.flags; - dst.Sizes = sizes; - dst.Strides = strides; - dst.DimensionCount = static_cast(dimensionCount); - dst.TotalTensorSizeInBytes = src.totalTensorSizeInBytes; - dst.GuaranteedBaseOffsetAlignment = src.guaranteedBaseOffsetAlignment; - return dst; - } - - template - DML_TENSOR_DESC MakeTensorDesc(const DmlBufferTensorDesc& src, StackAllocator* allocator) - { - auto* desc = allocator->Allocate(); - *desc = MakeBufferTensorDesc(src, allocator); - - DML_TENSOR_DESC dst; - dst.Type = DML_TENSOR_TYPE_BUFFER; - dst.Desc = desc; - return dst; - } - - template - DML_OPERATOR_DESC ConvertOperatorDesc(const AbstractOperatorDesc& abstractDesc, StackAllocator* allocator); - - template - void WriteOperatorDescField(const OperatorField& field, StructFieldWriter* dst, StackAllocator* allocator) - { - const DML_SCHEMA_FIELD& schema = *field.GetSchema(); - - switch (schema.Type) - { - case DML_SCHEMA_FIELD_TYPE_TENSOR_DESC: - { - DML_TENSOR_DESC* desc = nullptr; - - const auto& value = field.AsTensorDesc(); - if (value) - { - desc = allocator->Allocate(); - *desc = MakeTensorDesc(*value, allocator); - } - - dst->Write(desc); - } break; - - case DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY: - { - DML_TENSOR_DESC* descs = nullptr; - - const auto& values = field.AsTensorDescArray(); - if (values) - { - descs = allocator->Allocate(values->size()); - for (size_t i = 0; i < values->size(); ++i) - { - descs[i] = MakeTensorDesc((*values)[i], allocator); - } - } - - dst->Write(descs); - } break; - - case DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC: - { - DML_OPERATOR_DESC* desc = nullptr; - - const auto& value = field.AsOperatorDesc(); - if (value) - { - desc = allocator->Allocate(); - *desc = ConvertOperatorDesc(*value, allocator); - } - - dst->Write(desc); - } break; - - case DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY: - { - DML_OPERATOR_DESC* descs = nullptr; - - const auto& values = field.AsOperatorDescArray(); - if (values) - { - descs = allocator->Allocate(values->size()); - for (size_t i = 0; i < values->size(); ++i) - { - descs[i] = ConvertOperatorDesc((*values)[i], allocator); - } - } - - dst->Write(descs); - } break; - - case DML_SCHEMA_FIELD_TYPE_UINT: - { - uint32_t value = field.AsUInt(); - dst->Write(value); - } break; - - case DML_SCHEMA_FIELD_TYPE_INT: - { - int32_t value = field.AsInt(); - dst->Write(value); - } break; - - case DML_SCHEMA_FIELD_TYPE_FLOAT: - { - float value = field.AsFloat(); - dst->Write(value); - } break; - - case DML_SCHEMA_FIELD_TYPE_UINT_ARRAY: - { - uint32_t* arrayPtr = nullptr; - - const auto& values = field.AsUIntArray(); - if (values) - { - arrayPtr = allocator->Allocate(values->size()); - std::copy(values->begin(), values->end(), arrayPtr); - } - - dst->Write(arrayPtr); - } break; - - case DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY: - { - float* arrayPtr = nullptr; - - const auto& values = field.AsFloatArray(); - if (values) - { - arrayPtr = allocator->Allocate(values->size()); - std::copy(values->begin(), values->end(), arrayPtr); - } - - dst->Write(arrayPtr); - } break; - - case DML_SCHEMA_FIELD_TYPE_SCALE_BIAS: - { - DML_SCALE_BIAS* scaleBias = nullptr; - - const auto& value = field.AsScaleBias(); - if (value) - { - scaleBias = allocator->Allocate(); - *scaleBias = *value; - } - - dst->Write(scaleBias); - } break; - - case DML_SCHEMA_FIELD_TYPE_SIZE_2D: - { - DML_SIZE_2D value = field.AsSize2D(); - dst->Write(value); - } break; - - default: - assert(false); - THROW_HR(E_UNEXPECTED); - } - } - - template - DML_OPERATOR_DESC ConvertOperatorDesc(const AbstractOperatorDesc& abstractDesc, StackAllocator* allocator) - { - const DML_OPERATOR_SCHEMA& schema = *abstractDesc.schema; - - // Retrieve the size of the ABI operator desc struct - size_t abiDescSizeInBytes = ApiTraits::OperatorTypeVisitor(schema.OperatorType, [](auto tag) { - using T = decltype(tag); // T is one of the DML_*_OPERATOR_DESC structs - return sizeof(T); - }); - - // Allocate a blob of bytes to hold the struct - byte* abiDesc = allocator->Allocate(abiDescSizeInBytes); - - // Use the schema to write data into the blob - - StructFieldWriter writer(gsl::make_span(abiDesc, abiDescSizeInBytes)); - - for (const OperatorField& field : abstractDesc.fields) - { - WriteOperatorDescField(field, &writer, allocator); - } - - return DML_OPERATOR_DESC{ schema.OperatorType, abiDesc }; - } - -} // namespace SchemaHelpers From 6c960a9417447a09227b8a71ae735f3ddbc747b5 Mon Sep 17 00:00:00 2001 From: Dwayne Robinson Date: Fri, 27 Mar 2020 18:10:46 -0700 Subject: [PATCH 6/6] PR feedback. --- .../DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp | 4 ++-- .../providers/dml/OperatorAuthorHelper/OperatorHelper.cpp | 1 + .../core/providers/dml/OperatorAuthorHelper/OperatorHelper.h | 5 ++++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp index c26f223634..5a2e9d888b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp @@ -15,7 +15,7 @@ public: { const uint32_t inputCount = kernelInfo.GetInputCount(); ML_CHECK_VALID_ARGUMENT((opsetVersion < 10 && inputCount == 1) - || (opsetVersion == 10 && inputCount >= 3 && inputCount <= 5)); + || (opsetVersion >= 10 && opsetVersion <= 11 && inputCount >= 3 && inputCount <= 5)); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); std::vector> kernelInputIndices = { 0 }; // Only bind GPU to first 'data' tensor. @@ -65,5 +65,5 @@ void CALLBACK QuerySlice(IMLOperatorSupportQueryContextPrivate* context, bool *i DML_OP_DEFINE_CREATION_FUNCTION(Slice7, DmlOperatorSliceTemplate<7>); DML_OP_DEFINE_CREATION_FUNCTION(Slice10, DmlOperatorSliceTemplate<10>); -DML_OP_DEFINE_CREATION_FUNCTION(Slice11, DmlOperatorSliceTemplate<10>); +DML_OP_DEFINE_CREATION_FUNCTION(Slice11, DmlOperatorSliceTemplate<11>); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 185020ce24..f4123c97d0 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -55,6 +55,7 @@ void ReadCpuLocalTensorIntoInt32( case MLOperatorTensorDataType::Int64: { const int64_t* data = tensor.GetData(); + result.reserve(elementCount); for (auto d : gsl::make_span(data, data + elementCount)) { result.push_back(gsl::narrow_cast(d)); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 4e62ba9bb8..1f89948e34 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -562,7 +562,7 @@ public: ends = operatorInfo.GetOptionalAttributeVectorInt32(AttrName::Ends); axes = operatorInfo.GetOptionalAttributeVectorInt32(AttrName::Axes); } - else if (opsetVersion == 10) + else if (opsetVersion == 10 || opsetVersion == 11) { // Read starts, ends, and axes from tensors. ReadIndexTensors(operatorInfo, /*out*/ starts, /*out*/ ends, /*out*/ axes, /*out*/ steps); @@ -615,6 +615,9 @@ public: end = std::min(end, dim); int size = std::max(end - start, 0); + // Set the input window offsets/sizes, and compute output size based on input + // window size (rounding up). + // e.g. a window size 13 and step 3 yields 5 output elements. int absoluteStride = abs(stride); m_outputDimensions[dimIndex] = (size / absoluteStride) + (size % absoluteStride != 0); m_offsets[dimIndex] = start;