Support training_mode flag in eval (#4324)

* add training_mode feed for evaluation to support opset12
This commit is contained in:
Tixxx 2020-07-08 10:38:54 -07:00 committed by GitHub
parent 71aec2adcb
commit b156ae4448
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 123 additions and 52 deletions

View file

@ -26,8 +26,10 @@ struct OrtRunOptions {
// So it is possible that only some of the nodes are executed.
bool only_execute_path_to_fetches = false;
#ifdef ENABLE_TRAINING
// Set to 'true' to run in training mode.
bool training_mode = false;
bool training_mode = true;
#endif
OrtRunOptions() = default;
~OrtRunOptions() = default;

View file

@ -252,7 +252,7 @@ class InferenceSession {
* @return OK if success.
*/
common::Status Run(const NameMLValMap& feeds, const std::vector<std::string>& output_names,
std::vector<OrtValue>* p_fetches) ORT_MUST_USE_RESULT;
std::vector<OrtValue>* p_fetches) ORT_MUST_USE_RESULT;
/**
* See Run(const NameMLValMap& feeds, const std::vector<std::string>& output_names, std::vector<OrtValue>* p_fetches)
@ -271,7 +271,7 @@ class InferenceSession {
common::Status NewIOBinding(std::unique_ptr<IOBinding>* io_binding) ORT_MUST_USE_RESULT;
virtual common::Status Run(const RunOptions& run_options, IOBinding& io_binding) ORT_MUST_USE_RESULT;
virtual common::Status Run(IOBinding& io_binding) ORT_MUST_USE_RESULT;
common::Status Run(IOBinding& io_binding) ORT_MUST_USE_RESULT;
/**
* @return pair.first = OK; FAIL otherwise. pair.second is non-NULL when pair.first = OK.

View file

@ -815,10 +815,12 @@ Applies to a particular Run() invocation. Default is 0.)pbdoc")
.def_readwrite("terminate", &RunOptions::terminate,
R"pbdoc(Set to True to terminate any currently executing calls that are using this
RunOptions instance. The individual calls will exit gracefully and return an error status.)pbdoc")
.def_readwrite("only_execute_path_to_fetches", &RunOptions::only_execute_path_to_fetches,
R"pbdoc(Only execute the nodes needed by fetch list)pbdoc")
#ifdef ENABLE_TRAINING
.def_readwrite("training_mode", &RunOptions::training_mode,
R"pbdoc(Choose to run in training or inferencing mode)pbdoc");
R"pbdoc(Choose to run in training or inferencing mode)pbdoc")
#endif
.def_readwrite("only_execute_path_to_fetches", &RunOptions::only_execute_path_to_fetches,
R"pbdoc(Only execute the nodes needed by fetch list)pbdoc");
py::class_<ModelMetadata>(m, "ModelMetadata", R"pbdoc(Pre-defined and custom metadata about the model.
It is usually used to identify the model used to run the prediction and

View file

@ -18,7 +18,6 @@ Status AddToExistingNodeArgs(
std::vector<const NodeArg*>& nodeargs) {
std::unordered_set<const NodeArg*> nodeargs_set(existing_nodeargs.begin(), existing_nodeargs.end());
nodeargs = existing_nodeargs;
for (const auto& new_nodearg_name : new_nodearg_names) {
const auto* new_nodearg = graph.GetNodeArg(new_nodearg_name);
ORT_RETURN_IF_NOT(

View file

@ -122,6 +122,8 @@ void TrainingSession::FilterUnusedWeights(const std::unordered_set<std::string>&
}
}
const std::string TrainingSession::training_mode_string_ = "training_mode";
Status TrainingSession::ConfigureForTraining(
const TrainingConfiguration& config, TrainingConfigurationResult& config_result_out) {
ORT_RETURN_IF(
@ -309,8 +311,8 @@ Status TrainingSession::ConfigureForTraining(
}
}
// Set eval feed names for Dropout ratio.
ORT_RETURN_IF_ERROR(SetDropoutEvalFeedNames());
// Set eval feed names for nodes that differ between training and inferencing.
ORT_RETURN_IF_ERROR(SetEvalFeedNames());
// add Tensorboard
if (config.tensorboard_config.has_value()) {
@ -827,17 +829,37 @@ bool TrainingSession::IsGraphOutputFp32Node(const std::string& output_name) cons
common::Status TrainingSession::Run(const RunOptions& run_options, IOBinding& io_binding) {
// Override initializers in eval mode.
if (!run_options.training_mode) {
// override all dropout raiots to 0
for (auto& drop_ratio : dropout_eval_feeds_) {
OrtValue feed_value;
// We allocate on CPU first, copy will be taken care off downstream.
const auto& session_state = GetSessionState();
auto default_cpu_alloc_info = session_state.GetExecutionProviders().GetDefaultCpuMemoryInfo();
auto cpu_allocator = session_state.GetAllocator(default_cpu_alloc_info);
feed_value = onnxruntime::MakeScalarMLValue<float>(cpu_allocator, 0.f, true /*is_1d*/);
std::vector<std::pair<std::string, OrtValue>> new_feeds;
if (!dropout_eval_feeds_.empty()) {
// override all dropout ratios to 0
for (auto& drop_ratio : dropout_eval_feeds_) {
OrtValue feed_value;
// We allocate on CPU first, copy will be taken care of downstream.
auto cpu_allocator = GetSessionState().GetExecutionProviders()
.Get(onnxruntime::kCpuExecutionProvider)
->GetAllocator(0, OrtMemTypeDefault);
feed_value = onnxruntime::MakeScalarMLValue<float>(cpu_allocator, 0.f, true /*is_1d*/);
// Bind new feed to graph input.
new_feeds.emplace_back(drop_ratio, feed_value);
}
}
else {
auto& input_names = io_binding.GetInputNames();
if (GetSessionState().GetInputNodeInfoMap().find(training_mode_string_) != GetSessionState().GetInputNodeInfoMap().end() &&
std::find(input_names.begin(), input_names.end(), training_mode_string_) == input_names.end()) {
// Set training_mode input to false
OrtValue training_mode_feed_value;
// We allocate on CPU first, copy will be taken care of downstream.
auto cpu_allocator = GetSessionState().GetExecutionProviders()
.Get(onnxruntime::kCpuExecutionProvider)
->GetAllocator(0, OrtMemTypeDefault);
training_mode_feed_value = onnxruntime::MakeScalarMLValue<bool>(cpu_allocator, false, true /*is_1d*/);
new_feeds.emplace_back(training_mode_string_, training_mode_feed_value);
}
}
for (auto& new_feed : new_feeds) {
// Bind new feed to graph input.
ORT_RETURN_IF_ERROR(io_binding.BindInput(drop_ratio, feed_value));
ORT_RETURN_IF_ERROR(io_binding.BindInput(new_feed.first, new_feed.second));
}
}
@ -845,33 +867,50 @@ common::Status TrainingSession::Run(const RunOptions& run_options, IOBinding& io
return InferenceSession::Run(run_options, io_binding);
}
common::Status TrainingSession::Run(IOBinding& io_binding) {
RunOptions run_options;
// Set training_mode to true in training session by default.
run_options.training_mode = true;
return Run(run_options, io_binding);
}
static const std::unordered_set<std::string> Dropout_Nodes = {
static const std::unordered_set<std::string> Nodes_Need_Eval_Feeds = {
// TODO remove this once ONNX TrainableDropout is completely deprecated.
"TrainableDropout",
"Dropout",
};
// TODO remove this once ONNX properly supports training_mode input.
Status TrainingSession::SetDropoutEvalFeedNames() {
Status TrainingSession::SetEvalFeedNames() {
Graph& graph = model_->MainGraph();
// add ratio node to graph input for overriding.
GraphAugmenter::GraphDefs defs{};
for (const auto& node : graph.Nodes()) {
auto it = Dropout_Nodes.find(node.OpType());
if (it != Dropout_Nodes.cend()) {
auto& ratio_name = node.InputDefs()[1]->Name();
dropout_eval_feeds_.insert(ratio_name);
ORT_ENFORCE(model_->MainGraph().GetProducerNode(ratio_name) == nullptr,
"Input: " + ratio_name + " should not have any producer node.");
defs.AddGraphInputs({ratio_name});
for (auto& node : graph.Nodes()) {
auto it = Nodes_Need_Eval_Feeds.find(node.OpType());
if(it != Nodes_Need_Eval_Feeds.cend()) {
// The opset is < 12, add each ratio input to graph inputs for overriding.
// Needs to be removed when TrainableDropout is deprecated.
if(it->compare("TrainableDropout") == 0) {
auto& ratio_name = node.InputDefs()[1]->Name();
dropout_eval_feeds_.insert(ratio_name);
ORT_ENFORCE(model_->MainGraph().GetProducerNode(ratio_name) == nullptr,
"Input: " + ratio_name + " should not have any producer node.");
defs.AddGraphInputs({ratio_name});
}
// Found an opset-12 dropout node, replace initializer name.
else if(node.InputArgCount().size() > 2) {
auto& mode_input = node.MutableInputDefs()[2];
const ONNX_NAMESPACE::TensorProto* mode_initializer = nullptr;
if (!graph.GetInitializedTensor(training_mode_string_, mode_initializer)) {
// training_mode initializer has not been added before, add it here.
// Ideally we want only 1 training_mode initializer to control all relevant nodes.
const ONNX_NAMESPACE::TensorProto* original_mode_initializer = nullptr;
ORT_ENFORCE(graph.GetInitializedTensor(mode_input->Name(), original_mode_initializer) == true,
"Dropout's input: " + mode_input->Name() + " must be an initializer.");
ONNX_NAMESPACE::TensorProto new_mode_initializer(*original_mode_initializer);
new_mode_initializer.set_name(training_mode_string_);
defs.AddInitializers({new_mode_initializer});
}
mode_input = &model_->MainGraph().GetOrCreateNodeArg(training_mode_string_, mode_input->TypeAsProto());
// Set training_mode as graph input if any node that needs eval feed is found,
// it's okay to add it multiple times since it will be de-dup'ed downstream.
defs.AddGraphInputs({training_mode_string_});
}
}
}
ORT_RETURN_IF_ERROR(GraphAugmenter::AugmentGraph(graph, defs));
return DoPostLoadProcessing(*model_);
}

View file

@ -301,12 +301,11 @@ class TrainingSession : public InferenceSession {
* @return The list of feed names.
*/
std::unordered_set<std::string> GetDropoutEvalFeeds() const { return dropout_eval_feeds_; }
/** Override Run function in InferenceSession to inject some training-specific logics **/
using InferenceSession::Run; // For overload resolution.
common::Status Run(const RunOptions& run_options, IOBinding& io_binding) override;
common::Status Run(IOBinding& io_binding) override;
private:
/** Configures the loss function.
The loss function can either be provided externally or built from the provided loss function information.
@ -446,7 +445,7 @@ class TrainingSession : public InferenceSession {
std::unordered_set<std::string> GetStateTensorNames() const;
common::Status SetDropoutEvalFeedNames();
common::Status SetEvalFeedNames();
NameMLValMap GetWeights() const;
@ -479,6 +478,7 @@ class TrainingSession : public InferenceSession {
std::unordered_map<std::string, OptimizerNodeConfig> opt_configs_;
GradientGraphConfiguration gradient_graph_config_;
static const std::string training_mode_string_;
};
} // namespace training
} // namespace onnxruntime

View file

@ -662,14 +662,14 @@ void TrainingRunner::RunWithUpdate(VectorString& feed_names,
#else
ORT_UNUSED_PARAMETER(step);
#endif
RunOptions run_options;
status = session_.Run(
RunOptions(),
pipeline_worker_pool_.worker_states[worker_id].feed_names,
pipeline_worker_pool_.worker_states[worker_id].feeds,
pipeline_worker_pool_.worker_states[worker_id].fetch_names,
&(pipeline_worker_pool_.worker_states[worker_id].fetches));
},
worker_id, step_);
run_options,
pipeline_worker_pool_.worker_states[worker_id].feed_names,
pipeline_worker_pool_.worker_states[worker_id].feeds,
pipeline_worker_pool_.worker_states[worker_id].fetch_names,
&(pipeline_worker_pool_.worker_states[worker_id].fetches));
}, worker_id, step_);
// Wait all workers to finish this round of pipeline parallelism.
// The last batch in a pipeline collects gradient and update the model.
@ -751,6 +751,7 @@ void TrainingRunner::RunWithoutUpdate(VectorString& feed_names,
#endif
RunOptions run_options;
run_options.only_execute_path_to_fetches = true;
run_options.training_mode = true;
auto status = session_.Run(
run_options,
pipeline_worker_pool_.worker_states[worker_id].feed_names,
@ -1095,7 +1096,7 @@ Status TrainingRunner::EndTraining(IDataLoader* data_loader) {
return Status::OK();
}
Status TrainingRunner::Evaluate(InferenceSession& session, IDataLoader& data_loader) {
Status TrainingRunner::Evaluate(TrainingSession& session, IDataLoader& data_loader) {
if (params_.skip_evaluation) {
printf("Skipping evaluation...\n");
return Status::OK();
@ -1139,6 +1140,26 @@ Status TrainingRunner::Evaluate(InferenceSession& session, IDataLoader& data_loa
batch_idx,
feed_names,
feeds);
if (!session.GetDropoutEvalFeeds().empty()) {
float eval_ratio = 0.0f;
for (auto& dropout_ratio : session.GetDropoutEvalFeeds()) {
feed_names.push_back(dropout_ratio);
OrtValue ratio_val;
TrainingUtil::CreateCpuMLScalar(eval_ratio, &ratio_val, input_allocator_);
feeds.push_back(ratio_val);
}
}
const std::string training_mode_string = "training_mode";
auto input_list = session.GetOverridableInitializers().second;
for (auto input : *input_list) {
if(input->Name().compare(training_mode_string) == 0) {
feed_names.push_back("training_mode");
OrtValue mode_val;
TrainingUtil::CreateCpuMLScalar(false, &mode_val, input_allocator_);
feeds.push_back(mode_val);
break;
}
}
PrepareFetchNamesAndFetches(EvaluateStep,
fetch_names,
@ -1159,6 +1180,7 @@ Status TrainingRunner::Evaluate(InferenceSession& session, IDataLoader& data_loa
pipeline_worker_pool_.workers[worker_id] = std::thread([&]() {
RunOptions run_options;
run_options.only_execute_path_to_fetches = true;
run_options.training_mode = false;
status = session.Run(
run_options,
feed_names,

View file

@ -214,8 +214,8 @@ class TrainingRunner {
std::vector<MLValue>& feeds,
size_t& gradient_accumulation_step_count);
Status TrainingLoop(IDataLoader& training_data_loader, IDataLoader* test_data_loader,
const MapStringToString& mapped_dimensions);
Status Evaluate(InferenceSession& session, IDataLoader& data_loader);
const MapStringToString& mapped_dimensions);
Status Evaluate(TrainingSession& session, IDataLoader& data_loader);
Status SaveCheckpoint(const PathString& checkpoint_path);
Status LoadCheckpoint(const PathString& checkpoint_path);

View file

@ -902,7 +902,6 @@ class ORTTrainer():
elif self.current_step % self.gradient_accumulation_steps != 0:
run_options = ort.RunOptions()
run_options.only_execute_path_to_fetches = True
run_options.training_mode = True
output_desc = self.output_desc_with_group_accumulated_gradients
elif self.use_mixed_precision:
has_if_all_finite = True

View file

@ -97,6 +97,7 @@ static std::unique_ptr<TrainingSession> RunTrainingSessionWithChecks(
RunOptions run_options;
run_options.run_log_verbosity_level = so.session_log_verbosity_level;
run_options.run_tag = so.session_logid;
run_options.training_mode = true;
// Create dummy feeds
std::vector<int64_t> image_dims = {1, 784};
@ -317,6 +318,7 @@ static void RunBertTrainingWithChecks(
RunOptions run_options;
run_options.run_log_verbosity_level = so.session_log_verbosity_level;
run_options.run_tag = so.session_logid;
run_options.training_mode = true;
// Creating feeds
int batch_size = 13;
@ -1390,6 +1392,7 @@ TEST(GradientGraphBuilderTest, TrainingSession_WithPipeline) {
sub_sess.run_options.run_log_verbosity_level = sub_sess.so.session_log_verbosity_level;
sub_sess.run_options.run_tag = sub_sess.so.session_logid;
sub_sess.run_options.training_mode = true;
sub_sess.sess = onnxruntime::make_unique<TrainingSession>(sub_sess.so, *env);
ASSERT_STATUS_OK(sub_sess.sess->Load(sub_model_files[sub_id]));

View file

@ -435,6 +435,7 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {
// Now run
RunOptions run_options;
run_options.training_mode = true;
st = session_object.Run(run_options, feeds, output_names, &expected_ort_values);
EXPECT_TRUE(st.IsOK());
@ -466,6 +467,7 @@ TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {
// Now run
RunOptions run_options;
run_options.training_mode = true;
st = session_object.Run(run_options, feeds, output_names, &actual_ort_values);
EXPECT_TRUE(st.IsOK());
@ -562,6 +564,7 @@ TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionCorrectnessTest)
// Now run
RunOptions run_options;
run_options.training_mode = true;
st = session_object.Run(run_options, feeds, output_names, &expected_ort_values);
EXPECT_TRUE(st.IsOK());
}
@ -596,6 +599,7 @@ TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionCorrectnessTest)
// Now run
RunOptions run_options;
run_options.training_mode = true;
st = session_object.Run(run_options, feeds, output_names, &actual_ort_values);
EXPECT_TRUE(st.IsOK());
}

View file

@ -52,6 +52,7 @@ TwoDArray OpFunctionTester::RunFunctionBodyGraphOnCPU() {
RunOptions run_options;
run_options.run_tag = op_;
run_options.run_log_verbosity_level = 1;
run_options.training_mode = true;
std::vector<MLValue> cpu_fetches;
status = cpu_session_object.Run(run_options, feeds, output_names, &cpu_fetches);