diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index 5df941e9ef..0c4e3fb9ea 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1508,7 +1508,7 @@ onnxruntime::Status AbiOpKernel::Compute(onnxruntime::OpKernelContext* context) { tensorWrapper = wil::MakeOrThrow( const_cast(tensor), - IsAllocationInterface(tensor->Location()), + tensor ? IsAllocationInterface(tensor->Location()) : false, winmlProviderCapture.Get(), internalOpCapture); } @@ -1552,21 +1552,27 @@ onnxruntime::Status AbiOpKernel::Compute(onnxruntime::OpKernelContext* context) m_constantInputTensorContentsOfKernel.resize(context->InputCount()); for (uint32_t index : m_requiredConstantCpuInputs) { - MLOperatorTensor tensor = MLOperatorTensor(constantInputGetter(index).Get()); + const onnxruntime::Tensor* weakTensor = context->Input(static_cast(index)); - if (index >= static_cast(context->InputCount())) { - continue; - } - m_constantInputTensorContentsOfKernel[index].isValid = (tensor.GetInterface() != nullptr); + // Skip optional constant tensors. + if (weakTensor != nullptr) + { + MLOperatorTensor tensor = MLOperatorTensor(constantInputGetter(index).Get()); - if (tensor.GetInterface() != nullptr) { - m_constantInputTensorContentsOfKernel[index].shape = tensor.GetShape(); - m_constantInputTensorContentsOfKernel[index].type = tensor.GetTensorDataType(); - m_constantInputTensorContentsOfKernel[index].data.resize(tensor.GetUnalignedTensorByteSize()); + if (index >= static_cast(context->InputCount())) { + continue; + } + m_constantInputTensorContentsOfKernel[index].isValid = (tensor.GetInterface() != nullptr); + + if (tensor.GetInterface() != nullptr) { + m_constantInputTensorContentsOfKernel[index].shape = tensor.GetShape(); + m_constantInputTensorContentsOfKernel[index].type = tensor.GetTensorDataType(); + m_constantInputTensorContentsOfKernel[index].data.resize(tensor.GetUnalignedTensorByteSize()); + } + m_constantInputTensorContentsOfKernel[index].data.assign( + reinterpret_cast(tensor.GetByteData()), + reinterpret_cast(tensor.GetByteData()) + tensor.GetUnalignedTensorByteSize()); } - m_constantInputTensorContentsOfKernel[index].data.assign( - reinterpret_cast(tensor.GetByteData()), - reinterpret_cast(tensor.GetByteData()) + tensor.GetUnalignedTensorByteSize()); } m_kernel = inferShapesAndCreateKernel(m_inputShapesOfKernelInference, m_inferredOutputShapes); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 34066a6c88..cee4d219de 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -332,15 +332,17 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl {REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, // Adds negative axis. {REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, {REG_INFO_VER( 10, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported, {1, 2, 3}, std::nullopt, QuerySlice)}, +#if 0 // TODO:DwayneR {REG_INFO_VER( 11, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported, {1, 2, 3}, std::nullopt, QuerySlice)}, // Adds negative axes. +#endif {REG_INFO( 7, Pad, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, -#if 0 // TODO:DwayneR Pads and Value are inputs. https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Pad-11 +#if 0 // TODO:NickFe Pads and Value are inputs. https://microsoft.visualstudio.com/OS/_workitems/edit/24674281, https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Pad-11 {REG_INFO( 11, Pad, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, #endif {REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, #if 0 - // TODO:Dwayner https://microsoft.visualstudio.com/OS/_workitems/edit/24672169 + // TODO:Dwayner Update operator DepthToSpace-11 - added column-row-depth shuffle order mode https://microsoft.visualstudio.com/OS/_workitems/edit/24672169 {REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, #endif {REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported, {1})}, @@ -355,8 +357,6 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl {REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)}, {REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)}, {REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)}, - //!!!TODO:::DwayneR check remaining 11's for other work besides negative axes. - //Also verify that negative axes are handled. {REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)}, {REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)}, {REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)}, @@ -440,7 +440,7 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl {REG_INFO( 7, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmGraphSupport::Supported)}, {REG_INFO( 9, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmGraphSupport::Supported)}, {REG_INFO( 7, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmGraphSupport::Supported)}, - {REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmGraphSupport::Supported)}, + {REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmGraphSupport::Supported)}, {REG_INFO( 7, Not, typeNameListDefault, supportedTypeListBool, DmGraphSupport::Supported)}, {REG_INFO( 7, And, typeNameListDefault, supportedTypeListBool, DmGraphSupport::Supported)}, {REG_INFO( 7, Or, typeNameListDefault, supportedTypeListBool, DmGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index c18652f29f..dbaf941203 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -1005,6 +1005,16 @@ namespace OperatorHelper return { std::move(EdgeShapes(outputDimensions)) }; } + void SqueezeHelper::Initialize( + gsl::span axes, + gsl::span inputDimensions + ) + { + m_axes.assign(axes.begin(), axes.end()); + HandleNegativeAxes(/*inout*/ m_axes, gsl::narrow_cast(inputDimensions.size())); + std::sort(m_axes.begin(), m_axes.end()); + } + std::vector SqueezeHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const { auto outputDimensions = shapeInfo.GetInputTensorShape(0); @@ -1038,6 +1048,17 @@ namespace OperatorHelper return { std::move(outputDimensions) }; } + void UnsqueezeHelper::Initialize( + gsl::span axes, + gsl::span inputDimensions + ) + { + m_axes.assign(axes.begin(), axes.end()); + const uint32_t outputDimensionCount = gsl::narrow_cast(inputDimensions.size() + axes.size()); + HandleNegativeAxes(/*inout*/ m_axes, outputDimensionCount); + std::sort(m_axes.begin(), m_axes.end()); + } + std::vector UnsqueezeHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const { auto inputDimensions = shapeInfo.GetInputTensorShape(0); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 5b8757199f..d63d3214cc 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -996,12 +996,17 @@ class RoiPoolingHelper { class SqueezeHelper { public: + void Initialize( + gsl::span axes, + gsl::span inputDimensions); + // Info_t is used to obtain attributes which will be used for calculating the output shape later. // Shape_t is used to obtain input shape which will be used for adjusting attribute value. template SqueezeHelper(const Info_t& info, const Shape_t& shape) { - m_axes = info.GetOptionalAttributeVectorInt32(AttrName::Axes); - std::sort(m_axes.begin(), m_axes.end()); + Initialize( + info.GetOptionalAttributeVectorInt32(AttrName::Axes), + shape.GetInputTensorShape(0)); } std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; @@ -1012,12 +1017,17 @@ class SqueezeHelper { class UnsqueezeHelper { public: + void Initialize( + gsl::span axes, + gsl::span inputDimensions); + // Info_t is used to obtain attributes which will be used for calculating the output shape later. // Shape_t is used to obtain input shape which will be used for adjusting attribute value. template UnsqueezeHelper(const Info_t& info, const Shape_t& shape) { - m_axes = info.GetOptionalAttributeVectorInt32(AttrName::Axes); - std::sort(m_axes.begin(), m_axes.end()); + Initialize( + info.GetOptionalAttributeVectorInt32(AttrName::Axes), + shape.GetInputTensorShape(0)); } std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;