mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
PR feedback.
This commit is contained in:
parent
031647635b
commit
6c960a9417
3 changed files with 7 additions and 3 deletions
|
|
@ -15,7 +15,7 @@ public:
|
|||
{
|
||||
const uint32_t inputCount = kernelInfo.GetInputCount();
|
||||
ML_CHECK_VALID_ARGUMENT((opsetVersion < 10 && inputCount == 1)
|
||||
|| (opsetVersion == 10 && inputCount >= 3 && inputCount <= 5));
|
||||
|| (opsetVersion >= 10 && opsetVersion <= 11 && inputCount >= 3 && inputCount <= 5));
|
||||
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
|
||||
|
||||
std::vector<std::optional<uint32_t>> kernelInputIndices = { 0 }; // Only bind GPU to first 'data' tensor.
|
||||
|
|
@ -65,5 +65,5 @@ void CALLBACK QuerySlice(IMLOperatorSupportQueryContextPrivate* context, bool *i
|
|||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Slice7, DmlOperatorSliceTemplate<7>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Slice10, DmlOperatorSliceTemplate<10>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Slice11, DmlOperatorSliceTemplate<10>);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Slice11, DmlOperatorSliceTemplate<11>);
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ void ReadCpuLocalTensorIntoInt32(
|
|||
case MLOperatorTensorDataType::Int64:
|
||||
{
|
||||
const int64_t* data = tensor.GetData<int64_t>();
|
||||
result.reserve(elementCount);
|
||||
for (auto d : gsl::make_span(data, data + elementCount))
|
||||
{
|
||||
result.push_back(gsl::narrow_cast<int32_t>(d));
|
||||
|
|
|
|||
|
|
@ -562,7 +562,7 @@ public:
|
|||
ends = operatorInfo.GetOptionalAttributeVectorInt32(AttrName::Ends);
|
||||
axes = operatorInfo.GetOptionalAttributeVectorInt32(AttrName::Axes);
|
||||
}
|
||||
else if (opsetVersion == 10)
|
||||
else if (opsetVersion == 10 || opsetVersion == 11)
|
||||
{
|
||||
// Read starts, ends, and axes from tensors.
|
||||
ReadIndexTensors(operatorInfo, /*out*/ starts, /*out*/ ends, /*out*/ axes, /*out*/ steps);
|
||||
|
|
@ -615,6 +615,9 @@ public:
|
|||
end = std::min(end, dim);
|
||||
int size = std::max(end - start, 0);
|
||||
|
||||
// Set the input window offsets/sizes, and compute output size based on input
|
||||
// window size (rounding up).
|
||||
// e.g. a window size 13 and step 3 yields 5 output elements.
|
||||
int absoluteStride = abs(stride);
|
||||
m_outputDimensions[dimIndex] = (size / absoluteStride) + (size % absoluteStride != 0);
|
||||
m_offsets[dimIndex] = start;
|
||||
|
|
|
|||
Loading…
Reference in a new issue