mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
### Description <!-- Describe your changes. --> Pick up onnx/onnx#6010 to support EinSum shape inference ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> This change allows EinSum operator's output shape to be inferenced so that it can run on accelerators.
1029 lines
45 KiB
Diff
1029 lines
45 KiB
Diff
diff --git a/CMakeLists.txt b/CMakeLists.txt
|
|
index 6d7ca846..69aa622f 100644
|
|
--- a/CMakeLists.txt
|
|
+++ b/CMakeLists.txt
|
|
@@ -499,6 +499,7 @@ if (MSVC)
|
|
endif()
|
|
else()
|
|
# On non-Windows, hide all symbols we don't need
|
|
+ set(EXTRA_FLAGS "-Wno-unused-parameter")
|
|
set(ONNX_API_DEFINE "-DONNX_API=__attribute__\(\(__visibility__\(\"default\"\)\)\)")
|
|
set_target_properties(onnx_proto PROPERTIES CXX_VISIBILITY_PRESET hidden)
|
|
set_target_properties(onnx_proto PROPERTIES VISIBILITY_INLINES_HIDDEN 1)
|
|
@@ -653,20 +654,9 @@ endif()
|
|
if(MSVC)
|
|
target_compile_options(onnx_proto
|
|
PRIVATE /MP
|
|
- /wd4244 #'argument': conversion from 'google::
|
|
- #protobuf::uint64' to 'int', possible
|
|
- # loss of data
|
|
- /wd4267 # Conversion from 'size_t' to 'int',
|
|
- # possible loss of data
|
|
${EXTRA_FLAGS})
|
|
target_compile_options(onnx
|
|
PRIVATE /MP
|
|
- /wd4244 # 'argument': conversion from 'google::
|
|
- # protobuf::uint64' to 'int', possible
|
|
- # loss of data
|
|
- /wd4267 # Conversion from 'size_t' to 'int',
|
|
- # possible loss of data
|
|
- /wd4996 # The second parameter is ignored.
|
|
${EXTRA_FLAGS})
|
|
if(ONNX_USE_PROTOBUF_SHARED_LIBS)
|
|
target_compile_options(onnx_proto
|
|
diff --git a/onnx/common/file_utils.h b/onnx/common/file_utils.h
|
|
index b847798e..a6c31904 100644
|
|
--- a/onnx/common/file_utils.h
|
|
+++ b/onnx/common/file_utils.h
|
|
@@ -6,7 +6,6 @@
|
|
|
|
#pragma once
|
|
|
|
-#include <filesystem>
|
|
#include <fstream>
|
|
#include <string>
|
|
|
|
@@ -17,8 +16,7 @@ namespace ONNX_NAMESPACE {
|
|
|
|
template <typename T>
|
|
void LoadProtoFromPath(const std::string proto_path, T& proto) {
|
|
- std::filesystem::path proto_u8_path = std::filesystem::u8path(proto_path);
|
|
- std::fstream proto_stream(proto_u8_path, std::ios::in | std::ios::binary);
|
|
+ std::fstream proto_stream(proto_path, std::ios::in | std::ios::binary);
|
|
if (!proto_stream.good()) {
|
|
fail_check("Unable to open proto file: ", proto_path, ". Please check if it is a valid proto. ");
|
|
}
|
|
diff --git a/onnx/onnx_pb.h b/onnx/onnx_pb.h
|
|
index 0aab3e26..398ac2d6 100644
|
|
--- a/onnx/onnx_pb.h
|
|
+++ b/onnx/onnx_pb.h
|
|
@@ -47,10 +47,28 @@
|
|
#define ONNX_API ONNX_IMPORT
|
|
#endif
|
|
|
|
+#if defined(__GNUC__)
|
|
+#pragma GCC diagnostic push
|
|
+
|
|
+// In file included from onnx/onnx-ml.pb.h:30:
|
|
+// In file included from google/protobuf/extension_set.h:53:
|
|
+// google/protobuf/parse_context.h:328:47: error: implicit conversion loses integer precision: 'long' to 'int' [-Werror,-Wshorten-64-to-32]
|
|
+#if defined(__has_warning)
|
|
+#if __has_warning("-Wshorten-64-to-32")
|
|
+#pragma GCC diagnostic ignored "-Wshorten-64-to-32"
|
|
+#endif
|
|
+#endif // defined(__has_warning)
|
|
+
|
|
+#endif // defined(__GNUC__)
|
|
+
|
|
#ifdef ONNX_ML
|
|
#include "onnx/onnx-ml.pb.h"
|
|
#else
|
|
#include "onnx/onnx.pb.h"
|
|
#endif
|
|
|
|
+#if defined(__GNUC__)
|
|
+#pragma GCC diagnostic pop
|
|
+#endif
|
|
+
|
|
#endif // ! ONNX_ONNX_PB_H
|
|
diff --git a/onnx/defs/math/defs.cc b/onnx/defs/math/defs.cc
|
|
index c315a2a7..58963154 100644
|
|
--- a/onnx/defs/math/defs.cc
|
|
+++ b/onnx/defs/math/defs.cc
|
|
@@ -3472,6 +3472,9 @@ ONNX_OPERATOR_SET_SCHEMA(
|
|
}
|
|
|
|
auto& input_shape = getInputShape(ctx, 0);
|
|
+ if (input_shape.dim_size() < 2) {
|
|
+ fail_shape_inference("First input should have at least 2 dimensions in ", ctx.getDisplayName(), ".");
|
|
+ }
|
|
auto signal_dim = input_shape.dim(1);
|
|
if (!signal_dim.has_dim_value()) {
|
|
return;
|
|
diff --git a/onnx/defs/nn/defs.cc b/onnx/defs/nn/defs.cc
|
|
index be6a851d..fad595d0 100644
|
|
--- a/onnx/defs/nn/defs.cc
|
|
+++ b/onnx/defs/nn/defs.cc
|
|
@@ -126,6 +126,9 @@ void convPoolShapeInference(
|
|
residual -= stride;
|
|
}
|
|
}
|
|
+ if (i >= static_cast<int>(effective_kernel_shape.size())) {
|
|
+ fail_shape_inference("kernel shape should have ", input_dims_size, " values in ", ctx.getDisplayName(), ".");
|
|
+ }
|
|
int64_t total_pad = residual == 0 ? effective_kernel_shape[i] - stride : effective_kernel_shape[i] - residual;
|
|
if (total_pad < 0)
|
|
total_pad = 0;
|
|
@@ -959,19 +962,21 @@ ONNX_OPERATOR_SET_SCHEMA(
|
|
auto w_type = ctx.getInputType(3);
|
|
if (nullptr == x_type || nullptr == w_type || x_type->value_case() != TypeProto::kTensorType ||
|
|
w_type->value_case() != TypeProto::kTensorType) {
|
|
- fail_type_inference("inputs are expected to have tensor type.");
|
|
+ fail_type_inference("inputs are expected to have tensor type in ", ctx.getDisplayName(), ".");
|
|
}
|
|
|
|
auto x_zero_point_type = ctx.getInputType(2);
|
|
if (nullptr == x_zero_point_type ||
|
|
x_zero_point_type->tensor_type().elem_type() != x_type->tensor_type().elem_type()) {
|
|
- fail_type_inference("input and zero_point pair is expected to have be same type.");
|
|
+ fail_type_inference(
|
|
+ "input and zero_point pair is expected to have be same type in ", ctx.getDisplayName(), ".");
|
|
}
|
|
|
|
auto w_zero_point_type = ctx.getInputType(5);
|
|
if (nullptr == w_zero_point_type ||
|
|
w_zero_point_type->tensor_type().elem_type() != w_type->tensor_type().elem_type()) {
|
|
- fail_type_inference("weight and zero_point pair is expected to have same type.");
|
|
+ fail_type_inference(
|
|
+ "weight and zero_point pair is expected to have same type in ", ctx.getDisplayName(), ".");
|
|
}
|
|
|
|
propagateElemTypeFromInputToOutput(ctx, 7, 0);
|
|
@@ -2647,7 +2652,8 @@ ONNX_OPERATOR_SET_SCHEMA(
|
|
if (!hasNInputShapes(ctx, 1)) {
|
|
return;
|
|
}
|
|
- auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
|
|
+
|
|
+ auto& input_shape = getInputShape(ctx, 0);
|
|
int64_t input_ndim = input_shape.dim_size();
|
|
int64_t axis = -1;
|
|
auto axis_proto = ctx.getAttribute("axis");
|
|
@@ -2659,7 +2665,16 @@ ONNX_OPERATOR_SET_SCHEMA(
|
|
// positive value.
|
|
axis += input_ndim;
|
|
}
|
|
-
|
|
+ if (axis < 0) {
|
|
+ fail_shape_inference(
|
|
+ "Unexpected axis value (",
|
|
+ axis,
|
|
+ ") rank of first input is ",
|
|
+ input_ndim,
|
|
+ " in ",
|
|
+ ctx.getDisplayName(),
|
|
+ ".");
|
|
+ }
|
|
if (ctx.getNumOutputs() > 1) {
|
|
auto mean_shape = ctx.getOutputType(1)->mutable_tensor_type()->mutable_shape();
|
|
mean_shape->CopyFrom(input_shape);
|
|
diff --git a/onnx/defs/nn/old.cc b/onnx/defs/nn/old.cc
|
|
index 57f8e2a4..8b2dc07f 100644
|
|
--- a/onnx/defs/nn/old.cc
|
|
+++ b/onnx/defs/nn/old.cc
|
|
@@ -201,6 +201,9 @@ void convPoolShapeInference_opset19(
|
|
residual -= stride;
|
|
}
|
|
}
|
|
+ if (i >= static_cast<int>(effective_kernel_shape.size())) {
|
|
+ fail_shape_inference("kernel shape should have ", input_dims_size, " values in ", ctx.getDisplayName(), ".");
|
|
+ }
|
|
int64_t total_pad = residual == 0 ? effective_kernel_shape[i] - stride : effective_kernel_shape[i] - residual;
|
|
if (total_pad < 0)
|
|
total_pad = 0;
|
|
diff --git a/onnx/defs/shape_inference.h b/onnx/defs/shape_inference.h
|
|
index a80473b3..d1bcd401 100644
|
|
--- a/onnx/defs/shape_inference.h
|
|
+++ b/onnx/defs/shape_inference.h
|
|
@@ -105,6 +105,10 @@ struct InferenceContext {
|
|
virtual const SparseTensorProto* getInputSparseData(size_t index) const = 0;
|
|
// Gets the shape inputs computed by partial data propagation.
|
|
virtual const TensorShapeProto* getSymbolicInput(size_t index) const = 0;
|
|
+ // To display a name the user can use to narrow its search.
|
|
+ virtual std::string getDisplayName() const {
|
|
+ return "";
|
|
+ }
|
|
};
|
|
|
|
// We use data propagation to perform partial evaluation of the model, to compute statically
|
|
@@ -263,7 +267,15 @@ inline void propagateElemTypeFromDtypeToOutput(
|
|
} else {
|
|
// This is not expected to happen
|
|
fail_type_inference(
|
|
- "Output ", outputIndex, " expected to have: ", expected_value_case, " or UNDEFINED. Got: ", output_value_case);
|
|
+ "Output ",
|
|
+ outputIndex,
|
|
+ " expected to have: ",
|
|
+ expected_value_case,
|
|
+ " or UNDEFINED. Got: ",
|
|
+ output_value_case,
|
|
+ " in ",
|
|
+ ctx.getDisplayName(),
|
|
+ ".");
|
|
}
|
|
}
|
|
|
|
@@ -277,18 +289,18 @@ inline void propagateElemTypeFromDtypeToOutput(InferenceContext& ctx, const Attr
|
|
const auto attr_type = attr->type();
|
|
if (attr_type == AttributeProto::TENSOR) {
|
|
if (attr->t().dims().size() != 1) {
|
|
- fail_type_inference("Attribute expected to have a one-dim tensor");
|
|
+ fail_type_inference("Attribute expected to have a one-dim tensor in ", ctx.getDisplayName(), ".");
|
|
}
|
|
data_type = attr->t().data_type();
|
|
expected_value_case = TypeProto::kTensorType;
|
|
} else if (attr_type == AttributeProto::SPARSE_TENSOR) {
|
|
if (attr->sparse_tensor().dims().size() != 1) {
|
|
- fail_type_inference("Attribute expected to have a one-dim sparse tensor");
|
|
+ fail_type_inference("Attribute expected to have a one-dim sparse tensor in ", ctx.getDisplayName(), ".");
|
|
}
|
|
data_type = attr->sparse_tensor().values().data_type();
|
|
expected_value_case = TypeProto::kSparseTensorType;
|
|
} else {
|
|
- fail_type_inference("Attribute expected to have tensor or sparse tensor type");
|
|
+ fail_type_inference("Attribute expected to have tensor or sparse tensor type in ", ctx.getDisplayName(), ".");
|
|
}
|
|
|
|
propagateElemTypeFromDtypeToOutput(ctx, data_type, outputIndex, expected_value_case);
|
|
@@ -326,7 +338,10 @@ inline const TensorShapeProto& getInputShape(const InferenceContext& ctx, size_t
|
|
const auto* input_type = ctx.getInputType(n);
|
|
const auto value_case = input_type->value_case();
|
|
if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) {
|
|
- fail_type_inference("Attribute expected to have tensor or sparse tensor type");
|
|
+ fail_type_inference("Input ", n, "expected to be a tensor or a sparse tensor type in ", ctx.getDisplayName(), ".");
|
|
+ }
|
|
+ if (!hasShape(*input_type)) {
|
|
+ fail_shape_inference("Input ", n, " must have a non null shape in ", ctx.getDisplayName(), ".");
|
|
}
|
|
if (value_case == TypeProto::kTensorType) {
|
|
return input_type->tensor_type().shape();
|
|
@@ -344,7 +359,7 @@ inline const TensorShapeProto* getOptionalInputShape(InferenceContext& ctx, size
|
|
|
|
const auto value_case = input_type->value_case();
|
|
if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) {
|
|
- fail_type_inference("Attribute expected to have tensor or sparse tensor type");
|
|
+ fail_type_inference("Input ", n, "expected to be a tensor or a sparse tensor type in ", ctx.getDisplayName(), ".");
|
|
}
|
|
if (value_case == TypeProto::kTensorType) {
|
|
return &input_type->tensor_type().shape();
|
|
@@ -372,7 +387,10 @@ inline void appendSingleDimCopiedFromInputTypeToOutputType(
|
|
" does not match type of output: ",
|
|
outputIndex,
|
|
"type: ",
|
|
- output_value_case);
|
|
+ output_value_case,
|
|
+ " in ",
|
|
+ ctx.getDisplayName(),
|
|
+ ".");
|
|
}
|
|
if (TypeProto::kTensorType == input_value_case) {
|
|
auto* dim = output_type->mutable_tensor_type()->mutable_shape()->add_dim();
|
|
@@ -382,7 +400,13 @@ inline void appendSingleDimCopiedFromInputTypeToOutputType(
|
|
*dim = input_type->sparse_tensor_type().shape().dim(static_cast<int>(fromDimIndex));
|
|
} else {
|
|
fail_type_inference(
|
|
- "Input ", inputIndex, " and Output ", outputIndex, " expected to have tensor or sparse tensor type");
|
|
+ "Input ",
|
|
+ inputIndex,
|
|
+ " and Output ",
|
|
+ outputIndex,
|
|
+ " expected to have tensor or sparse tensor type in ",
|
|
+ ctx.getDisplayName(),
|
|
+ ".");
|
|
}
|
|
}
|
|
|
|
@@ -440,7 +464,14 @@ updateOutputElemType(InferenceContext& ctx, size_t outputIndex, int32_t elemType
|
|
setTensorElementType(elemType, expected_type, *output_type);
|
|
} else {
|
|
// This is not expected to happen
|
|
- fail_type_inference("Output ", outputIndex, " expected to have tensor or sparse tensor type: ", expected_type);
|
|
+ fail_type_inference(
|
|
+ "Output ",
|
|
+ outputIndex,
|
|
+ " expected to have tensor or sparse tensor type: ",
|
|
+ expected_type,
|
|
+ " in ",
|
|
+ ctx.getDisplayName(),
|
|
+ ".");
|
|
}
|
|
}
|
|
|
|
@@ -462,16 +493,17 @@ inline void propagateElemTypeFromAttributeToOutput(
|
|
updateOutputElemType(ctx, outputIndex, default_value, expected_type);
|
|
return;
|
|
} else {
|
|
- fail_type_inference("Value of attribute ", attributeName, " not specified");
|
|
+ fail_type_inference("Value of attribute ", attributeName, " not specified in ", ctx.getDisplayName(), ".");
|
|
}
|
|
}
|
|
if (!attr_proto->has_i()) {
|
|
- fail_type_inference("Attribute ", attributeName, " should be of integer type and specify a type.");
|
|
+ fail_type_inference(
|
|
+ "Attribute ", attributeName, " should be of integer type and specify a type in ", ctx.getDisplayName(), ".");
|
|
}
|
|
auto attr_value = attr_proto->i();
|
|
auto elem_type = static_cast<TensorProto_DataType>(attr_value);
|
|
if (!TensorProto_DataType_IsValid(elem_type)) {
|
|
- fail_type_inference("Attribute ", attributeName, " does not specify a valid type.");
|
|
+ fail_type_inference("Attribute ", attributeName, " does not specify a valid type in ", ctx.getDisplayName(), ".");
|
|
}
|
|
updateOutputElemType(ctx, outputIndex, elem_type, expected_type);
|
|
}
|
|
@@ -497,7 +529,7 @@ inline TensorShapeProto*
|
|
getOutputShape(InferenceContext& ctx, size_t n, TypeProto::ValueCase default_type = TypeProto::kTensorType) {
|
|
auto output_type = ctx.getOutputType(n);
|
|
if (output_type == nullptr) {
|
|
- fail_type_inference("Output ", n, " expected to have tensor or sparse type");
|
|
+ fail_type_inference("Output ", n, " expected to have tensor or sparse type in ", ctx.getDisplayName(), ".");
|
|
}
|
|
const auto output_value_case = output_type->value_case();
|
|
if (output_value_case == TypeProto::kTensorType || output_value_case == TypeProto::kSparseTensorType) {
|
|
@@ -505,7 +537,7 @@ getOutputShape(InferenceContext& ctx, size_t n, TypeProto::ValueCase default_typ
|
|
} else if (output_value_case == TypeProto::VALUE_NOT_SET) {
|
|
return getTensorMutableShape(default_type, *output_type);
|
|
} else {
|
|
- fail_type_inference("Output ", n, " expected to have tensor type");
|
|
+ fail_type_inference("Output ", n, " expected to have tensor type in ", ctx.getDisplayName(), ".");
|
|
}
|
|
}
|
|
|
|
@@ -562,13 +594,13 @@ inline void propagateShapeFromAttributeToOutput(
|
|
auto attr_proto = ctx.getAttribute(attributeName);
|
|
if ((nullptr == attr_proto) || (!attr_proto->has_type()) ||
|
|
(attr_proto->type() != AttributeProto_AttributeType_INTS)) {
|
|
- fail_shape_inference("Attribute ", attributeName, " should specify a shape");
|
|
+ fail_shape_inference("Attribute ", attributeName, " should specify a shape in ", ctx.getDisplayName(), ".");
|
|
}
|
|
auto& int_list = attr_proto->ints();
|
|
TensorShapeProto shape;
|
|
for (auto dim_size : int_list) {
|
|
if (dim_size < 0) {
|
|
- fail_shape_inference("Negative values are not allowed in a shape specification");
|
|
+ fail_shape_inference("Negative values are not allowed in a shape specification in ", ctx.getDisplayName(), ".");
|
|
}
|
|
shape.add_dim()->set_dim_value(dim_size);
|
|
}
|
|
@@ -745,7 +777,16 @@ inline void checkInputRank(InferenceContext& ctx, size_t input_index, int expect
|
|
if (hasInputShape(ctx, input_index)) {
|
|
auto rank = getInputShape(ctx, input_index).dim_size();
|
|
if (rank != expected_rank) {
|
|
- fail_shape_inference("Input ", input_index, " expected to have rank ", expected_rank, " but has rank ", rank);
|
|
+ fail_shape_inference(
|
|
+ "Input ",
|
|
+ input_index,
|
|
+ " expected to have rank ",
|
|
+ expected_rank,
|
|
+ " but has rank ",
|
|
+ rank,
|
|
+ " in ",
|
|
+ ctx.getDisplayName(),
|
|
+ ".");
|
|
}
|
|
}
|
|
}
|
|
@@ -798,7 +839,15 @@ inline void unifyInputDim(InferenceContext& ctx, size_t input_index, int dim_ind
|
|
// This shape is expected to have rank > dim_index:
|
|
if (input_shape.dim_size() <= dim_index) {
|
|
fail_shape_inference(
|
|
- "Input ", input_index, " expected to have rank >", dim_index, " but has rank ", input_shape.dim_size());
|
|
+ "Input ",
|
|
+ input_index,
|
|
+ " expected to have rank >",
|
|
+ dim_index,
|
|
+ " but has rank ",
|
|
+ input_shape.dim_size(),
|
|
+ " in ",
|
|
+ ctx.getDisplayName(),
|
|
+ ".");
|
|
}
|
|
const Dim& input_dim = input_shape.dim(dim_index);
|
|
// Now, unify dim and input_dim:
|
|
diff --git a/onnx/shape_inference/implementation.cc b/onnx/shape_inference/implementation.cc
|
|
index 8723dcd4..8249fc59 100644
|
|
--- a/onnx/shape_inference/implementation.cc
|
|
+++ b/onnx/shape_inference/implementation.cc
|
|
@@ -906,7 +906,7 @@ struct FunctionInferenceContext : public InferenceContext {
|
|
const std::vector<TypeProto>& input_types,
|
|
const std::vector<AttributeProto>& attributes,
|
|
const ShapeInferenceOptions& options)
|
|
- : input_types_(input_types), options_(options) {
|
|
+ : input_types_(input_types), options_(options), func_proto_(&func_proto) {
|
|
for (const auto& attr : attributes) {
|
|
attributesByName_[attr.name()] = &attr;
|
|
}
|
|
@@ -971,11 +971,25 @@ struct FunctionInferenceContext : public InferenceContext {
|
|
return std::move(output_types_);
|
|
}
|
|
|
|
+ std::string getDisplayName() const override {
|
|
+ if (func_proto_ == nullptr)
|
|
+ return "";
|
|
+ if (func_proto_->domain().empty()) {
|
|
+ if (func_proto_->name().empty())
|
|
+ return "";
|
|
+ return MakeString("function ", func_proto_->name());
|
|
+ }
|
|
+ if (func_proto_->name().empty())
|
|
+ return MakeString("function [", func_proto_->domain(), "]");
|
|
+ return MakeString("function ", func_proto_->name(), "[", func_proto_->domain(), "]");
|
|
+ }
|
|
+
|
|
private:
|
|
const std::vector<TypeProto>& input_types_;
|
|
std::vector<TypeProto> output_types_;
|
|
std::unordered_map<std::string, const AttributeProto*> attributesByName_;
|
|
ShapeInferenceOptions options_;
|
|
+ const FunctionProto* func_proto_;
|
|
};
|
|
|
|
std::vector<TypeProto> InferFunctionOutputTypes(
|
|
diff --git a/onnx/shape_inference/implementation.h b/onnx/shape_inference/implementation.h
|
|
index 2c63c910..b0e4c32d 100644
|
|
--- a/onnx/shape_inference/implementation.h
|
|
+++ b/onnx/shape_inference/implementation.h
|
|
@@ -146,7 +146,7 @@ struct InferenceContextImpl : public InferenceContext {
|
|
const ShapeInferenceOptions& options,
|
|
DataValueMap* generatedShapeData = nullptr,
|
|
GraphInferenceContext* graphInferenceContext = nullptr)
|
|
- : graphInferenceContext_{graphInferenceContext}, options_(options) {
|
|
+ : graphInferenceContext_{graphInferenceContext}, options_(options), node_(&n) {
|
|
for (auto& attr : *n.mutable_attribute()) {
|
|
attributesByName_[attr.name()] = &attr;
|
|
if (attr.has_g()) {
|
|
@@ -277,6 +277,19 @@ struct InferenceContextImpl : public InferenceContext {
|
|
return inferencer;
|
|
}
|
|
|
|
+ std::string getDisplayName() const override {
|
|
+ if (node_ == nullptr)
|
|
+ return "";
|
|
+ if (node_->domain().empty()) {
|
|
+ if (node_->name().empty())
|
|
+ return MakeString("node ", node_->op_type());
|
|
+ return MakeString("node ", node_->op_type(), " (", node_->name(), ")");
|
|
+ }
|
|
+ if (node_->name().empty())
|
|
+ return MakeString("node ", node_->op_type(), "[", node_->domain(), "]");
|
|
+ return MakeString("node ", node_->op_type(), "[", node_->domain(), "]", " (", node_->name(), ")");
|
|
+ }
|
|
+
|
|
std::vector<const TensorProto*> allInputData_;
|
|
std::vector<const SparseTensorProto*> allInputSparseData_;
|
|
std::vector<const TensorShapeProto*> allShapeInputData_;
|
|
@@ -289,6 +302,7 @@ struct InferenceContextImpl : public InferenceContext {
|
|
// mutable as internal cache of GraphInferencer instances
|
|
mutable std::unordered_map<std::string, std::unique_ptr<GraphInferencer>> graphAttributeInferencers_;
|
|
ShapeInferenceOptions options_;
|
|
+ NodeProto* node_;
|
|
};
|
|
|
|
struct DataPropagationContextImpl : public DataPropagationContext {
|
|
diff --git a/onnx/defs/math/defs.cc b/onnx/defs/math/defs.cc
|
|
index ef379d8f..b7dfe3c8 100644
|
|
--- a/onnx/defs/math/defs.cc
|
|
+++ b/onnx/defs/math/defs.cc
|
|
@@ -2568,17 +2568,17 @@ ONNX_OPERATOR_SET_SCHEMA(
|
|
}
|
|
}));
|
|
|
|
-void einsumRankInference(ONNX_NAMESPACE::InferenceContext& ctx, std::string equation) {
|
|
- const size_t numInputs = ctx.getNumInputs();
|
|
- if (numInputs < 1 || !hasNInputShapes(ctx, static_cast<int>(numInputs))) {
|
|
+void einsumShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, std::string const& equation) {
|
|
+ // Only accept letters for indices
|
|
+ auto is_letter = [](char c) { return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'); };
|
|
+
|
|
+ const size_t num_inputs = ctx.getNumInputs();
|
|
+ if (num_inputs < 1 || !hasNInputShapes(ctx, static_cast<int>(num_inputs))) {
|
|
return;
|
|
}
|
|
-
|
|
- auto* output_shape = getOutputShape(ctx, 0);
|
|
+ ONNX_NAMESPACE::TensorShapeProto output_shape;
|
|
std::string left_equation;
|
|
|
|
- equation.erase(std::remove(equation.begin(), equation.end(), ' '),
|
|
- equation.end()); // Remove space char
|
|
auto mid_index = equation.find("->");
|
|
if (mid_index != std::string::npos) {
|
|
// Separate right and left hand sides of the equation
|
|
@@ -2595,73 +2595,130 @@ void einsumRankInference(ONNX_NAMESPACE::InferenceContext& ctx, std::string equa
|
|
|
|
// Parse the left-hand side
|
|
std::stringstream str(left_equation);
|
|
+ std::map<char, size_t> label_maps;
|
|
+ std::set<char> repeated_labels;
|
|
+ ONNX_NAMESPACE::TensorShapeProto dims_value, ellipsis_dims_value;
|
|
+ size_t num_labels = 0;
|
|
+ bool ellipsis_flag = true;
|
|
+
|
|
while (!str.eof()) {
|
|
std::getline(str, term, ',');
|
|
auto ellipsis_index = term.find("...");
|
|
- if (numInputs <= num_operands) {
|
|
+ if (num_inputs <= num_operands) {
|
|
fail_shape_inference("Number of input tensors does not match the operands in the equation.");
|
|
}
|
|
- size_t rank = ctx.getInputType(num_operands)->tensor_type().shape().dim_size();
|
|
+ const auto& shape = ctx.getInputType(num_operands)->tensor_type().shape();
|
|
+ size_t rank = shape.dim_size();
|
|
+ size_t ellipsis_dims = 0;
|
|
+
|
|
+ size_t term_size = 0; // number of legal indices for the current term
|
|
+ size_t num_illegal_char = 0; // number of illegal char before the current 'index' in the current term
|
|
+
|
|
+ for (size_t index = 0; index < term.size(); ++index) {
|
|
+ if (is_letter(term[index])) {
|
|
+ term_size += 1;
|
|
+ }
|
|
+ }
|
|
+
|
|
+ for (size_t index = 0; index < term.size(); ++index) {
|
|
+ if (index == ellipsis_index) {
|
|
+ // find ellipsis and record the dims represented by ellipsis
|
|
+ ellipsis_dims = rank - term_size;
|
|
+ if (ellipsis_flag) {
|
|
+ ellipsis_flag = false;
|
|
+ for (size_t i = 0; i < ellipsis_dims; i++) {
|
|
+ *ellipsis_dims_value.add_dim() = shape.dim(index + i - num_illegal_char);
|
|
+ }
|
|
+ } else {
|
|
+ for (size_t i = 0; i < ellipsis_dims; i++) {
|
|
+ const auto shape_dim = shape.dim(index + i - num_illegal_char);
|
|
+ const auto current_dim = ellipsis_dims_value.mutable_dim(i);
|
|
+ if (shape_dim.has_dim_value() && current_dim->has_dim_value() &&
|
|
+ shape_dim.dim_value() > current_dim->dim_value() && current_dim->dim_value() == 1) {
|
|
+ current_dim->set_dim_value(shape_dim.dim_value());
|
|
+ }
|
|
+ }
|
|
+ }
|
|
+ index += 2; // skip the rest of dots
|
|
+ num_illegal_char += 3;
|
|
+ continue;
|
|
+
|
|
+ } else if (!is_letter(term[index])) {
|
|
+ num_illegal_char += 1;
|
|
+ continue;
|
|
+ }
|
|
+
|
|
+ const auto inserted = label_maps.insert({term[index], num_labels}).second;
|
|
+ if (inserted) {
|
|
+ *dims_value.add_dim() = shape.dim(index + ellipsis_dims - num_illegal_char);
|
|
+ ++num_labels;
|
|
+ } else {
|
|
+ repeated_labels.insert(term[index]);
|
|
+ }
|
|
+ }
|
|
+
|
|
if (ellipsis_index != std::string::npos) {
|
|
// If there is an ellipsis, the number of dimensions it represents
|
|
// must be total dim - letter dimensions
|
|
if (num_ellipsis == 0) {
|
|
- if (rank + 3 < term.size()) {
|
|
+ if (rank < term_size) {
|
|
fail_shape_inference("Ellipsis represents incompatible dimensions.");
|
|
}
|
|
- num_ellipsis_indices = rank - term.size() + 3;
|
|
+ num_ellipsis_indices = rank - term_size;
|
|
} else { // ellipsis has been seen before. Check that if dimensions
|
|
// are compatible
|
|
- if (num_ellipsis_indices != rank - term.size() + 3) {
|
|
+ if (num_ellipsis_indices != rank - term_size) {
|
|
fail_shape_inference("Ellipsis represents incompatible dimensions.");
|
|
}
|
|
}
|
|
num_ellipsis++;
|
|
} else {
|
|
- if (rank != term.size()) {
|
|
+ if (rank != term_size) {
|
|
fail_shape_inference("Rank of input ", num_operands, " does not match the equation indices.");
|
|
}
|
|
}
|
|
num_operands++;
|
|
}
|
|
|
|
- if (numInputs != num_operands) {
|
|
+ if (num_inputs != num_operands) {
|
|
fail_shape_inference("Number of input tensors does not match the operands in the equation.");
|
|
}
|
|
|
|
- const size_t number_of_letters = 26;
|
|
- size_t num_letter_occurrences[number_of_letters] = {0};
|
|
// Parse the provided right-hand side
|
|
if (mid_index != std::string::npos) {
|
|
std::string right_equation = equation.substr(mid_index + 2);
|
|
auto right_ellipsis_index = right_equation.find("...");
|
|
- if (right_ellipsis_index != std::string::npos) { // Right-hand side contains ellipsis
|
|
- for (size_t i = 0; i < num_ellipsis_indices; ++i) {
|
|
- output_shape->add_dim();
|
|
+
|
|
+ for (size_t index = 0; index < right_equation.size(); ++index) {
|
|
+ // If there's an ellipsis, add its corresponding dimensions
|
|
+ if (index == right_ellipsis_index) {
|
|
+ for (size_t i = 0; i < num_ellipsis_indices; i++) {
|
|
+ *output_shape.add_dim() = ellipsis_dims_value.dim(i);
|
|
+ }
|
|
+ index += 2; // skip the rest of dots
|
|
+ continue;
|
|
}
|
|
- }
|
|
- for (char c : right_equation) { // Add a dimension per each character
|
|
- // in right hand equation
|
|
- if (c != '.') {
|
|
- output_shape->add_dim();
|
|
+
|
|
+ if (is_letter(right_equation[index])) {
|
|
+ *output_shape.add_dim() = dims_value.dim(label_maps[right_equation[index]]);
|
|
}
|
|
}
|
|
} else { // Infer the dimension for right-hand side
|
|
- // If there's an ellipsis, add it's corresponding dimensions
|
|
+ // If there's an ellipsis, add its corresponding dimensions
|
|
for (size_t i = 0; i < num_ellipsis_indices; i++) {
|
|
- output_shape->add_dim();
|
|
+ *output_shape.add_dim() = ellipsis_dims_value.dim(i);
|
|
}
|
|
- for (size_t i = 0; i < left_equation.size(); i++) { // Count chars that appear exactly once on left hand side
|
|
- if ((left_equation.at(i) != ',') && (left_equation.at(i) != '.')) {
|
|
- num_letter_occurrences[left_equation.at(i) - 'a']++;
|
|
- }
|
|
- }
|
|
- for (size_t index = 0; index < number_of_letters; index++) {
|
|
- if (num_letter_occurrences[index] == 1) {
|
|
- output_shape->add_dim();
|
|
+ // If no explicit output was given, generate an implicit output by ordering all the
|
|
+ // labels in alphabetic order (by ASCII value consistent with numpy, so Z < a).
|
|
+ // Exclude any labels that occurred more than once, as these cancel out.
|
|
+ for (auto i : label_maps) {
|
|
+ if (repeated_labels.count(i.first) == 0) {
|
|
+ *output_shape.add_dim() = dims_value.dim(i.second);
|
|
}
|
|
}
|
|
}
|
|
+
|
|
+ updateOutputShape(ctx, 0, output_shape);
|
|
}
|
|
|
|
static const char* Einsum_ver12_doc = R"DOC(
|
|
@@ -2711,7 +2768,10 @@ ONNX_OPERATOR_SET_SCHEMA(
|
|
if (equation.compare("") == 0) {
|
|
return;
|
|
}
|
|
- einsumRankInference(ctx, equation);
|
|
+
|
|
+ equation.erase(std::remove(equation.begin(), equation.end(), ' '),
|
|
+ equation.end()); // Remove space char
|
|
+ einsumShapeInference(ctx, equation);
|
|
}));
|
|
|
|
const char* reduction_doc_sce =
|
|
diff --git a/onnx/test/shape_inference_test.py b/onnx/test/shape_inference_test.py
|
|
index 75280f6c..5543fda0 100644
|
|
--- a/onnx/test/shape_inference_test.py
|
|
+++ b/onnx/test/shape_inference_test.py
|
|
@@ -7026,7 +7026,7 @@ class TestShapeInference(TestShapeInferenceHelper):
|
|
[make_node("Einsum", ["x"], ["y"], equation="ij->ji")],
|
|
[],
|
|
)
|
|
- self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (None, None))]) # type: ignore
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (4, 3))]) # type: ignore
|
|
|
|
def test_einsum_dot(self) -> None:
|
|
graph = self._make_graph(
|
|
@@ -7050,7 +7050,7 @@ class TestShapeInference(TestShapeInferenceHelper):
|
|
[make_node("Einsum", ["x", "y"], ["z"], equation="ij,ab->ijab")],
|
|
[],
|
|
)
|
|
- self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (None, None, None, None))]) # type: ignore
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3, 5, 7, 9))]) # type: ignore
|
|
|
|
def test_einsum_sum_along_dim(self) -> None:
|
|
graph = self._make_graph(
|
|
@@ -7058,7 +7058,7 @@ class TestShapeInference(TestShapeInferenceHelper):
|
|
[make_node("Einsum", ["x"], ["y"], equation="i j->i ")],
|
|
[],
|
|
)
|
|
- self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (None,))]) # type: ignore
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (3,))]) # type: ignore
|
|
|
|
def test_einsum_ellipsis(self) -> None:
|
|
graph = self._make_graph(
|
|
@@ -7066,26 +7066,36 @@ class TestShapeInference(TestShapeInferenceHelper):
|
|
[make_node("Einsum", ["x"], ["y"], equation="... ii ->... i")],
|
|
[],
|
|
)
|
|
- self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (None, None))]) # type: ignore
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (3, 4))]) # type: ignore
|
|
|
|
def test_einsum_ellipsis_2(self) -> None:
|
|
graph = self._make_graph(
|
|
- [("x", TensorProto.FLOAT, (2, 2, 2)), ("y", TensorProto.FLOAT, (2, 2, 2))],
|
|
+ [("x", TensorProto.FLOAT, (2, 3, 4)), ("y", TensorProto.FLOAT, (2, 4, 5))],
|
|
[make_node("Einsum", ["x", "y"], ["z"], equation="...ij,...jk->...ik")],
|
|
[],
|
|
)
|
|
self._assert_inferred(
|
|
- graph, [make_tensor_value_info("z", TensorProto.FLOAT, (None, None, None))]
|
|
+ graph, [make_tensor_value_info("z", TensorProto.FLOAT, (2, 3, 5))]
|
|
) # type: ignore
|
|
|
|
def test_einsum_ellipsis_3(self) -> None:
|
|
graph = self._make_graph(
|
|
- [("x", TensorProto.FLOAT, (2, 2, 2)), ("y", TensorProto.FLOAT, (2, 2, 2))],
|
|
+ [("x", TensorProto.FLOAT, (2, 3, 4)), ("y", TensorProto.FLOAT, (2, 4, 5))],
|
|
[make_node("Einsum", ["x", "y"], ["z"], equation="...ij,...jk")],
|
|
[],
|
|
)
|
|
self._assert_inferred(
|
|
- graph, [make_tensor_value_info("z", TensorProto.FLOAT, (None, None, None))]
|
|
+ graph, [make_tensor_value_info("z", TensorProto.FLOAT, (2, 3, 5))]
|
|
+ ) # type: ignore
|
|
+
|
|
+ def test_einsum_ellipsis_broadcast(self) -> None:
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (1, 3, 4)), ("y", TensorProto.FLOAT, (32, 4, 5))],
|
|
+ [make_node("Einsum", ["x", "y"], ["z"], equation="...ij,...jk->...ik")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(
|
|
+ graph, [make_tensor_value_info("z", TensorProto.FLOAT, (32, 3, 5))]
|
|
) # type: ignore
|
|
|
|
def test_einsum_contraction(self) -> None:
|
|
@@ -7099,11 +7109,7 @@ class TestShapeInference(TestShapeInferenceHelper):
|
|
)
|
|
self._assert_inferred(
|
|
graph,
|
|
- [
|
|
- make_tensor_value_info(
|
|
- "z", TensorProto.FLOAT, (None, None, None, None, None)
|
|
- )
|
|
- ],
|
|
+ [make_tensor_value_info("z", TensorProto.FLOAT, (5, 6, 7, 9, 10))],
|
|
) # type: ignore
|
|
|
|
def test_einsum_contraction_2(self) -> None:
|
|
@@ -7113,7 +7119,7 @@ class TestShapeInference(TestShapeInferenceHelper):
|
|
[],
|
|
)
|
|
self._assert_inferred(
|
|
- graph, [make_tensor_value_info("z", TensorProto.FLOAT, (None, None))]
|
|
+ graph, [make_tensor_value_info("z", TensorProto.FLOAT, (4, 5))]
|
|
) # type: ignore
|
|
|
|
def test_einsum_batch_matmul(self) -> None:
|
|
@@ -7122,7 +7128,7 @@ class TestShapeInference(TestShapeInferenceHelper):
|
|
[make_node("Einsum", ["x", "y"], ["z"], equation="bij , b jk-> bik")],
|
|
[],
|
|
)
|
|
- self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (None, None, None))]) # type: ignore
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (5, 2, 4))]) # type: ignore
|
|
|
|
def test_einsum_left_hand_eqn(self) -> None:
|
|
graph = self._make_graph(
|
|
@@ -7130,7 +7136,7 @@ class TestShapeInference(TestShapeInferenceHelper):
|
|
[make_node("Einsum", ["x", "y"], ["z"], equation="ij,kl")],
|
|
[],
|
|
)
|
|
- self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (None, None, None, None))]) # type: ignore
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (2, 3, 3, 4))]) # type: ignore
|
|
|
|
def test_einsum_incorrect_num_inputs(self) -> None:
|
|
graph = self._make_graph(
|
|
@@ -7144,6 +7150,244 @@ class TestShapeInference(TestShapeInferenceHelper):
|
|
)
|
|
self.assertRaises(onnx.shape_inference.InferenceError, self._inferred, graph)
|
|
|
|
+ def test_einsum_view_A1(self) -> None: # returns a view of A1
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3,))],
|
|
+ [make_node("Einsum", ["x"], ["y"], equation="i")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (3,))]) # type: ignore
|
|
+
|
|
+ def test_einsum_sum_A1(self) -> None: # sums the values of A1
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3,))],
|
|
+ [make_node("Einsum", ["x"], ["y"], equation="i->")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, ())]) # type: ignore
|
|
+
|
|
+ def test_einsum_element_wise_multiplication_A1_B1(
|
|
+ self,
|
|
+ ) -> None: # element-wise multiplication of A1 and B1
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3,)), ("y", TensorProto.FLOAT, (3,))],
|
|
+ [make_node("Einsum", ["x", "y"], ["z"], equation="i,i->i")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3,))]) # type: ignore
|
|
+
|
|
+ def test_einsum_inner_product_A1_B1(self) -> None: # inner product of A1 and B1
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3,)), ("y", TensorProto.FLOAT, (3,))],
|
|
+ [make_node("Einsum", ["x", "y"], ["z"], equation="i,i->")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, ())]) # type: ignore
|
|
+
|
|
+ def test_einsum_outer_product_A1_B1(self) -> None: # outer product of A1 and B1
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3,)), ("y", TensorProto.FLOAT, (3,))],
|
|
+ [make_node("Einsum", ["x", "y"], ["z"], equation="i,j->ij")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3, 3))]) # type: ignore
|
|
+
|
|
+ def test_einsum_view_A2(self) -> None: # returns a view of A2
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3, 3))],
|
|
+ [make_node("Einsum", ["x"], ["y"], equation="ij->ij")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (3, 3))]) # type: ignore
|
|
+
|
|
+ def test_einsum_view_A2_2(self) -> None: # returns a view of A2, another case
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3, 3))],
|
|
+ [make_node("Einsum", ["x"], ["y"], equation="ij")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (3, 3))]) # type: ignore
|
|
+
|
|
+ def test_einsum_transpose_A2(self) -> None: # view transpose of A2
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3, 3))],
|
|
+ [make_node("Einsum", ["x"], ["y"], equation="ji")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (3, 3))]) # type: ignore
|
|
+
|
|
+ def test_einsum_transpose_A2_to_ij(self) -> None: # view transpose of A2
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3, 3))],
|
|
+ [make_node("Einsum", ["x"], ["y"], equation="ji->ij")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (3, 3))]) # type: ignore
|
|
+
|
|
+ def test_einsum_diag_A2(self) -> None: # view main diagonal of A2
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3, 3))],
|
|
+ [make_node("Einsum", ["x"], ["y"], equation="ii->i")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (3,))]) # type: ignore
|
|
+
|
|
+ def test_einsum_trace_A2(self) -> None: # sums main diagonal of A2
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3, 3))],
|
|
+ [make_node("Einsum", ["x"], ["y"], equation="ii->")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, ())]) # type: ignore
|
|
+
|
|
+ def test_einsum_sum_A2(self) -> None: # sums the values of A2
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3, 3))],
|
|
+ [make_node("Einsum", ["x"], ["y"], equation="ij->")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, ())]) # type: ignore
|
|
+
|
|
+ def test_einsum_sum_columns_A2(
|
|
+ self,
|
|
+ ) -> None: # sum down the columns of A2 (across rows)
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3, 3))],
|
|
+ [make_node("Einsum", ["x"], ["y"], equation="ij->j")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (3,))]) # type: ignore
|
|
+
|
|
+ def test_einsum_sum_rows_A2(self) -> None: # sum horizontally along the rows of A2
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3, 3))],
|
|
+ [make_node("Einsum", ["x"], ["y"], equation="ij->i")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("y", TensorProto.FLOAT, (3,))]) # type: ignore
|
|
+
|
|
+ def test_einsum_element_wise_multiplication_A2_B2(
|
|
+ self,
|
|
+ ) -> None: # element-wise multiplication of A2 and B2
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3, 3)), ("y", TensorProto.FLOAT, (3, 3))],
|
|
+ [make_node("Einsum", ["x", "y"], ["z"], equation="ij,ij->ij")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3, 3))]) # type: ignore
|
|
+
|
|
+ def test_einsum_element_wise_multiplication_A2_B2_transpose(
|
|
+ self,
|
|
+ ) -> None: # element-wise multiplication of A2 and B2.T
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3, 3)), ("y", TensorProto.FLOAT, (3, 3))],
|
|
+ [make_node("Einsum", ["x", "y"], ["z"], equation="ij,ji->ij")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3, 3))]) # type: ignore
|
|
+
|
|
+ def test_einsum_matrix_multiplication_A2_B2(
|
|
+ self,
|
|
+ ) -> None: # matrix multiplication of A2 and B2
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3, 3)), ("y", TensorProto.FLOAT, (3, 3))],
|
|
+ [make_node("Einsum", ["x", "y"], ["z"], equation="ij,jk")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3, 3))]) # type: ignore
|
|
+
|
|
+ def test_einsum_matrix_multiplication_A2_B2_to_ik(
|
|
+ self,
|
|
+ ) -> None: # matrix multiplication of A2 and B2
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3, 3)), ("y", TensorProto.FLOAT, (3, 3))],
|
|
+ [make_node("Einsum", ["x", "y"], ["z"], equation="ij,jk->ik")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3, 3))]) # type: ignore
|
|
+
|
|
+ def test_einsum_matrix_multiplication_A3_B3(
|
|
+ self,
|
|
+ ) -> None: # matrix multiplication of A3 and B3 (a stack of 2D matrices)
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (2, 3, 3)), ("y", TensorProto.FLOAT, (2, 3, 3))],
|
|
+ [make_node("Einsum", ["x", "y"], ["z"], equation="bij,bjk->bik")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (2, 3, 3))]) # type: ignore
|
|
+
|
|
+ def test_einsum_matrix_multiplication_A3_B3_transpose(
|
|
+ self,
|
|
+ ) -> None: # matrix multiplication of A3 and B3 (a stack of 2D matrices)
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (2, 3, 3)), ("y", TensorProto.FLOAT, (2, 3, 3))],
|
|
+ [make_node("Einsum", ["x", "y"], ["z"], equation="bij,bkj->bik")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (2, 3, 3))]) # type: ignore
|
|
+
|
|
+ def test_einsum_inner_product_A2_B2(self) -> None: # inner product of A2 and B2
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3, 3)), ("y", TensorProto.FLOAT, (3, 3))],
|
|
+ [make_node("Einsum", ["x", "y"], ["z"], equation="ij,kj->ik")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3, 3))]) # type: ignore
|
|
+
|
|
+ def test_einsum_row_multiplication_A2_B2(
|
|
+ self,
|
|
+ ) -> None: # each row of A2 multiplied by B2
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3, 3)), ("y", TensorProto.FLOAT, (3, 3))],
|
|
+ [make_node("Einsum", ["x", "y"], ["z"], equation="ij,kj->ikj")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3, 3, 3))]) # type: ignore
|
|
+
|
|
+ def test_einsum_value_multiplication_A2_B2(
|
|
+ self,
|
|
+ ) -> None: # each value of A2 multiplied by B2
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3, 3)), ("y", TensorProto.FLOAT, (3, 3))],
|
|
+ [make_node("Einsum", ["x", "y"], ["z"], equation="ij,kl->ijkl")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3, 3, 3, 3))]) # type: ignore
|
|
+
|
|
+ def test_einsum_scalar_times_array(self) -> None: # Scalar times array
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, ()), ("y", TensorProto.FLOAT, (3, 3))],
|
|
+ [make_node("Einsum", ["x", "y"], ["z"], equation=",ij->ij")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3, 3))]) # type: ignore
|
|
+
|
|
+ def test_einsum_matrix_vector_A2_B1(self) -> None: # Matrix and vector.
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3, 3)), ("y", TensorProto.FLOAT, (3,))],
|
|
+ [make_node("Einsum", ["x", "y"], ["z"], equation="ij,j->i")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3,))]) # type: ignore
|
|
+
|
|
+ def test_einsum_diag_multiplication_A2_B2(
|
|
+ self,
|
|
+ ) -> None: # diagonals multiplied by each other
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3, 3)), ("y", TensorProto.FLOAT, (3, 3))],
|
|
+ [make_node("Einsum", ["x", "y"], ["z"], equation="ii,ii->i")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, (3,))]) # type: ignore
|
|
+
|
|
+ def test_einsum_diag_dot_product_A2_B2(self) -> None: # dot product of diagonals
|
|
+ graph = self._make_graph(
|
|
+ [("x", TensorProto.FLOAT, (3, 3)), ("y", TensorProto.FLOAT, (3, 3))],
|
|
+ [make_node("Einsum", ["x", "y"], ["z"], equation="ii,ii->")],
|
|
+ [],
|
|
+ )
|
|
+ self._assert_inferred(graph, [make_tensor_value_info("z", TensorProto.FLOAT, ())]) # type: ignore
|
|
+
|
|
def test_negative_log_likehood_shape_is_NCdd(self) -> None:
|
|
N, C = 3, 4
|
|
graph = self._make_graph(
|
|
|