mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
Merged PR 5873494: Resize support nearest_mode floor in DML EP
Resize support nearest_mode floor in DML EP. Related work items: #32221069
This commit is contained in:
parent
e6f35cc132
commit
06a2b0401a
4 changed files with 73 additions and 20 deletions
|
|
@ -6,6 +6,24 @@
|
|||
namespace Dml
|
||||
{
|
||||
|
||||
constexpr NameAndIndex coordinateTransformationModes[] =
|
||||
{
|
||||
{"half_pixel", 0},
|
||||
{"pytorch_half_pixel", 1},
|
||||
{"align_corners", 2},
|
||||
{"asymmetric", 3},
|
||||
{"tf_half_pixel_for_nn", 4},
|
||||
{"tf_crop_and_resize", 5},
|
||||
};
|
||||
|
||||
constexpr NameAndIndex nearestNeighborRoundingModes[] =
|
||||
{
|
||||
{"", 0},
|
||||
{"round_prefer_floor", 0},
|
||||
{"round_prefer_ceil", 1},
|
||||
{"floor", 2},
|
||||
};
|
||||
|
||||
void ComputePixelOffsetsAndScales(
|
||||
const MLOperatorKernelCreationContext& kernelCreationContext,
|
||||
gsl::span<const float> regionOfInterest, // May be empty depending on mode.
|
||||
|
|
@ -23,17 +41,12 @@ void ComputePixelOffsetsAndScales(
|
|||
assert(regionOfInterest.empty() || regionOfInterest.size() == inputDimensions.size() * 2);
|
||||
|
||||
std::string coordinateTransformationMode = kernelCreationContext.GetOptionalAttribute<std::string>(AttrName::CoordinateTransformationMode, "half_pixel");
|
||||
uint32_t coordinateTransformationModeValue = UINT32_MAX;
|
||||
|
||||
const char* modes[] = { "half_pixel", "pytorch_half_pixel", "align_corners", "asymmetric", "tf_half_pixel_for_nn", "tf_crop_and_resize" };
|
||||
for (uint32_t i = 0; i < std::size(modes); ++i)
|
||||
auto optionalCoordinateTransformationModeValue = TryMapStringToIndex(coordinateTransformationMode, coordinateTransformationModes);
|
||||
if (!optionalCoordinateTransformationModeValue)
|
||||
{
|
||||
if (strcmp(modes[i], coordinateTransformationMode.c_str()) == 0)
|
||||
{
|
||||
coordinateTransformationModeValue = i;
|
||||
break;
|
||||
}
|
||||
ML_INVALID_ARGUMENT("Unsupported 'coordinate_transformation_mode'");
|
||||
}
|
||||
uint32_t coordinateTransformationModeValue = *optionalCoordinateTransformationModeValue;
|
||||
|
||||
ML_CHECK_VALID_ARGUMENT(
|
||||
!regionOfInterest.empty() || coordinateTransformationModeValue != 5 /*tf_crop_and_resize*/,
|
||||
|
|
@ -150,7 +163,7 @@ void ComputePixelOffsetsAndScales(
|
|||
break;
|
||||
|
||||
default:
|
||||
ML_INVALID_ARGUMENT("Unknown 'coordinate_transformation_mode'");
|
||||
assert(false); // TryMapStringToIndex would have already bailed above.
|
||||
}
|
||||
|
||||
inputPixelOffsets[i] = inputPixelOffset;
|
||||
|
|
@ -233,6 +246,34 @@ public:
|
|||
std::string mode = kernelCreationContext.GetOptionalAttribute<std::string>(AttrName::Mode, "NEAREST");
|
||||
DML_INTERPOLATION_MODE interpolationMode = Dml::MapStringToInteropolationMode(mode);
|
||||
|
||||
// DML's nearest neighbor mode uses round-halves-up (or round_prefer_ceil) via floor(input.x + 0.5).
|
||||
// So to support floor, adjust the input by half a pixel.
|
||||
// round_prefer_floor is not supported without an API extension,
|
||||
// but existing code already default to treating it as round_prefer_ceil.
|
||||
// So continue that.
|
||||
if (interpolationMode == DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR)
|
||||
{
|
||||
std::string nearestMode = kernelCreationContext.GetOptionalAttribute<std::string>(AttrName::NearestMode, "round_prefer_floor");
|
||||
auto optionalNearestModeValue = TryMapStringToIndex(nearestMode, nearestNeighborRoundingModes);
|
||||
if (optionalNearestModeValue)
|
||||
{
|
||||
switch (*optionalNearestModeValue)
|
||||
{
|
||||
case 0: // round_prefer_floor
|
||||
case 1: // round_prefer_ceil
|
||||
break;
|
||||
case 2: // floor
|
||||
for (auto& offset : inputPixelOffsets)
|
||||
{
|
||||
offset += 0.5;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create the operator description.
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
|
@ -282,7 +323,8 @@ void CALLBACK QueryResize(IMLOperatorSupportQueryContextPrivate* context, bool*
|
|||
|
||||
// DML's nearest neighbor mode uses half pixels rounded down.
|
||||
std::string nearestMode = attributes.GetOptionalAttribute<std::string>(AttrName::NearestMode, "round_prefer_floor");
|
||||
if (nearestMode != "round_prefer_floor")
|
||||
auto optionalNearestModeValue = TryMapStringToIndex(nearestMode, nearestNeighborRoundingModes);
|
||||
if (!optionalNearestModeValue)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,10 +31,11 @@ public:
|
|||
{"avg", DML_REDUCE_FUNCTION_AVERAGE},
|
||||
};
|
||||
const std::string mode = kernelCreationContext.GetOptionalAttribute<std::string>(AttrName::Mode, "avg");
|
||||
const auto reductionFunction = MapStringToIndex<DML_REDUCE_FUNCTION>(mode, mapping);
|
||||
const auto optionalReductionFunction = TryMapStringToIndex<DML_REDUCE_FUNCTION>(mode, mapping);
|
||||
const float spatialScale = kernelCreationContext.GetOptionalAttribute<float>(AttrName::SpatialScale, 1.0f);
|
||||
const int32_t samplesPerOutput = kernelCreationContext.GetOptionalAttribute<int32_t>(AttrName::SamplingRatio, 0u);
|
||||
ML_CHECK_VALID_ARGUMENT(samplesPerOutput >= 0, "sampling_ratio must be 0 or positive.");
|
||||
ML_CHECK_VALID_ARGUMENT(!!optionalReductionFunction, "Unsupported RoiAlign mode.");
|
||||
|
||||
DML_ROI_ALIGN_OPERATOR_DESC operatorDesc = {};
|
||||
operatorDesc.InputTensor = &inputDescs[0];
|
||||
|
|
@ -46,7 +47,7 @@ public:
|
|||
operatorDesc.OutOfBoundsInputValue = 0.0f; // ONNX does not specify a value for input elements outside bounds.
|
||||
operatorDesc.MinimumSamplesPerOutput = (samplesPerOutput == 0) ? 1 : samplesPerOutput;
|
||||
operatorDesc.MaximumSamplesPerOutput = (samplesPerOutput == 0) ? UINT32_MAX : samplesPerOutput;
|
||||
operatorDesc.ReductionFunction = reductionFunction;
|
||||
operatorDesc.ReductionFunction = *optionalReductionFunction;
|
||||
operatorDesc.InterpolationMode = DML_INTERPOLATION_MODE_LINEAR;
|
||||
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN, &operatorDesc };
|
||||
|
||||
|
|
|
|||
|
|
@ -359,7 +359,7 @@ namespace Dml
|
|||
}
|
||||
}
|
||||
|
||||
uint32_t MapStringToIndex(std::string_view mode, gsl::span<const NameAndIndex> nameAndIndexList)
|
||||
std::optional<uint32_t> TryMapStringToIndex(std::string_view mode, gsl::span<const NameAndIndex> nameAndIndexList)
|
||||
{
|
||||
for (auto& nameAndIndex : nameAndIndexList)
|
||||
{
|
||||
|
|
@ -369,7 +369,7 @@ namespace Dml
|
|||
}
|
||||
}
|
||||
|
||||
ML_INVALID_ARGUMENT("Unknown mode value.");
|
||||
return {};
|
||||
}
|
||||
|
||||
DML_INTERPOLATION_MODE MapStringToInteropolationMode(std::string_view mode)
|
||||
|
|
@ -387,7 +387,11 @@ namespace Dml
|
|||
{"BILINEAR", DML_INTERPOLATION_MODE_LINEAR},
|
||||
{"bilinear", DML_INTERPOLATION_MODE_LINEAR},
|
||||
};
|
||||
return MapStringToIndex<DML_INTERPOLATION_MODE>(mode, mapping);
|
||||
if (auto index = TryMapStringToIndex<DML_INTERPOLATION_MODE>(mode, mapping))
|
||||
{
|
||||
return *index;
|
||||
}
|
||||
ML_INVALID_ARGUMENT("Unknown interpolation mode");
|
||||
}
|
||||
|
||||
DML_DEPTH_SPACE_ORDER MapStringToDepthSpaceMode(std::string_view mode)
|
||||
|
|
@ -397,7 +401,11 @@ namespace Dml
|
|||
{"DCR", DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW},
|
||||
{"CRD", DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH},
|
||||
};
|
||||
return MapStringToIndex<DML_DEPTH_SPACE_ORDER>(mode, mapping);
|
||||
if (auto index = TryMapStringToIndex<DML_DEPTH_SPACE_ORDER>(mode, mapping))
|
||||
{
|
||||
return *index;
|
||||
}
|
||||
ML_INVALID_ARGUMENT("Unknown depth/space order");
|
||||
}
|
||||
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -64,12 +64,14 @@ namespace Dml
|
|||
};
|
||||
|
||||
template<typename T>
|
||||
T MapStringToIndex(std::string_view mode, gsl::span<const NameAndIndex> nameAndIndexList)
|
||||
std::optional<T> TryMapStringToIndex(std::string_view mode, gsl::span<const NameAndIndex> nameAndIndexList)
|
||||
{
|
||||
return static_cast<T>(MapStringToIndex(mode, nameAndIndexList));
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
auto result = TryMapStringToIndex(mode, nameAndIndexList);
|
||||
return *reinterpret_cast<std::optional<T>*>(std::addressof(result));
|
||||
}
|
||||
|
||||
uint32_t MapStringToIndex(std::string_view mode, gsl::span<const NameAndIndex> nameAndIndexList);
|
||||
std::optional<uint32_t> TryMapStringToIndex(std::string_view mode, gsl::span<const NameAndIndex> nameAndIndexList);
|
||||
|
||||
DML_INTERPOLATION_MODE MapStringToInteropolationMode(std::string_view mode);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue