Localized Recompute for Gelu and AttentionDropout (#4402)

* Gelu Activation Recompute Draft

* Prototype for localized recompute

* Introduce localized_recompute rewriter

* Command line args for enabling recompute

* Add logger to Gradient Graph Builder

* use const when possible
This commit is contained in:
Sherlock 2020-08-04 21:48:15 -07:00 committed by GitHub
parent 0933148fc3
commit eb0f57f0e4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 256 additions and 75 deletions

View file

@ -14,9 +14,7 @@ namespace onnxruntime {
// use value-based compare to make sure transformer output order is consistent
struct NodeCompare {
bool operator()(const Node* n1, const Node* n2) const {
return n1->Index() < n2->Index();
}
bool operator()(const Node* n1, const Node* n2) const;
};
/**

View file

@ -76,7 +76,7 @@ Status IExecutionFrame::GetOrCreateNodeOutputMLValue(int index, const TensorShap
bool IExecutionFrame::TryGetInferredShape(int /*index*/, TensorShape& /*shape*/) const {
// By default, there is not information about inferred shape, so this default
// implementation always returns false. The derived class of IExecutionFrame
// can override this function to provide, for example, activations' shape information.
// can override this function to provide, for example, activations' shape information.
return false;
}
@ -261,7 +261,7 @@ ExecutionFrame::ExecutionFrame(const std::vector<int>& feed_mlvalue_idxs, const
// it's less efficient (the arena will add some overhead to coalesce individual allocations
// back into blocks on 'free'), but better than failing completely.
try {
// static_activation_memory_in_bytes_ is max virtual memory size the planner computes
// static_activation_memory_in_bytes_ is max virtual memory size the planner computes
auto peak_size = mem_patterns_->patterns[i].PeakSize();
// Planning of one memory type should only happen once.
buffer = alloc->Alloc(peak_size);
@ -326,7 +326,7 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va
// Lazily get the allocator only if needed.
AllocatorPtr alloc = nullptr;
// create fence if needed
if (create_fence) {
ORT_ENFORCE(ort_value.Fence() == nullptr);
@ -630,7 +630,6 @@ Status ExecutionFrame::GeneratePatterns(MemoryPatternGroup* out) const {
return planner_->GeneratePatterns(out);
}
bool ExecutionFrame::TryGetInferredShape(int index, TensorShape& shape) const {
// NodeArg index to OrtValue index.
int ort_value_idx = GetNodeIdxToMLValueIdx(index);
@ -639,7 +638,7 @@ bool ExecutionFrame::TryGetInferredShape(int index, TensorShape& shape) const {
if (ort_value_idx == NodeIndexInfo::kInvalidEntry) {
return false;
}
// Search for inferred shape.
// If inferred shape is found, it's assigned to "shape" so that caller can use it.
auto it = inferred_shapes_.find(ort_value_idx);

View file

@ -11,6 +11,11 @@
#include "core/graph/graph_utils.h"
namespace onnxruntime {
bool NodeCompare::operator()(const Node* n1, const Node* n2) const {
return n1->Index() < n2->Index();
}
GraphViewer::GraphViewer(const Graph& graph) {
graph_ = &graph;
std::vector<const Node*> leaf_nodes;

View file

@ -22,13 +22,13 @@ using namespace common;
GradientGraphBuilder::GradientGraphBuilder(Graph* graph,
const unordered_set<string>& y_node_arg_names,
const unordered_set<string>& x_node_arg_names,
string loss_node_arg_name,
const std::string& loss_node_arg_name,
const GradientGraphConfiguration& gradient_graph_config,
const bool set_gradient_as_graph_output)
const logging::Logger& logger)
: graph_(graph),
loss_node_arg_name_(loss_node_arg_name),
gradient_graph_config_(gradient_graph_config),
set_gradient_as_graph_output_(set_gradient_as_graph_output) {
logger_(logger) {
auto rule_based_graph_transformer =
onnxruntime::make_unique<RuleBasedGraphTransformer>("pre_training_rule_based_graph_transformer");
rule_based_graph_transformer->Register(make_unique<InsertMaxPoolOutput>());
@ -91,7 +91,7 @@ NodeSet GradientGraphBuilder::ReverseBFS(const NodeSet& nodes) {
for (auto edge_it = n->InputEdgesBegin(); edge_it != n->InputEdgesEnd(); ++edge_it) {
auto it = STOP_GRADIENT_EDGES.find(n->OpType());
if (it != STOP_GRADIENT_EDGES.end() && it->second.count(edge_it->GetDstArgIndex())) {
std::cout << "Skip building gradient for node: " << edge_it->GetNode().Name() << std::endl;
LOGS(logger_, WARNING) << "Skip building gradient for node: " << edge_it->GetNode().Name() ;
continue;
}
@ -126,7 +126,7 @@ Status GradientGraphBuilder::CheckNodeArgsReachable(const NodeSet& reachable_nod
}
Status GradientGraphBuilder::Build() {
auto opt_ret = graph_transformation_mgr_.ApplyTransformers(*graph_, TransformerLevel::Level2, logging::LoggingManager::DefaultLogger());
auto opt_ret = graph_transformation_mgr_.ApplyTransformers(*graph_, TransformerLevel::Level2, logger_);
ORT_RETURN_IF_ERROR(opt_ret);
GraphAugmenter::GraphDefs gradient_graph_defs;
@ -175,7 +175,6 @@ Status GradientGraphBuilder::Build() {
// so far, visited are the minimum node in between
// visited_node_args are the node_args involved
for (auto node : visited) {
//TODO: might not need two sets, the union of them might be enough
unordered_set<string> input_args_need_grad, output_args_need_grad;
@ -190,7 +189,7 @@ Status GradientGraphBuilder::Build() {
}
}
GradientDef node_defs = GetGradientForOp(gradient_graph_config_, node, output_args_need_grad, input_args_need_grad);
GradientDef node_defs = GetGradientForOp(gradient_graph_config_, graph_, node, output_args_need_grad, input_args_need_grad, logger_);
// updates arg name if gradient accumulation is needed
for (auto& op_def : node_defs) {
@ -218,7 +217,7 @@ Status GradientGraphBuilder::Build() {
"AccumulateGrad_" + gradient_pair.first.name)});
}
if (set_gradient_as_graph_output_) {
if (gradient_graph_config_.set_gradients_as_graph_outputs) {
for (auto x_node_arg : x_node_args_) {
gradient_graph_defs.AddGraphOutputs({GradientBuilderBase::GradientName(x_node_arg->Name())});
}

View file

@ -32,7 +32,7 @@ static std::unordered_map<std::string, std::unordered_set<size_t>>
{"Reshape", {1}},
{"Expand", {1}},
{"TrainableDropout", {1}},
{"Dropout", {1}},
{"Dropout", {1, 2}},
{"Slice", {1, 2, 3, 4}},
{"SparseSoftmaxCrossEntropy", {1, 2}},
{"SoftmaxCrossEntropyLoss", {1, 2}},
@ -58,9 +58,9 @@ class GradientGraphBuilder {
GradientGraphBuilder(Graph* graph,
const std::unordered_set<std::string>& y_node_arg_names,
const std::unordered_set<std::string>& x_node_arg_names,
std::string loss_node_arg_name,
const std::string& loss_node_arg_name,
const GradientGraphConfiguration& gradient_graph_config,
const bool set_gradient_as_graph_output = false);
const logging::Logger& logger);
Status Build();
@ -77,6 +77,8 @@ class GradientGraphBuilder {
const GradientGraphConfiguration& gradient_graph_config_;
const logging::Logger& logger_;
onnxruntime::GraphTransformerManager graph_transformation_mgr_{5};
// key: ArgDef for the gradient after accumulation
@ -104,9 +106,6 @@ class GradientGraphBuilder {
@returns OK if all 'x_node_args_' are reachable, else an ONNXRUNTIME INVALID_ARGUMENT status
*/
Status CheckNodeArgsReachable(const NodeSet& reachable_nodes);
// if it is true, set gradient of trainable weight as graph output
const bool set_gradient_as_graph_output_;
};
} // namespace training

View file

@ -8,6 +8,7 @@
#include "core/graph/graph.h"
#include "orttraining/core/graph/graph_augmenter.h"
#include "orttraining/core/graph/gradient_config.h"
#include "orttraining/core/graph/recompute_graph_utils.h"
#include "onnx/defs/attr_proto_util.h"
namespace onnxruntime {
@ -35,12 +36,18 @@ typedef std::vector<NodeDef> GradientDef;
class GradientBuilderBase {
public:
GradientBuilderBase(
const GradientGraphConfiguration& gradient_graph_config,
const Node* node,
const std::unordered_set<std::string>& gradient_inputs,
const std::unordered_set<std::string>& gradient_outputs)
: gradient_graph_config_(gradient_graph_config), node_(node), gradient_inputs_(gradient_inputs), gradient_outputs_(gradient_outputs) {
GradientBuilderBase(const GradientGraphConfiguration& gradient_graph_config,
const Graph* graph,
const Node* node,
const std::unordered_set<std::string>& gradient_inputs,
const std::unordered_set<std::string>& gradient_outputs,
const logging::Logger& logger)
: gradient_graph_config_(gradient_graph_config),
graph_(graph),
node_(node),
gradient_inputs_(gradient_inputs),
gradient_outputs_(gradient_outputs),
logger_(logger) {
unique_node_prefix_ = CreateUniqueNodePrefix();
}
@ -71,6 +78,15 @@ class GradientBuilderBase {
// i-th input of forward op
ArgDef I(const size_t i) const {
ORT_ENFORCE(i < node_->InputDefs().size());
const std::string& name = node_->InputDefs()[i]->Name();
const NodeArg* recomputed_nodearg = graph_->GetNodeArg(graph_utils::RecomputeName(name));
if (recomputed_nodearg) {
const Node* producer_node = graph_->GetProducerNode(name);
LOGS(logger_, INFO) << "Recomputed node arg found for " << producer_node->Name();
return ArgDef(recomputed_nodearg->Name(), recomputed_nodearg->TypeAsProto());
}
return ArgDef(node_->InputDefs()[i]->Name(), node_->InputDefs()[i]->TypeAsProto());
}
@ -207,6 +223,7 @@ class GradientBuilderBase {
}
const GradientGraphConfiguration& gradient_graph_config_;
const Graph* graph_;
const Node* node_;
std::string unique_node_prefix_;
@ -215,6 +232,8 @@ class GradientBuilderBase {
// contains set of input arg names of node_ which requires gradient
std::unordered_set<std::string> gradient_outputs_;
const logging::Logger& logger_;
};
class EmptyGradientBuilder : public GradientBuilderBase {

View file

@ -9,22 +9,25 @@ namespace onnxruntime {
namespace training {
GradientDef GetGradientForOp(const GradientGraphConfiguration& gradient_graph_config,
const Graph* graph,
const Node* node,
const std::unordered_set<std::string>& output_args_need_grad,
const std::unordered_set<std::string>& input_args_need_grad) {
const std::unordered_set<std::string>& input_args_need_grad,
const logging::Logger& logger) {
// REVIEW(bahuang): We don't have a version control for forward to backward op mapping.
// Current SliceGrad(kMSDomain, 1) only supports Slice(kOnnxDomain, 10/11) because adding grad operator for versions
// less than 9 is not supported and for Slice we have Slice-1, Slice-10 and Slice-11.
auto gradient_builder = GradientBuilderRegistry::GetInstance().MakeUnique(node->OpType(),
gradient_graph_config,
graph,
node,
output_args_need_grad,
input_args_need_grad);
input_args_need_grad,
logger);
ORT_ENFORCE(gradient_builder != nullptr,
"The gradient builder has not been registered:", node->OpType());
"The gradient builder has not been registered:", node->OpType(), " for node ", node->Name());
return gradient_builder->GetGradientDefs();
}

View file

@ -13,9 +13,11 @@ namespace training {
typedef GenericRegistry<GradientBuilderBase,
const GradientGraphConfiguration&,
const Node*&, //node
const Graph*&, // graph
const Node*&, // node
const std::unordered_set<std::string>&, // gradient_inputs
const std::unordered_set<std::string>&> // gradient_outputs
const std::unordered_set<std::string>&, // gradient_outputs
const logging::Logger&>
GradientRegistryType;
class GradientBuilderRegistry : public GradientRegistryType {
@ -33,9 +35,11 @@ class GradientBuilderRegistry : public GradientRegistryType {
};
GradientDef GetGradientForOp(const GradientGraphConfiguration& gradient_graph_config,
const Graph* graph,
const Node* node,
const std::unordered_set<std::string>& output_args_need_grad,
const std::unordered_set<std::string>& input_args_need_grad);
const std::unordered_set<std::string>& input_args_need_grad,
const logging::Logger& logger);
} // namespace training
} // namespace onnxruntime

View file

@ -12,6 +12,10 @@ struct GradientGraphConfiguration {
// To save memory, ideally, only one(input vs output) should be stashed rather than both.
// By default, the input based algorithm is used. This flag is to enable the output based algorithm.
bool use_invertible_layernorm_grad{false};
// If set to true, all gradients will be exposed as graph output.
// This should only be used for unit test or debugging purpose.
bool set_gradients_as_graph_outputs{false};
};
} // namespace training

View file

@ -4,9 +4,11 @@
#pragma once
namespace onnxruntime {
namespace training {
namespace graph_utils {
void RegisterGradientSchemas();
inline std::string RecomputeName(const std::string& name) {
return name + "_recompute";
}
} // namespace training
} // namespace graph_utils
} // namespace onnxruntime

View file

@ -38,6 +38,7 @@
#include "orttraining/core/framework/distributed_run_context.h"
#include "orttraining/core/optimizer/bias_dropout_fusion.h"
#include "orttraining/core/optimizer/insert_output_rewriter.h"
#include "orttraining/core/optimizer/localized_recompute.h"
#include "orttraining/core/optimizer/megatron_transformer.h"
#include "orttraining/core/optimizer/nonzero_shape_setter.h"
@ -48,7 +49,7 @@ namespace transformer_utils {
std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
TransformerLevel level,
const std::unordered_set<std::string>& weights_to_train,
bool enable_gelu_approximation,
const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration& config,
const std::vector<std::string>& transformers_and_rules_to_enable) {
std::vector<std::unique_ptr<GraphTransformer>> transformers;
std::unique_ptr<RuleBasedGraphTransformer> rule_transformer = nullptr;
@ -69,6 +70,12 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
rule_transformer->Register(make_unique<CastElimination>());
rule_transformer->Register(make_unique<NonZeroShapeSetter>());
rule_transformer->Register(make_unique<InsertSoftmaxCrossEntropyLossOutput>());
if (config.gelu_checkpoint) {
rule_transformer->Register(make_unique<GeluRecompute>());
}
if (config.attn_dropout_checkpoint) {
rule_transformer->Register(make_unique<AttentionDropoutRecompute>());
}
transformers.emplace_back(onnxruntime::make_unique<GeluFusion>(compatible_eps));
transformers.emplace_back(onnxruntime::make_unique<LayerNormFusion>(compatible_eps));
@ -76,7 +83,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
transformers.emplace_back(onnxruntime::make_unique<BiasGeluFusion>(compatible_eps));
if (enable_gelu_approximation) {
if (config.enable_gelu_approximation) {
transformers.emplace_back(onnxruntime::make_unique<GeluApproximation>(compatible_eps));
}

View file

@ -6,6 +6,7 @@
#include <gsl/gsl>
#include "core/optimizer/graph_transformer.h"
#include "orttraining/core/session/training_session.h"
namespace onnxruntime {
struct FreeDimensionOverride;
@ -17,7 +18,7 @@ namespace transformer_utils {
std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
TransformerLevel level,
const std::unordered_set<std::string>& weights_to_train,
bool enable_gelu_approximation,
const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration& config,
const std::vector<std::string>& rules_and_transformers_to_enable = {});
/** Generates all predefined (both rule-based and non-rule-based) transformers for this level.

View file

@ -0,0 +1,71 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/graph/graph_utils.h"
#include "orttraining/core/graph/recompute_graph_utils.h"
#include "orttraining/core/optimizer/localized_recompute.h"
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
bool GeluRecompute::SatisfyCondition(const Graph& /*graph*/, const Node& node, const logging::Logger& /*logger*/) const {
const auto next_node = node.OutputNodesBegin();
if (next_node != node.OutputNodesEnd() && next_node->OpType() == "MatMul") {
return true;
}
return false;
}
Status GeluRecompute::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& /*logger*/) const {
const auto& output = node.OutputDefs()[0];
auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()),
output->TypeAsProto());
graph.AddNode(node.Name() + "_recompute",
node.OpType(),
"Recompute of " + node.Name(),
{node.MutableInputDefs()[0]},
{&recomputed_output},
&node.GetAttributes(),
node.Domain());
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;
return Status::OK();
}
bool AttentionDropoutRecompute::SatisfyCondition(const Graph& /*graph*/, const Node& node, const logging::Logger& /*logger*/) const {
const auto prev_node = node.InputNodesBegin();
const auto next_node = node.OutputNodesBegin();
if (prev_node != node.InputNodesEnd() && prev_node->OpType() == "Softmax" &&
next_node != node.OutputNodesEnd() && next_node->OpType() == "MatMul") {
return true;
}
return false;
}
Status AttentionDropoutRecompute::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& /*logger*/) const {
const auto& output = node.OutputDefs()[0];
auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()),
output->TypeAsProto());
graph.AddNode(node.Name() + "_recompute",
"DropoutGrad", // Reusing DropoutGrad as the recompute op
"Recompute of " + node.Name(),
{
node.MutableInputDefs()[0], // X
node.MutableOutputDefs()[1], // mask
node.MutableInputDefs()[1], // ratio
node.MutableInputDefs()[2] // training_mode
},
{&recomputed_output},
{},
kMSDomain);
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,50 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/rewrite_rule.h"
namespace onnxruntime {
/**
@Class GeluRecompute
Recompute Gelu/BiasGelu/FastGelu
*/
class GeluRecompute : public RewriteRule {
public:
GeluRecompute() noexcept : RewriteRule("GeluRecompute") {}
std::vector<std::string> TargetOpTypes() const noexcept override {
return {"Gelu", "FastGelu", "BiasGelu"};
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};
/**
@Class AttentionDropoutRecompute
Recompute Dropout in the attention layer
*/
class AttentionDropoutRecompute : public RewriteRule {
public:
AttentionDropoutRecompute() noexcept : RewriteRule("AttentionDropoutRecompute") {}
std::vector<std::string> TargetOpTypes() const noexcept override {
return {"Dropout"};
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -216,7 +216,7 @@ Status TrainingSession::ConfigureForTraining(
}
}
ORT_RETURN_IF_ERROR(ApplyTransformationsToMainGraph(trainable_initializers, config.enable_gelu_approximation));
ORT_RETURN_IF_ERROR(ApplyTransformationsToMainGraph(trainable_initializers, config.graph_transformer_config));
// derive actual set of weights to train
std::unordered_set<std::string> weight_names_to_train =
@ -237,7 +237,7 @@ Status TrainingSession::ConfigureForTraining(
}
ORT_RETURN_IF_ERROR(BuildGradientGraph(
weight_names_to_train, loss_name, config.gradient_graph_config, config.set_gradients_as_graph_outputs));
weight_names_to_train, loss_name, config.gradient_graph_config, *session_logger_));
// transform for mixed precision
std::unordered_map<std::string, NodeArg*> fp32_weight_name_to_fp16_node_arg{};
@ -443,14 +443,14 @@ static Status BuildGradientGraphInternal(Graph& graph,
const std::string& loss_function_output_name,
const std::unordered_set<std::string>& node_arg_names_to_train,
const GradientGraphConfiguration& gradient_graph_config,
const bool set_gradient_as_graph_output = false) {
const logging::Logger& logger) {
// Compute the gradient graph def.
GradientGraphBuilder grad_graph_builder(&graph,
{loss_function_output_name},
node_arg_names_to_train,
loss_function_output_name,
gradient_graph_config,
set_gradient_as_graph_output);
logger);
return grad_graph_builder.Build();
}
@ -488,10 +488,10 @@ static Status AddGradientAccumulationNodes(Graph& graph,
return GraphAugmenter::AugmentGraph(graph, graph_defs);
}
Status TrainingSession::ApplyTransformationsToMainGraph(
const std::unordered_set<std::string>& weights_to_train, bool enable_gelu_approximation) {
Status TrainingSession::ApplyTransformationsToMainGraph(const std::unordered_set<std::string>& weights_to_train,
const TrainingConfiguration::GraphTransformerConfiguration& config) {
GraphTransformerManager graph_transformation_mgr{1};
AddPreTrainingTransformers(graph_transformation_mgr, weights_to_train, enable_gelu_approximation);
AddPreTrainingTransformers(graph_transformation_mgr, weights_to_train, config);
// apply transformers
Graph& graph = model_->MainGraph();
@ -505,13 +505,13 @@ Status TrainingSession::ApplyTransformationsToMainGraph(
// Registers all the pre transformers with transformer manager
void TrainingSession::AddPreTrainingTransformers(GraphTransformerManager& transformer_manager,
const std::unordered_set<std::string>& weights_to_train,
bool enable_gelu_approximation,
const TrainingConfiguration::GraphTransformerConfiguration& config,
TransformerLevel graph_optimization_level,
const std::vector<std::string>& custom_list) {
auto add_transformers = [&](TransformerLevel level) {
// Generate and register transformers for level
auto transformers_to_register = transformer_utils::GeneratePreTrainingTransformers(
level, weights_to_train, enable_gelu_approximation, custom_list);
level, weights_to_train, config, custom_list);
for (auto& entry : transformers_to_register) {
transformer_manager.Register(std::move(entry), level);
}
@ -664,7 +664,7 @@ Status TrainingSession::EnableMixedPrecision(
Status TrainingSession::BuildGradientGraph(const std::unordered_set<std::string>& weights_to_train,
const std::string& loss_function_output_name,
const GradientGraphConfiguration& gradient_graph_config,
const bool set_gradient_as_graph_output) {
const logging::Logger& logger) {
// Fill weights_to_train_ according to weights_to_train
weights_to_train_ = weights_to_train;
gradient_graph_config_ = gradient_graph_config;
@ -673,7 +673,7 @@ Status TrainingSession::BuildGradientGraph(const std::unordered_set<std::string>
loss_function_output_name,
weights_to_train_,
gradient_graph_config_,
set_gradient_as_graph_output));
logger));
return DoPostLoadProcessing(*model_);
}
@ -791,7 +791,7 @@ Status TrainingSession::Save(const PathString& model_uri, TrainingSession::SaveO
actual_loss_name,
weights_to_train_,
gradient_graph_config_,
false));
*session_logger_));
OptimizerOutputKeyMap<std::string> opt_graph_outputs;
std::unordered_set<std::string> opt_state_initializer_names;

View file

@ -46,9 +46,6 @@ class TrainingSession : public InferenceSession {
// Gradient graph configuration
GradientGraphConfiguration gradient_graph_config{};
// Whether to set the gradients as graph outputs.
bool set_gradients_as_graph_outputs{false};
// The number of gradient accumulation steps.
int gradient_accumulation_steps{1};
@ -177,8 +174,16 @@ class TrainingSession : public InferenceSession {
// Otherwise, it returns false.
optional<PipelineConfiguration> pipeline_config{};
// Whether to enable GELU approximation which is faster but produces different results.
bool enable_gelu_approximation{false};
struct GraphTransformerConfiguration {
// Whether to enable GELU approximation which is faster but produces different results.
bool enable_gelu_approximation{false};
// Enable checkpointing of attention dropout to save memory
bool attn_dropout_checkpoint{false};
// Enable checkpointing of Gelu activation output to save memory
bool gelu_checkpoint{false};
};
GraphTransformerConfiguration graph_transformer_config{};
};
/**
@ -390,13 +395,13 @@ class TrainingSession : public InferenceSession {
std::string& backward_waited_event_after_recv_name,
std::string& backward_recorded_event_before_send_name);
common::Status ApplyTransformationsToMainGraph(
const std::unordered_set<std::string>& weights_to_train, bool enable_gelu_approximation);
common::Status ApplyTransformationsToMainGraph(const std::unordered_set<std::string>& weights_to_train,
const TrainingConfiguration::GraphTransformerConfiguration& config);
/** configure initial transformers for training */
void AddPreTrainingTransformers(GraphTransformerManager& transformer_manager,
const std::unordered_set<std::string>& weights_to_train,
bool enable_gelu_approximation,
const TrainingConfiguration::GraphTransformerConfiguration& config,
TransformerLevel graph_optimization_level = TransformerLevel::MaxLevel,
const std::vector<std::string>& custom_list = {});
@ -408,12 +413,11 @@ class TrainingSession : public InferenceSession {
/** Perform auto-diff to add backward graph into the model.
@param weights_to_train a set of weights to be training.
@param loss_function_output_name the name of the loss function's output.
@param set_gradient_as_graph_output if it is true, set gradient of trainable weight as graph output
*/
common::Status BuildGradientGraph(const std::unordered_set<std::string>& weights_to_train,
const std::string& loss_function_output_name,
const GradientGraphConfiguration& gradient_graph_config,
const bool set_gradient_as_graph_output = false);
const logging::Logger& logger);
common::Status BuildAccumulationNode(const std::unordered_set<std::string>& weights_to_train);

View file

@ -166,6 +166,10 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet
cxxopts::value<bool>()->default_value("true"))
("enable_gelu_approximation", "Specify whether to enable GELU approximation.",
cxxopts::value<bool>()->default_value("true"))
("attn_dropout_checkpoint", "Enable checkpointing of attention dropout to save memory.",
cxxopts::value<bool>()->default_value("false"))
("gelu_checkpoint", "Enable checkpointing of Gelu activation output to save memory.",
cxxopts::value<bool>()->default_value("false"))
("use_invertible_layernorm_grad", "Specify whether to use invertible laynorm(dropping the input activation)",
cxxopts::value<bool>()->default_value("false"));
options
@ -453,6 +457,8 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet
}
params.enable_gelu_approximation = flags["enable_gelu_approximation"].as<bool>();
params.attn_dropout_checkpoint = flags["attn_dropout_checkpoint"].as<bool>();
params.gelu_checkpoint = flags["gelu_checkpoint"].as<bool>();
ort_params.log_severity = static_cast<logging::Severity>(flags["ort_log_severity"].as<int>());
ORT_RETURN_IF_NOT(

View file

@ -92,7 +92,7 @@ Status TrainingRunner::Initialize() {
config.immutable_weights = params_.immutable_weights;
config.gradient_graph_config.use_invertible_layernorm_grad = params_.use_invertible_layernorm_grad;
config.set_gradients_as_graph_outputs = false;
config.gradient_graph_config.set_gradients_as_graph_outputs = false;
config.gradient_accumulation_steps = params_.gradient_accumulation_steps;
@ -164,7 +164,15 @@ Status TrainingRunner::Initialize() {
config.pipeline_config = pipe;
}
config.enable_gelu_approximation = params_.enable_gelu_approximation;
// always configure the graph transformer
{
TrainingSession::TrainingConfiguration::GraphTransformerConfiguration gt_config{};
gt_config.enable_gelu_approximation = params_.enable_gelu_approximation;
gt_config.attn_dropout_checkpoint = params_.attn_dropout_checkpoint;
gt_config.gelu_checkpoint = params_.gelu_checkpoint;
config.graph_transformer_config = gt_config;
}
TrainingSession::TrainingConfigurationResult config_result{};

View file

@ -169,7 +169,10 @@ class TrainingRunner {
// Enable GELU approximation
bool enable_gelu_approximation = false;
// Enable checkpointing of attention dropout to save memory
bool attn_dropout_checkpoint = false;
// Enable checkpointing of Gelu activation output to save memory
bool gelu_checkpoint = false;
// Use invertible layernorm grad
bool use_invertible_layernorm_grad = false;
};

View file

@ -14,7 +14,6 @@
#include "orttraining/core/framework/mpi_setup.h"
#include "python/onnxruntime_pybind_mlvalue.h"
namespace onnxruntime {
namespace python {
namespace py = pybind11;
@ -91,8 +90,6 @@ TrainingConfigurationResult ConfigureSessionForTraining(
config.weight_names_to_not_train = parameters.weights_not_to_train;
config.immutable_weights = parameters.immutable_weights;
config.set_gradients_as_graph_outputs = parameters.set_gradients_as_graph_outputs;
config.gradient_accumulation_steps = parameters.gradient_accumulation_steps;
config.distributed_config.world_rank = parameters.world_rank;
@ -146,6 +143,7 @@ TrainingConfigurationResult ConfigureSessionForTraining(
}
config.gradient_graph_config.use_invertible_layernorm_grad = parameters.use_invertible_layernorm_grad;
config.gradient_graph_config.set_gradients_as_graph_outputs = parameters.set_gradients_as_graph_outputs;
training::TrainingSession::TrainingConfigurationResult config_result{};
@ -194,9 +192,9 @@ void addObjectMethodsForTraining(py::module& m) {
py::class_<onnxruntime::training::TrainingSession, InferenceSession> training_session(m, "TrainingSession");
training_session.def(py::init([](const SessionOptions& so) {
Environment& env = get_env();
return onnxruntime::make_unique<onnxruntime::training::TrainingSession>(so, env);
}))
Environment& env = get_env();
return onnxruntime::make_unique<onnxruntime::training::TrainingSession>(so, env);
}))
.def(py::init([]() {
Environment& env = get_env();
return onnxruntime::make_unique<onnxruntime::training::TrainingSession>(GetDefaultCPUSessionOptions(), env);
@ -274,7 +272,6 @@ void addObjectMethodsForTraining(py::module& m) {
.def("is_output_fp32_node", [](onnxruntime::training::TrainingSession* sess, const std::string& output_name) {
return sess->IsGraphOutputFp32Node(output_name);
});
}
} // namespace python
} // namespace onnxruntime

View file

@ -307,12 +307,13 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::InitOpTesterWithGradGraph(
}
training::GradientGraphConfiguration gradient_graph_config;
gradient_graph_config.set_gradients_as_graph_outputs = true;
training::GradientGraphBuilder grad_graph_builder(&graph,
dy_values,
weights_to_train,
"",
gradient_graph_config,
true);
logging::LoggingManager::DefaultLogger());
Status status = grad_graph_builder.Build();
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();

View file

@ -71,12 +71,13 @@ void GradientOpTester::Run(
}
training::GradientGraphConfiguration gradient_graph_config;
gradient_graph_config.set_gradients_as_graph_outputs = true;
training::GradientGraphBuilder grad_graph_builder(&graph,
dy_values,
weights_to_train,
"",
gradient_graph_config,
true);
logging::LoggingManager::DefaultLogger());
status = grad_graph_builder.Build();
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
}