diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp index a3547188f0..96c6557c01 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp @@ -6,6 +6,24 @@ namespace Dml { +constexpr NameAndIndex coordinateTransformationModes[] = +{ + {"half_pixel", 0}, + {"pytorch_half_pixel", 1}, + {"align_corners", 2}, + {"asymmetric", 3}, + {"tf_half_pixel_for_nn", 4}, + {"tf_crop_and_resize", 5}, +}; + +constexpr NameAndIndex nearestNeighborRoundingModes[] = +{ + {"", 0}, + {"round_prefer_floor", 0}, + {"round_prefer_ceil", 1}, + {"floor", 2}, +}; + void ComputePixelOffsetsAndScales( const MLOperatorKernelCreationContext& kernelCreationContext, gsl::span regionOfInterest, // May be empty depending on mode. @@ -23,17 +41,12 @@ void ComputePixelOffsetsAndScales( assert(regionOfInterest.empty() || regionOfInterest.size() == inputDimensions.size() * 2); std::string coordinateTransformationMode = kernelCreationContext.GetOptionalAttribute(AttrName::CoordinateTransformationMode, "half_pixel"); - uint32_t coordinateTransformationModeValue = UINT32_MAX; - - const char* modes[] = { "half_pixel", "pytorch_half_pixel", "align_corners", "asymmetric", "tf_half_pixel_for_nn", "tf_crop_and_resize" }; - for (uint32_t i = 0; i < std::size(modes); ++i) + auto optionalCoordinateTransformationModeValue = TryMapStringToIndex(coordinateTransformationMode, coordinateTransformationModes); + if (!optionalCoordinateTransformationModeValue) { - if (strcmp(modes[i], coordinateTransformationMode.c_str()) == 0) - { - coordinateTransformationModeValue = i; - break; - } + ML_INVALID_ARGUMENT("Unsupported 'coordinate_transformation_mode'"); } + uint32_t coordinateTransformationModeValue = *optionalCoordinateTransformationModeValue; ML_CHECK_VALID_ARGUMENT( !regionOfInterest.empty() || coordinateTransformationModeValue != 5 /*tf_crop_and_resize*/, @@ -150,7 +163,7 @@ void ComputePixelOffsetsAndScales( break; default: - ML_INVALID_ARGUMENT("Unknown 'coordinate_transformation_mode'"); + assert(false); // TryMapStringToIndex would have already bailed above. } inputPixelOffsets[i] = inputPixelOffset; @@ -233,6 +246,34 @@ public: std::string mode = kernelCreationContext.GetOptionalAttribute(AttrName::Mode, "NEAREST"); DML_INTERPOLATION_MODE interpolationMode = Dml::MapStringToInteropolationMode(mode); + // DML's nearest neighbor mode uses round-halves-up (or round_prefer_ceil) via floor(input.x + 0.5). + // So to support floor, adjust the input by half a pixel. + // round_prefer_floor is not supported without an API extension, + // but existing code already default to treating it as round_prefer_ceil. + // So continue that. + if (interpolationMode == DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR) + { + std::string nearestMode = kernelCreationContext.GetOptionalAttribute(AttrName::NearestMode, "round_prefer_floor"); + auto optionalNearestModeValue = TryMapStringToIndex(nearestMode, nearestNeighborRoundingModes); + if (optionalNearestModeValue) + { + switch (*optionalNearestModeValue) + { + case 0: // round_prefer_floor + case 1: // round_prefer_ceil + break; + case 2: // floor + for (auto& offset : inputPixelOffsets) + { + offset += 0.5; + } + break; + default: + assert(false); + } + } + } + // Create the operator description. std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); @@ -282,7 +323,8 @@ void CALLBACK QueryResize(IMLOperatorSupportQueryContextPrivate* context, bool* // DML's nearest neighbor mode uses half pixels rounded down. std::string nearestMode = attributes.GetOptionalAttribute(AttrName::NearestMode, "round_prefer_floor"); - if (nearestMode != "round_prefer_floor") + auto optionalNearestModeValue = TryMapStringToIndex(nearestMode, nearestNeighborRoundingModes); + if (!optionalNearestModeValue) { return; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiAlign.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiAlign.cpp index 75e7595a3a..533c894119 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiAlign.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiAlign.cpp @@ -31,10 +31,11 @@ public: {"avg", DML_REDUCE_FUNCTION_AVERAGE}, }; const std::string mode = kernelCreationContext.GetOptionalAttribute(AttrName::Mode, "avg"); - const auto reductionFunction = MapStringToIndex(mode, mapping); + const auto optionalReductionFunction = TryMapStringToIndex(mode, mapping); const float spatialScale = kernelCreationContext.GetOptionalAttribute(AttrName::SpatialScale, 1.0f); const int32_t samplesPerOutput = kernelCreationContext.GetOptionalAttribute(AttrName::SamplingRatio, 0u); ML_CHECK_VALID_ARGUMENT(samplesPerOutput >= 0, "sampling_ratio must be 0 or positive."); + ML_CHECK_VALID_ARGUMENT(!!optionalReductionFunction, "Unsupported RoiAlign mode."); DML_ROI_ALIGN_OPERATOR_DESC operatorDesc = {}; operatorDesc.InputTensor = &inputDescs[0]; @@ -46,7 +47,7 @@ public: operatorDesc.OutOfBoundsInputValue = 0.0f; // ONNX does not specify a value for input elements outside bounds. operatorDesc.MinimumSamplesPerOutput = (samplesPerOutput == 0) ? 1 : samplesPerOutput; operatorDesc.MaximumSamplesPerOutput = (samplesPerOutput == 0) ? UINT32_MAX : samplesPerOutput; - operatorDesc.ReductionFunction = reductionFunction; + operatorDesc.ReductionFunction = *optionalReductionFunction; operatorDesc.InterpolationMode = DML_INTERPOLATION_MODE_LINEAR; DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN, &operatorDesc }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp index afbe678b60..e8bd9bd952 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp @@ -359,7 +359,7 @@ namespace Dml } } - uint32_t MapStringToIndex(std::string_view mode, gsl::span nameAndIndexList) + std::optional TryMapStringToIndex(std::string_view mode, gsl::span nameAndIndexList) { for (auto& nameAndIndex : nameAndIndexList) { @@ -369,7 +369,7 @@ namespace Dml } } - ML_INVALID_ARGUMENT("Unknown mode value."); + return {}; } DML_INTERPOLATION_MODE MapStringToInteropolationMode(std::string_view mode) @@ -387,7 +387,11 @@ namespace Dml {"BILINEAR", DML_INTERPOLATION_MODE_LINEAR}, {"bilinear", DML_INTERPOLATION_MODE_LINEAR}, }; - return MapStringToIndex(mode, mapping); + if (auto index = TryMapStringToIndex(mode, mapping)) + { + return *index; + } + ML_INVALID_ARGUMENT("Unknown interpolation mode"); } DML_DEPTH_SPACE_ORDER MapStringToDepthSpaceMode(std::string_view mode) @@ -397,7 +401,11 @@ namespace Dml {"DCR", DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW}, {"CRD", DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH}, }; - return MapStringToIndex(mode, mapping); + if (auto index = TryMapStringToIndex(mode, mapping)) + { + return *index; + } + ML_INVALID_ARGUMENT("Unknown depth/space order"); } } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h index 027c2228bb..a638326b2a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h @@ -64,12 +64,14 @@ namespace Dml }; template - T MapStringToIndex(std::string_view mode, gsl::span nameAndIndexList) + std::optional TryMapStringToIndex(std::string_view mode, gsl::span nameAndIndexList) { - return static_cast(MapStringToIndex(mode, nameAndIndexList)); + static_assert(sizeof(T) == sizeof(uint32_t)); + auto result = TryMapStringToIndex(mode, nameAndIndexList); + return *reinterpret_cast*>(std::addressof(result)); } - uint32_t MapStringToIndex(std::string_view mode, gsl::span nameAndIndexList); + std::optional TryMapStringToIndex(std::string_view mode, gsl::span nameAndIndexList); DML_INTERPOLATION_MODE MapStringToInteropolationMode(std::string_view mode);