Fix squeeze.

This commit is contained in:
Dwayne Robinson 2020-03-24 23:24:58 -07:00
parent 5366273110
commit cc095fefbb
4 changed files with 59 additions and 22 deletions

View file

@ -1508,7 +1508,7 @@ onnxruntime::Status AbiOpKernel::Compute(onnxruntime::OpKernelContext* context)
{
tensorWrapper = wil::MakeOrThrow<TensorWrapper>(
const_cast<onnxruntime::Tensor*>(tensor),
IsAllocationInterface(tensor->Location()),
tensor ? IsAllocationInterface(tensor->Location()) : false,
winmlProviderCapture.Get(),
internalOpCapture);
}
@ -1552,21 +1552,27 @@ onnxruntime::Status AbiOpKernel::Compute(onnxruntime::OpKernelContext* context)
m_constantInputTensorContentsOfKernel.resize(context->InputCount());
for (uint32_t index : m_requiredConstantCpuInputs) {
MLOperatorTensor tensor = MLOperatorTensor(constantInputGetter(index).Get());
const onnxruntime::Tensor* weakTensor = context->Input<onnxruntime::Tensor>(static_cast<int>(index));
if (index >= static_cast<uint32_t>(context->InputCount())) {
continue;
}
m_constantInputTensorContentsOfKernel[index].isValid = (tensor.GetInterface() != nullptr);
// Skip optional constant tensors.
if (weakTensor != nullptr)
{
MLOperatorTensor tensor = MLOperatorTensor(constantInputGetter(index).Get());
if (tensor.GetInterface() != nullptr) {
m_constantInputTensorContentsOfKernel[index].shape = tensor.GetShape();
m_constantInputTensorContentsOfKernel[index].type = tensor.GetTensorDataType();
m_constantInputTensorContentsOfKernel[index].data.resize(tensor.GetUnalignedTensorByteSize());
if (index >= static_cast<uint32_t>(context->InputCount())) {
continue;
}
m_constantInputTensorContentsOfKernel[index].isValid = (tensor.GetInterface() != nullptr);
if (tensor.GetInterface() != nullptr) {
m_constantInputTensorContentsOfKernel[index].shape = tensor.GetShape();
m_constantInputTensorContentsOfKernel[index].type = tensor.GetTensorDataType();
m_constantInputTensorContentsOfKernel[index].data.resize(tensor.GetUnalignedTensorByteSize());
}
m_constantInputTensorContentsOfKernel[index].data.assign(
reinterpret_cast<const std::byte*>(tensor.GetByteData()),
reinterpret_cast<const std::byte*>(tensor.GetByteData()) + tensor.GetUnalignedTensorByteSize());
}
m_constantInputTensorContentsOfKernel[index].data.assign(
reinterpret_cast<const std::byte*>(tensor.GetByteData()),
reinterpret_cast<const std::byte*>(tensor.GetByteData()) + tensor.GetUnalignedTensorByteSize());
}
m_kernel = inferShapesAndCreateKernel(m_inputShapesOfKernelInference, m_inferredOutputShapes);

View file

@ -332,15 +332,17 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl
{REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, // Adds negative axis.
{REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
{REG_INFO_VER( 10, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported, {1, 2, 3}, std::nullopt, QuerySlice)},
#if 0 // TODO:DwayneR
{REG_INFO_VER( 11, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported, {1, 2, 3}, std::nullopt, QuerySlice)}, // Adds negative axes.
#endif
{REG_INFO( 7, Pad, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
#if 0 // TODO:DwayneR Pads and Value are inputs. https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Pad-11
#if 0 // TODO:NickFe Pads and Value are inputs. https://microsoft.visualstudio.com/OS/_workitems/edit/24674281, https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Pad-11
{REG_INFO( 11, Pad, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
#endif
{REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
#if 0
// TODO:Dwayner https://microsoft.visualstudio.com/OS/_workitems/edit/24672169
// TODO:Dwayner Update operator DepthToSpace-11 - added column-row-depth shuffle order mode https://microsoft.visualstudio.com/OS/_workitems/edit/24672169
{REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
#endif
{REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported, {1})},
@ -355,8 +357,6 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl
{REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
{REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
{REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
//!!!TODO:::DwayneR check remaining 11's for other work besides negative axes.
//Also verify that negative axes are handled.
{REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
{REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
{REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
@ -440,7 +440,7 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl
{REG_INFO( 7, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmGraphSupport::Supported)},
{REG_INFO( 9, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmGraphSupport::Supported)},
{REG_INFO( 7, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmGraphSupport::Supported)},
{REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmGraphSupport::Supported)},
{REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmGraphSupport::Supported)},
{REG_INFO( 7, Not, typeNameListDefault, supportedTypeListBool, DmGraphSupport::Supported)},
{REG_INFO( 7, And, typeNameListDefault, supportedTypeListBool, DmGraphSupport::Supported)},
{REG_INFO( 7, Or, typeNameListDefault, supportedTypeListBool, DmGraphSupport::Supported)},

View file

@ -1005,6 +1005,16 @@ namespace OperatorHelper
return { std::move(EdgeShapes(outputDimensions)) };
}
void SqueezeHelper::Initialize(
gsl::span<const int32_t> axes,
gsl::span<const DimensionType> inputDimensions
)
{
m_axes.assign(axes.begin(), axes.end());
HandleNegativeAxes(/*inout*/ m_axes, gsl::narrow_cast<uint32_t>(inputDimensions.size()));
std::sort(m_axes.begin(), m_axes.end());
}
std::vector<EdgeShapes> SqueezeHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
auto outputDimensions = shapeInfo.GetInputTensorShape(0);
@ -1038,6 +1048,17 @@ namespace OperatorHelper
return { std::move(outputDimensions) };
}
void UnsqueezeHelper::Initialize(
gsl::span<const int32_t> axes,
gsl::span<const DimensionType> inputDimensions
)
{
m_axes.assign(axes.begin(), axes.end());
const uint32_t outputDimensionCount = gsl::narrow_cast<uint32_t>(inputDimensions.size() + axes.size());
HandleNegativeAxes(/*inout*/ m_axes, outputDimensionCount);
std::sort(m_axes.begin(), m_axes.end());
}
std::vector<EdgeShapes> UnsqueezeHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
auto inputDimensions = shapeInfo.GetInputTensorShape(0);

View file

@ -996,12 +996,17 @@ class RoiPoolingHelper {
class SqueezeHelper {
public:
void Initialize(
gsl::span<const int32_t> axes,
gsl::span<const DimensionType> inputDimensions);
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
template <typename Info_t, typename Shape_t>
SqueezeHelper(const Info_t& info, const Shape_t& shape) {
m_axes = info.GetOptionalAttributeVectorInt32(AttrName::Axes);
std::sort(m_axes.begin(), m_axes.end());
Initialize(
info.GetOptionalAttributeVectorInt32(AttrName::Axes),
shape.GetInputTensorShape(0));
}
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
@ -1012,12 +1017,17 @@ class SqueezeHelper {
class UnsqueezeHelper {
public:
void Initialize(
gsl::span<const int32_t> axes,
gsl::span<const DimensionType> inputDimensions);
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
template <typename Info_t, typename Shape_t>
UnsqueezeHelper(const Info_t& info, const Shape_t& shape) {
m_axes = info.GetOptionalAttributeVectorInt32(AttrName::Axes);
std::sort(m_axes.begin(), m_axes.end());
Initialize(
info.GetOptionalAttributeVectorInt32(AttrName::Axes),
shape.GetInputTensorShape(0));
}
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;