Fix a shape inference bug for FusedConv and MaxpoolWithMask (#748)

This commit is contained in:
Du Li 2019-04-02 21:48:06 -07:00 committed by Changming Sun
parent fc26b24138
commit 6e9ed17adc
3 changed files with 8 additions and 8 deletions

View file

@ -385,7 +385,7 @@ Sample echo operator.)DOC");
"T")
.TypeConstraint("T", {"tensor(float)"}, "Constrain input0 and output types to float tensors")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
ONNX_NAMESPACE::convPoolTypeAndShapeInference(ctx, true, false);
ONNX_NAMESPACE::convPoolTypeAndShapeInference(ctx, false, true);
});
ONNX_CONTRIB_OPERATOR_SCHEMA(FusedConv)
@ -447,7 +447,7 @@ activation.)DOC")
"T")
.TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "Constrain input and output types to float tensors")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
ONNX_NAMESPACE::convPoolTypeAndShapeInference(ctx, false, true);
ONNX_NAMESPACE::convPoolTypeAndShapeInference(ctx, true, false);
});
ONNX_CONTRIB_OPERATOR_SCHEMA(FusedGemm)

View file

@ -57,7 +57,7 @@ std::unique_ptr<RuleBasedGraphTransformer> GenerateRuleBasedGraphTransformer(Tra
std::unique_ptr<RuleBasedGraphTransformer> rule_transformer =
std::make_unique<RuleBasedGraphTransformer>(transformer_utils::GenerateRuleBasedTransformerName(level),
"Apply rewrite rules for Level" +
"Apply rewrite rules for Level" +
std::to_string(static_cast<uint32_t>(level)),
compatible_execution_providers);
for (auto& entry : rewrite_rules_to_register) {
@ -97,6 +97,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL
}
transformers.emplace_back(std::make_unique<GemmActivationFusion>(l2_execution_providers));
transformers.emplace_back(std::make_unique<MatMulAddFusion>(l2_execution_providers));
transformers.emplace_back(std::make_unique<ConvActivationFusion>(l2_execution_providers));
transformers.emplace_back(std::make_unique<ConvAddFusion>());
transformers.emplace_back(std::make_unique<ConvMulFusion>());
transformers.emplace_back(std::make_unique<ConvBNFusion>());

View file

@ -149,11 +149,10 @@ TEST(GraphTransformationTests, FuseConvBNMulAddUnsqueeze) {
}
TEST(GraphTransformationTests, FuseConvActivation) {
std::unordered_map<std::string, std::string> model_to_op_name {{"fusion/conv_relu.onnx", "Relu"},
{"fusion/conv_sigmoid.onnx", "Sigmoid"},
{"fusion/conv_tanh.onnx", "Tanh"},
{"fusion/conv_leakyrelu.onnx", "LeakyRelu"}};
std::unordered_map<std::string, std::string> model_to_op_name{{"fusion/conv_relu.onnx", "Relu"},
{"fusion/conv_sigmoid.onnx", "Sigmoid"},
{"fusion/conv_tanh.onnx", "Tanh"},
{"fusion/conv_leakyrelu.onnx", "LeakyRelu"}};
for (const auto& model : model_to_op_name) {
std::string model_uri = MODEL_FOLDER + model.first;