mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
Fix squeeze.
This commit is contained in:
parent
5366273110
commit
cc095fefbb
4 changed files with 59 additions and 22 deletions
|
|
@ -1508,7 +1508,7 @@ onnxruntime::Status AbiOpKernel::Compute(onnxruntime::OpKernelContext* context)
|
|||
{
|
||||
tensorWrapper = wil::MakeOrThrow<TensorWrapper>(
|
||||
const_cast<onnxruntime::Tensor*>(tensor),
|
||||
IsAllocationInterface(tensor->Location()),
|
||||
tensor ? IsAllocationInterface(tensor->Location()) : false,
|
||||
winmlProviderCapture.Get(),
|
||||
internalOpCapture);
|
||||
}
|
||||
|
|
@ -1552,21 +1552,27 @@ onnxruntime::Status AbiOpKernel::Compute(onnxruntime::OpKernelContext* context)
|
|||
|
||||
m_constantInputTensorContentsOfKernel.resize(context->InputCount());
|
||||
for (uint32_t index : m_requiredConstantCpuInputs) {
|
||||
MLOperatorTensor tensor = MLOperatorTensor(constantInputGetter(index).Get());
|
||||
const onnxruntime::Tensor* weakTensor = context->Input<onnxruntime::Tensor>(static_cast<int>(index));
|
||||
|
||||
if (index >= static_cast<uint32_t>(context->InputCount())) {
|
||||
continue;
|
||||
}
|
||||
m_constantInputTensorContentsOfKernel[index].isValid = (tensor.GetInterface() != nullptr);
|
||||
// Skip optional constant tensors.
|
||||
if (weakTensor != nullptr)
|
||||
{
|
||||
MLOperatorTensor tensor = MLOperatorTensor(constantInputGetter(index).Get());
|
||||
|
||||
if (tensor.GetInterface() != nullptr) {
|
||||
m_constantInputTensorContentsOfKernel[index].shape = tensor.GetShape();
|
||||
m_constantInputTensorContentsOfKernel[index].type = tensor.GetTensorDataType();
|
||||
m_constantInputTensorContentsOfKernel[index].data.resize(tensor.GetUnalignedTensorByteSize());
|
||||
if (index >= static_cast<uint32_t>(context->InputCount())) {
|
||||
continue;
|
||||
}
|
||||
m_constantInputTensorContentsOfKernel[index].isValid = (tensor.GetInterface() != nullptr);
|
||||
|
||||
if (tensor.GetInterface() != nullptr) {
|
||||
m_constantInputTensorContentsOfKernel[index].shape = tensor.GetShape();
|
||||
m_constantInputTensorContentsOfKernel[index].type = tensor.GetTensorDataType();
|
||||
m_constantInputTensorContentsOfKernel[index].data.resize(tensor.GetUnalignedTensorByteSize());
|
||||
}
|
||||
m_constantInputTensorContentsOfKernel[index].data.assign(
|
||||
reinterpret_cast<const std::byte*>(tensor.GetByteData()),
|
||||
reinterpret_cast<const std::byte*>(tensor.GetByteData()) + tensor.GetUnalignedTensorByteSize());
|
||||
}
|
||||
m_constantInputTensorContentsOfKernel[index].data.assign(
|
||||
reinterpret_cast<const std::byte*>(tensor.GetByteData()),
|
||||
reinterpret_cast<const std::byte*>(tensor.GetByteData()) + tensor.GetUnalignedTensorByteSize());
|
||||
}
|
||||
|
||||
m_kernel = inferShapesAndCreateKernel(m_inputShapesOfKernelInference, m_inferredOutputShapes);
|
||||
|
|
|
|||
|
|
@ -332,15 +332,17 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl
|
|||
{REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, // Adds negative axis.
|
||||
{REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 10, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported, {1, 2, 3}, std::nullopt, QuerySlice)},
|
||||
#if 0 // TODO:DwayneR
|
||||
{REG_INFO_VER( 11, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported, {1, 2, 3}, std::nullopt, QuerySlice)}, // Adds negative axes.
|
||||
#endif
|
||||
{REG_INFO( 7, Pad, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
#if 0 // TODO:DwayneR Pads and Value are inputs. https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Pad-11
|
||||
#if 0 // TODO:NickFe Pads and Value are inputs. https://microsoft.visualstudio.com/OS/_workitems/edit/24674281, https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Pad-11
|
||||
{REG_INFO( 11, Pad, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
#endif
|
||||
{REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
#if 0
|
||||
// TODO:Dwayner https://microsoft.visualstudio.com/OS/_workitems/edit/24672169
|
||||
// TODO:Dwayner Update operator DepthToSpace-11 - added column-row-depth shuffle order mode https://microsoft.visualstudio.com/OS/_workitems/edit/24672169
|
||||
{REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
#endif
|
||||
{REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported, {1})},
|
||||
|
|
@ -355,8 +357,6 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl
|
|||
{REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
|
||||
{REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
|
||||
{REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
|
||||
//!!!TODO:::DwayneR check remaining 11's for other work besides negative axes.
|
||||
//Also verify that negative axes are handled.
|
||||
{REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
|
||||
{REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
|
||||
{REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)},
|
||||
|
|
@ -440,7 +440,7 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl
|
|||
{REG_INFO( 7, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmGraphSupport::Supported)},
|
||||
{REG_INFO( 9, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Not, typeNameListDefault, supportedTypeListBool, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, And, typeNameListDefault, supportedTypeListBool, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Or, typeNameListDefault, supportedTypeListBool, DmGraphSupport::Supported)},
|
||||
|
|
|
|||
|
|
@ -1005,6 +1005,16 @@ namespace OperatorHelper
|
|||
return { std::move(EdgeShapes(outputDimensions)) };
|
||||
}
|
||||
|
||||
void SqueezeHelper::Initialize(
|
||||
gsl::span<const int32_t> axes,
|
||||
gsl::span<const DimensionType> inputDimensions
|
||||
)
|
||||
{
|
||||
m_axes.assign(axes.begin(), axes.end());
|
||||
HandleNegativeAxes(/*inout*/ m_axes, gsl::narrow_cast<uint32_t>(inputDimensions.size()));
|
||||
std::sort(m_axes.begin(), m_axes.end());
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> SqueezeHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
|
||||
{
|
||||
auto outputDimensions = shapeInfo.GetInputTensorShape(0);
|
||||
|
|
@ -1038,6 +1048,17 @@ namespace OperatorHelper
|
|||
return { std::move(outputDimensions) };
|
||||
}
|
||||
|
||||
void UnsqueezeHelper::Initialize(
|
||||
gsl::span<const int32_t> axes,
|
||||
gsl::span<const DimensionType> inputDimensions
|
||||
)
|
||||
{
|
||||
m_axes.assign(axes.begin(), axes.end());
|
||||
const uint32_t outputDimensionCount = gsl::narrow_cast<uint32_t>(inputDimensions.size() + axes.size());
|
||||
HandleNegativeAxes(/*inout*/ m_axes, outputDimensionCount);
|
||||
std::sort(m_axes.begin(), m_axes.end());
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> UnsqueezeHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
|
||||
{
|
||||
auto inputDimensions = shapeInfo.GetInputTensorShape(0);
|
||||
|
|
|
|||
|
|
@ -996,12 +996,17 @@ class RoiPoolingHelper {
|
|||
|
||||
class SqueezeHelper {
|
||||
public:
|
||||
void Initialize(
|
||||
gsl::span<const int32_t> axes,
|
||||
gsl::span<const DimensionType> inputDimensions);
|
||||
|
||||
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
|
||||
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
|
||||
template <typename Info_t, typename Shape_t>
|
||||
SqueezeHelper(const Info_t& info, const Shape_t& shape) {
|
||||
m_axes = info.GetOptionalAttributeVectorInt32(AttrName::Axes);
|
||||
std::sort(m_axes.begin(), m_axes.end());
|
||||
Initialize(
|
||||
info.GetOptionalAttributeVectorInt32(AttrName::Axes),
|
||||
shape.GetInputTensorShape(0));
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
|
||||
|
|
@ -1012,12 +1017,17 @@ class SqueezeHelper {
|
|||
|
||||
class UnsqueezeHelper {
|
||||
public:
|
||||
void Initialize(
|
||||
gsl::span<const int32_t> axes,
|
||||
gsl::span<const DimensionType> inputDimensions);
|
||||
|
||||
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
|
||||
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
|
||||
template <typename Info_t, typename Shape_t>
|
||||
UnsqueezeHelper(const Info_t& info, const Shape_t& shape) {
|
||||
m_axes = info.GetOptionalAttributeVectorInt32(AttrName::Axes);
|
||||
std::sort(m_axes.begin(), m_axes.end());
|
||||
Initialize(
|
||||
info.GetOptionalAttributeVectorInt32(AttrName::Axes),
|
||||
shape.GetInputTensorShape(0));
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
|
||||
|
|
|
|||
Loading…
Reference in a new issue