Merged PR 4591959: Fix ORT DML EP's Slice shape operator helper

Related work items: #24672220
This commit is contained in:
Dwayne Robinson 2020-04-22 21:55:46 +00:00
parent e2288ff2b4
commit 26282359bf

View file

@ -617,13 +617,13 @@ public:
HandleNegativeAxes(/*inout*/ axes, inputDimensionCount);
ML_CHECK_VALID_ARGUMENT(starts.size() == ends.size(), "'starts' must equal 'ends' in size.");
ML_CHECK_VALID_ARGUMENT(steps.empty() || steps.size() == axes.size(), "'steps' must equal 'axes' in size, or 'steps' must be empty.");
ML_CHECK_VALID_ARGUMENT(axes.empty() || starts.size() == axes.size(), "'axes' must equal 'starts' in size, or 'axes' must be empty.");
m_outputDimensions.assign(inputDimensions.begin(), inputDimensions.end());
m_offsets.resize(m_outputDimensions.size());
m_sizes.resize(m_outputDimensions.size());
m_strides = std::move(steps);
m_strides.resize(m_outputDimensions.size(), 1); // Only a stride of 1 element is supported by ONNX 1.2.
m_strides.resize(m_outputDimensions.size(), 1); // Default initialize to all steps to 1's.
// Set initial defaults lest 'starts' and 'ends' arrays are shorter than the dimension count.
std::copy(inputDimensions.begin(), inputDimensions.begin() + m_outputDimensions.size(), m_sizes.begin());
@ -632,7 +632,7 @@ public:
for (int i = 0, ci = gsl::narrow_cast<int>(starts.size()); i < ci; ++i)
{
int dimIndex = axes.empty() ? i : axes[i];
int stride = m_strides[i];
int stride = steps.empty() ? 1 : steps[i];
ML_CHECK_VALID_ARGUMENT(dimIndex < inputDimensions.size(), "'axes' must be valid with within actual input dimensions.");
ML_CHECK_VALID_ARGUMENT(stride != 0, "'steps' must not be 0.");
@ -666,6 +666,7 @@ public:
int absoluteStride = abs(stride);
m_outputDimensions[dimIndex] = (size / absoluteStride) + (size % absoluteStride != 0);
m_offsets[dimIndex] = start;
m_strides[dimIndex] = stride;
m_sizes[dimIndex] = gsl::narrow_cast<uint32_t>(size);
}
}