mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-21 02:18:09 +00:00
Expose recompute configs to the frontend (#5318)
* Expose recompute configs to the frontend * Add frontend test * Ensure recompute graph transformer is only applied once Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
parent
e33de20861
commit
e71668f92c
17 changed files with 262 additions and 51 deletions
|
|
@ -39,6 +39,8 @@ class GraphTransformer {
|
|||
*/
|
||||
common::Status Apply(Graph& graph, bool& modified, const logging::Logger& logger) const;
|
||||
|
||||
virtual bool ShouldOnlyApplyOnce() const { return false; }
|
||||
|
||||
protected:
|
||||
/** Helper method to call ApplyImpl on any subgraphs in the Node. */
|
||||
common::Status Recurse(Node& node, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ __author__ = "Microsoft"
|
|||
|
||||
from onnxruntime.capi._pybind_state import get_all_providers, get_available_providers, get_device, set_seed, \
|
||||
RunOptions, SessionOptions, set_default_logger_severity, NodeArg, ModelMetadata, GraphOptimizationLevel, \
|
||||
ExecutionMode, OrtDevice, SessionIOBinding
|
||||
ExecutionMode, ExecutionOrder, OrtDevice, SessionIOBinding
|
||||
|
||||
try:
|
||||
from onnxruntime.capi._pybind_state import set_cuda_mem_limit, set_cuda_device_id
|
||||
|
|
|
|||
|
|
@ -28,6 +28,9 @@ common::Status GraphTransformerManager::ApplyTransformers(Graph& graph, Transfor
|
|||
for (unsigned step = 0; step < steps_; ++step) {
|
||||
bool graph_changed = false;
|
||||
for (const auto& transformer : transformers->second) {
|
||||
if (step > 0 && transformer->ShouldOnlyApplyOnce())
|
||||
continue;
|
||||
|
||||
bool modified = false;
|
||||
ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, logger));
|
||||
graph_changed = graph_changed || modified;
|
||||
|
|
|
|||
|
|
@ -937,6 +937,10 @@ void addObjectMethods(py::module& m, Environment& env) {
|
|||
.value("ORT_SEQUENTIAL", ExecutionMode::ORT_SEQUENTIAL)
|
||||
.value("ORT_PARALLEL", ExecutionMode::ORT_PARALLEL);
|
||||
|
||||
py::enum_<ExecutionOrder>(m, "ExecutionOrder")
|
||||
.value("DEFAULT", ExecutionOrder::DEFAULT)
|
||||
.value("PRIORITY_BASED", ExecutionOrder::PRIORITY_BASED);
|
||||
|
||||
py::class_<OrtDevice> device(m, "OrtDevice", R"pbdoc(ONNXRuntime device informaion.)pbdoc");
|
||||
device.def(py::init<OrtDevice::DeviceType, OrtDevice::MemoryType, OrtDevice::DeviceId>())
|
||||
.def("device_id", &OrtDevice::Id, R"pbdoc(Device Id.)pbdoc")
|
||||
|
|
@ -1089,7 +1093,7 @@ Applies to session load, initialization, etc. Default is 0.)pbdoc")
|
|||
R"pbdoc(Sets the number of threads used to parallelize the execution of the graph (across nodes). Default is 0 to let onnxruntime choose.)pbdoc")
|
||||
.def_readwrite("execution_mode", &PySessionOptions::execution_mode,
|
||||
R"pbdoc(Sets the execution mode. Default is sequential.)pbdoc")
|
||||
.def_readwrite("execution_order", &SessionOptions::execution_order,
|
||||
.def_readwrite("execution_order", &PySessionOptions::execution_order,
|
||||
R"pbdoc(Sets the execution order. Default is basic topological order.)pbdoc")
|
||||
.def_property(
|
||||
"graph_optimization_level",
|
||||
|
|
|
|||
|
|
@ -74,12 +74,6 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
rule_transformer->Register(make_unique<CastElimination>());
|
||||
rule_transformer->Register(make_unique<NonZeroShapeSetter>());
|
||||
rule_transformer->Register(make_unique<InsertSoftmaxCrossEntropyLossOutput>());
|
||||
if (config.gelu_recompute) {
|
||||
rule_transformer->Register(make_unique<GeluRecompute>());
|
||||
}
|
||||
if (config.attn_dropout_recompute) {
|
||||
rule_transformer->Register(make_unique<AttentionDropoutRecompute>());
|
||||
}
|
||||
|
||||
transformers.emplace_back(onnxruntime::make_unique<GeluFusion>(compatible_eps));
|
||||
transformers.emplace_back(onnxruntime::make_unique<LayerNormFusion>(compatible_eps));
|
||||
|
|
@ -106,8 +100,15 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
}
|
||||
transformers.emplace_back(onnxruntime::make_unique<ComputationReductionTransformer>(compatible_eps));
|
||||
|
||||
if (config.gelu_recompute) {
|
||||
transformers.emplace_back(onnxruntime::make_unique<GeluRecompute>());
|
||||
}
|
||||
if (config.attn_dropout_recompute) {
|
||||
transformers.emplace_back(onnxruntime::make_unique<AttentionDropoutRecompute>());
|
||||
}
|
||||
if (config.transformer_layer_recompute) {
|
||||
transformers.emplace_back(onnxruntime::make_unique<TransformerLayerRecompute>(compatible_eps));
|
||||
transformers.emplace_back(onnxruntime::make_unique<TransformerLayerRecompute>(
|
||||
config.number_recompute_layers, compatible_eps));
|
||||
}
|
||||
} break;
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,12 @@ using namespace ONNX_NAMESPACE;
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
bool GeluRecompute::SatisfyCondition(const Graph& /*graph*/, const Node& node, const logging::Logger& /*logger*/) const {
|
||||
bool GeluRecompute::SatisfyCondition(const Node& node) const {
|
||||
static const std::unordered_set<std::string> target_optypes = {"Gelu", "FastGelu", "BiasGelu"};
|
||||
if (target_optypes.find(node.OpType()) == target_optypes.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto next_node = node.OutputNodesBegin();
|
||||
if (next_node != node.OutputNodesEnd() && next_node->OpType() == "MatMul") {
|
||||
return true;
|
||||
|
|
@ -18,27 +23,42 @@ bool GeluRecompute::SatisfyCondition(const Graph& /*graph*/, const Node& node, c
|
|||
return false;
|
||||
}
|
||||
|
||||
Status GeluRecompute::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& /*logger*/) const {
|
||||
const auto& output = node.OutputDefs()[0];
|
||||
Status GeluRecompute::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& /*logger*/) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()),
|
||||
output->TypeAsProto());
|
||||
for (NodeIndex i : order) {
|
||||
Node& node = *graph.GetNode(i);
|
||||
|
||||
Node& recompute_node = graph.AddNode(node.Name() + "_recompute",
|
||||
node.OpType(),
|
||||
"Recompute of " + node.Name(),
|
||||
{node.MutableInputDefs()[0]},
|
||||
{&recomputed_output},
|
||||
&node.GetAttributes(),
|
||||
node.Domain());
|
||||
if (!SatisfyCondition(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
recompute_node.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_LOW));
|
||||
const auto& output = node.OutputDefs()[0];
|
||||
|
||||
auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()),
|
||||
output->TypeAsProto());
|
||||
|
||||
Node& recompute_node = graph.AddNode(node.Name() + "_recompute",
|
||||
node.OpType(),
|
||||
"Recompute of " + node.Name(),
|
||||
{node.MutableInputDefs()[0]},
|
||||
{&recomputed_output},
|
||||
&node.GetAttributes(),
|
||||
node.Domain());
|
||||
|
||||
recompute_node.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_LOW));
|
||||
|
||||
modified = true;
|
||||
}
|
||||
|
||||
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool AttentionDropoutRecompute::SatisfyCondition(const Graph& /*graph*/, const Node& node, const logging::Logger& /*logger*/) const {
|
||||
bool AttentionDropoutRecompute::SatisfyCondition(const Node& node) const {
|
||||
if (node.OpType() != "Dropout")
|
||||
return false;
|
||||
|
||||
const auto prev_node = node.InputNodesBegin();
|
||||
const auto next_node = node.OutputNodesBegin();
|
||||
if (prev_node != node.InputNodesEnd() && prev_node->OpType() == "Softmax" &&
|
||||
|
|
@ -48,11 +68,22 @@ bool AttentionDropoutRecompute::SatisfyCondition(const Graph& /*graph*/, const N
|
|||
return false;
|
||||
}
|
||||
|
||||
Status AttentionDropoutRecompute::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& /*logger*/) const {
|
||||
Node& recompute_node = InsertDropoutRecompute(graph, node, /*use_original_input*/ true);
|
||||
recompute_node.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_LOW));
|
||||
Status AttentionDropoutRecompute::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& /*logger*/) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;
|
||||
for (NodeIndex i : order) {
|
||||
Node& node = *graph.GetNode(i);
|
||||
|
||||
if (!SatisfyCondition(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Node& recompute_node = InsertDropoutRecompute(graph, node, /*use_original_input*/ true);
|
||||
recompute_node.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_LOW));
|
||||
|
||||
modified = true;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/rewrite_rule.h"
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -13,18 +13,16 @@ namespace onnxruntime {
|
|||
Recompute Gelu/BiasGelu/FastGelu
|
||||
|
||||
*/
|
||||
class GeluRecompute : public RewriteRule {
|
||||
class GeluRecompute : public GraphTransformer {
|
||||
public:
|
||||
GeluRecompute() noexcept : RewriteRule("GeluRecompute") {}
|
||||
GeluRecompute() noexcept : GraphTransformer("GeluRecompute") {}
|
||||
|
||||
std::vector<std::string> TargetOpTypes() const noexcept override {
|
||||
return {"Gelu", "FastGelu", "BiasGelu"};
|
||||
}
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
|
||||
bool ShouldOnlyApplyOnce() const override { return true; }
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
|
||||
bool SatisfyCondition(const Node& node) const;
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -33,18 +31,16 @@ class GeluRecompute : public RewriteRule {
|
|||
Recompute Dropout in the attention layer
|
||||
|
||||
*/
|
||||
class AttentionDropoutRecompute : public RewriteRule {
|
||||
class AttentionDropoutRecompute : public GraphTransformer {
|
||||
public:
|
||||
AttentionDropoutRecompute() noexcept : RewriteRule("AttentionDropoutRecompute") {}
|
||||
AttentionDropoutRecompute() noexcept : GraphTransformer("AttentionDropoutRecompute") {}
|
||||
|
||||
std::vector<std::string> TargetOpTypes() const noexcept override {
|
||||
return {"Dropout"};
|
||||
}
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
|
||||
bool ShouldOnlyApplyOnce() const override { return true; }
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
|
||||
bool SatisfyCondition(const Node& node) const;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ void TransformerLayerRecompute::InsertRecomputeNodes(Graph& graph, const std::ve
|
|||
Node* node = graph.GetNode(n->Index());
|
||||
|
||||
// recomputed Dropout need to produce the same output as original dropout
|
||||
// currently reusing original dropout's mask to achieve this
|
||||
// currently reusing original dropout's mask to achieve this
|
||||
if (node->OpType() == "Dropout") {
|
||||
const NodeArg* input = node->InputDefs()[0];
|
||||
const Node* p_node = graph.GetProducerNode(input->Name());
|
||||
|
|
@ -175,9 +175,25 @@ Status TransformerLayerRecompute::ApplyImpl(Graph& graph, bool& modified, int /*
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// insert recompute nodes expect for the last transformer layer
|
||||
// by default, apply recompute expect for the last transformer layer
|
||||
// otherwise, take user specified 'number_recompute_layers_'
|
||||
|
||||
int n_layers;
|
||||
const int n_layers_limit = static_cast<int>(start_end_edges.size() - 1);
|
||||
if (number_recompute_layers_ > n_layers_limit) {
|
||||
LOGS(logger, WARNING) << "User specified number_recompute_layers " << number_recompute_layers_
|
||||
<< " is larger than limit " << n_layers_limit << "."
|
||||
<< "number_recompute_layers is now cliped to limit.";
|
||||
n_layers = n_layers_limit;
|
||||
} else if (number_recompute_layers_ > 0) {
|
||||
n_layers = number_recompute_layers_;
|
||||
} else {
|
||||
LOGS(logger, INFO) << "number_recompute_layers is not set by user, using default " << n_layers_limit << ".";
|
||||
n_layers = n_layers_limit;
|
||||
}
|
||||
|
||||
// latter recompute layers have higher execution priorty
|
||||
for (size_t i = 0; i < start_end_edges.size() - 1; ++i) {
|
||||
for (int i = 0; i < n_layers; ++i) {
|
||||
std::vector<const Node*> nodes = NodesBetweenEdges(graph, start_end_edges[i].first, start_end_edges[i].second);
|
||||
InsertRecomputeNodes(graph, nodes, static_cast<int>(start_end_edges.size() - i));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,11 +10,15 @@ namespace onnxruntime {
|
|||
|
||||
class TransformerLayerRecompute : public GraphTransformer {
|
||||
public:
|
||||
TransformerLayerRecompute(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
|
||||
: GraphTransformer("TransformerLayerRecompute", compatible_execution_providers) {}
|
||||
TransformerLayerRecompute(int number_recompute_layers,
|
||||
const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
|
||||
: GraphTransformer("TransformerLayerRecompute", compatible_execution_providers),
|
||||
number_recompute_layers_(number_recompute_layers) {}
|
||||
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
|
||||
bool ShouldOnlyApplyOnce() const override { return true; }
|
||||
|
||||
private:
|
||||
Status IdentifyTransformerLayerEdges(const Graph& graph,
|
||||
std::vector<std::pair<const NodeArg*, const NodeArg*>>& start_end_edges,
|
||||
|
|
@ -23,6 +27,8 @@ class TransformerLayerRecompute : public GraphTransformer {
|
|||
std::vector<const Node*> NodesBetweenEdges(const Graph& graph, const NodeArg* start, const NodeArg* end) const;
|
||||
|
||||
void InsertRecomputeNodes(Graph& graph, const std::vector<const Node*>& nodes, int priority) const;
|
||||
|
||||
int number_recompute_layers_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -200,6 +200,8 @@ class TrainingSession : public InferenceSession {
|
|||
bool gelu_recompute{false};
|
||||
// Enable recompute of transformer layer ouput to save memory
|
||||
bool transformer_layer_recompute{false};
|
||||
// Number of layers to apply recompute
|
||||
int number_recompute_layers{0};
|
||||
};
|
||||
|
||||
GraphTransformerConfiguration graph_transformer_config{};
|
||||
|
|
|
|||
|
|
@ -173,6 +173,8 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet
|
|||
cxxopts::value<bool>()->default_value("false"))
|
||||
("transformer_layer_recompute", "Enable checkpointing of transformer layer output to save memory.",
|
||||
cxxopts::value<bool>()->default_value("false"))
|
||||
("number_recompute_layers", "Number of layers to apply recompute.",
|
||||
cxxopts::value<int>()->default_value("0"))
|
||||
("use_invertible_layernorm_grad", "Specify whether to use invertible laynorm(dropping the input activation)",
|
||||
cxxopts::value<bool>()->default_value("false"));
|
||||
options
|
||||
|
|
@ -463,6 +465,7 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet
|
|||
params.attn_dropout_recompute = flags["attn_dropout_recompute"].as<bool>();
|
||||
params.gelu_recompute = flags["gelu_recompute"].as<bool>();
|
||||
params.transformer_layer_recompute = flags["transformer_layer_recompute"].as<bool>();
|
||||
params.number_recompute_layers = flags["number_recompute_layers"].as<int>();
|
||||
|
||||
ort_params.log_severity = static_cast<logging::Severity>(flags["ort_log_severity"].as<int>());
|
||||
ORT_RETURN_IF_NOT(
|
||||
|
|
|
|||
|
|
@ -187,6 +187,7 @@ Status TrainingRunner::Initialize() {
|
|||
gt_config.attn_dropout_recompute = params_.attn_dropout_recompute;
|
||||
gt_config.gelu_recompute = params_.gelu_recompute;
|
||||
gt_config.transformer_layer_recompute = params_.transformer_layer_recompute;
|
||||
gt_config.number_recompute_layers = params_.number_recompute_layers;
|
||||
|
||||
config.graph_transformer_config = gt_config;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -180,6 +180,8 @@ class TrainingRunner {
|
|||
bool gelu_recompute = false;
|
||||
// Enable checkpointing of transformer layer output to save memory
|
||||
bool transformer_layer_recompute = false;
|
||||
// Number of layers to apply recompute
|
||||
int number_recompute_layers = 0;
|
||||
// Use invertible layernorm grad
|
||||
bool use_invertible_layernorm_grad = false;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -47,6 +47,12 @@ struct TrainingParameters {
|
|||
bool enable_grad_norm_clip = true;
|
||||
bool set_gradients_as_graph_outputs = false;
|
||||
bool use_invertible_layernorm_grad = false;
|
||||
|
||||
// recompute
|
||||
bool attn_dropout_recompute = false;
|
||||
bool gelu_recompute = false;
|
||||
bool transformer_layer_recompute = false;
|
||||
int number_recompute_layers = 0;
|
||||
};
|
||||
|
||||
struct TrainingConfigurationResult {
|
||||
|
|
@ -130,6 +136,11 @@ TrainingConfigurationResult ConfigureSessionForTraining(
|
|||
config.gradient_graph_config.use_invertible_layernorm_grad = parameters.use_invertible_layernorm_grad;
|
||||
config.gradient_graph_config.set_gradients_as_graph_outputs = parameters.set_gradients_as_graph_outputs;
|
||||
|
||||
config.graph_transformer_config.attn_dropout_recompute = parameters.attn_dropout_recompute;
|
||||
config.graph_transformer_config.gelu_recompute = parameters.gelu_recompute;
|
||||
config.graph_transformer_config.transformer_layer_recompute = parameters.transformer_layer_recompute;
|
||||
config.graph_transformer_config.number_recompute_layers = parameters.number_recompute_layers;
|
||||
|
||||
training::TrainingSession::TrainingConfigurationResult config_result{};
|
||||
|
||||
OrtPybindThrowIfError(sess->ConfigureForTraining(config, config_result));
|
||||
|
|
@ -186,7 +197,11 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
.def_readwrite("deepspeed_zero_stage", &TrainingParameters::deepspeed_zero_stage)
|
||||
.def_readwrite("enable_grad_norm_clip", &TrainingParameters::enable_grad_norm_clip)
|
||||
.def_readwrite("set_gradients_as_graph_outputs", &TrainingParameters::set_gradients_as_graph_outputs)
|
||||
.def_readwrite("use_invertible_layernorm_grad", &TrainingParameters::use_invertible_layernorm_grad);
|
||||
.def_readwrite("use_invertible_layernorm_grad", &TrainingParameters::use_invertible_layernorm_grad)
|
||||
.def_readwrite("attn_dropout_recompute", &TrainingParameters::attn_dropout_recompute)
|
||||
.def_readwrite("gelu_recompute", &TrainingParameters::gelu_recompute)
|
||||
.def_readwrite("transformer_layer_recompute", &TrainingParameters::transformer_layer_recompute)
|
||||
.def_readwrite("number_recompute_layers", &TrainingParameters::number_recompute_layers);
|
||||
|
||||
#if defined(USE_NCCL)
|
||||
m.def("get_mpi_context_local_rank", []() -> int { return MPIContext::GetInstance().GetLocalRank(); });
|
||||
|
|
|
|||
|
|
@ -633,9 +633,18 @@ class ORTTrainer(object):
|
|||
ort_parameters.optimizer_attributes_map = optimizer_attributes_map
|
||||
ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map
|
||||
|
||||
ort_parameters.attn_dropout_recompute = self.options.graph_transformer.attn_dropout_recompute
|
||||
ort_parameters.gelu_recompute = self.options.graph_transformer.gelu_recompute
|
||||
ort_parameters.transformer_layer_recompute = self.options.graph_transformer.transformer_layer_recompute
|
||||
ort_parameters.number_recompute_layers = self.options.graph_transformer.number_recompute_layers
|
||||
|
||||
# SessionOptions
|
||||
session_options = ort.SessionOptions()
|
||||
session_options.use_deterministic_compute = self.options.debug.deterministic_compute
|
||||
if (self.options.graph_transformer.attn_dropout_recompute or
|
||||
self.options.graph_transformer.gelu_recompute or
|
||||
self.options.graph_transformer.transformer_layer_recompute):
|
||||
session_options.execution_order = ort.ExecutionOrder.PRIORITY_BASED
|
||||
|
||||
# TrainingSession
|
||||
self._training_session = ort.TrainingSession(self._onnx_model.SerializeToString(),
|
||||
|
|
|
|||
|
|
@ -116,6 +116,30 @@ class ORTTrainerOptions(object):
|
|||
}
|
||||
}
|
||||
},
|
||||
'graph_transformer': {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'default': {},
|
||||
'schema': {
|
||||
'attn_dropout_recompute': {
|
||||
'type': 'boolean',
|
||||
'default': False
|
||||
},
|
||||
'gelu_recompute': {
|
||||
'type': 'boolean',
|
||||
'default': False
|
||||
},
|
||||
'transformer_layer_recompute': {
|
||||
'type': 'boolean',
|
||||
'default': False
|
||||
},
|
||||
'number_recompute_layers': {
|
||||
'type': 'integer',
|
||||
'min': 0,
|
||||
'default': 0
|
||||
}
|
||||
}
|
||||
},
|
||||
'utils' : {
|
||||
'type' : 'dict',
|
||||
'required': False,
|
||||
|
|
@ -221,6 +245,17 @@ class ORTTrainerOptions(object):
|
|||
Users can also instantiate :py:class:`.DynamicLossScaler` and
|
||||
override its parameters. Lastly, a completely new implementation
|
||||
can be specified by extending :py:class:`.LossScaler` class from scratch
|
||||
graph_transformer (dict):
|
||||
graph transformer related configurations
|
||||
attn_dropout_recompute (bool, default is False):
|
||||
enable recomputing attention dropout to save memory
|
||||
gelu_recompute (bool, default is False):
|
||||
enable recomputing Gelu activation output to save memory
|
||||
transformer_layer_recompute (bool, default is False):
|
||||
enable recomputing transformer layerwise to save memory
|
||||
number_recompute_layers (int, default is 0)
|
||||
number of layers to apply transformer_layer_recompute, by default system will
|
||||
apply recompute to all the layers, except for the last one
|
||||
utils (dict):
|
||||
miscellaneous options
|
||||
utils.frozen_weights (list of str, []):
|
||||
|
|
@ -435,6 +470,30 @@ _ORTTRAINER_OPTIONS_SCHEMA = {
|
|||
}
|
||||
}
|
||||
},
|
||||
'graph_transformer': {
|
||||
'type': 'dict',
|
||||
'default_setter': lambda _: {},
|
||||
'required': False,
|
||||
'schema': {
|
||||
'attn_dropout_recompute': {
|
||||
'type': 'boolean',
|
||||
'default': False
|
||||
},
|
||||
'gelu_recompute': {
|
||||
'type': 'boolean',
|
||||
'default': False
|
||||
},
|
||||
'transformer_layer_recompute': {
|
||||
'type': 'boolean',
|
||||
'default': False
|
||||
},
|
||||
'number_recompute_layers': {
|
||||
'type': 'integer',
|
||||
'min': 0,
|
||||
'default': 0
|
||||
}
|
||||
}
|
||||
},
|
||||
'utils': {
|
||||
'type': 'dict',
|
||||
'default_setter': lambda _: {},
|
||||
|
|
|
|||
|
|
@ -94,6 +94,12 @@ def testORTTrainerOptionsDefaultValues(test_input):
|
|||
'enabled': False,
|
||||
'loss_scaler': None
|
||||
},
|
||||
'graph_transformer': {
|
||||
'attn_dropout_recompute': False,
|
||||
'gelu_recompute': False,
|
||||
'transformer_layer_recompute': False,
|
||||
'number_recompute_layers': 0
|
||||
},
|
||||
'utils': {
|
||||
'frozen_weights': [],
|
||||
'grad_norm_clip': True,
|
||||
|
|
@ -728,6 +734,61 @@ def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss, fetches)
|
|||
assert trainer._onnx_model is not None
|
||||
|
||||
|
||||
def _recompute_data():
|
||||
device_capability_major = torch.cuda.get_device_capability()[0]
|
||||
if device_capability_major == 7: # V100 for Dev machine
|
||||
expected_loss = [10.577394, 10.444777, 10.425666, 10.299958, 10.290016]
|
||||
return [
|
||||
(False, False, False, 0, expected_loss), # no recompute
|
||||
(True, False, False, 0, expected_loss), # attn_dropout recompute
|
||||
(False, True, False, 0, expected_loss), # gelu recompute
|
||||
(False, False, True, 0, expected_loss), # transformer_layer recompute
|
||||
(False, False, True, 1, expected_loss), # transformer_layer recompute with 1 layer
|
||||
]
|
||||
elif device_capability_major == 5: # M60 for CI machines
|
||||
expected_loss = [10.56341 , 10.461096, 10.364473, 10.297504, 10.249142]
|
||||
return [
|
||||
(False, False, False, 0, expected_loss), # no recompute
|
||||
(True, False, False, 0, expected_loss), # attn_dropout recompute
|
||||
(False, True, False, 0, expected_loss), # gelu recompute
|
||||
(False, False, True, 0, expected_loss), # transformer_layer recompute
|
||||
(False, False, True, 1, expected_loss), # transformer_layer recompute with 1 layer
|
||||
]
|
||||
@pytest.mark.parametrize("attn_dropout, gelu, transformer_layer, number_layers, expected_loss", _recompute_data())
|
||||
def testORTTrainerRecompute(attn_dropout, gelu, transformer_layer, number_layers, expected_loss):
|
||||
seed = 321
|
||||
device = 'cuda'
|
||||
rtol = 1e-3
|
||||
total_steps = len(expected_loss)
|
||||
torch.manual_seed(seed)
|
||||
set_seed(seed)
|
||||
|
||||
# Setup ORTTrainer
|
||||
loss_scaler = amp.DynamicLossScaler()
|
||||
options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
|
||||
'graph_transformer' : {
|
||||
'attn_dropout_recompute': attn_dropout,
|
||||
'gelu_recompute': gelu,
|
||||
'transformer_layer_recompute': transformer_layer,
|
||||
'number_recompute_layers': number_layers
|
||||
},
|
||||
'debug' : {'deterministic_compute' : True}})
|
||||
model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
|
||||
optim_config = optim.LambConfig(lr=0.001)
|
||||
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
|
||||
|
||||
# Training loop
|
||||
actual_loss = []
|
||||
for i in range(total_steps):
|
||||
data, targets = batcher_fn(train_data, i)
|
||||
loss, _ = trainer.train_step(data, targets)
|
||||
actual_loss.append(loss.cpu())
|
||||
|
||||
# Compare loss to ground truth computed from current ORTTrainer API
|
||||
_test_helpers.assert_model_outputs(expected_loss, actual_loss, True, rtol=rtol)
|
||||
assert trainer._onnx_model is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed,device,gradient_accumulation_steps,total_steps,expected_loss", [
|
||||
(0, 'cuda', 1, 12, [10.5368022919, 10.4146203995, 10.3635568619, 10.2650547028, 10.2284049988, 10.1304626465,\
|
||||
10.0853414536, 9.9987659454, 9.9472427368, 9.8832416534, 9.8223171234, 9.8222122192]),
|
||||
|
|
|
|||
Loading…
Reference in a new issue