PR feedback.

This commit is contained in:
Dwayne Robinson 2020-03-27 18:10:46 -07:00
parent 031647635b
commit 6c960a9417
3 changed files with 7 additions and 3 deletions

View file

@ -15,7 +15,7 @@ public:
{
const uint32_t inputCount = kernelInfo.GetInputCount();
ML_CHECK_VALID_ARGUMENT((opsetVersion < 10 && inputCount == 1)
|| (opsetVersion == 10 && inputCount >= 3 && inputCount <= 5));
|| (opsetVersion >= 10 && opsetVersion <= 11 && inputCount >= 3 && inputCount <= 5));
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
std::vector<std::optional<uint32_t>> kernelInputIndices = { 0 }; // Only bind GPU to first 'data' tensor.
@ -65,5 +65,5 @@ void CALLBACK QuerySlice(IMLOperatorSupportQueryContextPrivate* context, bool *i
DML_OP_DEFINE_CREATION_FUNCTION(Slice7, DmlOperatorSliceTemplate<7>);
DML_OP_DEFINE_CREATION_FUNCTION(Slice10, DmlOperatorSliceTemplate<10>);
DML_OP_DEFINE_CREATION_FUNCTION(Slice11, DmlOperatorSliceTemplate<10>);
DML_OP_DEFINE_CREATION_FUNCTION(Slice11, DmlOperatorSliceTemplate<11>);
} // namespace Dml

View file

@ -55,6 +55,7 @@ void ReadCpuLocalTensorIntoInt32(
case MLOperatorTensorDataType::Int64:
{
const int64_t* data = tensor.GetData<int64_t>();
result.reserve(elementCount);
for (auto d : gsl::make_span(data, data + elementCount))
{
result.push_back(gsl::narrow_cast<int32_t>(d));

View file

@ -562,7 +562,7 @@ public:
ends = operatorInfo.GetOptionalAttributeVectorInt32(AttrName::Ends);
axes = operatorInfo.GetOptionalAttributeVectorInt32(AttrName::Axes);
}
else if (opsetVersion == 10)
else if (opsetVersion == 10 || opsetVersion == 11)
{
// Read starts, ends, and axes from tensors.
ReadIndexTensors(operatorInfo, /*out*/ starts, /*out*/ ends, /*out*/ axes, /*out*/ steps);
@ -615,6 +615,9 @@ public:
end = std::min(end, dim);
int size = std::max(end - start, 0);
// Set the input window offsets/sizes, and compute output size based on input
// window size (rounding up).
// e.g. a window size 13 and step 3 yields 5 output elements.
int absoluteStride = abs(stride);
m_outputDimensions[dimIndex] = (size / absoluteStride) + (size % absoluteStride != 0);
m_offsets[dimIndex] = start;