Expose recompute configs to the frontend (#5318)

* Expose recompute configs to the frontend

* Add frontend test

* Ensure recompute graph transformer is only applied once

Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
Sherlock 2020-10-02 09:49:47 -07:00 committed by GitHub
parent e33de20861
commit e71668f92c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 262 additions and 51 deletions

View file

@ -39,6 +39,8 @@ class GraphTransformer {
*/
common::Status Apply(Graph& graph, bool& modified, const logging::Logger& logger) const;
virtual bool ShouldOnlyApplyOnce() const { return false; }
protected:
/** Helper method to call ApplyImpl on any subgraphs in the Node. */
common::Status Recurse(Node& node, bool& modified, int graph_level, const logging::Logger& logger) const {

View file

@ -12,7 +12,7 @@ __author__ = "Microsoft"
from onnxruntime.capi._pybind_state import get_all_providers, get_available_providers, get_device, set_seed, \
RunOptions, SessionOptions, set_default_logger_severity, NodeArg, ModelMetadata, GraphOptimizationLevel, \
ExecutionMode, OrtDevice, SessionIOBinding
ExecutionMode, ExecutionOrder, OrtDevice, SessionIOBinding
try:
from onnxruntime.capi._pybind_state import set_cuda_mem_limit, set_cuda_device_id

View file

@ -28,6 +28,9 @@ common::Status GraphTransformerManager::ApplyTransformers(Graph& graph, Transfor
for (unsigned step = 0; step < steps_; ++step) {
bool graph_changed = false;
for (const auto& transformer : transformers->second) {
if (step > 0 && transformer->ShouldOnlyApplyOnce())
continue;
bool modified = false;
ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, logger));
graph_changed = graph_changed || modified;

View file

@ -937,6 +937,10 @@ void addObjectMethods(py::module& m, Environment& env) {
.value("ORT_SEQUENTIAL", ExecutionMode::ORT_SEQUENTIAL)
.value("ORT_PARALLEL", ExecutionMode::ORT_PARALLEL);
py::enum_<ExecutionOrder>(m, "ExecutionOrder")
.value("DEFAULT", ExecutionOrder::DEFAULT)
.value("PRIORITY_BASED", ExecutionOrder::PRIORITY_BASED);
py::class_<OrtDevice> device(m, "OrtDevice", R"pbdoc(ONNXRuntime device informaion.)pbdoc");
device.def(py::init<OrtDevice::DeviceType, OrtDevice::MemoryType, OrtDevice::DeviceId>())
.def("device_id", &OrtDevice::Id, R"pbdoc(Device Id.)pbdoc")
@ -1089,7 +1093,7 @@ Applies to session load, initialization, etc. Default is 0.)pbdoc")
R"pbdoc(Sets the number of threads used to parallelize the execution of the graph (across nodes). Default is 0 to let onnxruntime choose.)pbdoc")
.def_readwrite("execution_mode", &PySessionOptions::execution_mode,
R"pbdoc(Sets the execution mode. Default is sequential.)pbdoc")
.def_readwrite("execution_order", &SessionOptions::execution_order,
.def_readwrite("execution_order", &PySessionOptions::execution_order,
R"pbdoc(Sets the execution order. Default is basic topological order.)pbdoc")
.def_property(
"graph_optimization_level",

View file

@ -74,12 +74,6 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
rule_transformer->Register(make_unique<CastElimination>());
rule_transformer->Register(make_unique<NonZeroShapeSetter>());
rule_transformer->Register(make_unique<InsertSoftmaxCrossEntropyLossOutput>());
if (config.gelu_recompute) {
rule_transformer->Register(make_unique<GeluRecompute>());
}
if (config.attn_dropout_recompute) {
rule_transformer->Register(make_unique<AttentionDropoutRecompute>());
}
transformers.emplace_back(onnxruntime::make_unique<GeluFusion>(compatible_eps));
transformers.emplace_back(onnxruntime::make_unique<LayerNormFusion>(compatible_eps));
@ -106,8 +100,15 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
}
transformers.emplace_back(onnxruntime::make_unique<ComputationReductionTransformer>(compatible_eps));
if (config.gelu_recompute) {
transformers.emplace_back(onnxruntime::make_unique<GeluRecompute>());
}
if (config.attn_dropout_recompute) {
transformers.emplace_back(onnxruntime::make_unique<AttentionDropoutRecompute>());
}
if (config.transformer_layer_recompute) {
transformers.emplace_back(onnxruntime::make_unique<TransformerLayerRecompute>(compatible_eps));
transformers.emplace_back(onnxruntime::make_unique<TransformerLayerRecompute>(
config.number_recompute_layers, compatible_eps));
}
} break;

View file

@ -10,7 +10,12 @@ using namespace ONNX_NAMESPACE;
namespace onnxruntime {
bool GeluRecompute::SatisfyCondition(const Graph& /*graph*/, const Node& node, const logging::Logger& /*logger*/) const {
bool GeluRecompute::SatisfyCondition(const Node& node) const {
static const std::unordered_set<std::string> target_optypes = {"Gelu", "FastGelu", "BiasGelu"};
if (target_optypes.find(node.OpType()) == target_optypes.end()) {
return false;
}
const auto next_node = node.OutputNodesBegin();
if (next_node != node.OutputNodesEnd() && next_node->OpType() == "MatMul") {
return true;
@ -18,27 +23,42 @@ bool GeluRecompute::SatisfyCondition(const Graph& /*graph*/, const Node& node, c
return false;
}
Status GeluRecompute::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& /*logger*/) const {
const auto& output = node.OutputDefs()[0];
Status GeluRecompute::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& /*logger*/) const {
GraphViewer graph_viewer(graph);
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()),
output->TypeAsProto());
for (NodeIndex i : order) {
Node& node = *graph.GetNode(i);
Node& recompute_node = graph.AddNode(node.Name() + "_recompute",
node.OpType(),
"Recompute of " + node.Name(),
{node.MutableInputDefs()[0]},
{&recomputed_output},
&node.GetAttributes(),
node.Domain());
if (!SatisfyCondition(node)) {
continue;
}
recompute_node.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_LOW));
const auto& output = node.OutputDefs()[0];
auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()),
output->TypeAsProto());
Node& recompute_node = graph.AddNode(node.Name() + "_recompute",
node.OpType(),
"Recompute of " + node.Name(),
{node.MutableInputDefs()[0]},
{&recomputed_output},
&node.GetAttributes(),
node.Domain());
recompute_node.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_LOW));
modified = true;
}
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;
return Status::OK();
}
bool AttentionDropoutRecompute::SatisfyCondition(const Graph& /*graph*/, const Node& node, const logging::Logger& /*logger*/) const {
bool AttentionDropoutRecompute::SatisfyCondition(const Node& node) const {
if (node.OpType() != "Dropout")
return false;
const auto prev_node = node.InputNodesBegin();
const auto next_node = node.OutputNodesBegin();
if (prev_node != node.InputNodesEnd() && prev_node->OpType() == "Softmax" &&
@ -48,11 +68,22 @@ bool AttentionDropoutRecompute::SatisfyCondition(const Graph& /*graph*/, const N
return false;
}
Status AttentionDropoutRecompute::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& /*logger*/) const {
Node& recompute_node = InsertDropoutRecompute(graph, node, /*use_original_input*/ true);
recompute_node.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_LOW));
Status AttentionDropoutRecompute::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& /*logger*/) const {
GraphViewer graph_viewer(graph);
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;
for (NodeIndex i : order) {
Node& node = *graph.GetNode(i);
if (!SatisfyCondition(node)) {
continue;
}
Node& recompute_node = InsertDropoutRecompute(graph, node, /*use_original_input*/ true);
recompute_node.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_LOW));
modified = true;
}
return Status::OK();
}

View file

@ -3,7 +3,7 @@
#pragma once
#include "core/optimizer/rewrite_rule.h"
#include "core/optimizer/graph_transformer.h"
namespace onnxruntime {
@ -13,18 +13,16 @@ namespace onnxruntime {
Recompute Gelu/BiasGelu/FastGelu
*/
class GeluRecompute : public RewriteRule {
class GeluRecompute : public GraphTransformer {
public:
GeluRecompute() noexcept : RewriteRule("GeluRecompute") {}
GeluRecompute() noexcept : GraphTransformer("GeluRecompute") {}
std::vector<std::string> TargetOpTypes() const noexcept override {
return {"Gelu", "FastGelu", "BiasGelu"};
}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
bool ShouldOnlyApplyOnce() const override { return true; }
private:
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
bool SatisfyCondition(const Node& node) const;
};
/**
@ -33,18 +31,16 @@ class GeluRecompute : public RewriteRule {
Recompute Dropout in the attention layer
*/
class AttentionDropoutRecompute : public RewriteRule {
class AttentionDropoutRecompute : public GraphTransformer {
public:
AttentionDropoutRecompute() noexcept : RewriteRule("AttentionDropoutRecompute") {}
AttentionDropoutRecompute() noexcept : GraphTransformer("AttentionDropoutRecompute") {}
std::vector<std::string> TargetOpTypes() const noexcept override {
return {"Dropout"};
}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
bool ShouldOnlyApplyOnce() const override { return true; }
private:
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
bool SatisfyCondition(const Node& node) const;
};
} // namespace onnxruntime

View file

@ -116,7 +116,7 @@ void TransformerLayerRecompute::InsertRecomputeNodes(Graph& graph, const std::ve
Node* node = graph.GetNode(n->Index());
// recomputed Dropout need to produce the same output as original dropout
// currently reusing original dropout's mask to achieve this
// currently reusing original dropout's mask to achieve this
if (node->OpType() == "Dropout") {
const NodeArg* input = node->InputDefs()[0];
const Node* p_node = graph.GetProducerNode(input->Name());
@ -175,9 +175,25 @@ Status TransformerLayerRecompute::ApplyImpl(Graph& graph, bool& modified, int /*
return Status::OK();
}
// insert recompute nodes expect for the last transformer layer
// by default, apply recompute expect for the last transformer layer
// otherwise, take user specified 'number_recompute_layers_'
int n_layers;
const int n_layers_limit = static_cast<int>(start_end_edges.size() - 1);
if (number_recompute_layers_ > n_layers_limit) {
LOGS(logger, WARNING) << "User specified number_recompute_layers " << number_recompute_layers_
<< " is larger than limit " << n_layers_limit << "."
<< "number_recompute_layers is now cliped to limit.";
n_layers = n_layers_limit;
} else if (number_recompute_layers_ > 0) {
n_layers = number_recompute_layers_;
} else {
LOGS(logger, INFO) << "number_recompute_layers is not set by user, using default " << n_layers_limit << ".";
n_layers = n_layers_limit;
}
// latter recompute layers have higher execution priorty
for (size_t i = 0; i < start_end_edges.size() - 1; ++i) {
for (int i = 0; i < n_layers; ++i) {
std::vector<const Node*> nodes = NodesBetweenEdges(graph, start_end_edges[i].first, start_end_edges[i].second);
InsertRecomputeNodes(graph, nodes, static_cast<int>(start_end_edges.size() - i));
}

View file

@ -10,11 +10,15 @@ namespace onnxruntime {
class TransformerLayerRecompute : public GraphTransformer {
public:
TransformerLayerRecompute(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
: GraphTransformer("TransformerLayerRecompute", compatible_execution_providers) {}
TransformerLayerRecompute(int number_recompute_layers,
const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
: GraphTransformer("TransformerLayerRecompute", compatible_execution_providers),
number_recompute_layers_(number_recompute_layers) {}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
bool ShouldOnlyApplyOnce() const override { return true; }
private:
Status IdentifyTransformerLayerEdges(const Graph& graph,
std::vector<std::pair<const NodeArg*, const NodeArg*>>& start_end_edges,
@ -23,6 +27,8 @@ class TransformerLayerRecompute : public GraphTransformer {
std::vector<const Node*> NodesBetweenEdges(const Graph& graph, const NodeArg* start, const NodeArg* end) const;
void InsertRecomputeNodes(Graph& graph, const std::vector<const Node*>& nodes, int priority) const;
int number_recompute_layers_;
};
} // namespace onnxruntime

View file

@ -200,6 +200,8 @@ class TrainingSession : public InferenceSession {
bool gelu_recompute{false};
// Enable recompute of transformer layer ouput to save memory
bool transformer_layer_recompute{false};
// Number of layers to apply recompute
int number_recompute_layers{0};
};
GraphTransformerConfiguration graph_transformer_config{};

View file

@ -173,6 +173,8 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet
cxxopts::value<bool>()->default_value("false"))
("transformer_layer_recompute", "Enable checkpointing of transformer layer output to save memory.",
cxxopts::value<bool>()->default_value("false"))
("number_recompute_layers", "Number of layers to apply recompute.",
cxxopts::value<int>()->default_value("0"))
("use_invertible_layernorm_grad", "Specify whether to use invertible laynorm(dropping the input activation)",
cxxopts::value<bool>()->default_value("false"));
options
@ -463,6 +465,7 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet
params.attn_dropout_recompute = flags["attn_dropout_recompute"].as<bool>();
params.gelu_recompute = flags["gelu_recompute"].as<bool>();
params.transformer_layer_recompute = flags["transformer_layer_recompute"].as<bool>();
params.number_recompute_layers = flags["number_recompute_layers"].as<int>();
ort_params.log_severity = static_cast<logging::Severity>(flags["ort_log_severity"].as<int>());
ORT_RETURN_IF_NOT(

View file

@ -187,6 +187,7 @@ Status TrainingRunner::Initialize() {
gt_config.attn_dropout_recompute = params_.attn_dropout_recompute;
gt_config.gelu_recompute = params_.gelu_recompute;
gt_config.transformer_layer_recompute = params_.transformer_layer_recompute;
gt_config.number_recompute_layers = params_.number_recompute_layers;
config.graph_transformer_config = gt_config;
}

View file

@ -180,6 +180,8 @@ class TrainingRunner {
bool gelu_recompute = false;
// Enable checkpointing of transformer layer output to save memory
bool transformer_layer_recompute = false;
// Number of layers to apply recompute
int number_recompute_layers = 0;
// Use invertible layernorm grad
bool use_invertible_layernorm_grad = false;
};

View file

@ -47,6 +47,12 @@ struct TrainingParameters {
bool enable_grad_norm_clip = true;
bool set_gradients_as_graph_outputs = false;
bool use_invertible_layernorm_grad = false;
// recompute
bool attn_dropout_recompute = false;
bool gelu_recompute = false;
bool transformer_layer_recompute = false;
int number_recompute_layers = 0;
};
struct TrainingConfigurationResult {
@ -130,6 +136,11 @@ TrainingConfigurationResult ConfigureSessionForTraining(
config.gradient_graph_config.use_invertible_layernorm_grad = parameters.use_invertible_layernorm_grad;
config.gradient_graph_config.set_gradients_as_graph_outputs = parameters.set_gradients_as_graph_outputs;
config.graph_transformer_config.attn_dropout_recompute = parameters.attn_dropout_recompute;
config.graph_transformer_config.gelu_recompute = parameters.gelu_recompute;
config.graph_transformer_config.transformer_layer_recompute = parameters.transformer_layer_recompute;
config.graph_transformer_config.number_recompute_layers = parameters.number_recompute_layers;
training::TrainingSession::TrainingConfigurationResult config_result{};
OrtPybindThrowIfError(sess->ConfigureForTraining(config, config_result));
@ -186,7 +197,11 @@ void addObjectMethodsForTraining(py::module& m) {
.def_readwrite("deepspeed_zero_stage", &TrainingParameters::deepspeed_zero_stage)
.def_readwrite("enable_grad_norm_clip", &TrainingParameters::enable_grad_norm_clip)
.def_readwrite("set_gradients_as_graph_outputs", &TrainingParameters::set_gradients_as_graph_outputs)
.def_readwrite("use_invertible_layernorm_grad", &TrainingParameters::use_invertible_layernorm_grad);
.def_readwrite("use_invertible_layernorm_grad", &TrainingParameters::use_invertible_layernorm_grad)
.def_readwrite("attn_dropout_recompute", &TrainingParameters::attn_dropout_recompute)
.def_readwrite("gelu_recompute", &TrainingParameters::gelu_recompute)
.def_readwrite("transformer_layer_recompute", &TrainingParameters::transformer_layer_recompute)
.def_readwrite("number_recompute_layers", &TrainingParameters::number_recompute_layers);
#if defined(USE_NCCL)
m.def("get_mpi_context_local_rank", []() -> int { return MPIContext::GetInstance().GetLocalRank(); });

View file

@ -633,9 +633,18 @@ class ORTTrainer(object):
ort_parameters.optimizer_attributes_map = optimizer_attributes_map
ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map
ort_parameters.attn_dropout_recompute = self.options.graph_transformer.attn_dropout_recompute
ort_parameters.gelu_recompute = self.options.graph_transformer.gelu_recompute
ort_parameters.transformer_layer_recompute = self.options.graph_transformer.transformer_layer_recompute
ort_parameters.number_recompute_layers = self.options.graph_transformer.number_recompute_layers
# SessionOptions
session_options = ort.SessionOptions()
session_options.use_deterministic_compute = self.options.debug.deterministic_compute
if (self.options.graph_transformer.attn_dropout_recompute or
self.options.graph_transformer.gelu_recompute or
self.options.graph_transformer.transformer_layer_recompute):
session_options.execution_order = ort.ExecutionOrder.PRIORITY_BASED
# TrainingSession
self._training_session = ort.TrainingSession(self._onnx_model.SerializeToString(),

View file

@ -116,6 +116,30 @@ class ORTTrainerOptions(object):
}
}
},
'graph_transformer': {
'type': 'dict',
'required': False,
'default': {},
'schema': {
'attn_dropout_recompute': {
'type': 'boolean',
'default': False
},
'gelu_recompute': {
'type': 'boolean',
'default': False
},
'transformer_layer_recompute': {
'type': 'boolean',
'default': False
},
'number_recompute_layers': {
'type': 'integer',
'min': 0,
'default': 0
}
}
},
'utils' : {
'type' : 'dict',
'required': False,
@ -221,6 +245,17 @@ class ORTTrainerOptions(object):
Users can also instantiate :py:class:`.DynamicLossScaler` and
override its parameters. Lastly, a completely new implementation
can be specified by extending :py:class:`.LossScaler` class from scratch
graph_transformer (dict):
graph transformer related configurations
attn_dropout_recompute (bool, default is False):
enable recomputing attention dropout to save memory
gelu_recompute (bool, default is False):
enable recomputing Gelu activation output to save memory
transformer_layer_recompute (bool, default is False):
enable recomputing transformer layerwise to save memory
number_recompute_layers (int, default is 0)
number of layers to apply transformer_layer_recompute, by default system will
apply recompute to all the layers, except for the last one
utils (dict):
miscellaneous options
utils.frozen_weights (list of str, []):
@ -435,6 +470,30 @@ _ORTTRAINER_OPTIONS_SCHEMA = {
}
}
},
'graph_transformer': {
'type': 'dict',
'default_setter': lambda _: {},
'required': False,
'schema': {
'attn_dropout_recompute': {
'type': 'boolean',
'default': False
},
'gelu_recompute': {
'type': 'boolean',
'default': False
},
'transformer_layer_recompute': {
'type': 'boolean',
'default': False
},
'number_recompute_layers': {
'type': 'integer',
'min': 0,
'default': 0
}
}
},
'utils': {
'type': 'dict',
'default_setter': lambda _: {},

View file

@ -94,6 +94,12 @@ def testORTTrainerOptionsDefaultValues(test_input):
'enabled': False,
'loss_scaler': None
},
'graph_transformer': {
'attn_dropout_recompute': False,
'gelu_recompute': False,
'transformer_layer_recompute': False,
'number_recompute_layers': 0
},
'utils': {
'frozen_weights': [],
'grad_norm_clip': True,
@ -728,6 +734,61 @@ def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss, fetches)
assert trainer._onnx_model is not None
def _recompute_data():
device_capability_major = torch.cuda.get_device_capability()[0]
if device_capability_major == 7: # V100 for Dev machine
expected_loss = [10.577394, 10.444777, 10.425666, 10.299958, 10.290016]
return [
(False, False, False, 0, expected_loss), # no recompute
(True, False, False, 0, expected_loss), # attn_dropout recompute
(False, True, False, 0, expected_loss), # gelu recompute
(False, False, True, 0, expected_loss), # transformer_layer recompute
(False, False, True, 1, expected_loss), # transformer_layer recompute with 1 layer
]
elif device_capability_major == 5: # M60 for CI machines
expected_loss = [10.56341 , 10.461096, 10.364473, 10.297504, 10.249142]
return [
(False, False, False, 0, expected_loss), # no recompute
(True, False, False, 0, expected_loss), # attn_dropout recompute
(False, True, False, 0, expected_loss), # gelu recompute
(False, False, True, 0, expected_loss), # transformer_layer recompute
(False, False, True, 1, expected_loss), # transformer_layer recompute with 1 layer
]
@pytest.mark.parametrize("attn_dropout, gelu, transformer_layer, number_layers, expected_loss", _recompute_data())
def testORTTrainerRecompute(attn_dropout, gelu, transformer_layer, number_layers, expected_loss):
seed = 321
device = 'cuda'
rtol = 1e-3
total_steps = len(expected_loss)
torch.manual_seed(seed)
set_seed(seed)
# Setup ORTTrainer
loss_scaler = amp.DynamicLossScaler()
options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
'graph_transformer' : {
'attn_dropout_recompute': attn_dropout,
'gelu_recompute': gelu,
'transformer_layer_recompute': transformer_layer,
'number_recompute_layers': number_layers
},
'debug' : {'deterministic_compute' : True}})
model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
optim_config = optim.LambConfig(lr=0.001)
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
# Training loop
actual_loss = []
for i in range(total_steps):
data, targets = batcher_fn(train_data, i)
loss, _ = trainer.train_step(data, targets)
actual_loss.append(loss.cpu())
# Compare loss to ground truth computed from current ORTTrainer API
_test_helpers.assert_model_outputs(expected_loss, actual_loss, True, rtol=rtol)
assert trainer._onnx_model is not None
@pytest.mark.parametrize("seed,device,gradient_accumulation_steps,total_steps,expected_loss", [
(0, 'cuda', 1, 12, [10.5368022919, 10.4146203995, 10.3635568619, 10.2650547028, 10.2284049988, 10.1304626465,\
10.0853414536, 9.9987659454, 9.9472427368, 9.8832416534, 9.8223171234, 9.8222122192]),