Fix evaluation issues (#3538)

* allow switching between eval and training modes dynamically

Co-authored-by: Tixxx <root@525204a066204ea794f942530b05ae7f000000.axlncovkyjne5caro2tmz3zryb.xx.internal.cloudapp.net>
This commit is contained in:
Tixxx 2020-04-28 21:03:37 -07:00 committed by GitHub
parent 939589c265
commit 0638565fe0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 243 additions and 31 deletions

View file

@ -26,6 +26,9 @@ struct OrtRunOptions {
// So it is possible that only some of the nodes are executed.
bool only_execute_path_to_fetches = false;
// Set to 'true' to run in training mode.
bool training_mode = false;
OrtRunOptions() = default;
~OrtRunOptions() = default;

View file

@ -344,20 +344,6 @@ LoopImpl::LoopImpl(OpKernelContextInternal& context,
condition_ = cond_tensor ? *cond_tensor->Data<bool>() : true;
}
template <typename T>
static OrtValue MakeScalarMLValue(const AllocatorPtr& allocator, T value, bool is_1d) {
auto* data_type = DataTypeImpl::GetType<T>();
std::unique_ptr<Tensor> p_tensor = onnxruntime::make_unique<Tensor>(data_type,
is_1d ? TensorShape({1}) : TensorShape({}),
allocator);
*p_tensor->MutableData<T>() = value;
auto ml_tensor = DataTypeImpl::GetType<Tensor>();
return OrtValue{p_tensor.release(), ml_tensor,
ml_tensor->GetDeleteFunc()};
}
Status LoopImpl::Initialize() {
auto status = Status::OK();

View file

@ -13,6 +13,21 @@
namespace onnxruntime {
class Graph;
// Creates a scalar MLValue based on given value and allocator.
template <typename T>
OrtValue MakeScalarMLValue(const AllocatorPtr& allocator, T value, bool is_1d) {
auto* data_type = DataTypeImpl::GetType<T>();
std::unique_ptr<Tensor> p_tensor = onnxruntime::make_unique<Tensor>(data_type,
is_1d ? TensorShape({1}) : TensorShape({}),
allocator);
*p_tensor->MutableData<T>() = value;
auto ml_tensor = DataTypeImpl::GetType<Tensor>();
return OrtValue{p_tensor.release(), ml_tensor,
ml_tensor->GetDeleteFunc()};
}
namespace controlflow {
/** Interface for control flow kernels */

View file

@ -269,8 +269,8 @@ class InferenceSession {
*/
common::Status NewIOBinding(std::unique_ptr<IOBinding>* io_binding) ORT_MUST_USE_RESULT;
common::Status Run(const RunOptions& run_options, IOBinding& io_binding) ORT_MUST_USE_RESULT;
common::Status Run(IOBinding& io_binding) ORT_MUST_USE_RESULT;
virtual common::Status Run(const RunOptions& run_options, IOBinding& io_binding) ORT_MUST_USE_RESULT;
virtual common::Status Run(IOBinding& io_binding) ORT_MUST_USE_RESULT;
/**
* @return pair.first = OK; FAIL otherwise. pair.second is non-NULL when pair.first = OK.

View file

@ -717,7 +717,9 @@ Applies to a particular Run() invocation. Default is 0.)pbdoc")
R"pbdoc(Set to True to terminate any currently executing calls that are using this
RunOptions instance. The individual calls will exit gracefully and return an error status.)pbdoc")
.def_readwrite("only_execute_path_to_fetches", &RunOptions::only_execute_path_to_fetches,
R"pbdoc(Only execute the nodes needed by fetch list)pbdoc");
R"pbdoc(Only execute the nodes needed by fetch list)pbdoc")
.def_readwrite("training_mode", &RunOptions::training_mode,
R"pbdoc(Choose to run in training or inferencing mode)pbdoc");
py::class_<ModelMetadata>(m, "ModelMetadata", R"pbdoc(Pre-defined and custom metadata about the model.
It is usually used to identify the model used to run the prediction and

View file

@ -343,7 +343,7 @@ class TestOrtTrainer(unittest.TestCase):
expected_losses = [
11.02906322479248, 11.094074249267578, 11.00899887084961, 11.06129264831543,
11.029067039489746, 11.040265083312988, 11.046793937683105, 10.993699073791504]
expected_eval_loss = [10.9691801071167]
expected_eval_loss = [10.95898914]
actual_losses, actual_eval_loss = runBertTrainingTest(
gradient_accumulation_steps=1, use_mixed_precision=False, allreduce_post_accumulation=False)
@ -354,7 +354,7 @@ class TestOrtTrainer(unittest.TestCase):
# print('eval_loss actual: ', actual_eval_loss)
# import pdb; pdb.set_trace()
rtol = 1e-03
rtol = 1e-04
assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch")
assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch")
@ -362,7 +362,7 @@ class TestOrtTrainer(unittest.TestCase):
expected_losses = [
11.02906322479248, 11.094074249267578, 11.008995056152344, 11.061283111572266,
11.029059410095215, 11.04024887084961, 11.04680347442627, 10.993708610534668]
expected_eval_loss = [10.969207763671875]
expected_eval_loss = [10.959011]
actual_losses, actual_eval_loss = runBertTrainingTest(
gradient_accumulation_steps=4, use_mixed_precision=False, allreduce_post_accumulation=False)
@ -374,9 +374,9 @@ class TestOrtTrainer(unittest.TestCase):
# print('eval_loss actual: ', actual_eval_loss)
# import pdb; pdb.set_trace()
rtol = 1e-03
assert_allclose(expected_losses, actual_losses, err_msg="loss mismatch")
assert_allclose(expected_eval_loss, actual_eval_loss, err_msg="evaluation loss mismatch")
rtol = 1e-04
assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch")
assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch")
if __name__ == '__main__':
unittest.main(module=__name__, buffer=True)

View file

@ -12,18 +12,18 @@ class TestOrtTrainer(unittest.TestCase):
def testBertTrainingMixedPrecision(self):
expected_losses = [11.0234375, 11.09375, 11.0078125, 11.0625, 11.03125, 11.0390625, 11.046875, 10.9921875]
expected_all_finites = [False, True, True, True, True, True, True, True]
expected_eval_loss = [10.96875]
expected_eval_loss = [10.960938]
actual_losses, actual_all_finites, actual_eval_loss = runBertTrainingTest(
gradient_accumulation_steps=1, use_mixed_precision=True, allreduce_post_accumulation=False, use_simple_model_desc=False)
rtol = 1e-03
rtol = 1e-04
assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch")
assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch")
assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch")
def testBertTrainingMixedPrecisionInternalLossScale(self):
expected_losses = [11.0234375, 11.09375, 11.0078125, 11.0625, 11.03125, 11.0390625, 11.046875, 10.9921875]
expected_eval_loss = [10.96875]
expected_eval_loss = [10.960938]
actual_losses, actual_eval_loss = runBertTrainingTest(
gradient_accumulation_steps=1,
use_mixed_precision=True,
@ -31,18 +31,18 @@ class TestOrtTrainer(unittest.TestCase):
use_simple_model_desc=False,
use_internel_loss_scale=True)
rtol = 1e-03
rtol = 1e-04
assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch")
assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch")
def testBertTrainingGradientAccumulationMixedPrecision(self):
expected_losses = [11.0234375, 11.09375, 11.0078125, 11.0625, 11.03125, 11.0390625, 11.046875, 10.9921875]
expected_all_finites = [False, True]
expected_eval_loss = [10.96875]
expected_eval_loss = [10.960938]
actual_losses, actual_all_finites, actual_eval_loss = runBertTrainingTest(
gradient_accumulation_steps=4, use_mixed_precision=True, allreduce_post_accumulation=False, use_simple_model_desc=False)
rtol = 1e-03
rtol = 1e-04
assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch")
assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch")
assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch")

View file

@ -0,0 +1,76 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import unittest
import pytest
import sys
import copy
from numpy.testing import assert_allclose, assert_array_equal
import onnx
import torch
import torch.nn as nn
import torch.nn.functional as F
from onnxruntime_test_ort_trainer import map_optimizer_attributes, ort_trainer_learning_rate_description
from helper import get_name
import onnxruntime
from onnxruntime_test_training_unittest_utils import process_dropout
from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription, ModelDescription, LossScaler, generate_sample
torch.manual_seed(1)
onnxruntime.set_seed(1)
class TestTrainingDropout(unittest.TestCase):
def testTrainingAndEvalDropout(self):
# Temporarily disable this test.
# The graph below will trigger ORT
# to sort backward graph before forward graph which gives incorrect result.
# TODO Re-enable when that is fixed.
return
class TwoDropoutNet(nn.Module):
def __init__(self, drop_prb_1, drop_prb_2, dim_size):
super(TwoDropoutNet, self).__init__()
self.drop_1 = nn.Dropout(drop_prb_1)
self.drop_2 = nn.Dropout(drop_prb_2)
self.weight_1 = torch.nn.Parameter(torch.zeros(dim_size, dtype=torch.float32))
def forward(self, x):
x = x + self.weight_1
x = self.drop_1(x)
x = self.drop_2(x)
output = x
return output[0]
dim_size = 3
device = torch.device("cuda", 0)
# This will drop all values, therefore expecting all 0 in output tensor
model = TwoDropoutNet(0.999, 0.999, dim_size)
input_desc = IODescription('input', [dim_size], torch.float32)
output_desc = IODescription('output', [], torch.float32)
model_desc = ModelDescription([input_desc], [output_desc])
lr_desc = ort_trainer_learning_rate_description()
model = ORTTrainer(model, None, model_desc, "LambOptimizer",
map_optimizer_attributes,
lr_desc,
device,
postprocess_model=process_dropout,
world_rank=0, world_size=1)
input = torch.ones(dim_size, dtype=torch.float32).to(device)
expected_training_output = [0.0]
expected_eval_output = [1.0]
learning_rate = torch.tensor([1.0000000e+00]).to(device)
input_args=[input, learning_rate]
train_output = model.train_step(*input_args)
rtol = 1e-04
assert_allclose(expected_training_output, train_output.item(), rtol=rtol, err_msg="dropout training loss mismatch")
eval_output = model.eval_step(input)
assert_allclose(expected_eval_output, eval_output.item(), rtol=rtol, err_msg="dropout eval loss mismatch")
# Do another train step to make sure it's using original ratios
train_output_2 = model.train_step(*input_args)
assert_allclose(expected_training_output, train_output_2.item(), rtol=rtol, err_msg="dropout training loss 2 mismatch")
if __name__ == '__main__':
unittest.main(module=__name__, buffer=True)

View file

@ -0,0 +1,49 @@
import sys
import numpy as np
from onnx import numpy_helper
def get_node_index(model, node):
i = 0
while i < len(model.graph.node):
if model.graph.node[i] == node:
break
i += 1
return i if i < len(model.graph.node) else None
def add_const(model, name, output, t_value = None, f_value = None):
const_node = model.graph.node.add()
const_node.op_type = 'Constant'
const_node.name = name
const_node.output.extend([output])
attr = const_node.attribute.add()
attr.name = 'value'
if t_value is not None:
attr.type = 4
attr.t.CopyFrom(t_value)
else:
attr.type = 1
attr.f = f_value
return const_node
def process_dropout(model):
dropouts = []
index = 0
for node in model.graph.node:
if node.op_type == 'Dropout':
new_dropout = model.graph.node.add()
new_dropout.op_type = 'TrainableDropout'
new_dropout.name = 'TrainableDropout_%d' % index
#make ratio node
ratio = np.asarray([node.attribute[0].f], dtype=np.float32)
print(ratio.shape)
ratio_value = numpy_helper.from_array(ratio)
ratio_node = add_const(model, 'dropout_node_ratio_%d' % index, 'dropout_node_ratio_%d' % index, t_value=ratio_value)
print (ratio_node)
new_dropout.input.extend([node.input[0], ratio_node.output[0]])
new_dropout.output.extend(node.output)
dropouts.append(get_node_index(model, node))
index += 1
dropouts.sort(reverse=True)
for d in dropouts:
del model.graph.node[d]
model.opset_import[0].version = 10

View file

@ -5,6 +5,8 @@
#include "core/framework/data_transfer_utils.h"
#include "core/graph/model.h"
#include "core/session/IOBinding.h"
#include "core/providers/cpu/controlflow/utils.h"
#include "orttraining/core/graph/loss_function_builder.h"
#include "orttraining/core/graph/optimizer_builder.h"
#include "orttraining/core/framework/checkpointing.h"
@ -228,6 +230,9 @@ Status TrainingSession::ConfigureForTraining(
}
}
// Set eval feed names for Dropout ratio.
ORT_RETURN_IF_ERROR(SetDropoutEvalFeedNames());
// add Tensorboard
if (config.tensorboard_config.has_value()) {
const auto& tensorboard_config = config.tensorboard_config.value();
@ -676,6 +681,57 @@ bool TrainingSession::IsGraphOutputFp32Node(const std::string& output_name) cons
return IsFP32Node(output_producer_node);
}
common::Status TrainingSession::Run(const RunOptions& run_options, IOBinding& io_binding) {
// Override initializers in eval mode.
if (!run_options.training_mode) {
// override all dropout raiots to 0
for (auto& drop_ratio : dropout_eval_feeds_) {
OrtValue feed_value;
// We allocate on CPU first, copy will be taken care off downstream.
auto cpu_allocator = session_state_->GetExecutionProviders()
.Get(onnxruntime::kCpuExecutionProvider)
->GetAllocator(0, OrtMemTypeDefault);
feed_value = onnxruntime::MakeScalarMLValue<float>(cpu_allocator, 0.f, true /*is_1d*/);
// Bind new feed to graph input.
ORT_RETURN_IF_ERROR(io_binding.BindInput(drop_ratio, feed_value));
}
}
// Call Run in inferenceSession
return InferenceSession::Run(run_options, io_binding);
}
common::Status TrainingSession::Run(IOBinding& io_binding) {
RunOptions run_options;
// Set training_mode to true in training session by default.
run_options.training_mode = true;
return Run(run_options, io_binding);
}
static const std::unordered_set<std::string> Dropout_Nodes = {
"TrainableDropout",
};
// TODO remove this once ONNX properly supports training_mode input.
Status TrainingSession::SetDropoutEvalFeedNames() {
Graph& graph = model_->MainGraph();
// add ratio node to graph input for overriding.
GraphAugmenter::GraphDefs defs{};
for (const auto& node : graph.Nodes()) {
auto it = Dropout_Nodes.find(node.OpType());
if(it != Dropout_Nodes.cend()) {
auto& ratio_name = node.InputDefs()[1]->Name();
dropout_eval_feeds_.insert(ratio_name);
ORT_ENFORCE(model_->MainGraph().GetProducerNode(ratio_name) == nullptr,
"Input: " + ratio_name + " should not have any producer node.");
defs.AddGraphInputs({ratio_name});
}
}
ORT_RETURN_IF_ERROR(GraphAugmenter::AugmentGraph(graph, defs));
return DoPostLoadProcessing(*model_);
}
Status TrainingSession::SetStateTensors(const NameMLValMap& state_tensors, bool strict) {
ORT_RETURN_IF_NOT(IsInitialized(), "Can't update initializers before session has been initialized.");

View file

@ -212,9 +212,26 @@ class TrainingSession : public InferenceSession {
/** Gets the model location. */
const PathString& GetModelLocation() const { return model_location_; }
/** Checks to be see if given graph output is produced by an fp32-only node. */
/**
* Checks to be see if given graph output is produced by an fp32-only node.
* @param The name of the output.
* @return Whether output is from fp32-only node or not.
*/
bool IsGraphOutputFp32Node(const std::string& output_name) const;
/**
* Gets the list of Dropout ratio inputs that will be used as feeds in eval mode,
* since each ratio input has its own name.
* @return The list of feed names.
*/
std::unordered_set<std::string> GetDropoutEvalFeeds() const { return dropout_eval_feeds_; }
/** Override Run function in InferenceSession to inject some training-specific logics **/
using InferenceSession::Run; // For overload resolution.
common::Status Run(const RunOptions& run_options, IOBinding& io_binding) override;
common::Status Run(IOBinding& io_binding) override;
private:
/** Configures the loss function.
The loss function can either be provided externally or built from the provided loss function information.
@ -312,6 +329,8 @@ class TrainingSession : public InferenceSession {
std::unordered_set<std::string> GetStateTensorNames() const;
common::Status SetDropoutEvalFeedNames();
NameMLValMap GetWeights() const;
static bool IsImmutableWeight(const ImmutableWeights& immutable_weights,
@ -335,6 +354,7 @@ class TrainingSession : public InferenceSession {
std::unique_ptr<ILossFunction> loss_graph_builder_;
optional<LossFunctionInfo> loss_function_info_;
std::unordered_set<std::string> dropout_eval_feeds_;
OptimizerGraphConfig opt_graph_config_;
std::unordered_map<std::string, OptimizerNodeConfig> opt_configs_;
};

View file

@ -90,7 +90,7 @@ def ort_training_session_run_helper(session, iobinding, inputs, input_descs, out
device_index = input_get_device_index(input)
iobinding.bind_input(input_desc.name_, input.device.type, device_index, dtype_torch_to_numpy(input.dtype),
list(input.size()), input.data_ptr())
output_descs_resolved = resolve_symbolic_dimensions(inputs, input_descs, output_descs)
torch_outputs = {}
for output_desc in output_descs_resolved:
@ -787,6 +787,7 @@ class ORTTrainer():
elif self.current_step % self.gradient_accumulation_steps != 0:
run_options = ort.RunOptions()
run_options.only_execute_path_to_fetches = True
run_options.training_mode = True
output_desc = self.output_desc_with_group_accumulated_gradients
elif self.use_mixed_precision:
has_if_all_finite = True
@ -864,6 +865,7 @@ class ORTTrainer():
run_options = ort.RunOptions()
run_options.only_execute_path_to_fetches = True
run_options.training_mode = False
session_run_results = ort_training_session_run_helper(self.session, self.eval_io_binding, input,
input_desc,

View file

@ -1095,6 +1095,9 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs,
run_subprocess(
[sys.executable, 'onnxruntime_test_ort_trainer.py'],
cwd=cwd, dll_path=dll_path)
run_subprocess(
[sys.executable, 'onnxruntime_test_training_unit_tests.py'],
cwd=cwd, dll_path=dll_path)
# run additional frontend tests for orttraining-linux-gpu-frontend_test_ci-pipeline
if args.enable_training_python_frontend_e2e_tests: