mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
Support training_mode flag in eval (#4324)
* add training_mode feed for evaluation to support opset12
This commit is contained in:
parent
71aec2adcb
commit
b156ae4448
12 changed files with 123 additions and 52 deletions
|
|
@ -26,8 +26,10 @@ struct OrtRunOptions {
|
|||
// So it is possible that only some of the nodes are executed.
|
||||
bool only_execute_path_to_fetches = false;
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
// Set to 'true' to run in training mode.
|
||||
bool training_mode = false;
|
||||
bool training_mode = true;
|
||||
#endif
|
||||
|
||||
OrtRunOptions() = default;
|
||||
~OrtRunOptions() = default;
|
||||
|
|
|
|||
|
|
@ -252,7 +252,7 @@ class InferenceSession {
|
|||
* @return OK if success.
|
||||
*/
|
||||
common::Status Run(const NameMLValMap& feeds, const std::vector<std::string>& output_names,
|
||||
std::vector<OrtValue>* p_fetches) ORT_MUST_USE_RESULT;
|
||||
std::vector<OrtValue>* p_fetches) ORT_MUST_USE_RESULT;
|
||||
|
||||
/**
|
||||
* See Run(const NameMLValMap& feeds, const std::vector<std::string>& output_names, std::vector<OrtValue>* p_fetches)
|
||||
|
|
@ -271,7 +271,7 @@ class InferenceSession {
|
|||
common::Status NewIOBinding(std::unique_ptr<IOBinding>* io_binding) ORT_MUST_USE_RESULT;
|
||||
|
||||
virtual common::Status Run(const RunOptions& run_options, IOBinding& io_binding) ORT_MUST_USE_RESULT;
|
||||
virtual common::Status Run(IOBinding& io_binding) ORT_MUST_USE_RESULT;
|
||||
common::Status Run(IOBinding& io_binding) ORT_MUST_USE_RESULT;
|
||||
|
||||
/**
|
||||
* @return pair.first = OK; FAIL otherwise. pair.second is non-NULL when pair.first = OK.
|
||||
|
|
|
|||
|
|
@ -815,10 +815,12 @@ Applies to a particular Run() invocation. Default is 0.)pbdoc")
|
|||
.def_readwrite("terminate", &RunOptions::terminate,
|
||||
R"pbdoc(Set to True to terminate any currently executing calls that are using this
|
||||
RunOptions instance. The individual calls will exit gracefully and return an error status.)pbdoc")
|
||||
.def_readwrite("only_execute_path_to_fetches", &RunOptions::only_execute_path_to_fetches,
|
||||
R"pbdoc(Only execute the nodes needed by fetch list)pbdoc")
|
||||
#ifdef ENABLE_TRAINING
|
||||
.def_readwrite("training_mode", &RunOptions::training_mode,
|
||||
R"pbdoc(Choose to run in training or inferencing mode)pbdoc");
|
||||
R"pbdoc(Choose to run in training or inferencing mode)pbdoc")
|
||||
#endif
|
||||
.def_readwrite("only_execute_path_to_fetches", &RunOptions::only_execute_path_to_fetches,
|
||||
R"pbdoc(Only execute the nodes needed by fetch list)pbdoc");
|
||||
|
||||
py::class_<ModelMetadata>(m, "ModelMetadata", R"pbdoc(Pre-defined and custom metadata about the model.
|
||||
It is usually used to identify the model used to run the prediction and
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ Status AddToExistingNodeArgs(
|
|||
std::vector<const NodeArg*>& nodeargs) {
|
||||
std::unordered_set<const NodeArg*> nodeargs_set(existing_nodeargs.begin(), existing_nodeargs.end());
|
||||
nodeargs = existing_nodeargs;
|
||||
|
||||
for (const auto& new_nodearg_name : new_nodearg_names) {
|
||||
const auto* new_nodearg = graph.GetNodeArg(new_nodearg_name);
|
||||
ORT_RETURN_IF_NOT(
|
||||
|
|
|
|||
|
|
@ -122,6 +122,8 @@ void TrainingSession::FilterUnusedWeights(const std::unordered_set<std::string>&
|
|||
}
|
||||
}
|
||||
|
||||
const std::string TrainingSession::training_mode_string_ = "training_mode";
|
||||
|
||||
Status TrainingSession::ConfigureForTraining(
|
||||
const TrainingConfiguration& config, TrainingConfigurationResult& config_result_out) {
|
||||
ORT_RETURN_IF(
|
||||
|
|
@ -309,8 +311,8 @@ Status TrainingSession::ConfigureForTraining(
|
|||
}
|
||||
}
|
||||
|
||||
// Set eval feed names for Dropout ratio.
|
||||
ORT_RETURN_IF_ERROR(SetDropoutEvalFeedNames());
|
||||
// Set eval feed names for nodes that differ between training and inferencing.
|
||||
ORT_RETURN_IF_ERROR(SetEvalFeedNames());
|
||||
|
||||
// add Tensorboard
|
||||
if (config.tensorboard_config.has_value()) {
|
||||
|
|
@ -827,17 +829,37 @@ bool TrainingSession::IsGraphOutputFp32Node(const std::string& output_name) cons
|
|||
common::Status TrainingSession::Run(const RunOptions& run_options, IOBinding& io_binding) {
|
||||
// Override initializers in eval mode.
|
||||
if (!run_options.training_mode) {
|
||||
// override all dropout raiots to 0
|
||||
for (auto& drop_ratio : dropout_eval_feeds_) {
|
||||
OrtValue feed_value;
|
||||
// We allocate on CPU first, copy will be taken care off downstream.
|
||||
const auto& session_state = GetSessionState();
|
||||
auto default_cpu_alloc_info = session_state.GetExecutionProviders().GetDefaultCpuMemoryInfo();
|
||||
auto cpu_allocator = session_state.GetAllocator(default_cpu_alloc_info);
|
||||
|
||||
feed_value = onnxruntime::MakeScalarMLValue<float>(cpu_allocator, 0.f, true /*is_1d*/);
|
||||
std::vector<std::pair<std::string, OrtValue>> new_feeds;
|
||||
if (!dropout_eval_feeds_.empty()) {
|
||||
// override all dropout ratios to 0
|
||||
for (auto& drop_ratio : dropout_eval_feeds_) {
|
||||
OrtValue feed_value;
|
||||
// We allocate on CPU first, copy will be taken care of downstream.
|
||||
auto cpu_allocator = GetSessionState().GetExecutionProviders()
|
||||
.Get(onnxruntime::kCpuExecutionProvider)
|
||||
->GetAllocator(0, OrtMemTypeDefault);
|
||||
feed_value = onnxruntime::MakeScalarMLValue<float>(cpu_allocator, 0.f, true /*is_1d*/);
|
||||
// Bind new feed to graph input.
|
||||
new_feeds.emplace_back(drop_ratio, feed_value);
|
||||
}
|
||||
}
|
||||
else {
|
||||
auto& input_names = io_binding.GetInputNames();
|
||||
if (GetSessionState().GetInputNodeInfoMap().find(training_mode_string_) != GetSessionState().GetInputNodeInfoMap().end() &&
|
||||
std::find(input_names.begin(), input_names.end(), training_mode_string_) == input_names.end()) {
|
||||
// Set training_mode input to false
|
||||
OrtValue training_mode_feed_value;
|
||||
// We allocate on CPU first, copy will be taken care of downstream.
|
||||
auto cpu_allocator = GetSessionState().GetExecutionProviders()
|
||||
.Get(onnxruntime::kCpuExecutionProvider)
|
||||
->GetAllocator(0, OrtMemTypeDefault);
|
||||
training_mode_feed_value = onnxruntime::MakeScalarMLValue<bool>(cpu_allocator, false, true /*is_1d*/);
|
||||
new_feeds.emplace_back(training_mode_string_, training_mode_feed_value);
|
||||
}
|
||||
}
|
||||
for (auto& new_feed : new_feeds) {
|
||||
// Bind new feed to graph input.
|
||||
ORT_RETURN_IF_ERROR(io_binding.BindInput(drop_ratio, feed_value));
|
||||
ORT_RETURN_IF_ERROR(io_binding.BindInput(new_feed.first, new_feed.second));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -845,33 +867,50 @@ common::Status TrainingSession::Run(const RunOptions& run_options, IOBinding& io
|
|||
return InferenceSession::Run(run_options, io_binding);
|
||||
}
|
||||
|
||||
common::Status TrainingSession::Run(IOBinding& io_binding) {
|
||||
RunOptions run_options;
|
||||
// Set training_mode to true in training session by default.
|
||||
run_options.training_mode = true;
|
||||
return Run(run_options, io_binding);
|
||||
}
|
||||
|
||||
static const std::unordered_set<std::string> Dropout_Nodes = {
|
||||
static const std::unordered_set<std::string> Nodes_Need_Eval_Feeds = {
|
||||
// TODO remove this once ONNX TrainableDropout is completely deprecated.
|
||||
"TrainableDropout",
|
||||
"Dropout",
|
||||
};
|
||||
// TODO remove this once ONNX properly supports training_mode input.
|
||||
Status TrainingSession::SetDropoutEvalFeedNames() {
|
||||
Status TrainingSession::SetEvalFeedNames() {
|
||||
Graph& graph = model_->MainGraph();
|
||||
|
||||
// add ratio node to graph input for overriding.
|
||||
GraphAugmenter::GraphDefs defs{};
|
||||
|
||||
for (const auto& node : graph.Nodes()) {
|
||||
auto it = Dropout_Nodes.find(node.OpType());
|
||||
if (it != Dropout_Nodes.cend()) {
|
||||
auto& ratio_name = node.InputDefs()[1]->Name();
|
||||
dropout_eval_feeds_.insert(ratio_name);
|
||||
ORT_ENFORCE(model_->MainGraph().GetProducerNode(ratio_name) == nullptr,
|
||||
"Input: " + ratio_name + " should not have any producer node.");
|
||||
defs.AddGraphInputs({ratio_name});
|
||||
for (auto& node : graph.Nodes()) {
|
||||
auto it = Nodes_Need_Eval_Feeds.find(node.OpType());
|
||||
if(it != Nodes_Need_Eval_Feeds.cend()) {
|
||||
// The opset is < 12, add each ratio input to graph inputs for overriding.
|
||||
// Needs to be removed when TrainableDropout is deprecated.
|
||||
if(it->compare("TrainableDropout") == 0) {
|
||||
auto& ratio_name = node.InputDefs()[1]->Name();
|
||||
dropout_eval_feeds_.insert(ratio_name);
|
||||
ORT_ENFORCE(model_->MainGraph().GetProducerNode(ratio_name) == nullptr,
|
||||
"Input: " + ratio_name + " should not have any producer node.");
|
||||
defs.AddGraphInputs({ratio_name});
|
||||
}
|
||||
// Found an opset-12 dropout node, replace initializer name.
|
||||
else if(node.InputArgCount().size() > 2) {
|
||||
auto& mode_input = node.MutableInputDefs()[2];
|
||||
const ONNX_NAMESPACE::TensorProto* mode_initializer = nullptr;
|
||||
if (!graph.GetInitializedTensor(training_mode_string_, mode_initializer)) {
|
||||
// training_mode initializer has not been added before, add it here.
|
||||
// Ideally we want only 1 training_mode initializer to control all relevant nodes.
|
||||
const ONNX_NAMESPACE::TensorProto* original_mode_initializer = nullptr;
|
||||
ORT_ENFORCE(graph.GetInitializedTensor(mode_input->Name(), original_mode_initializer) == true,
|
||||
"Dropout's input: " + mode_input->Name() + " must be an initializer.");
|
||||
ONNX_NAMESPACE::TensorProto new_mode_initializer(*original_mode_initializer);
|
||||
new_mode_initializer.set_name(training_mode_string_);
|
||||
defs.AddInitializers({new_mode_initializer});
|
||||
}
|
||||
mode_input = &model_->MainGraph().GetOrCreateNodeArg(training_mode_string_, mode_input->TypeAsProto());
|
||||
// Set training_mode as graph input if any node that needs eval feed is found,
|
||||
// it's okay to add it multiple times since it will be de-dup'ed downstream.
|
||||
defs.AddGraphInputs({training_mode_string_});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(GraphAugmenter::AugmentGraph(graph, defs));
|
||||
return DoPostLoadProcessing(*model_);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -301,12 +301,11 @@ class TrainingSession : public InferenceSession {
|
|||
* @return The list of feed names.
|
||||
*/
|
||||
std::unordered_set<std::string> GetDropoutEvalFeeds() const { return dropout_eval_feeds_; }
|
||||
|
||||
/** Override Run function in InferenceSession to inject some training-specific logics **/
|
||||
using InferenceSession::Run; // For overload resolution.
|
||||
common::Status Run(const RunOptions& run_options, IOBinding& io_binding) override;
|
||||
|
||||
common::Status Run(IOBinding& io_binding) override;
|
||||
|
||||
private:
|
||||
/** Configures the loss function.
|
||||
The loss function can either be provided externally or built from the provided loss function information.
|
||||
|
|
@ -446,7 +445,7 @@ class TrainingSession : public InferenceSession {
|
|||
|
||||
std::unordered_set<std::string> GetStateTensorNames() const;
|
||||
|
||||
common::Status SetDropoutEvalFeedNames();
|
||||
common::Status SetEvalFeedNames();
|
||||
|
||||
NameMLValMap GetWeights() const;
|
||||
|
||||
|
|
@ -479,6 +478,7 @@ class TrainingSession : public InferenceSession {
|
|||
std::unordered_map<std::string, OptimizerNodeConfig> opt_configs_;
|
||||
|
||||
GradientGraphConfiguration gradient_graph_config_;
|
||||
static const std::string training_mode_string_;
|
||||
};
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -662,14 +662,14 @@ void TrainingRunner::RunWithUpdate(VectorString& feed_names,
|
|||
#else
|
||||
ORT_UNUSED_PARAMETER(step);
|
||||
#endif
|
||||
RunOptions run_options;
|
||||
status = session_.Run(
|
||||
RunOptions(),
|
||||
pipeline_worker_pool_.worker_states[worker_id].feed_names,
|
||||
pipeline_worker_pool_.worker_states[worker_id].feeds,
|
||||
pipeline_worker_pool_.worker_states[worker_id].fetch_names,
|
||||
&(pipeline_worker_pool_.worker_states[worker_id].fetches));
|
||||
},
|
||||
worker_id, step_);
|
||||
run_options,
|
||||
pipeline_worker_pool_.worker_states[worker_id].feed_names,
|
||||
pipeline_worker_pool_.worker_states[worker_id].feeds,
|
||||
pipeline_worker_pool_.worker_states[worker_id].fetch_names,
|
||||
&(pipeline_worker_pool_.worker_states[worker_id].fetches));
|
||||
}, worker_id, step_);
|
||||
|
||||
// Wait all workers to finish this round of pipeline parallelism.
|
||||
// The last batch in a pipeline collects gradient and update the model.
|
||||
|
|
@ -751,6 +751,7 @@ void TrainingRunner::RunWithoutUpdate(VectorString& feed_names,
|
|||
#endif
|
||||
RunOptions run_options;
|
||||
run_options.only_execute_path_to_fetches = true;
|
||||
run_options.training_mode = true;
|
||||
auto status = session_.Run(
|
||||
run_options,
|
||||
pipeline_worker_pool_.worker_states[worker_id].feed_names,
|
||||
|
|
@ -1095,7 +1096,7 @@ Status TrainingRunner::EndTraining(IDataLoader* data_loader) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TrainingRunner::Evaluate(InferenceSession& session, IDataLoader& data_loader) {
|
||||
Status TrainingRunner::Evaluate(TrainingSession& session, IDataLoader& data_loader) {
|
||||
if (params_.skip_evaluation) {
|
||||
printf("Skipping evaluation...\n");
|
||||
return Status::OK();
|
||||
|
|
@ -1139,6 +1140,26 @@ Status TrainingRunner::Evaluate(InferenceSession& session, IDataLoader& data_loa
|
|||
batch_idx,
|
||||
feed_names,
|
||||
feeds);
|
||||
if (!session.GetDropoutEvalFeeds().empty()) {
|
||||
float eval_ratio = 0.0f;
|
||||
for (auto& dropout_ratio : session.GetDropoutEvalFeeds()) {
|
||||
feed_names.push_back(dropout_ratio);
|
||||
OrtValue ratio_val;
|
||||
TrainingUtil::CreateCpuMLScalar(eval_ratio, &ratio_val, input_allocator_);
|
||||
feeds.push_back(ratio_val);
|
||||
}
|
||||
}
|
||||
const std::string training_mode_string = "training_mode";
|
||||
auto input_list = session.GetOverridableInitializers().second;
|
||||
for (auto input : *input_list) {
|
||||
if(input->Name().compare(training_mode_string) == 0) {
|
||||
feed_names.push_back("training_mode");
|
||||
OrtValue mode_val;
|
||||
TrainingUtil::CreateCpuMLScalar(false, &mode_val, input_allocator_);
|
||||
feeds.push_back(mode_val);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
PrepareFetchNamesAndFetches(EvaluateStep,
|
||||
fetch_names,
|
||||
|
|
@ -1159,6 +1180,7 @@ Status TrainingRunner::Evaluate(InferenceSession& session, IDataLoader& data_loa
|
|||
pipeline_worker_pool_.workers[worker_id] = std::thread([&]() {
|
||||
RunOptions run_options;
|
||||
run_options.only_execute_path_to_fetches = true;
|
||||
run_options.training_mode = false;
|
||||
status = session.Run(
|
||||
run_options,
|
||||
feed_names,
|
||||
|
|
|
|||
|
|
@ -214,8 +214,8 @@ class TrainingRunner {
|
|||
std::vector<MLValue>& feeds,
|
||||
size_t& gradient_accumulation_step_count);
|
||||
Status TrainingLoop(IDataLoader& training_data_loader, IDataLoader* test_data_loader,
|
||||
const MapStringToString& mapped_dimensions);
|
||||
Status Evaluate(InferenceSession& session, IDataLoader& data_loader);
|
||||
const MapStringToString& mapped_dimensions);
|
||||
Status Evaluate(TrainingSession& session, IDataLoader& data_loader);
|
||||
|
||||
Status SaveCheckpoint(const PathString& checkpoint_path);
|
||||
Status LoadCheckpoint(const PathString& checkpoint_path);
|
||||
|
|
|
|||
|
|
@ -902,7 +902,6 @@ class ORTTrainer():
|
|||
elif self.current_step % self.gradient_accumulation_steps != 0:
|
||||
run_options = ort.RunOptions()
|
||||
run_options.only_execute_path_to_fetches = True
|
||||
run_options.training_mode = True
|
||||
output_desc = self.output_desc_with_group_accumulated_gradients
|
||||
elif self.use_mixed_precision:
|
||||
has_if_all_finite = True
|
||||
|
|
|
|||
|
|
@ -97,6 +97,7 @@ static std::unique_ptr<TrainingSession> RunTrainingSessionWithChecks(
|
|||
RunOptions run_options;
|
||||
run_options.run_log_verbosity_level = so.session_log_verbosity_level;
|
||||
run_options.run_tag = so.session_logid;
|
||||
run_options.training_mode = true;
|
||||
|
||||
// Create dummy feeds
|
||||
std::vector<int64_t> image_dims = {1, 784};
|
||||
|
|
@ -317,6 +318,7 @@ static void RunBertTrainingWithChecks(
|
|||
RunOptions run_options;
|
||||
run_options.run_log_verbosity_level = so.session_log_verbosity_level;
|
||||
run_options.run_tag = so.session_logid;
|
||||
run_options.training_mode = true;
|
||||
|
||||
// Creating feeds
|
||||
int batch_size = 13;
|
||||
|
|
@ -1390,6 +1392,7 @@ TEST(GradientGraphBuilderTest, TrainingSession_WithPipeline) {
|
|||
|
||||
sub_sess.run_options.run_log_verbosity_level = sub_sess.so.session_log_verbosity_level;
|
||||
sub_sess.run_options.run_tag = sub_sess.so.session_logid;
|
||||
sub_sess.run_options.training_mode = true;
|
||||
|
||||
sub_sess.sess = onnxruntime::make_unique<TrainingSession>(sub_sess.so, *env);
|
||||
ASSERT_STATUS_OK(sub_sess.sess->Load(sub_model_files[sub_id]));
|
||||
|
|
|
|||
|
|
@ -435,6 +435,7 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {
|
|||
|
||||
// Now run
|
||||
RunOptions run_options;
|
||||
run_options.training_mode = true;
|
||||
st = session_object.Run(run_options, feeds, output_names, &expected_ort_values);
|
||||
|
||||
EXPECT_TRUE(st.IsOK());
|
||||
|
|
@ -466,6 +467,7 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {
|
|||
|
||||
// Now run
|
||||
RunOptions run_options;
|
||||
run_options.training_mode = true;
|
||||
st = session_object.Run(run_options, feeds, output_names, &actual_ort_values);
|
||||
|
||||
EXPECT_TRUE(st.IsOK());
|
||||
|
|
@ -562,6 +564,7 @@ TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionCorrectnessTest)
|
|||
|
||||
// Now run
|
||||
RunOptions run_options;
|
||||
run_options.training_mode = true;
|
||||
st = session_object.Run(run_options, feeds, output_names, &expected_ort_values);
|
||||
EXPECT_TRUE(st.IsOK());
|
||||
}
|
||||
|
|
@ -596,6 +599,7 @@ TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionCorrectnessTest)
|
|||
|
||||
// Now run
|
||||
RunOptions run_options;
|
||||
run_options.training_mode = true;
|
||||
st = session_object.Run(run_options, feeds, output_names, &actual_ort_values);
|
||||
EXPECT_TRUE(st.IsOK());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ TwoDArray OpFunctionTester::RunFunctionBodyGraphOnCPU() {
|
|||
RunOptions run_options;
|
||||
run_options.run_tag = op_;
|
||||
run_options.run_log_verbosity_level = 1;
|
||||
run_options.training_mode = true;
|
||||
|
||||
std::vector<MLValue> cpu_fetches;
|
||||
status = cpu_session_object.Run(run_options, feeds, output_names, &cpu_fetches);
|
||||
|
|
|
|||
Loading…
Reference in a new issue