Merged PR 5873494: Resize support nearest_mode floor in DML EP

Resize support nearest_mode floor in DML EP.

Related work items: #32221069
This commit is contained in:
Dwayne Robinson 2021-04-02 00:28:27 +00:00
parent e6f35cc132
commit 06a2b0401a
4 changed files with 73 additions and 20 deletions

View file

@ -6,6 +6,24 @@
namespace Dml
{
constexpr NameAndIndex coordinateTransformationModes[] =
{
{"half_pixel", 0},
{"pytorch_half_pixel", 1},
{"align_corners", 2},
{"asymmetric", 3},
{"tf_half_pixel_for_nn", 4},
{"tf_crop_and_resize", 5},
};
constexpr NameAndIndex nearestNeighborRoundingModes[] =
{
{"", 0},
{"round_prefer_floor", 0},
{"round_prefer_ceil", 1},
{"floor", 2},
};
void ComputePixelOffsetsAndScales(
const MLOperatorKernelCreationContext& kernelCreationContext,
gsl::span<const float> regionOfInterest, // May be empty depending on mode.
@ -23,17 +41,12 @@ void ComputePixelOffsetsAndScales(
assert(regionOfInterest.empty() || regionOfInterest.size() == inputDimensions.size() * 2);
std::string coordinateTransformationMode = kernelCreationContext.GetOptionalAttribute<std::string>(AttrName::CoordinateTransformationMode, "half_pixel");
uint32_t coordinateTransformationModeValue = UINT32_MAX;
const char* modes[] = { "half_pixel", "pytorch_half_pixel", "align_corners", "asymmetric", "tf_half_pixel_for_nn", "tf_crop_and_resize" };
for (uint32_t i = 0; i < std::size(modes); ++i)
auto optionalCoordinateTransformationModeValue = TryMapStringToIndex(coordinateTransformationMode, coordinateTransformationModes);
if (!optionalCoordinateTransformationModeValue)
{
if (strcmp(modes[i], coordinateTransformationMode.c_str()) == 0)
{
coordinateTransformationModeValue = i;
break;
}
ML_INVALID_ARGUMENT("Unsupported 'coordinate_transformation_mode'");
}
uint32_t coordinateTransformationModeValue = *optionalCoordinateTransformationModeValue;
ML_CHECK_VALID_ARGUMENT(
!regionOfInterest.empty() || coordinateTransformationModeValue != 5 /*tf_crop_and_resize*/,
@ -150,7 +163,7 @@ void ComputePixelOffsetsAndScales(
break;
default:
ML_INVALID_ARGUMENT("Unknown 'coordinate_transformation_mode'");
assert(false); // TryMapStringToIndex would have already bailed above.
}
inputPixelOffsets[i] = inputPixelOffset;
@ -233,6 +246,34 @@ public:
std::string mode = kernelCreationContext.GetOptionalAttribute<std::string>(AttrName::Mode, "NEAREST");
DML_INTERPOLATION_MODE interpolationMode = Dml::MapStringToInteropolationMode(mode);
// DML's nearest neighbor mode uses round-halves-up (or round_prefer_ceil) via floor(input.x + 0.5).
// So to support floor, adjust the input by half a pixel.
// round_prefer_floor is not supported without an API extension,
// but existing code already default to treating it as round_prefer_ceil.
// So continue that.
if (interpolationMode == DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR)
{
std::string nearestMode = kernelCreationContext.GetOptionalAttribute<std::string>(AttrName::NearestMode, "round_prefer_floor");
auto optionalNearestModeValue = TryMapStringToIndex(nearestMode, nearestNeighborRoundingModes);
if (optionalNearestModeValue)
{
switch (*optionalNearestModeValue)
{
case 0: // round_prefer_floor
case 1: // round_prefer_ceil
break;
case 2: // floor
for (auto& offset : inputPixelOffsets)
{
offset += 0.5;
}
break;
default:
assert(false);
}
}
}
// Create the operator description.
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
@ -282,7 +323,8 @@ void CALLBACK QueryResize(IMLOperatorSupportQueryContextPrivate* context, bool*
// DML's nearest neighbor mode uses half pixels rounded down.
std::string nearestMode = attributes.GetOptionalAttribute<std::string>(AttrName::NearestMode, "round_prefer_floor");
if (nearestMode != "round_prefer_floor")
auto optionalNearestModeValue = TryMapStringToIndex(nearestMode, nearestNeighborRoundingModes);
if (!optionalNearestModeValue)
{
return;
}

View file

@ -31,10 +31,11 @@ public:
{"avg", DML_REDUCE_FUNCTION_AVERAGE},
};
const std::string mode = kernelCreationContext.GetOptionalAttribute<std::string>(AttrName::Mode, "avg");
const auto reductionFunction = MapStringToIndex<DML_REDUCE_FUNCTION>(mode, mapping);
const auto optionalReductionFunction = TryMapStringToIndex<DML_REDUCE_FUNCTION>(mode, mapping);
const float spatialScale = kernelCreationContext.GetOptionalAttribute<float>(AttrName::SpatialScale, 1.0f);
const int32_t samplesPerOutput = kernelCreationContext.GetOptionalAttribute<int32_t>(AttrName::SamplingRatio, 0u);
ML_CHECK_VALID_ARGUMENT(samplesPerOutput >= 0, "sampling_ratio must be 0 or positive.");
ML_CHECK_VALID_ARGUMENT(!!optionalReductionFunction, "Unsupported RoiAlign mode.");
DML_ROI_ALIGN_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = &inputDescs[0];
@ -46,7 +47,7 @@ public:
operatorDesc.OutOfBoundsInputValue = 0.0f; // ONNX does not specify a value for input elements outside bounds.
operatorDesc.MinimumSamplesPerOutput = (samplesPerOutput == 0) ? 1 : samplesPerOutput;
operatorDesc.MaximumSamplesPerOutput = (samplesPerOutput == 0) ? UINT32_MAX : samplesPerOutput;
operatorDesc.ReductionFunction = reductionFunction;
operatorDesc.ReductionFunction = *optionalReductionFunction;
operatorDesc.InterpolationMode = DML_INTERPOLATION_MODE_LINEAR;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN, &operatorDesc };

View file

@ -359,7 +359,7 @@ namespace Dml
}
}
uint32_t MapStringToIndex(std::string_view mode, gsl::span<const NameAndIndex> nameAndIndexList)
std::optional<uint32_t> TryMapStringToIndex(std::string_view mode, gsl::span<const NameAndIndex> nameAndIndexList)
{
for (auto& nameAndIndex : nameAndIndexList)
{
@ -369,7 +369,7 @@ namespace Dml
}
}
ML_INVALID_ARGUMENT("Unknown mode value.");
return {};
}
DML_INTERPOLATION_MODE MapStringToInteropolationMode(std::string_view mode)
@ -387,7 +387,11 @@ namespace Dml
{"BILINEAR", DML_INTERPOLATION_MODE_LINEAR},
{"bilinear", DML_INTERPOLATION_MODE_LINEAR},
};
return MapStringToIndex<DML_INTERPOLATION_MODE>(mode, mapping);
if (auto index = TryMapStringToIndex<DML_INTERPOLATION_MODE>(mode, mapping))
{
return *index;
}
ML_INVALID_ARGUMENT("Unknown interpolation mode");
}
DML_DEPTH_SPACE_ORDER MapStringToDepthSpaceMode(std::string_view mode)
@ -397,7 +401,11 @@ namespace Dml
{"DCR", DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW},
{"CRD", DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH},
};
return MapStringToIndex<DML_DEPTH_SPACE_ORDER>(mode, mapping);
if (auto index = TryMapStringToIndex<DML_DEPTH_SPACE_ORDER>(mode, mapping))
{
return *index;
}
ML_INVALID_ARGUMENT("Unknown depth/space order");
}
} // namespace Dml

View file

@ -64,12 +64,14 @@ namespace Dml
};
template<typename T>
T MapStringToIndex(std::string_view mode, gsl::span<const NameAndIndex> nameAndIndexList)
std::optional<T> TryMapStringToIndex(std::string_view mode, gsl::span<const NameAndIndex> nameAndIndexList)
{
return static_cast<T>(MapStringToIndex(mode, nameAndIndexList));
static_assert(sizeof(T) == sizeof(uint32_t));
auto result = TryMapStringToIndex(mode, nameAndIndexList);
return *reinterpret_cast<std::optional<T>*>(std::addressof(result));
}
uint32_t MapStringToIndex(std::string_view mode, gsl::span<const NameAndIndex> nameAndIndexList);
std::optional<uint32_t> TryMapStringToIndex(std::string_view mode, gsl::span<const NameAndIndex> nameAndIndexList);
DML_INTERPOLATION_MODE MapStringToInteropolationMode(std::string_view mode);