mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Remove Cast before and after Gelu (#11885)
* fuse cast gelu * use PropagateCastOps * fix ut
This commit is contained in:
parent
4bf22e2a40
commit
03beed0ceb
7 changed files with 140 additions and 4 deletions
|
|
@ -43,6 +43,13 @@ struct OP_Gelu : public CtxGelu {
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OP_Gelu<half> : public CtxGelu {
|
||||
__device__ __inline__ half operator()(const half& a) const {
|
||||
return static_cast<half>(_Gelu(static_cast<float>(a)));
|
||||
}
|
||||
};
|
||||
|
||||
#define UNARY_ACTIVATION_IMPL(name) \
|
||||
UNARY_ACTIVATION_IMPL_DECLARATION(name) { \
|
||||
UnaryElementWiseImpl(stream, \
|
||||
|
|
|
|||
|
|
@ -140,9 +140,9 @@ static bool IsFP16Allow(const std::string& op_type, size_t level, const FP16Allo
|
|||
|
||||
using OpsSetType = InlinedHashSet<std::string_view>;
|
||||
static const OpsSetType level1_fp16_allow_set =
|
||||
{"Expand", "Transpose", "Relu", "Reshape", "Split", "Tanh", "Squeeze", "Unsqueeze"};
|
||||
{"Expand", "Transpose", "Relu", "Reshape", "Split", "Tanh", "Squeeze", "Unsqueeze", "Gelu"};
|
||||
static const OpsSetType level2_fp16_allow_set = {
|
||||
"Add", "BiasGelu", "Dropout", "FastGelu", "Gather", "Gelu", "LayerNormalization", "Where"};
|
||||
"Add", "BiasGelu", "Dropout", "FastGelu", "Gather", "LayerNormalization", "Where"};
|
||||
|
||||
// To support new optimization levels, you need to extend the below array with a set ops for the new level
|
||||
static const std::array<std::reference_wrapper<const OpsSetType>, MaxSupportedCastPropagationLevel> allowed_ops =
|
||||
|
|
|
|||
|
|
@ -73,6 +73,7 @@
|
|||
#include "test/common/tensor_op_test_utils.h"
|
||||
#include "test/compare_ortvalue.h"
|
||||
#include "test/framework/test_utils.h"
|
||||
#include "test/optimizer/graph_transform_test_builder.h"
|
||||
#include "test/optimizer/graph_transform_test_fixture.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/test_environment.h"
|
||||
|
|
@ -4992,5 +4993,68 @@ TEST_F(GraphTransformationTests, PropagateCastOpsTests) {
|
|||
}
|
||||
}
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
TEST_F(GraphTransformationTests, PropagateCastOpsTests_Gelu) {
|
||||
using Strategy = GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy;
|
||||
{
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<MLFloat16>({{2, 3, 3, 3}});
|
||||
auto* cast_out_0 = builder.MakeIntermediate();
|
||||
auto* gelu_out = builder.MakeIntermediate();
|
||||
auto* cast_out_1 = builder.MakeIntermediate();
|
||||
auto* identity_out = builder.MakeOutput();
|
||||
|
||||
builder.AddNode("Cast", {input_arg}, {cast_out_0})
|
||||
.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT));
|
||||
builder.AddNode("Gelu", {cast_out_0}, {gelu_out}, kMSDomain);
|
||||
builder.AddNode("Cast", {gelu_out}, {cast_out_1})
|
||||
.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
|
||||
builder.AddNode("Identity", {cast_out_1}, {identity_out});
|
||||
};
|
||||
|
||||
auto pre_graph_checker = [&](Graph& graph) {
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Cast"], 2);
|
||||
};
|
||||
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Cast"], 0);
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<PropagateCastOps>(Strategy::FloodFill, 1);
|
||||
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker);
|
||||
}
|
||||
|
||||
{
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<BFloat16>({{2, -1, 3, -1}});
|
||||
auto* cast_out_0 = builder.MakeIntermediate();
|
||||
auto* gelu_out = builder.MakeIntermediate();
|
||||
auto* cast_out_1 = builder.MakeIntermediate();
|
||||
auto* identity_out = builder.MakeOutput();
|
||||
|
||||
builder.AddNode("Cast", {input_arg}, {cast_out_0})
|
||||
.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT));
|
||||
builder.AddNode("Gelu", {cast_out_0}, {gelu_out}, kMSDomain);
|
||||
builder.AddNode("Cast", {gelu_out}, {cast_out_1})
|
||||
.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16));
|
||||
builder.AddNode("Identity", {cast_out_1}, {identity_out});
|
||||
};
|
||||
|
||||
auto pre_graph_checker = [&](Graph& graph) {
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Cast"], 2);
|
||||
};
|
||||
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Cast"], 2);
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<PropagateCastOps>(Strategy::FloodFill, 1);
|
||||
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -97,5 +97,28 @@ void TransformerTester(const std::function<void(ModelTestBuilder& helper)>& buil
|
|||
}
|
||||
}
|
||||
|
||||
void TestGraphTransformer(const std::function<void(ModelTestBuilder& helper)>& build_test_case, int opset_version,
|
||||
const logging::Logger& logger, std::unique_ptr<GraphTransformer> transformer,
|
||||
TransformerLevel level, unsigned steps, const std::function<void(Graph&)>& pre_graph_checker,
|
||||
const std::function<void(Graph&)>& post_graph_checker) {
|
||||
// Build the model for this test.
|
||||
std::unordered_map<std::string, int> domain_to_version;
|
||||
domain_to_version[kOnnxDomain] = opset_version;
|
||||
domain_to_version[kMSDomain] = 1;
|
||||
Model model("TransformerTester", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
|
||||
domain_to_version, {}, logger);
|
||||
Graph& graph = model.MainGraph();
|
||||
ModelTestBuilder helper(graph);
|
||||
ASSERT_TRUE(build_test_case);
|
||||
build_test_case(helper);
|
||||
helper.SetGraphOutputs();
|
||||
ASSERT_STATUS_OK(graph.Resolve());
|
||||
pre_graph_checker(graph);
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{steps};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(transformer), level));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, level, logger));
|
||||
post_graph_checker(graph);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -68,6 +68,23 @@ class ModelTestBuilder {
|
|||
return MakeInput<bool>(shape, data);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
NodeArg* MakeInput(const std::optional<std::vector<int64_t>>& shape) {
|
||||
ONNX_NAMESPACE::TypeProto type_proto;
|
||||
type_proto.mutable_tensor_type()->set_elem_type(utils::ToTensorProtoElementType<T>());
|
||||
if (shape != std::nullopt) {
|
||||
type_proto.mutable_tensor_type()->mutable_shape();
|
||||
for (auto& d : *shape) {
|
||||
auto dim = type_proto.mutable_tensor_type()->mutable_shape()->add_dim();
|
||||
if (d != -1) {
|
||||
dim->set_dim_value(d);
|
||||
}
|
||||
}
|
||||
}
|
||||
std::string name = graph_.GenerateNodeArgName("input");
|
||||
return &graph_.GetOrCreateNodeArg(name, &type_proto);
|
||||
}
|
||||
|
||||
NodeArg* MakeOutput() {
|
||||
std::string name = graph_.GenerateNodeArgName("output");
|
||||
output_names_.push_back(name);
|
||||
|
|
@ -285,5 +302,21 @@ void TransformerTester(const std::function<void(ModelTestBuilder& helper)>& buil
|
|||
const std::function<void(SessionOptions&)>& add_session_options = {},
|
||||
const InlinedHashSet<std::string>& disabled_optimizers = {});
|
||||
|
||||
/**
|
||||
* @brief Apply a GraphTransformer to a graph, and run graph checkers before and after applying the transformer.
|
||||
*
|
||||
* @param build_test_case The function to build a graph for testing
|
||||
* @param opset_version The OpSet version of the graph
|
||||
* @param logger The logger
|
||||
* @param transformer The GraphTransformer to be applied
|
||||
* @param level The transformer level on which the transformer will be applied
|
||||
* @param steps The step count of the GraphTransformerManager
|
||||
* @param pre_graph_checker The graph checker function before applying the transformer
|
||||
* @param post_graph_checker The graph checker function after applying the transformer
|
||||
*/
|
||||
void TestGraphTransformer(const std::function<void(ModelTestBuilder& helper)>& build_test_case, int opset_version,
|
||||
const logging::Logger& logger, std::unique_ptr<GraphTransformer> transformer,
|
||||
TransformerLevel level, unsigned steps, const std::function<void(Graph&)>& pre_graph_checker,
|
||||
const std::function<void(Graph&)>& post_graph_checker);
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -121,8 +121,9 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
# as "FP16 safe", in order to insert/(re)move cast operations before/after to perform such operations in reduced (16-bit) precision.
|
||||
# - If propagate_cast_ops_level is positive, 1 or 2, then in addition to opcode codes specified by propagate_cast_ops_allow use onnxruntime
|
||||
# predetermined list of opcodes considered safe to move before/after cast operation.
|
||||
# - Onnxruntime Level 1 predetermind "FP16 safe" opcodes include only opcode that do not perform any computation such as Transpose, Split, Reshape, etc.
|
||||
# whereas Level 2 perdetermined "FP16 safe" opcodes include opcodes that perform computation using contrib ops, GeLU, Dropout, LayerNormalization, etc.
|
||||
# - Onnxruntime Level 1 predetermind "FP16 safe" opcodes include only opcode that do not perform any computation such as Transpose, Split, Reshape, etc.,
|
||||
# or the computation is actual in Float such as GeLU, etc.
|
||||
# whereas Level 2 perdetermined "FP16 safe" opcodes include opcodes that perform computation using contrib ops, Dropout, LayerNormalization, etc.
|
||||
self._propagate_cast_ops_level = 1
|
||||
# List of opcodes to be considered safe to move before/after cast operation if propagate_cast_ops_level is zero.
|
||||
self._propagate_cast_ops_allow = []
|
||||
|
|
|
|||
|
|
@ -17,6 +17,14 @@ struct OP_GeluGrad : public CtxGeluGrad {
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OP_GeluGrad<half> : public CtxGeluGrad {
|
||||
__device__ __inline__ half operator()(const half& dy, const half& x) const {
|
||||
return static_cast<half>(
|
||||
ComputeGeluGradScalar(static_cast<float>(dy), static_cast<float>(x), gelu_computation_mode::Default{}));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct OP_FastGeluGrad : public CtxGeluGrad {
|
||||
__device__ __inline__ T operator()(const T& dy, const T& x) const {
|
||||
|
|
|
|||
Loading…
Reference in a new issue