mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-07 00:13:17 +00:00
Add data type check in ConvAddRelu fusion (#12058)
This commit is contained in:
parent
57ac3d0a61
commit
df712d80ca
3 changed files with 97 additions and 4 deletions
|
|
@ -43,6 +43,21 @@ bool HasElementDataType(const NodeArg& node_arg, int32_t data_type) {
|
|||
return data_type == actual_data_type;
|
||||
}
|
||||
|
||||
bool ConvFusionDataTypeCheck(const Node& conv_node) {
|
||||
// TODO(hasesh): The CPU and CUDA EP only support float type for the Conv+Activation
|
||||
// and the Conv+Add+Relu fusions.
|
||||
// Assess the support level for the other compatible EPs and if they also
|
||||
// only support float, remove the EP check altogether.
|
||||
const std::string_view node_ep = conv_node.GetExecutionProviderType();
|
||||
if (node_ep == kCudaExecutionProvider || node_ep == kCpuExecutionProvider) {
|
||||
if (!HasElementDataType(*conv_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
class ConvActivation : public NodeSelector {
|
||||
public:
|
||||
ConvActivation() = default;
|
||||
|
|
@ -74,12 +89,12 @@ class ConvActivation : public NodeSelector {
|
|||
return false;
|
||||
};
|
||||
|
||||
if (!ConvFusionDataTypeCheck(node)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// check EP type and activation
|
||||
if (node_ep == kCudaExecutionProvider) {
|
||||
if (!HasElementDataType(*node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Relu", {6, 13, 14})) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
|
@ -112,6 +127,10 @@ class ConvAddRelu : public NodeSelector {
|
|||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (!ConvFusionDataTypeCheck(node)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
const auto* add_node = GetLoneConsumerNode(graph_viewer, node);
|
||||
if (!add_node ||
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(*add_node, "Add", {6, 7, 13, 14}) ||
|
||||
|
|
|
|||
|
|
@ -714,6 +714,37 @@ TEST_F(GraphTransformationTests, FuseCudaConvAddRelu) {
|
|||
ASSERT_TRUE(op_to_count["Relu"] == 0); // Relu removed from graph
|
||||
}
|
||||
|
||||
// Currently the ConvAddRelu fusion is only backed by a float kernel for the
|
||||
// the CUDA EP.
|
||||
|
||||
// When we see the corresponding pattern for the fp16 data type, the fusion
|
||||
// should not be triggered as there is no kernel to back the fused pattern.
|
||||
|
||||
// TODO(hasesh): Limit the test to using the CUDA EP for now as the level of
|
||||
// data type support in other compatible EPs is still yet to be ascertained.
|
||||
|
||||
// TODO(hasesh): If at all the fp16 type is supported for the fusion, adjust/remove
|
||||
// this test.
|
||||
TEST_F(GraphTransformationTests, FuseCudaConvAddRelu_UnsupportedType) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/conv_add_relu_fp16.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
for (auto& node : p_model->MainGraph().Nodes()) {
|
||||
node.SetExecutionProviderType(kCudaExecutionProvider);
|
||||
}
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["Add"], 1);
|
||||
ASSERT_EQ(op_to_count["Relu"], 1);
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<ConvActivationFusion>(), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["Add"], 1); // Add not removed from graph (fusion not triggered)
|
||||
ASSERT_EQ(op_to_count["Relu"], 1); // Relu not removed from graph (fusion not triggered)
|
||||
}
|
||||
|
||||
// Conv->Add->Relu will be left intact since there is Identity depend on Add
|
||||
TEST_F(GraphTransformationTests, FuseCudaConvAddReluIdentity) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/conv_add_relu_identity.onnx";
|
||||
|
|
|
|||
43
onnxruntime/test/testdata/transform/fusion/conv_add_relu_fp16.onnx
vendored
Normal file
43
onnxruntime/test/testdata/transform/fusion/conv_add_relu_fp16.onnx
vendored
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
:¾
|
||||
|
||||
X
|
||||
W
|
||||
BC"Conv
|
||||
|
||||
SY"Relu
|
||||
|
||||
C
|
||||
AS"AddgraphZ
|
||||
X
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Z
|
||||
W
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Z
|
||||
B
|
||||
|
||||
|
||||
|
||||
Z
|
||||
A
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
b
|
||||
Y
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
B
|
||||
Loading…
Reference in a new issue