diff --git a/include/onnxruntime/core/framework/run_options.h b/include/onnxruntime/core/framework/run_options.h index 50d3e39e03..3b8d3b1bf3 100644 --- a/include/onnxruntime/core/framework/run_options.h +++ b/include/onnxruntime/core/framework/run_options.h @@ -26,8 +26,10 @@ struct OrtRunOptions { // So it is possible that only some of the nodes are executed. bool only_execute_path_to_fetches = false; +#ifdef ENABLE_TRAINING // Set to 'true' to run in training mode. - bool training_mode = false; + bool training_mode = true; +#endif OrtRunOptions() = default; ~OrtRunOptions() = default; diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 16bd3b653a..0d939c82e5 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -252,7 +252,7 @@ class InferenceSession { * @return OK if success. */ common::Status Run(const NameMLValMap& feeds, const std::vector& output_names, - std::vector* p_fetches) ORT_MUST_USE_RESULT; + std::vector* p_fetches) ORT_MUST_USE_RESULT; /** * See Run(const NameMLValMap& feeds, const std::vector& output_names, std::vector* p_fetches) @@ -271,7 +271,7 @@ class InferenceSession { common::Status NewIOBinding(std::unique_ptr* io_binding) ORT_MUST_USE_RESULT; virtual common::Status Run(const RunOptions& run_options, IOBinding& io_binding) ORT_MUST_USE_RESULT; - virtual common::Status Run(IOBinding& io_binding) ORT_MUST_USE_RESULT; + common::Status Run(IOBinding& io_binding) ORT_MUST_USE_RESULT; /** * @return pair.first = OK; FAIL otherwise. pair.second is non-NULL when pair.first = OK. diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index d3d983a37b..ff824c9e46 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -815,10 +815,12 @@ Applies to a particular Run() invocation. Default is 0.)pbdoc") .def_readwrite("terminate", &RunOptions::terminate, R"pbdoc(Set to True to terminate any currently executing calls that are using this RunOptions instance. The individual calls will exit gracefully and return an error status.)pbdoc") - .def_readwrite("only_execute_path_to_fetches", &RunOptions::only_execute_path_to_fetches, - R"pbdoc(Only execute the nodes needed by fetch list)pbdoc") +#ifdef ENABLE_TRAINING .def_readwrite("training_mode", &RunOptions::training_mode, - R"pbdoc(Choose to run in training or inferencing mode)pbdoc"); + R"pbdoc(Choose to run in training or inferencing mode)pbdoc") +#endif + .def_readwrite("only_execute_path_to_fetches", &RunOptions::only_execute_path_to_fetches, + R"pbdoc(Only execute the nodes needed by fetch list)pbdoc"); py::class_(m, "ModelMetadata", R"pbdoc(Pre-defined and custom metadata about the model. It is usually used to identify the model used to run the prediction and diff --git a/orttraining/orttraining/core/graph/graph_augmenter.cc b/orttraining/orttraining/core/graph/graph_augmenter.cc index 435e09c31a..fe45a9b6bd 100644 --- a/orttraining/orttraining/core/graph/graph_augmenter.cc +++ b/orttraining/orttraining/core/graph/graph_augmenter.cc @@ -18,7 +18,6 @@ Status AddToExistingNodeArgs( std::vector& nodeargs) { std::unordered_set nodeargs_set(existing_nodeargs.begin(), existing_nodeargs.end()); nodeargs = existing_nodeargs; - for (const auto& new_nodearg_name : new_nodearg_names) { const auto* new_nodearg = graph.GetNodeArg(new_nodearg_name); ORT_RETURN_IF_NOT( diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index 3cc0e46546..be9f47ab4e 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -122,6 +122,8 @@ void TrainingSession::FilterUnusedWeights(const std::unordered_set& } } +const std::string TrainingSession::training_mode_string_ = "training_mode"; + Status TrainingSession::ConfigureForTraining( const TrainingConfiguration& config, TrainingConfigurationResult& config_result_out) { ORT_RETURN_IF( @@ -309,8 +311,8 @@ Status TrainingSession::ConfigureForTraining( } } - // Set eval feed names for Dropout ratio. - ORT_RETURN_IF_ERROR(SetDropoutEvalFeedNames()); + // Set eval feed names for nodes that differ between training and inferencing. + ORT_RETURN_IF_ERROR(SetEvalFeedNames()); // add Tensorboard if (config.tensorboard_config.has_value()) { @@ -827,17 +829,37 @@ bool TrainingSession::IsGraphOutputFp32Node(const std::string& output_name) cons common::Status TrainingSession::Run(const RunOptions& run_options, IOBinding& io_binding) { // Override initializers in eval mode. if (!run_options.training_mode) { - // override all dropout raiots to 0 - for (auto& drop_ratio : dropout_eval_feeds_) { - OrtValue feed_value; - // We allocate on CPU first, copy will be taken care off downstream. - const auto& session_state = GetSessionState(); - auto default_cpu_alloc_info = session_state.GetExecutionProviders().GetDefaultCpuMemoryInfo(); - auto cpu_allocator = session_state.GetAllocator(default_cpu_alloc_info); - - feed_value = onnxruntime::MakeScalarMLValue(cpu_allocator, 0.f, true /*is_1d*/); + std::vector> new_feeds; + if (!dropout_eval_feeds_.empty()) { + // override all dropout ratios to 0 + for (auto& drop_ratio : dropout_eval_feeds_) { + OrtValue feed_value; + // We allocate on CPU first, copy will be taken care of downstream. + auto cpu_allocator = GetSessionState().GetExecutionProviders() + .Get(onnxruntime::kCpuExecutionProvider) + ->GetAllocator(0, OrtMemTypeDefault); + feed_value = onnxruntime::MakeScalarMLValue(cpu_allocator, 0.f, true /*is_1d*/); + // Bind new feed to graph input. + new_feeds.emplace_back(drop_ratio, feed_value); + } + } + else { + auto& input_names = io_binding.GetInputNames(); + if (GetSessionState().GetInputNodeInfoMap().find(training_mode_string_) != GetSessionState().GetInputNodeInfoMap().end() && + std::find(input_names.begin(), input_names.end(), training_mode_string_) == input_names.end()) { + // Set training_mode input to false + OrtValue training_mode_feed_value; + // We allocate on CPU first, copy will be taken care of downstream. + auto cpu_allocator = GetSessionState().GetExecutionProviders() + .Get(onnxruntime::kCpuExecutionProvider) + ->GetAllocator(0, OrtMemTypeDefault); + training_mode_feed_value = onnxruntime::MakeScalarMLValue(cpu_allocator, false, true /*is_1d*/); + new_feeds.emplace_back(training_mode_string_, training_mode_feed_value); + } + } + for (auto& new_feed : new_feeds) { // Bind new feed to graph input. - ORT_RETURN_IF_ERROR(io_binding.BindInput(drop_ratio, feed_value)); + ORT_RETURN_IF_ERROR(io_binding.BindInput(new_feed.first, new_feed.second)); } } @@ -845,33 +867,50 @@ common::Status TrainingSession::Run(const RunOptions& run_options, IOBinding& io return InferenceSession::Run(run_options, io_binding); } -common::Status TrainingSession::Run(IOBinding& io_binding) { - RunOptions run_options; - // Set training_mode to true in training session by default. - run_options.training_mode = true; - return Run(run_options, io_binding); -} - -static const std::unordered_set Dropout_Nodes = { +static const std::unordered_set Nodes_Need_Eval_Feeds = { + // TODO remove this once ONNX TrainableDropout is completely deprecated. "TrainableDropout", + "Dropout", }; -// TODO remove this once ONNX properly supports training_mode input. -Status TrainingSession::SetDropoutEvalFeedNames() { +Status TrainingSession::SetEvalFeedNames() { Graph& graph = model_->MainGraph(); - // add ratio node to graph input for overriding. GraphAugmenter::GraphDefs defs{}; - for (const auto& node : graph.Nodes()) { - auto it = Dropout_Nodes.find(node.OpType()); - if (it != Dropout_Nodes.cend()) { - auto& ratio_name = node.InputDefs()[1]->Name(); - dropout_eval_feeds_.insert(ratio_name); - ORT_ENFORCE(model_->MainGraph().GetProducerNode(ratio_name) == nullptr, - "Input: " + ratio_name + " should not have any producer node."); - defs.AddGraphInputs({ratio_name}); + for (auto& node : graph.Nodes()) { + auto it = Nodes_Need_Eval_Feeds.find(node.OpType()); + if(it != Nodes_Need_Eval_Feeds.cend()) { + // The opset is < 12, add each ratio input to graph inputs for overriding. + // Needs to be removed when TrainableDropout is deprecated. + if(it->compare("TrainableDropout") == 0) { + auto& ratio_name = node.InputDefs()[1]->Name(); + dropout_eval_feeds_.insert(ratio_name); + ORT_ENFORCE(model_->MainGraph().GetProducerNode(ratio_name) == nullptr, + "Input: " + ratio_name + " should not have any producer node."); + defs.AddGraphInputs({ratio_name}); + } + // Found an opset-12 dropout node, replace initializer name. + else if(node.InputArgCount().size() > 2) { + auto& mode_input = node.MutableInputDefs()[2]; + const ONNX_NAMESPACE::TensorProto* mode_initializer = nullptr; + if (!graph.GetInitializedTensor(training_mode_string_, mode_initializer)) { + // training_mode initializer has not been added before, add it here. + // Ideally we want only 1 training_mode initializer to control all relevant nodes. + const ONNX_NAMESPACE::TensorProto* original_mode_initializer = nullptr; + ORT_ENFORCE(graph.GetInitializedTensor(mode_input->Name(), original_mode_initializer) == true, + "Dropout's input: " + mode_input->Name() + " must be an initializer."); + ONNX_NAMESPACE::TensorProto new_mode_initializer(*original_mode_initializer); + new_mode_initializer.set_name(training_mode_string_); + defs.AddInitializers({new_mode_initializer}); + } + mode_input = &model_->MainGraph().GetOrCreateNodeArg(training_mode_string_, mode_input->TypeAsProto()); + // Set training_mode as graph input if any node that needs eval feed is found, + // it's okay to add it multiple times since it will be de-dup'ed downstream. + defs.AddGraphInputs({training_mode_string_}); + } } } + ORT_RETURN_IF_ERROR(GraphAugmenter::AugmentGraph(graph, defs)); return DoPostLoadProcessing(*model_); } diff --git a/orttraining/orttraining/core/session/training_session.h b/orttraining/orttraining/core/session/training_session.h index cbd1d20aef..41ebd04395 100644 --- a/orttraining/orttraining/core/session/training_session.h +++ b/orttraining/orttraining/core/session/training_session.h @@ -301,12 +301,11 @@ class TrainingSession : public InferenceSession { * @return The list of feed names. */ std::unordered_set GetDropoutEvalFeeds() const { return dropout_eval_feeds_; } + /** Override Run function in InferenceSession to inject some training-specific logics **/ using InferenceSession::Run; // For overload resolution. common::Status Run(const RunOptions& run_options, IOBinding& io_binding) override; - common::Status Run(IOBinding& io_binding) override; - private: /** Configures the loss function. The loss function can either be provided externally or built from the provided loss function information. @@ -446,7 +445,7 @@ class TrainingSession : public InferenceSession { std::unordered_set GetStateTensorNames() const; - common::Status SetDropoutEvalFeedNames(); + common::Status SetEvalFeedNames(); NameMLValMap GetWeights() const; @@ -479,6 +478,7 @@ class TrainingSession : public InferenceSession { std::unordered_map opt_configs_; GradientGraphConfiguration gradient_graph_config_; + static const std::string training_mode_string_; }; } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/models/runner/training_runner.cc b/orttraining/orttraining/models/runner/training_runner.cc index 7a4ddcdbcc..a6c1f61fb2 100644 --- a/orttraining/orttraining/models/runner/training_runner.cc +++ b/orttraining/orttraining/models/runner/training_runner.cc @@ -662,14 +662,14 @@ void TrainingRunner::RunWithUpdate(VectorString& feed_names, #else ORT_UNUSED_PARAMETER(step); #endif + RunOptions run_options; status = session_.Run( - RunOptions(), - pipeline_worker_pool_.worker_states[worker_id].feed_names, - pipeline_worker_pool_.worker_states[worker_id].feeds, - pipeline_worker_pool_.worker_states[worker_id].fetch_names, - &(pipeline_worker_pool_.worker_states[worker_id].fetches)); - }, - worker_id, step_); + run_options, + pipeline_worker_pool_.worker_states[worker_id].feed_names, + pipeline_worker_pool_.worker_states[worker_id].feeds, + pipeline_worker_pool_.worker_states[worker_id].fetch_names, + &(pipeline_worker_pool_.worker_states[worker_id].fetches)); + }, worker_id, step_); // Wait all workers to finish this round of pipeline parallelism. // The last batch in a pipeline collects gradient and update the model. @@ -751,6 +751,7 @@ void TrainingRunner::RunWithoutUpdate(VectorString& feed_names, #endif RunOptions run_options; run_options.only_execute_path_to_fetches = true; + run_options.training_mode = true; auto status = session_.Run( run_options, pipeline_worker_pool_.worker_states[worker_id].feed_names, @@ -1095,7 +1096,7 @@ Status TrainingRunner::EndTraining(IDataLoader* data_loader) { return Status::OK(); } -Status TrainingRunner::Evaluate(InferenceSession& session, IDataLoader& data_loader) { +Status TrainingRunner::Evaluate(TrainingSession& session, IDataLoader& data_loader) { if (params_.skip_evaluation) { printf("Skipping evaluation...\n"); return Status::OK(); @@ -1139,6 +1140,26 @@ Status TrainingRunner::Evaluate(InferenceSession& session, IDataLoader& data_loa batch_idx, feed_names, feeds); + if (!session.GetDropoutEvalFeeds().empty()) { + float eval_ratio = 0.0f; + for (auto& dropout_ratio : session.GetDropoutEvalFeeds()) { + feed_names.push_back(dropout_ratio); + OrtValue ratio_val; + TrainingUtil::CreateCpuMLScalar(eval_ratio, &ratio_val, input_allocator_); + feeds.push_back(ratio_val); + } + } + const std::string training_mode_string = "training_mode"; + auto input_list = session.GetOverridableInitializers().second; + for (auto input : *input_list) { + if(input->Name().compare(training_mode_string) == 0) { + feed_names.push_back("training_mode"); + OrtValue mode_val; + TrainingUtil::CreateCpuMLScalar(false, &mode_val, input_allocator_); + feeds.push_back(mode_val); + break; + } + } PrepareFetchNamesAndFetches(EvaluateStep, fetch_names, @@ -1159,6 +1180,7 @@ Status TrainingRunner::Evaluate(InferenceSession& session, IDataLoader& data_loa pipeline_worker_pool_.workers[worker_id] = std::thread([&]() { RunOptions run_options; run_options.only_execute_path_to_fetches = true; + run_options.training_mode = false; status = session.Run( run_options, feed_names, diff --git a/orttraining/orttraining/models/runner/training_runner.h b/orttraining/orttraining/models/runner/training_runner.h index 89caa421a0..65432fe5c2 100644 --- a/orttraining/orttraining/models/runner/training_runner.h +++ b/orttraining/orttraining/models/runner/training_runner.h @@ -214,8 +214,8 @@ class TrainingRunner { std::vector& feeds, size_t& gradient_accumulation_step_count); Status TrainingLoop(IDataLoader& training_data_loader, IDataLoader* test_data_loader, - const MapStringToString& mapped_dimensions); - Status Evaluate(InferenceSession& session, IDataLoader& data_loader); + const MapStringToString& mapped_dimensions); + Status Evaluate(TrainingSession& session, IDataLoader& data_loader); Status SaveCheckpoint(const PathString& checkpoint_path); Status LoadCheckpoint(const PathString& checkpoint_path); diff --git a/orttraining/orttraining/python/ort_trainer.py b/orttraining/orttraining/python/ort_trainer.py index a928a1b149..a69c41e312 100644 --- a/orttraining/orttraining/python/ort_trainer.py +++ b/orttraining/orttraining/python/ort_trainer.py @@ -902,7 +902,6 @@ class ORTTrainer(): elif self.current_step % self.gradient_accumulation_steps != 0: run_options = ort.RunOptions() run_options.only_execute_path_to_fetches = True - run_options.training_mode = True output_desc = self.output_desc_with_group_accumulated_gradients elif self.use_mixed_precision: has_if_all_finite = True diff --git a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc index 9d60d756ad..6dbb2349ef 100644 --- a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc +++ b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc @@ -97,6 +97,7 @@ static std::unique_ptr RunTrainingSessionWithChecks( RunOptions run_options; run_options.run_log_verbosity_level = so.session_log_verbosity_level; run_options.run_tag = so.session_logid; + run_options.training_mode = true; // Create dummy feeds std::vector image_dims = {1, 784}; @@ -317,6 +318,7 @@ static void RunBertTrainingWithChecks( RunOptions run_options; run_options.run_log_verbosity_level = so.session_log_verbosity_level; run_options.run_tag = so.session_logid; + run_options.training_mode = true; // Creating feeds int batch_size = 13; @@ -1390,6 +1392,7 @@ TEST(GradientGraphBuilderTest, TrainingSession_WithPipeline) { sub_sess.run_options.run_log_verbosity_level = sub_sess.so.session_log_verbosity_level; sub_sess.run_options.run_tag = sub_sess.so.session_logid; + sub_sess.run_options.training_mode = true; sub_sess.sess = onnxruntime::make_unique(sub_sess.so, *env); ASSERT_STATUS_OK(sub_sess.sess->Load(sub_model_files[sub_id])); diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index 2f4e119ec0..873c7ade7e 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -435,6 +435,7 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) { // Now run RunOptions run_options; + run_options.training_mode = true; st = session_object.Run(run_options, feeds, output_names, &expected_ort_values); EXPECT_TRUE(st.IsOK()); @@ -466,6 +467,7 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) { // Now run RunOptions run_options; + run_options.training_mode = true; st = session_object.Run(run_options, feeds, output_names, &actual_ort_values); EXPECT_TRUE(st.IsOK()); @@ -562,6 +564,7 @@ TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionCorrectnessTest) // Now run RunOptions run_options; + run_options.training_mode = true; st = session_object.Run(run_options, feeds, output_names, &expected_ort_values); EXPECT_TRUE(st.IsOK()); } @@ -596,6 +599,7 @@ TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionCorrectnessTest) // Now run RunOptions run_options; + run_options.training_mode = true; st = session_object.Run(run_options, feeds, output_names, &actual_ort_values); EXPECT_TRUE(st.IsOK()); } diff --git a/orttraining/orttraining/test/training_ops/function_op_test_utils.cc b/orttraining/orttraining/test/training_ops/function_op_test_utils.cc index 726963e63b..3ad0015498 100644 --- a/orttraining/orttraining/test/training_ops/function_op_test_utils.cc +++ b/orttraining/orttraining/test/training_ops/function_op_test_utils.cc @@ -52,6 +52,7 @@ TwoDArray OpFunctionTester::RunFunctionBodyGraphOnCPU() { RunOptions run_options; run_options.run_tag = op_; run_options.run_log_verbosity_level = 1; + run_options.training_mode = true; std::vector cpu_fetches; status = cpu_session_object.Run(run_options, feeds, output_names, &cpu_fetches);