Add data type check in ConvAddRelu fusion (#12058)

This commit is contained in:
Hariharan Seshadri 2022-07-01 15:31:15 -07:00 committed by GitHub
parent 57ac3d0a61
commit df712d80ca
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 97 additions and 4 deletions

View file

@ -43,6 +43,21 @@ bool HasElementDataType(const NodeArg& node_arg, int32_t data_type) {
return data_type == actual_data_type;
}
bool ConvFusionDataTypeCheck(const Node& conv_node) {
// TODO(hasesh): The CPU and CUDA EP only support float type for the Conv+Activation
// and the Conv+Add+Relu fusions.
// Assess the support level for the other compatible EPs and if they also
// only support float, remove the EP check altogether.
const std::string_view node_ep = conv_node.GetExecutionProviderType();
if (node_ep == kCudaExecutionProvider || node_ep == kCpuExecutionProvider) {
if (!HasElementDataType(*conv_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) {
return false;
}
}
return true;
}
class ConvActivation : public NodeSelector {
public:
ConvActivation() = default;
@ -74,12 +89,12 @@ class ConvActivation : public NodeSelector {
return false;
};
if (!ConvFusionDataTypeCheck(node)) {
return std::nullopt;
}
// check EP type and activation
if (node_ep == kCudaExecutionProvider) {
if (!HasElementDataType(*node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) {
return std::nullopt;
}
if (!graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Relu", {6, 13, 14})) {
return std::nullopt;
}
@ -112,6 +127,10 @@ class ConvAddRelu : public NodeSelector {
return std::nullopt;
}
if (!ConvFusionDataTypeCheck(node)) {
return std::nullopt;
}
const auto* add_node = GetLoneConsumerNode(graph_viewer, node);
if (!add_node ||
!graph_utils::IsSupportedOptypeVersionAndDomain(*add_node, "Add", {6, 7, 13, 14}) ||

View file

@ -714,6 +714,37 @@ TEST_F(GraphTransformationTests, FuseCudaConvAddRelu) {
ASSERT_TRUE(op_to_count["Relu"] == 0); // Relu removed from graph
}
// Currently the ConvAddRelu fusion is only backed by a float kernel for the
// the CUDA EP.
// When we see the corresponding pattern for the fp16 data type, the fusion
// should not be triggered as there is no kernel to back the fused pattern.
// TODO(hasesh): Limit the test to using the CUDA EP for now as the level of
// data type support in other compatible EPs is still yet to be ascertained.
// TODO(hasesh): If at all the fp16 type is supported for the fusion, adjust/remove
// this test.
TEST_F(GraphTransformationTests, FuseCudaConvAddRelu_UnsupportedType) {
auto model_uri = MODEL_FOLDER "fusion/conv_add_relu_fp16.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
for (auto& node : p_model->MainGraph().Nodes()) {
node.SetExecutionProviderType(kCudaExecutionProvider);
}
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Add"], 1);
ASSERT_EQ(op_to_count["Relu"], 1);
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<ConvActivationFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Add"], 1); // Add not removed from graph (fusion not triggered)
ASSERT_EQ(op_to_count["Relu"], 1); // Relu not removed from graph (fusion not triggered)
}
// Conv->Add->Relu will be left intact since there is Identity depend on Add
TEST_F(GraphTransformationTests, FuseCudaConvAddReluIdentity) {
auto model_uri = MODEL_FOLDER "fusion/conv_add_relu_identity.onnx";

View file

@ -0,0 +1,43 @@


X
W
BC"Conv
SY"Relu

C
AS"AddgraphZ
X





Z
W





Z
B


Z
A





b
Y





B