From 0638565fe0d6149d1512ec3e63512e6ea13743a3 Mon Sep 17 00:00:00 2001 From: Tixxx Date: Tue, 28 Apr 2020 21:03:37 -0700 Subject: [PATCH] Fix evaluation issues (#3538) * allow switching between eval and training modes dynamically Co-authored-by: Tixxx --- .../onnxruntime/core/framework/run_options.h | 3 + .../core/providers/cpu/controlflow/loop.cc | 14 ---- .../core/providers/cpu/controlflow/utils.h | 15 ++++ onnxruntime/core/session/inference_session.h | 4 +- .../python/onnxruntime_pybind_state.cc | 4 +- .../python/onnxruntime_test_ort_trainer.py | 12 +-- ...e_test_ort_trainer_with_mixed_precision.py | 12 +-- .../onnxruntime_test_training_unit_tests.py | 76 +++++++++++++++++++ ...nnxruntime_test_training_unittest_utils.py | 49 ++++++++++++ .../core/session/training_session.cc | 56 ++++++++++++++ .../core/session/training_session.h | 22 +++++- orttraining/orttraining/python/ort_trainer.py | 4 +- tools/ci_build/build.py | 3 + 13 files changed, 243 insertions(+), 31 deletions(-) create mode 100644 onnxruntime/test/python/onnxruntime_test_training_unit_tests.py create mode 100644 onnxruntime/test/python/onnxruntime_test_training_unittest_utils.py diff --git a/include/onnxruntime/core/framework/run_options.h b/include/onnxruntime/core/framework/run_options.h index 25cb1c29d5..50d3e39e03 100644 --- a/include/onnxruntime/core/framework/run_options.h +++ b/include/onnxruntime/core/framework/run_options.h @@ -26,6 +26,9 @@ struct OrtRunOptions { // So it is possible that only some of the nodes are executed. bool only_execute_path_to_fetches = false; + // Set to 'true' to run in training mode. + bool training_mode = false; + OrtRunOptions() = default; ~OrtRunOptions() = default; diff --git a/onnxruntime/core/providers/cpu/controlflow/loop.cc b/onnxruntime/core/providers/cpu/controlflow/loop.cc index 8da807bc9f..39aa97a5e9 100644 --- a/onnxruntime/core/providers/cpu/controlflow/loop.cc +++ b/onnxruntime/core/providers/cpu/controlflow/loop.cc @@ -344,20 +344,6 @@ LoopImpl::LoopImpl(OpKernelContextInternal& context, condition_ = cond_tensor ? *cond_tensor->Data() : true; } -template -static OrtValue MakeScalarMLValue(const AllocatorPtr& allocator, T value, bool is_1d) { - auto* data_type = DataTypeImpl::GetType(); - std::unique_ptr p_tensor = onnxruntime::make_unique(data_type, - is_1d ? TensorShape({1}) : TensorShape({}), - allocator); - - *p_tensor->MutableData() = value; - - auto ml_tensor = DataTypeImpl::GetType(); - return OrtValue{p_tensor.release(), ml_tensor, - ml_tensor->GetDeleteFunc()}; -} - Status LoopImpl::Initialize() { auto status = Status::OK(); diff --git a/onnxruntime/core/providers/cpu/controlflow/utils.h b/onnxruntime/core/providers/cpu/controlflow/utils.h index 177463ae22..7c0732f56c 100644 --- a/onnxruntime/core/providers/cpu/controlflow/utils.h +++ b/onnxruntime/core/providers/cpu/controlflow/utils.h @@ -13,6 +13,21 @@ namespace onnxruntime { class Graph; +// Creates a scalar MLValue based on given value and allocator. +template +OrtValue MakeScalarMLValue(const AllocatorPtr& allocator, T value, bool is_1d) { + auto* data_type = DataTypeImpl::GetType(); + std::unique_ptr p_tensor = onnxruntime::make_unique(data_type, + is_1d ? TensorShape({1}) : TensorShape({}), + allocator); + + *p_tensor->MutableData() = value; + + auto ml_tensor = DataTypeImpl::GetType(); + return OrtValue{p_tensor.release(), ml_tensor, + ml_tensor->GetDeleteFunc()}; +} + namespace controlflow { /** Interface for control flow kernels */ diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index d65f185b04..208e081485 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -269,8 +269,8 @@ class InferenceSession { */ common::Status NewIOBinding(std::unique_ptr* io_binding) ORT_MUST_USE_RESULT; - common::Status Run(const RunOptions& run_options, IOBinding& io_binding) ORT_MUST_USE_RESULT; - common::Status Run(IOBinding& io_binding) ORT_MUST_USE_RESULT; + virtual common::Status Run(const RunOptions& run_options, IOBinding& io_binding) ORT_MUST_USE_RESULT; + virtual common::Status Run(IOBinding& io_binding) ORT_MUST_USE_RESULT; /** * @return pair.first = OK; FAIL otherwise. pair.second is non-NULL when pair.first = OK. diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index b7b0bc0b4f..a67cd8aa76 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -717,7 +717,9 @@ Applies to a particular Run() invocation. Default is 0.)pbdoc") R"pbdoc(Set to True to terminate any currently executing calls that are using this RunOptions instance. The individual calls will exit gracefully and return an error status.)pbdoc") .def_readwrite("only_execute_path_to_fetches", &RunOptions::only_execute_path_to_fetches, - R"pbdoc(Only execute the nodes needed by fetch list)pbdoc"); + R"pbdoc(Only execute the nodes needed by fetch list)pbdoc") + .def_readwrite("training_mode", &RunOptions::training_mode, + R"pbdoc(Choose to run in training or inferencing mode)pbdoc"); py::class_(m, "ModelMetadata", R"pbdoc(Pre-defined and custom metadata about the model. It is usually used to identify the model used to run the prediction and diff --git a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py index 00a73a0b71..4660e7b671 100644 --- a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py +++ b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py @@ -343,7 +343,7 @@ class TestOrtTrainer(unittest.TestCase): expected_losses = [ 11.02906322479248, 11.094074249267578, 11.00899887084961, 11.06129264831543, 11.029067039489746, 11.040265083312988, 11.046793937683105, 10.993699073791504] - expected_eval_loss = [10.9691801071167] + expected_eval_loss = [10.95898914] actual_losses, actual_eval_loss = runBertTrainingTest( gradient_accumulation_steps=1, use_mixed_precision=False, allreduce_post_accumulation=False) @@ -354,7 +354,7 @@ class TestOrtTrainer(unittest.TestCase): # print('eval_loss actual: ', actual_eval_loss) # import pdb; pdb.set_trace() - rtol = 1e-03 + rtol = 1e-04 assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch") @@ -362,7 +362,7 @@ class TestOrtTrainer(unittest.TestCase): expected_losses = [ 11.02906322479248, 11.094074249267578, 11.008995056152344, 11.061283111572266, 11.029059410095215, 11.04024887084961, 11.04680347442627, 10.993708610534668] - expected_eval_loss = [10.969207763671875] + expected_eval_loss = [10.959011] actual_losses, actual_eval_loss = runBertTrainingTest( gradient_accumulation_steps=4, use_mixed_precision=False, allreduce_post_accumulation=False) @@ -374,9 +374,9 @@ class TestOrtTrainer(unittest.TestCase): # print('eval_loss actual: ', actual_eval_loss) # import pdb; pdb.set_trace() - rtol = 1e-03 - assert_allclose(expected_losses, actual_losses, err_msg="loss mismatch") - assert_allclose(expected_eval_loss, actual_eval_loss, err_msg="evaluation loss mismatch") + rtol = 1e-04 + assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") + assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch") if __name__ == '__main__': unittest.main(module=__name__, buffer=True) diff --git a/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py b/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py index 7a742cac01..81127d12a7 100644 --- a/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py +++ b/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py @@ -12,18 +12,18 @@ class TestOrtTrainer(unittest.TestCase): def testBertTrainingMixedPrecision(self): expected_losses = [11.0234375, 11.09375, 11.0078125, 11.0625, 11.03125, 11.0390625, 11.046875, 10.9921875] expected_all_finites = [False, True, True, True, True, True, True, True] - expected_eval_loss = [10.96875] + expected_eval_loss = [10.960938] actual_losses, actual_all_finites, actual_eval_loss = runBertTrainingTest( gradient_accumulation_steps=1, use_mixed_precision=True, allreduce_post_accumulation=False, use_simple_model_desc=False) - rtol = 1e-03 + rtol = 1e-04 assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch") assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch") def testBertTrainingMixedPrecisionInternalLossScale(self): expected_losses = [11.0234375, 11.09375, 11.0078125, 11.0625, 11.03125, 11.0390625, 11.046875, 10.9921875] - expected_eval_loss = [10.96875] + expected_eval_loss = [10.960938] actual_losses, actual_eval_loss = runBertTrainingTest( gradient_accumulation_steps=1, use_mixed_precision=True, @@ -31,18 +31,18 @@ class TestOrtTrainer(unittest.TestCase): use_simple_model_desc=False, use_internel_loss_scale=True) - rtol = 1e-03 + rtol = 1e-04 assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch") def testBertTrainingGradientAccumulationMixedPrecision(self): expected_losses = [11.0234375, 11.09375, 11.0078125, 11.0625, 11.03125, 11.0390625, 11.046875, 10.9921875] expected_all_finites = [False, True] - expected_eval_loss = [10.96875] + expected_eval_loss = [10.960938] actual_losses, actual_all_finites, actual_eval_loss = runBertTrainingTest( gradient_accumulation_steps=4, use_mixed_precision=True, allreduce_post_accumulation=False, use_simple_model_desc=False) - rtol = 1e-03 + rtol = 1e-04 assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch") assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch") diff --git a/onnxruntime/test/python/onnxruntime_test_training_unit_tests.py b/onnxruntime/test/python/onnxruntime_test_training_unit_tests.py new file mode 100644 index 0000000000..5511ce30b8 --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_training_unit_tests.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import unittest +import pytest +import sys +import copy +from numpy.testing import assert_allclose, assert_array_equal + +import onnx +import torch +import torch.nn as nn +import torch.nn.functional as F + +from onnxruntime_test_ort_trainer import map_optimizer_attributes, ort_trainer_learning_rate_description +from helper import get_name +import onnxruntime +from onnxruntime_test_training_unittest_utils import process_dropout +from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription, ModelDescription, LossScaler, generate_sample + +torch.manual_seed(1) +onnxruntime.set_seed(1) + +class TestTrainingDropout(unittest.TestCase): + def testTrainingAndEvalDropout(self): + # Temporarily disable this test. + # The graph below will trigger ORT + # to sort backward graph before forward graph which gives incorrect result. + # TODO Re-enable when that is fixed. + return + class TwoDropoutNet(nn.Module): + def __init__(self, drop_prb_1, drop_prb_2, dim_size): + super(TwoDropoutNet, self).__init__() + self.drop_1 = nn.Dropout(drop_prb_1) + self.drop_2 = nn.Dropout(drop_prb_2) + self.weight_1 = torch.nn.Parameter(torch.zeros(dim_size, dtype=torch.float32)) + def forward(self, x): + x = x + self.weight_1 + x = self.drop_1(x) + x = self.drop_2(x) + output = x + return output[0] + dim_size = 3 + device = torch.device("cuda", 0) + # This will drop all values, therefore expecting all 0 in output tensor + model = TwoDropoutNet(0.999, 0.999, dim_size) + input_desc = IODescription('input', [dim_size], torch.float32) + output_desc = IODescription('output', [], torch.float32) + model_desc = ModelDescription([input_desc], [output_desc]) + lr_desc = ort_trainer_learning_rate_description() + model = ORTTrainer(model, None, model_desc, "LambOptimizer", + map_optimizer_attributes, + lr_desc, + device, + postprocess_model=process_dropout, + world_rank=0, world_size=1) + input = torch.ones(dim_size, dtype=torch.float32).to(device) + expected_training_output = [0.0] + expected_eval_output = [1.0] + learning_rate = torch.tensor([1.0000000e+00]).to(device) + input_args=[input, learning_rate] + train_output = model.train_step(*input_args) + + rtol = 1e-04 + assert_allclose(expected_training_output, train_output.item(), rtol=rtol, err_msg="dropout training loss mismatch") + + eval_output = model.eval_step(input) + assert_allclose(expected_eval_output, eval_output.item(), rtol=rtol, err_msg="dropout eval loss mismatch") + + # Do another train step to make sure it's using original ratios + train_output_2 = model.train_step(*input_args) + assert_allclose(expected_training_output, train_output_2.item(), rtol=rtol, err_msg="dropout training loss 2 mismatch") + +if __name__ == '__main__': + unittest.main(module=__name__, buffer=True) + diff --git a/onnxruntime/test/python/onnxruntime_test_training_unittest_utils.py b/onnxruntime/test/python/onnxruntime_test_training_unittest_utils.py new file mode 100644 index 0000000000..566df444da --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_training_unittest_utils.py @@ -0,0 +1,49 @@ +import sys +import numpy as np +from onnx import numpy_helper + +def get_node_index(model, node): + i = 0 + while i < len(model.graph.node): + if model.graph.node[i] == node: + break + i += 1 + return i if i < len(model.graph.node) else None + +def add_const(model, name, output, t_value = None, f_value = None): + const_node = model.graph.node.add() + const_node.op_type = 'Constant' + const_node.name = name + const_node.output.extend([output]) + attr = const_node.attribute.add() + attr.name = 'value' + if t_value is not None: + attr.type = 4 + attr.t.CopyFrom(t_value) + else: + attr.type = 1 + attr.f = f_value + return const_node + +def process_dropout(model): + dropouts = [] + index = 0 + for node in model.graph.node: + if node.op_type == 'Dropout': + new_dropout = model.graph.node.add() + new_dropout.op_type = 'TrainableDropout' + new_dropout.name = 'TrainableDropout_%d' % index + #make ratio node + ratio = np.asarray([node.attribute[0].f], dtype=np.float32) + print(ratio.shape) + ratio_value = numpy_helper.from_array(ratio) + ratio_node = add_const(model, 'dropout_node_ratio_%d' % index, 'dropout_node_ratio_%d' % index, t_value=ratio_value) + print (ratio_node) + new_dropout.input.extend([node.input[0], ratio_node.output[0]]) + new_dropout.output.extend(node.output) + dropouts.append(get_node_index(model, node)) + index += 1 + dropouts.sort(reverse=True) + for d in dropouts: + del model.graph.node[d] + model.opset_import[0].version = 10 diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index 117ffcf262..2388f79016 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -5,6 +5,8 @@ #include "core/framework/data_transfer_utils.h" #include "core/graph/model.h" +#include "core/session/IOBinding.h" +#include "core/providers/cpu/controlflow/utils.h" #include "orttraining/core/graph/loss_function_builder.h" #include "orttraining/core/graph/optimizer_builder.h" #include "orttraining/core/framework/checkpointing.h" @@ -228,6 +230,9 @@ Status TrainingSession::ConfigureForTraining( } } + // Set eval feed names for Dropout ratio. + ORT_RETURN_IF_ERROR(SetDropoutEvalFeedNames()); + // add Tensorboard if (config.tensorboard_config.has_value()) { const auto& tensorboard_config = config.tensorboard_config.value(); @@ -676,6 +681,57 @@ bool TrainingSession::IsGraphOutputFp32Node(const std::string& output_name) cons return IsFP32Node(output_producer_node); } +common::Status TrainingSession::Run(const RunOptions& run_options, IOBinding& io_binding) { + // Override initializers in eval mode. + if (!run_options.training_mode) { + // override all dropout raiots to 0 + for (auto& drop_ratio : dropout_eval_feeds_) { + OrtValue feed_value; + // We allocate on CPU first, copy will be taken care off downstream. + auto cpu_allocator = session_state_->GetExecutionProviders() + .Get(onnxruntime::kCpuExecutionProvider) + ->GetAllocator(0, OrtMemTypeDefault); + feed_value = onnxruntime::MakeScalarMLValue(cpu_allocator, 0.f, true /*is_1d*/); + // Bind new feed to graph input. + ORT_RETURN_IF_ERROR(io_binding.BindInput(drop_ratio, feed_value)); + } + } + + // Call Run in inferenceSession + return InferenceSession::Run(run_options, io_binding); +} + +common::Status TrainingSession::Run(IOBinding& io_binding) { + RunOptions run_options; + // Set training_mode to true in training session by default. + run_options.training_mode = true; + return Run(run_options, io_binding); +} + +static const std::unordered_set Dropout_Nodes = { + "TrainableDropout", +}; +// TODO remove this once ONNX properly supports training_mode input. +Status TrainingSession::SetDropoutEvalFeedNames() { + Graph& graph = model_->MainGraph(); + + // add ratio node to graph input for overriding. + GraphAugmenter::GraphDefs defs{}; + + for (const auto& node : graph.Nodes()) { + auto it = Dropout_Nodes.find(node.OpType()); + if(it != Dropout_Nodes.cend()) { + auto& ratio_name = node.InputDefs()[1]->Name(); + dropout_eval_feeds_.insert(ratio_name); + ORT_ENFORCE(model_->MainGraph().GetProducerNode(ratio_name) == nullptr, + "Input: " + ratio_name + " should not have any producer node."); + defs.AddGraphInputs({ratio_name}); + } + } + ORT_RETURN_IF_ERROR(GraphAugmenter::AugmentGraph(graph, defs)); + return DoPostLoadProcessing(*model_); +} + Status TrainingSession::SetStateTensors(const NameMLValMap& state_tensors, bool strict) { ORT_RETURN_IF_NOT(IsInitialized(), "Can't update initializers before session has been initialized."); diff --git a/orttraining/orttraining/core/session/training_session.h b/orttraining/orttraining/core/session/training_session.h index 91725cadc1..1731c696cd 100644 --- a/orttraining/orttraining/core/session/training_session.h +++ b/orttraining/orttraining/core/session/training_session.h @@ -212,9 +212,26 @@ class TrainingSession : public InferenceSession { /** Gets the model location. */ const PathString& GetModelLocation() const { return model_location_; } - /** Checks to be see if given graph output is produced by an fp32-only node. */ + /** + * Checks to be see if given graph output is produced by an fp32-only node. + * @param The name of the output. + * @return Whether output is from fp32-only node or not. + */ bool IsGraphOutputFp32Node(const std::string& output_name) const; + /** + * Gets the list of Dropout ratio inputs that will be used as feeds in eval mode, + * since each ratio input has its own name. + * @return The list of feed names. + */ + std::unordered_set GetDropoutEvalFeeds() const { return dropout_eval_feeds_; } + + /** Override Run function in InferenceSession to inject some training-specific logics **/ + using InferenceSession::Run; // For overload resolution. + common::Status Run(const RunOptions& run_options, IOBinding& io_binding) override; + + common::Status Run(IOBinding& io_binding) override; + private: /** Configures the loss function. The loss function can either be provided externally or built from the provided loss function information. @@ -312,6 +329,8 @@ class TrainingSession : public InferenceSession { std::unordered_set GetStateTensorNames() const; + common::Status SetDropoutEvalFeedNames(); + NameMLValMap GetWeights() const; static bool IsImmutableWeight(const ImmutableWeights& immutable_weights, @@ -335,6 +354,7 @@ class TrainingSession : public InferenceSession { std::unique_ptr loss_graph_builder_; optional loss_function_info_; + std::unordered_set dropout_eval_feeds_; OptimizerGraphConfig opt_graph_config_; std::unordered_map opt_configs_; }; diff --git a/orttraining/orttraining/python/ort_trainer.py b/orttraining/orttraining/python/ort_trainer.py index c38b92df83..26079cff3a 100644 --- a/orttraining/orttraining/python/ort_trainer.py +++ b/orttraining/orttraining/python/ort_trainer.py @@ -90,7 +90,7 @@ def ort_training_session_run_helper(session, iobinding, inputs, input_descs, out device_index = input_get_device_index(input) iobinding.bind_input(input_desc.name_, input.device.type, device_index, dtype_torch_to_numpy(input.dtype), list(input.size()), input.data_ptr()) - + output_descs_resolved = resolve_symbolic_dimensions(inputs, input_descs, output_descs) torch_outputs = {} for output_desc in output_descs_resolved: @@ -787,6 +787,7 @@ class ORTTrainer(): elif self.current_step % self.gradient_accumulation_steps != 0: run_options = ort.RunOptions() run_options.only_execute_path_to_fetches = True + run_options.training_mode = True output_desc = self.output_desc_with_group_accumulated_gradients elif self.use_mixed_precision: has_if_all_finite = True @@ -864,6 +865,7 @@ class ORTTrainer(): run_options = ort.RunOptions() run_options.only_execute_path_to_fetches = True + run_options.training_mode = False session_run_results = ort_training_session_run_helper(self.session, self.eval_io_binding, input, input_desc, diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index b1833db77d..6e54e26791 100755 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1095,6 +1095,9 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs, run_subprocess( [sys.executable, 'onnxruntime_test_ort_trainer.py'], cwd=cwd, dll_path=dll_path) + run_subprocess( + [sys.executable, 'onnxruntime_test_training_unit_tests.py'], + cwd=cwd, dll_path=dll_path) # run additional frontend tests for orttraining-linux-gpu-frontend_test_ci-pipeline if args.enable_training_python_frontend_e2e_tests: