diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu index 7988ecd42f..65a1dc992a 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu +++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu @@ -43,6 +43,13 @@ struct OP_Gelu : public CtxGelu { } }; +template <> +struct OP_Gelu : public CtxGelu { + __device__ __inline__ half operator()(const half& a) const { + return static_cast(_Gelu(static_cast(a))); + } +}; + #define UNARY_ACTIVATION_IMPL(name) \ UNARY_ACTIVATION_IMPL_DECLARATION(name) { \ UnaryElementWiseImpl(stream, \ diff --git a/onnxruntime/core/optimizer/propagate_cast_ops.cc b/onnxruntime/core/optimizer/propagate_cast_ops.cc index db084e95ef..8c772f7a91 100644 --- a/onnxruntime/core/optimizer/propagate_cast_ops.cc +++ b/onnxruntime/core/optimizer/propagate_cast_ops.cc @@ -140,9 +140,9 @@ static bool IsFP16Allow(const std::string& op_type, size_t level, const FP16Allo using OpsSetType = InlinedHashSet; static const OpsSetType level1_fp16_allow_set = - {"Expand", "Transpose", "Relu", "Reshape", "Split", "Tanh", "Squeeze", "Unsqueeze"}; + {"Expand", "Transpose", "Relu", "Reshape", "Split", "Tanh", "Squeeze", "Unsqueeze", "Gelu"}; static const OpsSetType level2_fp16_allow_set = { - "Add", "BiasGelu", "Dropout", "FastGelu", "Gather", "Gelu", "LayerNormalization", "Where"}; + "Add", "BiasGelu", "Dropout", "FastGelu", "Gather", "LayerNormalization", "Where"}; // To support new optimization levels, you need to extend the below array with a set ops for the new level static const std::array, MaxSupportedCastPropagationLevel> allowed_ops = diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index af6ed1886f..1e9b65a991 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -73,6 +73,7 @@ #include "test/common/tensor_op_test_utils.h" #include "test/compare_ortvalue.h" #include "test/framework/test_utils.h" +#include "test/optimizer/graph_transform_test_builder.h" #include "test/optimizer/graph_transform_test_fixture.h" #include "test/providers/provider_test_utils.h" #include "test/test_environment.h" @@ -4992,5 +4993,68 @@ TEST_F(GraphTransformationTests, PropagateCastOpsTests) { } } +#ifdef ENABLE_TRAINING +TEST_F(GraphTransformationTests, PropagateCastOpsTests_Gelu) { + using Strategy = GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy; + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({{2, 3, 3, 3}}); + auto* cast_out_0 = builder.MakeIntermediate(); + auto* gelu_out = builder.MakeIntermediate(); + auto* cast_out_1 = builder.MakeIntermediate(); + auto* identity_out = builder.MakeOutput(); + + builder.AddNode("Cast", {input_arg}, {cast_out_0}) + .AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); + builder.AddNode("Gelu", {cast_out_0}, {gelu_out}, kMSDomain); + builder.AddNode("Cast", {gelu_out}, {cast_out_1}) + .AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)); + builder.AddNode("Identity", {cast_out_1}, {identity_out}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + ASSERT_EQ(CountOpsInGraph(graph)["Cast"], 2); + }; + + auto post_graph_checker = [&](Graph& graph) { + ASSERT_EQ(CountOpsInGraph(graph)["Cast"], 0); + }; + + std::unique_ptr transformer = std::make_unique(Strategy::FloodFill, 1); + TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker); + } + + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({{2, -1, 3, -1}}); + auto* cast_out_0 = builder.MakeIntermediate(); + auto* gelu_out = builder.MakeIntermediate(); + auto* cast_out_1 = builder.MakeIntermediate(); + auto* identity_out = builder.MakeOutput(); + + builder.AddNode("Cast", {input_arg}, {cast_out_0}) + .AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); + builder.AddNode("Gelu", {cast_out_0}, {gelu_out}, kMSDomain); + builder.AddNode("Cast", {gelu_out}, {cast_out_1}) + .AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)); + builder.AddNode("Identity", {cast_out_1}, {identity_out}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + ASSERT_EQ(CountOpsInGraph(graph)["Cast"], 2); + }; + + auto post_graph_checker = [&](Graph& graph) { + ASSERT_EQ(CountOpsInGraph(graph)["Cast"], 2); + }; + + std::unique_ptr transformer = std::make_unique(Strategy::FloodFill, 1); + TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker); + } +} +#endif + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.cc b/onnxruntime/test/optimizer/graph_transform_test_builder.cc index 28a5e01af9..0bb14c15f4 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.cc @@ -97,5 +97,28 @@ void TransformerTester(const std::function& buil } } +void TestGraphTransformer(const std::function& build_test_case, int opset_version, + const logging::Logger& logger, std::unique_ptr transformer, + TransformerLevel level, unsigned steps, const std::function& pre_graph_checker, + const std::function& post_graph_checker) { + // Build the model for this test. + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = opset_version; + domain_to_version[kMSDomain] = 1; + Model model("TransformerTester", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, {}, logger); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + ASSERT_TRUE(build_test_case); + build_test_case(helper); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(graph.Resolve()); + pre_graph_checker(graph); + onnxruntime::GraphTransformerManager graph_transformation_mgr{steps}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(transformer), level)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, level, logger)); + post_graph_checker(graph); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.h b/onnxruntime/test/optimizer/graph_transform_test_builder.h index a721a940d6..c1dbd76f34 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.h +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.h @@ -68,6 +68,23 @@ class ModelTestBuilder { return MakeInput(shape, data); } + template + NodeArg* MakeInput(const std::optional>& shape) { + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(utils::ToTensorProtoElementType()); + if (shape != std::nullopt) { + type_proto.mutable_tensor_type()->mutable_shape(); + for (auto& d : *shape) { + auto dim = type_proto.mutable_tensor_type()->mutable_shape()->add_dim(); + if (d != -1) { + dim->set_dim_value(d); + } + } + } + std::string name = graph_.GenerateNodeArgName("input"); + return &graph_.GetOrCreateNodeArg(name, &type_proto); + } + NodeArg* MakeOutput() { std::string name = graph_.GenerateNodeArgName("output"); output_names_.push_back(name); @@ -285,5 +302,21 @@ void TransformerTester(const std::function& buil const std::function& add_session_options = {}, const InlinedHashSet& disabled_optimizers = {}); +/** + * @brief Apply a GraphTransformer to a graph, and run graph checkers before and after applying the transformer. + * + * @param build_test_case The function to build a graph for testing + * @param opset_version The OpSet version of the graph + * @param logger The logger + * @param transformer The GraphTransformer to be applied + * @param level The transformer level on which the transformer will be applied + * @param steps The step count of the GraphTransformerManager + * @param pre_graph_checker The graph checker function before applying the transformer + * @param post_graph_checker The graph checker function after applying the transformer + */ +void TestGraphTransformer(const std::function& build_test_case, int opset_version, + const logging::Logger& logger, std::unique_ptr transformer, + TransformerLevel level, unsigned steps, const std::function& pre_graph_checker, + const std::function& post_graph_checker); } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 47c0608311..4eb46dde55 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -121,8 +121,9 @@ class GraphExecutionManager(GraphExecutionInterface): # as "FP16 safe", in order to insert/(re)move cast operations before/after to perform such operations in reduced (16-bit) precision. # - If propagate_cast_ops_level is positive, 1 or 2, then in addition to opcode codes specified by propagate_cast_ops_allow use onnxruntime # predetermined list of opcodes considered safe to move before/after cast operation. - # - Onnxruntime Level 1 predetermind "FP16 safe" opcodes include only opcode that do not perform any computation such as Transpose, Split, Reshape, etc. - # whereas Level 2 perdetermined "FP16 safe" opcodes include opcodes that perform computation using contrib ops, GeLU, Dropout, LayerNormalization, etc. + # - Onnxruntime Level 1 predetermind "FP16 safe" opcodes include only opcode that do not perform any computation such as Transpose, Split, Reshape, etc., + # or the computation is actual in Float such as GeLU, etc. + # whereas Level 2 perdetermined "FP16 safe" opcodes include opcodes that perform computation using contrib ops, Dropout, LayerNormalization, etc. self._propagate_cast_ops_level = 1 # List of opcodes to be considered safe to move before/after cast operation if propagate_cast_ops_level is zero. self._propagate_cast_ops_allow = [] diff --git a/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu b/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu index befcda3dec..1633cc45b7 100644 --- a/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu @@ -17,6 +17,14 @@ struct OP_GeluGrad : public CtxGeluGrad { } }; +template <> +struct OP_GeluGrad : public CtxGeluGrad { + __device__ __inline__ half operator()(const half& dy, const half& x) const { + return static_cast( + ComputeGeluGradScalar(static_cast(dy), static_cast(x), gelu_computation_mode::Default{})); + } +}; + template struct OP_FastGeluGrad : public CtxGeluGrad { __device__ __inline__ T operator()(const T& dy, const T& x) const {