mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
Fix a shape inference bug for FusedConv and MaxpoolWithMask (#748)
This commit is contained in:
parent
fc26b24138
commit
6e9ed17adc
3 changed files with 8 additions and 8 deletions
|
|
@ -385,7 +385,7 @@ Sample echo operator.)DOC");
|
|||
"T")
|
||||
.TypeConstraint("T", {"tensor(float)"}, "Constrain input0 and output types to float tensors")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
ONNX_NAMESPACE::convPoolTypeAndShapeInference(ctx, true, false);
|
||||
ONNX_NAMESPACE::convPoolTypeAndShapeInference(ctx, false, true);
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(FusedConv)
|
||||
|
|
@ -447,7 +447,7 @@ activation.)DOC")
|
|||
"T")
|
||||
.TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "Constrain input and output types to float tensors")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
ONNX_NAMESPACE::convPoolTypeAndShapeInference(ctx, false, true);
|
||||
ONNX_NAMESPACE::convPoolTypeAndShapeInference(ctx, true, false);
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(FusedGemm)
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ std::unique_ptr<RuleBasedGraphTransformer> GenerateRuleBasedGraphTransformer(Tra
|
|||
|
||||
std::unique_ptr<RuleBasedGraphTransformer> rule_transformer =
|
||||
std::make_unique<RuleBasedGraphTransformer>(transformer_utils::GenerateRuleBasedTransformerName(level),
|
||||
"Apply rewrite rules for Level" +
|
||||
"Apply rewrite rules for Level" +
|
||||
std::to_string(static_cast<uint32_t>(level)),
|
||||
compatible_execution_providers);
|
||||
for (auto& entry : rewrite_rules_to_register) {
|
||||
|
|
@ -97,6 +97,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL
|
|||
}
|
||||
transformers.emplace_back(std::make_unique<GemmActivationFusion>(l2_execution_providers));
|
||||
transformers.emplace_back(std::make_unique<MatMulAddFusion>(l2_execution_providers));
|
||||
transformers.emplace_back(std::make_unique<ConvActivationFusion>(l2_execution_providers));
|
||||
transformers.emplace_back(std::make_unique<ConvAddFusion>());
|
||||
transformers.emplace_back(std::make_unique<ConvMulFusion>());
|
||||
transformers.emplace_back(std::make_unique<ConvBNFusion>());
|
||||
|
|
|
|||
|
|
@ -149,11 +149,10 @@ TEST(GraphTransformationTests, FuseConvBNMulAddUnsqueeze) {
|
|||
}
|
||||
|
||||
TEST(GraphTransformationTests, FuseConvActivation) {
|
||||
|
||||
std::unordered_map<std::string, std::string> model_to_op_name {{"fusion/conv_relu.onnx", "Relu"},
|
||||
{"fusion/conv_sigmoid.onnx", "Sigmoid"},
|
||||
{"fusion/conv_tanh.onnx", "Tanh"},
|
||||
{"fusion/conv_leakyrelu.onnx", "LeakyRelu"}};
|
||||
std::unordered_map<std::string, std::string> model_to_op_name{{"fusion/conv_relu.onnx", "Relu"},
|
||||
{"fusion/conv_sigmoid.onnx", "Sigmoid"},
|
||||
{"fusion/conv_tanh.onnx", "Tanh"},
|
||||
{"fusion/conv_leakyrelu.onnx", "LeakyRelu"}};
|
||||
|
||||
for (const auto& model : model_to_op_name) {
|
||||
std::string model_uri = MODEL_FOLDER + model.first;
|
||||
|
|
|
|||
Loading…
Reference in a new issue