mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Merged PR 4591959: Fix ORT DML EP's Slice shape operator helper
Related work items: #24672220
This commit is contained in:
parent
e2288ff2b4
commit
26282359bf
1 changed files with 4 additions and 3 deletions
|
|
@ -617,13 +617,13 @@ public:
|
|||
HandleNegativeAxes(/*inout*/ axes, inputDimensionCount);
|
||||
|
||||
ML_CHECK_VALID_ARGUMENT(starts.size() == ends.size(), "'starts' must equal 'ends' in size.");
|
||||
ML_CHECK_VALID_ARGUMENT(steps.empty() || steps.size() == axes.size(), "'steps' must equal 'axes' in size, or 'steps' must be empty.");
|
||||
ML_CHECK_VALID_ARGUMENT(axes.empty() || starts.size() == axes.size(), "'axes' must equal 'starts' in size, or 'axes' must be empty.");
|
||||
|
||||
m_outputDimensions.assign(inputDimensions.begin(), inputDimensions.end());
|
||||
m_offsets.resize(m_outputDimensions.size());
|
||||
m_sizes.resize(m_outputDimensions.size());
|
||||
m_strides = std::move(steps);
|
||||
m_strides.resize(m_outputDimensions.size(), 1); // Only a stride of 1 element is supported by ONNX 1.2.
|
||||
m_strides.resize(m_outputDimensions.size(), 1); // Default initialize to all steps to 1's.
|
||||
|
||||
// Set initial defaults lest 'starts' and 'ends' arrays are shorter than the dimension count.
|
||||
std::copy(inputDimensions.begin(), inputDimensions.begin() + m_outputDimensions.size(), m_sizes.begin());
|
||||
|
|
@ -632,7 +632,7 @@ public:
|
|||
for (int i = 0, ci = gsl::narrow_cast<int>(starts.size()); i < ci; ++i)
|
||||
{
|
||||
int dimIndex = axes.empty() ? i : axes[i];
|
||||
int stride = m_strides[i];
|
||||
int stride = steps.empty() ? 1 : steps[i];
|
||||
ML_CHECK_VALID_ARGUMENT(dimIndex < inputDimensions.size(), "'axes' must be valid with within actual input dimensions.");
|
||||
ML_CHECK_VALID_ARGUMENT(stride != 0, "'steps' must not be 0.");
|
||||
|
||||
|
|
@ -666,6 +666,7 @@ public:
|
|||
int absoluteStride = abs(stride);
|
||||
m_outputDimensions[dimIndex] = (size / absoluteStride) + (size % absoluteStride != 0);
|
||||
m_offsets[dimIndex] = start;
|
||||
m_strides[dimIndex] = stride;
|
||||
m_sizes[dimIndex] = gsl::narrow_cast<uint32_t>(size);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue