Remove Cast before and after Gelu (#11885)

* fuse cast gelu

* use PropagateCastOps

* fix ut
This commit is contained in:
Vincent Wang 2022-06-22 09:07:48 +08:00 committed by GitHub
parent 4bf22e2a40
commit 03beed0ceb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 140 additions and 4 deletions

View file

@ -43,6 +43,13 @@ struct OP_Gelu : public CtxGelu {
}
};
template <>
struct OP_Gelu<half> : public CtxGelu {
__device__ __inline__ half operator()(const half& a) const {
return static_cast<half>(_Gelu(static_cast<float>(a)));
}
};
#define UNARY_ACTIVATION_IMPL(name) \
UNARY_ACTIVATION_IMPL_DECLARATION(name) { \
UnaryElementWiseImpl(stream, \

View file

@ -140,9 +140,9 @@ static bool IsFP16Allow(const std::string& op_type, size_t level, const FP16Allo
using OpsSetType = InlinedHashSet<std::string_view>;
static const OpsSetType level1_fp16_allow_set =
{"Expand", "Transpose", "Relu", "Reshape", "Split", "Tanh", "Squeeze", "Unsqueeze"};
{"Expand", "Transpose", "Relu", "Reshape", "Split", "Tanh", "Squeeze", "Unsqueeze", "Gelu"};
static const OpsSetType level2_fp16_allow_set = {
"Add", "BiasGelu", "Dropout", "FastGelu", "Gather", "Gelu", "LayerNormalization", "Where"};
"Add", "BiasGelu", "Dropout", "FastGelu", "Gather", "LayerNormalization", "Where"};
// To support new optimization levels, you need to extend the below array with a set ops for the new level
static const std::array<std::reference_wrapper<const OpsSetType>, MaxSupportedCastPropagationLevel> allowed_ops =

View file

@ -73,6 +73,7 @@
#include "test/common/tensor_op_test_utils.h"
#include "test/compare_ortvalue.h"
#include "test/framework/test_utils.h"
#include "test/optimizer/graph_transform_test_builder.h"
#include "test/optimizer/graph_transform_test_fixture.h"
#include "test/providers/provider_test_utils.h"
#include "test/test_environment.h"
@ -4992,5 +4993,68 @@ TEST_F(GraphTransformationTests, PropagateCastOpsTests) {
}
}
#ifdef ENABLE_TRAINING
TEST_F(GraphTransformationTests, PropagateCastOpsTests_Gelu) {
using Strategy = GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy;
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<MLFloat16>({{2, 3, 3, 3}});
auto* cast_out_0 = builder.MakeIntermediate();
auto* gelu_out = builder.MakeIntermediate();
auto* cast_out_1 = builder.MakeIntermediate();
auto* identity_out = builder.MakeOutput();
builder.AddNode("Cast", {input_arg}, {cast_out_0})
.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT));
builder.AddNode("Gelu", {cast_out_0}, {gelu_out}, kMSDomain);
builder.AddNode("Cast", {gelu_out}, {cast_out_1})
.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
builder.AddNode("Identity", {cast_out_1}, {identity_out});
};
auto pre_graph_checker = [&](Graph& graph) {
ASSERT_EQ(CountOpsInGraph(graph)["Cast"], 2);
};
auto post_graph_checker = [&](Graph& graph) {
ASSERT_EQ(CountOpsInGraph(graph)["Cast"], 0);
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<PropagateCastOps>(Strategy::FloodFill, 1);
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker);
}
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<BFloat16>({{2, -1, 3, -1}});
auto* cast_out_0 = builder.MakeIntermediate();
auto* gelu_out = builder.MakeIntermediate();
auto* cast_out_1 = builder.MakeIntermediate();
auto* identity_out = builder.MakeOutput();
builder.AddNode("Cast", {input_arg}, {cast_out_0})
.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT));
builder.AddNode("Gelu", {cast_out_0}, {gelu_out}, kMSDomain);
builder.AddNode("Cast", {gelu_out}, {cast_out_1})
.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16));
builder.AddNode("Identity", {cast_out_1}, {identity_out});
};
auto pre_graph_checker = [&](Graph& graph) {
ASSERT_EQ(CountOpsInGraph(graph)["Cast"], 2);
};
auto post_graph_checker = [&](Graph& graph) {
ASSERT_EQ(CountOpsInGraph(graph)["Cast"], 2);
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<PropagateCastOps>(Strategy::FloodFill, 1);
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker);
}
}
#endif
} // namespace test
} // namespace onnxruntime

View file

@ -97,5 +97,28 @@ void TransformerTester(const std::function<void(ModelTestBuilder& helper)>& buil
}
}
void TestGraphTransformer(const std::function<void(ModelTestBuilder& helper)>& build_test_case, int opset_version,
const logging::Logger& logger, std::unique_ptr<GraphTransformer> transformer,
TransformerLevel level, unsigned steps, const std::function<void(Graph&)>& pre_graph_checker,
const std::function<void(Graph&)>& post_graph_checker) {
// Build the model for this test.
std::unordered_map<std::string, int> domain_to_version;
domain_to_version[kOnnxDomain] = opset_version;
domain_to_version[kMSDomain] = 1;
Model model("TransformerTester", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version, {}, logger);
Graph& graph = model.MainGraph();
ModelTestBuilder helper(graph);
ASSERT_TRUE(build_test_case);
build_test_case(helper);
helper.SetGraphOutputs();
ASSERT_STATUS_OK(graph.Resolve());
pre_graph_checker(graph);
onnxruntime::GraphTransformerManager graph_transformation_mgr{steps};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(transformer), level));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, level, logger));
post_graph_checker(graph);
}
} // namespace test
} // namespace onnxruntime

View file

@ -68,6 +68,23 @@ class ModelTestBuilder {
return MakeInput<bool>(shape, data);
}
template <typename T>
NodeArg* MakeInput(const std::optional<std::vector<int64_t>>& shape) {
ONNX_NAMESPACE::TypeProto type_proto;
type_proto.mutable_tensor_type()->set_elem_type(utils::ToTensorProtoElementType<T>());
if (shape != std::nullopt) {
type_proto.mutable_tensor_type()->mutable_shape();
for (auto& d : *shape) {
auto dim = type_proto.mutable_tensor_type()->mutable_shape()->add_dim();
if (d != -1) {
dim->set_dim_value(d);
}
}
}
std::string name = graph_.GenerateNodeArgName("input");
return &graph_.GetOrCreateNodeArg(name, &type_proto);
}
NodeArg* MakeOutput() {
std::string name = graph_.GenerateNodeArgName("output");
output_names_.push_back(name);
@ -285,5 +302,21 @@ void TransformerTester(const std::function<void(ModelTestBuilder& helper)>& buil
const std::function<void(SessionOptions&)>& add_session_options = {},
const InlinedHashSet<std::string>& disabled_optimizers = {});
/**
* @brief Apply a GraphTransformer to a graph, and run graph checkers before and after applying the transformer.
*
* @param build_test_case The function to build a graph for testing
* @param opset_version The OpSet version of the graph
* @param logger The logger
* @param transformer The GraphTransformer to be applied
* @param level The transformer level on which the transformer will be applied
* @param steps The step count of the GraphTransformerManager
* @param pre_graph_checker The graph checker function before applying the transformer
* @param post_graph_checker The graph checker function after applying the transformer
*/
void TestGraphTransformer(const std::function<void(ModelTestBuilder& helper)>& build_test_case, int opset_version,
const logging::Logger& logger, std::unique_ptr<GraphTransformer> transformer,
TransformerLevel level, unsigned steps, const std::function<void(Graph&)>& pre_graph_checker,
const std::function<void(Graph&)>& post_graph_checker);
} // namespace test
} // namespace onnxruntime

View file

@ -121,8 +121,9 @@ class GraphExecutionManager(GraphExecutionInterface):
# as "FP16 safe", in order to insert/(re)move cast operations before/after to perform such operations in reduced (16-bit) precision.
# - If propagate_cast_ops_level is positive, 1 or 2, then in addition to opcode codes specified by propagate_cast_ops_allow use onnxruntime
# predetermined list of opcodes considered safe to move before/after cast operation.
# - Onnxruntime Level 1 predetermind "FP16 safe" opcodes include only opcode that do not perform any computation such as Transpose, Split, Reshape, etc.
# whereas Level 2 perdetermined "FP16 safe" opcodes include opcodes that perform computation using contrib ops, GeLU, Dropout, LayerNormalization, etc.
# - Onnxruntime Level 1 predetermind "FP16 safe" opcodes include only opcode that do not perform any computation such as Transpose, Split, Reshape, etc.,
# or the computation is actual in Float such as GeLU, etc.
# whereas Level 2 perdetermined "FP16 safe" opcodes include opcodes that perform computation using contrib ops, Dropout, LayerNormalization, etc.
self._propagate_cast_ops_level = 1
# List of opcodes to be considered safe to move before/after cast operation if propagate_cast_ops_level is zero.
self._propagate_cast_ops_allow = []

View file

@ -17,6 +17,14 @@ struct OP_GeluGrad : public CtxGeluGrad {
}
};
template <>
struct OP_GeluGrad<half> : public CtxGeluGrad {
__device__ __inline__ half operator()(const half& dy, const half& x) const {
return static_cast<half>(
ComputeGeluGradScalar(static_cast<float>(dy), static_cast<float>(x), gelu_computation_mode::Default{}));
}
};
template <typename T>
struct OP_FastGeluGrad : public CtxGeluGrad {
__device__ __inline__ T operator()(const T& dy, const T& x) const {