mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Localized Recompute for Gelu and AttentionDropout (#4402)
* Gelu Activation Recompute Draft * Prototype for localized recompute * Introduce localized_recompute rewriter * Command line args for enabling recompute * Add logger to Gradient Graph Builder * use const when possible
This commit is contained in:
parent
0933148fc3
commit
eb0f57f0e4
22 changed files with 256 additions and 75 deletions
|
|
@ -14,9 +14,7 @@ namespace onnxruntime {
|
|||
|
||||
// use value-based compare to make sure transformer output order is consistent
|
||||
struct NodeCompare {
|
||||
bool operator()(const Node* n1, const Node* n2) const {
|
||||
return n1->Index() < n2->Index();
|
||||
}
|
||||
bool operator()(const Node* n1, const Node* n2) const;
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ Status IExecutionFrame::GetOrCreateNodeOutputMLValue(int index, const TensorShap
|
|||
bool IExecutionFrame::TryGetInferredShape(int /*index*/, TensorShape& /*shape*/) const {
|
||||
// By default, there is not information about inferred shape, so this default
|
||||
// implementation always returns false. The derived class of IExecutionFrame
|
||||
// can override this function to provide, for example, activations' shape information.
|
||||
// can override this function to provide, for example, activations' shape information.
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -261,7 +261,7 @@ ExecutionFrame::ExecutionFrame(const std::vector<int>& feed_mlvalue_idxs, const
|
|||
// it's less efficient (the arena will add some overhead to coalesce individual allocations
|
||||
// back into blocks on 'free'), but better than failing completely.
|
||||
try {
|
||||
// static_activation_memory_in_bytes_ is max virtual memory size the planner computes
|
||||
// static_activation_memory_in_bytes_ is max virtual memory size the planner computes
|
||||
auto peak_size = mem_patterns_->patterns[i].PeakSize();
|
||||
// Planning of one memory type should only happen once.
|
||||
buffer = alloc->Alloc(peak_size);
|
||||
|
|
@ -326,7 +326,7 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va
|
|||
|
||||
// Lazily get the allocator only if needed.
|
||||
AllocatorPtr alloc = nullptr;
|
||||
|
||||
|
||||
// create fence if needed
|
||||
if (create_fence) {
|
||||
ORT_ENFORCE(ort_value.Fence() == nullptr);
|
||||
|
|
@ -630,7 +630,6 @@ Status ExecutionFrame::GeneratePatterns(MemoryPatternGroup* out) const {
|
|||
return planner_->GeneratePatterns(out);
|
||||
}
|
||||
|
||||
|
||||
bool ExecutionFrame::TryGetInferredShape(int index, TensorShape& shape) const {
|
||||
// NodeArg index to OrtValue index.
|
||||
int ort_value_idx = GetNodeIdxToMLValueIdx(index);
|
||||
|
|
@ -639,7 +638,7 @@ bool ExecutionFrame::TryGetInferredShape(int index, TensorShape& shape) const {
|
|||
if (ort_value_idx == NodeIndexInfo::kInvalidEntry) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
// Search for inferred shape.
|
||||
// If inferred shape is found, it's assigned to "shape" so that caller can use it.
|
||||
auto it = inferred_shapes_.find(ort_value_idx);
|
||||
|
|
|
|||
|
|
@ -11,6 +11,11 @@
|
|||
#include "core/graph/graph_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
bool NodeCompare::operator()(const Node* n1, const Node* n2) const {
|
||||
return n1->Index() < n2->Index();
|
||||
}
|
||||
|
||||
GraphViewer::GraphViewer(const Graph& graph) {
|
||||
graph_ = &graph;
|
||||
std::vector<const Node*> leaf_nodes;
|
||||
|
|
|
|||
|
|
@ -22,13 +22,13 @@ using namespace common;
|
|||
GradientGraphBuilder::GradientGraphBuilder(Graph* graph,
|
||||
const unordered_set<string>& y_node_arg_names,
|
||||
const unordered_set<string>& x_node_arg_names,
|
||||
string loss_node_arg_name,
|
||||
const std::string& loss_node_arg_name,
|
||||
const GradientGraphConfiguration& gradient_graph_config,
|
||||
const bool set_gradient_as_graph_output)
|
||||
const logging::Logger& logger)
|
||||
: graph_(graph),
|
||||
loss_node_arg_name_(loss_node_arg_name),
|
||||
gradient_graph_config_(gradient_graph_config),
|
||||
set_gradient_as_graph_output_(set_gradient_as_graph_output) {
|
||||
logger_(logger) {
|
||||
auto rule_based_graph_transformer =
|
||||
onnxruntime::make_unique<RuleBasedGraphTransformer>("pre_training_rule_based_graph_transformer");
|
||||
rule_based_graph_transformer->Register(make_unique<InsertMaxPoolOutput>());
|
||||
|
|
@ -91,7 +91,7 @@ NodeSet GradientGraphBuilder::ReverseBFS(const NodeSet& nodes) {
|
|||
for (auto edge_it = n->InputEdgesBegin(); edge_it != n->InputEdgesEnd(); ++edge_it) {
|
||||
auto it = STOP_GRADIENT_EDGES.find(n->OpType());
|
||||
if (it != STOP_GRADIENT_EDGES.end() && it->second.count(edge_it->GetDstArgIndex())) {
|
||||
std::cout << "Skip building gradient for node: " << edge_it->GetNode().Name() << std::endl;
|
||||
LOGS(logger_, WARNING) << "Skip building gradient for node: " << edge_it->GetNode().Name() ;
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -126,7 +126,7 @@ Status GradientGraphBuilder::CheckNodeArgsReachable(const NodeSet& reachable_nod
|
|||
}
|
||||
|
||||
Status GradientGraphBuilder::Build() {
|
||||
auto opt_ret = graph_transformation_mgr_.ApplyTransformers(*graph_, TransformerLevel::Level2, logging::LoggingManager::DefaultLogger());
|
||||
auto opt_ret = graph_transformation_mgr_.ApplyTransformers(*graph_, TransformerLevel::Level2, logger_);
|
||||
ORT_RETURN_IF_ERROR(opt_ret);
|
||||
|
||||
GraphAugmenter::GraphDefs gradient_graph_defs;
|
||||
|
|
@ -175,7 +175,6 @@ Status GradientGraphBuilder::Build() {
|
|||
|
||||
// so far, visited are the minimum node in between
|
||||
// visited_node_args are the node_args involved
|
||||
|
||||
for (auto node : visited) {
|
||||
//TODO: might not need two sets, the union of them might be enough
|
||||
unordered_set<string> input_args_need_grad, output_args_need_grad;
|
||||
|
|
@ -190,7 +189,7 @@ Status GradientGraphBuilder::Build() {
|
|||
}
|
||||
}
|
||||
|
||||
GradientDef node_defs = GetGradientForOp(gradient_graph_config_, node, output_args_need_grad, input_args_need_grad);
|
||||
GradientDef node_defs = GetGradientForOp(gradient_graph_config_, graph_, node, output_args_need_grad, input_args_need_grad, logger_);
|
||||
|
||||
// updates arg name if gradient accumulation is needed
|
||||
for (auto& op_def : node_defs) {
|
||||
|
|
@ -218,7 +217,7 @@ Status GradientGraphBuilder::Build() {
|
|||
"AccumulateGrad_" + gradient_pair.first.name)});
|
||||
}
|
||||
|
||||
if (set_gradient_as_graph_output_) {
|
||||
if (gradient_graph_config_.set_gradients_as_graph_outputs) {
|
||||
for (auto x_node_arg : x_node_args_) {
|
||||
gradient_graph_defs.AddGraphOutputs({GradientBuilderBase::GradientName(x_node_arg->Name())});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ static std::unordered_map<std::string, std::unordered_set<size_t>>
|
|||
{"Reshape", {1}},
|
||||
{"Expand", {1}},
|
||||
{"TrainableDropout", {1}},
|
||||
{"Dropout", {1}},
|
||||
{"Dropout", {1, 2}},
|
||||
{"Slice", {1, 2, 3, 4}},
|
||||
{"SparseSoftmaxCrossEntropy", {1, 2}},
|
||||
{"SoftmaxCrossEntropyLoss", {1, 2}},
|
||||
|
|
@ -58,9 +58,9 @@ class GradientGraphBuilder {
|
|||
GradientGraphBuilder(Graph* graph,
|
||||
const std::unordered_set<std::string>& y_node_arg_names,
|
||||
const std::unordered_set<std::string>& x_node_arg_names,
|
||||
std::string loss_node_arg_name,
|
||||
const std::string& loss_node_arg_name,
|
||||
const GradientGraphConfiguration& gradient_graph_config,
|
||||
const bool set_gradient_as_graph_output = false);
|
||||
const logging::Logger& logger);
|
||||
|
||||
Status Build();
|
||||
|
||||
|
|
@ -77,6 +77,8 @@ class GradientGraphBuilder {
|
|||
|
||||
const GradientGraphConfiguration& gradient_graph_config_;
|
||||
|
||||
const logging::Logger& logger_;
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr_{5};
|
||||
|
||||
// key: ArgDef for the gradient after accumulation
|
||||
|
|
@ -104,9 +106,6 @@ class GradientGraphBuilder {
|
|||
@returns OK if all 'x_node_args_' are reachable, else an ONNXRUNTIME INVALID_ARGUMENT status
|
||||
*/
|
||||
Status CheckNodeArgsReachable(const NodeSet& reachable_nodes);
|
||||
|
||||
// if it is true, set gradient of trainable weight as graph output
|
||||
const bool set_gradient_as_graph_output_;
|
||||
};
|
||||
|
||||
} // namespace training
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include "core/graph/graph.h"
|
||||
#include "orttraining/core/graph/graph_augmenter.h"
|
||||
#include "orttraining/core/graph/gradient_config.h"
|
||||
#include "orttraining/core/graph/recompute_graph_utils.h"
|
||||
#include "onnx/defs/attr_proto_util.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -35,12 +36,18 @@ typedef std::vector<NodeDef> GradientDef;
|
|||
|
||||
class GradientBuilderBase {
|
||||
public:
|
||||
GradientBuilderBase(
|
||||
const GradientGraphConfiguration& gradient_graph_config,
|
||||
const Node* node,
|
||||
const std::unordered_set<std::string>& gradient_inputs,
|
||||
const std::unordered_set<std::string>& gradient_outputs)
|
||||
: gradient_graph_config_(gradient_graph_config), node_(node), gradient_inputs_(gradient_inputs), gradient_outputs_(gradient_outputs) {
|
||||
GradientBuilderBase(const GradientGraphConfiguration& gradient_graph_config,
|
||||
const Graph* graph,
|
||||
const Node* node,
|
||||
const std::unordered_set<std::string>& gradient_inputs,
|
||||
const std::unordered_set<std::string>& gradient_outputs,
|
||||
const logging::Logger& logger)
|
||||
: gradient_graph_config_(gradient_graph_config),
|
||||
graph_(graph),
|
||||
node_(node),
|
||||
gradient_inputs_(gradient_inputs),
|
||||
gradient_outputs_(gradient_outputs),
|
||||
logger_(logger) {
|
||||
unique_node_prefix_ = CreateUniqueNodePrefix();
|
||||
}
|
||||
|
||||
|
|
@ -71,6 +78,15 @@ class GradientBuilderBase {
|
|||
// i-th input of forward op
|
||||
ArgDef I(const size_t i) const {
|
||||
ORT_ENFORCE(i < node_->InputDefs().size());
|
||||
|
||||
const std::string& name = node_->InputDefs()[i]->Name();
|
||||
const NodeArg* recomputed_nodearg = graph_->GetNodeArg(graph_utils::RecomputeName(name));
|
||||
if (recomputed_nodearg) {
|
||||
const Node* producer_node = graph_->GetProducerNode(name);
|
||||
LOGS(logger_, INFO) << "Recomputed node arg found for " << producer_node->Name();
|
||||
return ArgDef(recomputed_nodearg->Name(), recomputed_nodearg->TypeAsProto());
|
||||
}
|
||||
|
||||
return ArgDef(node_->InputDefs()[i]->Name(), node_->InputDefs()[i]->TypeAsProto());
|
||||
}
|
||||
|
||||
|
|
@ -207,6 +223,7 @@ class GradientBuilderBase {
|
|||
}
|
||||
|
||||
const GradientGraphConfiguration& gradient_graph_config_;
|
||||
const Graph* graph_;
|
||||
const Node* node_;
|
||||
std::string unique_node_prefix_;
|
||||
|
||||
|
|
@ -215,6 +232,8 @@ class GradientBuilderBase {
|
|||
|
||||
// contains set of input arg names of node_ which requires gradient
|
||||
std::unordered_set<std::string> gradient_outputs_;
|
||||
|
||||
const logging::Logger& logger_;
|
||||
};
|
||||
|
||||
class EmptyGradientBuilder : public GradientBuilderBase {
|
||||
|
|
|
|||
|
|
@ -9,22 +9,25 @@ namespace onnxruntime {
|
|||
namespace training {
|
||||
|
||||
GradientDef GetGradientForOp(const GradientGraphConfiguration& gradient_graph_config,
|
||||
const Graph* graph,
|
||||
const Node* node,
|
||||
const std::unordered_set<std::string>& output_args_need_grad,
|
||||
const std::unordered_set<std::string>& input_args_need_grad) {
|
||||
|
||||
const std::unordered_set<std::string>& input_args_need_grad,
|
||||
const logging::Logger& logger) {
|
||||
// REVIEW(bahuang): We don't have a version control for forward to backward op mapping.
|
||||
// Current SliceGrad(kMSDomain, 1) only supports Slice(kOnnxDomain, 10/11) because adding grad operator for versions
|
||||
// less than 9 is not supported and for Slice we have Slice-1, Slice-10 and Slice-11.
|
||||
|
||||
auto gradient_builder = GradientBuilderRegistry::GetInstance().MakeUnique(node->OpType(),
|
||||
gradient_graph_config,
|
||||
graph,
|
||||
node,
|
||||
output_args_need_grad,
|
||||
input_args_need_grad);
|
||||
input_args_need_grad,
|
||||
logger);
|
||||
|
||||
ORT_ENFORCE(gradient_builder != nullptr,
|
||||
"The gradient builder has not been registered:", node->OpType());
|
||||
"The gradient builder has not been registered:", node->OpType(), " for node ", node->Name());
|
||||
|
||||
return gradient_builder->GetGradientDefs();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,9 +13,11 @@ namespace training {
|
|||
|
||||
typedef GenericRegistry<GradientBuilderBase,
|
||||
const GradientGraphConfiguration&,
|
||||
const Node*&, //node
|
||||
const Graph*&, // graph
|
||||
const Node*&, // node
|
||||
const std::unordered_set<std::string>&, // gradient_inputs
|
||||
const std::unordered_set<std::string>&> // gradient_outputs
|
||||
const std::unordered_set<std::string>&, // gradient_outputs
|
||||
const logging::Logger&>
|
||||
GradientRegistryType;
|
||||
|
||||
class GradientBuilderRegistry : public GradientRegistryType {
|
||||
|
|
@ -33,9 +35,11 @@ class GradientBuilderRegistry : public GradientRegistryType {
|
|||
};
|
||||
|
||||
GradientDef GetGradientForOp(const GradientGraphConfiguration& gradient_graph_config,
|
||||
const Graph* graph,
|
||||
const Node* node,
|
||||
const std::unordered_set<std::string>& output_args_need_grad,
|
||||
const std::unordered_set<std::string>& input_args_need_grad);
|
||||
const std::unordered_set<std::string>& input_args_need_grad,
|
||||
const logging::Logger& logger);
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -12,6 +12,10 @@ struct GradientGraphConfiguration {
|
|||
// To save memory, ideally, only one(input vs output) should be stashed rather than both.
|
||||
// By default, the input based algorithm is used. This flag is to enable the output based algorithm.
|
||||
bool use_invertible_layernorm_grad{false};
|
||||
|
||||
// If set to true, all gradients will be exposed as graph output.
|
||||
// This should only be used for unit test or debugging purpose.
|
||||
bool set_gradients_as_graph_outputs{false};
|
||||
};
|
||||
|
||||
} // namespace training
|
||||
|
|
|
|||
|
|
@ -4,9 +4,11 @@
|
|||
#pragma once
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
namespace graph_utils {
|
||||
|
||||
void RegisterGradientSchemas();
|
||||
inline std::string RecomputeName(const std::string& name) {
|
||||
return name + "_recompute";
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
} // namespace graph_utils
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -38,6 +38,7 @@
|
|||
#include "orttraining/core/framework/distributed_run_context.h"
|
||||
#include "orttraining/core/optimizer/bias_dropout_fusion.h"
|
||||
#include "orttraining/core/optimizer/insert_output_rewriter.h"
|
||||
#include "orttraining/core/optimizer/localized_recompute.h"
|
||||
#include "orttraining/core/optimizer/megatron_transformer.h"
|
||||
#include "orttraining/core/optimizer/nonzero_shape_setter.h"
|
||||
|
||||
|
|
@ -48,7 +49,7 @@ namespace transformer_utils {
|
|||
std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
||||
TransformerLevel level,
|
||||
const std::unordered_set<std::string>& weights_to_train,
|
||||
bool enable_gelu_approximation,
|
||||
const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration& config,
|
||||
const std::vector<std::string>& transformers_and_rules_to_enable) {
|
||||
std::vector<std::unique_ptr<GraphTransformer>> transformers;
|
||||
std::unique_ptr<RuleBasedGraphTransformer> rule_transformer = nullptr;
|
||||
|
|
@ -69,6 +70,12 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
rule_transformer->Register(make_unique<CastElimination>());
|
||||
rule_transformer->Register(make_unique<NonZeroShapeSetter>());
|
||||
rule_transformer->Register(make_unique<InsertSoftmaxCrossEntropyLossOutput>());
|
||||
if (config.gelu_checkpoint) {
|
||||
rule_transformer->Register(make_unique<GeluRecompute>());
|
||||
}
|
||||
if (config.attn_dropout_checkpoint) {
|
||||
rule_transformer->Register(make_unique<AttentionDropoutRecompute>());
|
||||
}
|
||||
|
||||
transformers.emplace_back(onnxruntime::make_unique<GeluFusion>(compatible_eps));
|
||||
transformers.emplace_back(onnxruntime::make_unique<LayerNormFusion>(compatible_eps));
|
||||
|
|
@ -76,7 +83,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
|
||||
transformers.emplace_back(onnxruntime::make_unique<BiasGeluFusion>(compatible_eps));
|
||||
|
||||
if (enable_gelu_approximation) {
|
||||
if (config.enable_gelu_approximation) {
|
||||
transformers.emplace_back(onnxruntime::make_unique<GeluApproximation>(compatible_eps));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include <gsl/gsl>
|
||||
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
#include "orttraining/core/session/training_session.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
struct FreeDimensionOverride;
|
||||
|
|
@ -17,7 +18,7 @@ namespace transformer_utils {
|
|||
std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
||||
TransformerLevel level,
|
||||
const std::unordered_set<std::string>& weights_to_train,
|
||||
bool enable_gelu_approximation,
|
||||
const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration& config,
|
||||
const std::vector<std::string>& rules_and_transformers_to_enable = {});
|
||||
|
||||
/** Generates all predefined (both rule-based and non-rule-based) transformers for this level.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,71 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "orttraining/core/graph/recompute_graph_utils.h"
|
||||
#include "orttraining/core/optimizer/localized_recompute.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
bool GeluRecompute::SatisfyCondition(const Graph& /*graph*/, const Node& node, const logging::Logger& /*logger*/) const {
|
||||
const auto next_node = node.OutputNodesBegin();
|
||||
if (next_node != node.OutputNodesEnd() && next_node->OpType() == "MatMul") {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Status GeluRecompute::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& /*logger*/) const {
|
||||
const auto& output = node.OutputDefs()[0];
|
||||
|
||||
auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()),
|
||||
output->TypeAsProto());
|
||||
|
||||
graph.AddNode(node.Name() + "_recompute",
|
||||
node.OpType(),
|
||||
"Recompute of " + node.Name(),
|
||||
{node.MutableInputDefs()[0]},
|
||||
{&recomputed_output},
|
||||
&node.GetAttributes(),
|
||||
node.Domain());
|
||||
|
||||
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool AttentionDropoutRecompute::SatisfyCondition(const Graph& /*graph*/, const Node& node, const logging::Logger& /*logger*/) const {
|
||||
const auto prev_node = node.InputNodesBegin();
|
||||
const auto next_node = node.OutputNodesBegin();
|
||||
if (prev_node != node.InputNodesEnd() && prev_node->OpType() == "Softmax" &&
|
||||
next_node != node.OutputNodesEnd() && next_node->OpType() == "MatMul") {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Status AttentionDropoutRecompute::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& /*logger*/) const {
|
||||
const auto& output = node.OutputDefs()[0];
|
||||
|
||||
auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()),
|
||||
output->TypeAsProto());
|
||||
|
||||
graph.AddNode(node.Name() + "_recompute",
|
||||
"DropoutGrad", // Reusing DropoutGrad as the recompute op
|
||||
"Recompute of " + node.Name(),
|
||||
{
|
||||
node.MutableInputDefs()[0], // X
|
||||
node.MutableOutputDefs()[1], // mask
|
||||
node.MutableInputDefs()[1], // ratio
|
||||
node.MutableInputDefs()[2] // training_mode
|
||||
},
|
||||
{&recomputed_output},
|
||||
{},
|
||||
kMSDomain);
|
||||
|
||||
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
50
orttraining/orttraining/core/optimizer/localized_recompute.h
Normal file
50
orttraining/orttraining/core/optimizer/localized_recompute.h
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/rewrite_rule.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
@Class GeluRecompute
|
||||
|
||||
Recompute Gelu/BiasGelu/FastGelu
|
||||
|
||||
*/
|
||||
class GeluRecompute : public RewriteRule {
|
||||
public:
|
||||
GeluRecompute() noexcept : RewriteRule("GeluRecompute") {}
|
||||
|
||||
std::vector<std::string> TargetOpTypes() const noexcept override {
|
||||
return {"Gelu", "FastGelu", "BiasGelu"};
|
||||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
/**
|
||||
@Class AttentionDropoutRecompute
|
||||
|
||||
Recompute Dropout in the attention layer
|
||||
|
||||
*/
|
||||
class AttentionDropoutRecompute : public RewriteRule {
|
||||
public:
|
||||
AttentionDropoutRecompute() noexcept : RewriteRule("AttentionDropoutRecompute") {}
|
||||
|
||||
std::vector<std::string> TargetOpTypes() const noexcept override {
|
||||
return {"Dropout"};
|
||||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -216,7 +216,7 @@ Status TrainingSession::ConfigureForTraining(
|
|||
}
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(ApplyTransformationsToMainGraph(trainable_initializers, config.enable_gelu_approximation));
|
||||
ORT_RETURN_IF_ERROR(ApplyTransformationsToMainGraph(trainable_initializers, config.graph_transformer_config));
|
||||
|
||||
// derive actual set of weights to train
|
||||
std::unordered_set<std::string> weight_names_to_train =
|
||||
|
|
@ -237,7 +237,7 @@ Status TrainingSession::ConfigureForTraining(
|
|||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(BuildGradientGraph(
|
||||
weight_names_to_train, loss_name, config.gradient_graph_config, config.set_gradients_as_graph_outputs));
|
||||
weight_names_to_train, loss_name, config.gradient_graph_config, *session_logger_));
|
||||
|
||||
// transform for mixed precision
|
||||
std::unordered_map<std::string, NodeArg*> fp32_weight_name_to_fp16_node_arg{};
|
||||
|
|
@ -443,14 +443,14 @@ static Status BuildGradientGraphInternal(Graph& graph,
|
|||
const std::string& loss_function_output_name,
|
||||
const std::unordered_set<std::string>& node_arg_names_to_train,
|
||||
const GradientGraphConfiguration& gradient_graph_config,
|
||||
const bool set_gradient_as_graph_output = false) {
|
||||
const logging::Logger& logger) {
|
||||
// Compute the gradient graph def.
|
||||
GradientGraphBuilder grad_graph_builder(&graph,
|
||||
{loss_function_output_name},
|
||||
node_arg_names_to_train,
|
||||
loss_function_output_name,
|
||||
gradient_graph_config,
|
||||
set_gradient_as_graph_output);
|
||||
logger);
|
||||
return grad_graph_builder.Build();
|
||||
}
|
||||
|
||||
|
|
@ -488,10 +488,10 @@ static Status AddGradientAccumulationNodes(Graph& graph,
|
|||
return GraphAugmenter::AugmentGraph(graph, graph_defs);
|
||||
}
|
||||
|
||||
Status TrainingSession::ApplyTransformationsToMainGraph(
|
||||
const std::unordered_set<std::string>& weights_to_train, bool enable_gelu_approximation) {
|
||||
Status TrainingSession::ApplyTransformationsToMainGraph(const std::unordered_set<std::string>& weights_to_train,
|
||||
const TrainingConfiguration::GraphTransformerConfiguration& config) {
|
||||
GraphTransformerManager graph_transformation_mgr{1};
|
||||
AddPreTrainingTransformers(graph_transformation_mgr, weights_to_train, enable_gelu_approximation);
|
||||
AddPreTrainingTransformers(graph_transformation_mgr, weights_to_train, config);
|
||||
|
||||
// apply transformers
|
||||
Graph& graph = model_->MainGraph();
|
||||
|
|
@ -505,13 +505,13 @@ Status TrainingSession::ApplyTransformationsToMainGraph(
|
|||
// Registers all the pre transformers with transformer manager
|
||||
void TrainingSession::AddPreTrainingTransformers(GraphTransformerManager& transformer_manager,
|
||||
const std::unordered_set<std::string>& weights_to_train,
|
||||
bool enable_gelu_approximation,
|
||||
const TrainingConfiguration::GraphTransformerConfiguration& config,
|
||||
TransformerLevel graph_optimization_level,
|
||||
const std::vector<std::string>& custom_list) {
|
||||
auto add_transformers = [&](TransformerLevel level) {
|
||||
// Generate and register transformers for level
|
||||
auto transformers_to_register = transformer_utils::GeneratePreTrainingTransformers(
|
||||
level, weights_to_train, enable_gelu_approximation, custom_list);
|
||||
level, weights_to_train, config, custom_list);
|
||||
for (auto& entry : transformers_to_register) {
|
||||
transformer_manager.Register(std::move(entry), level);
|
||||
}
|
||||
|
|
@ -664,7 +664,7 @@ Status TrainingSession::EnableMixedPrecision(
|
|||
Status TrainingSession::BuildGradientGraph(const std::unordered_set<std::string>& weights_to_train,
|
||||
const std::string& loss_function_output_name,
|
||||
const GradientGraphConfiguration& gradient_graph_config,
|
||||
const bool set_gradient_as_graph_output) {
|
||||
const logging::Logger& logger) {
|
||||
// Fill weights_to_train_ according to weights_to_train
|
||||
weights_to_train_ = weights_to_train;
|
||||
gradient_graph_config_ = gradient_graph_config;
|
||||
|
|
@ -673,7 +673,7 @@ Status TrainingSession::BuildGradientGraph(const std::unordered_set<std::string>
|
|||
loss_function_output_name,
|
||||
weights_to_train_,
|
||||
gradient_graph_config_,
|
||||
set_gradient_as_graph_output));
|
||||
logger));
|
||||
|
||||
return DoPostLoadProcessing(*model_);
|
||||
}
|
||||
|
|
@ -791,7 +791,7 @@ Status TrainingSession::Save(const PathString& model_uri, TrainingSession::SaveO
|
|||
actual_loss_name,
|
||||
weights_to_train_,
|
||||
gradient_graph_config_,
|
||||
false));
|
||||
*session_logger_));
|
||||
|
||||
OptimizerOutputKeyMap<std::string> opt_graph_outputs;
|
||||
std::unordered_set<std::string> opt_state_initializer_names;
|
||||
|
|
|
|||
|
|
@ -46,9 +46,6 @@ class TrainingSession : public InferenceSession {
|
|||
// Gradient graph configuration
|
||||
GradientGraphConfiguration gradient_graph_config{};
|
||||
|
||||
// Whether to set the gradients as graph outputs.
|
||||
bool set_gradients_as_graph_outputs{false};
|
||||
|
||||
// The number of gradient accumulation steps.
|
||||
int gradient_accumulation_steps{1};
|
||||
|
||||
|
|
@ -177,8 +174,16 @@ class TrainingSession : public InferenceSession {
|
|||
// Otherwise, it returns false.
|
||||
optional<PipelineConfiguration> pipeline_config{};
|
||||
|
||||
// Whether to enable GELU approximation which is faster but produces different results.
|
||||
bool enable_gelu_approximation{false};
|
||||
struct GraphTransformerConfiguration {
|
||||
// Whether to enable GELU approximation which is faster but produces different results.
|
||||
bool enable_gelu_approximation{false};
|
||||
// Enable checkpointing of attention dropout to save memory
|
||||
bool attn_dropout_checkpoint{false};
|
||||
// Enable checkpointing of Gelu activation output to save memory
|
||||
bool gelu_checkpoint{false};
|
||||
};
|
||||
|
||||
GraphTransformerConfiguration graph_transformer_config{};
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -390,13 +395,13 @@ class TrainingSession : public InferenceSession {
|
|||
std::string& backward_waited_event_after_recv_name,
|
||||
std::string& backward_recorded_event_before_send_name);
|
||||
|
||||
common::Status ApplyTransformationsToMainGraph(
|
||||
const std::unordered_set<std::string>& weights_to_train, bool enable_gelu_approximation);
|
||||
common::Status ApplyTransformationsToMainGraph(const std::unordered_set<std::string>& weights_to_train,
|
||||
const TrainingConfiguration::GraphTransformerConfiguration& config);
|
||||
|
||||
/** configure initial transformers for training */
|
||||
void AddPreTrainingTransformers(GraphTransformerManager& transformer_manager,
|
||||
const std::unordered_set<std::string>& weights_to_train,
|
||||
bool enable_gelu_approximation,
|
||||
const TrainingConfiguration::GraphTransformerConfiguration& config,
|
||||
TransformerLevel graph_optimization_level = TransformerLevel::MaxLevel,
|
||||
const std::vector<std::string>& custom_list = {});
|
||||
|
||||
|
|
@ -408,12 +413,11 @@ class TrainingSession : public InferenceSession {
|
|||
/** Perform auto-diff to add backward graph into the model.
|
||||
@param weights_to_train a set of weights to be training.
|
||||
@param loss_function_output_name the name of the loss function's output.
|
||||
@param set_gradient_as_graph_output if it is true, set gradient of trainable weight as graph output
|
||||
*/
|
||||
common::Status BuildGradientGraph(const std::unordered_set<std::string>& weights_to_train,
|
||||
const std::string& loss_function_output_name,
|
||||
const GradientGraphConfiguration& gradient_graph_config,
|
||||
const bool set_gradient_as_graph_output = false);
|
||||
const logging::Logger& logger);
|
||||
|
||||
common::Status BuildAccumulationNode(const std::unordered_set<std::string>& weights_to_train);
|
||||
|
||||
|
|
|
|||
|
|
@ -166,6 +166,10 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet
|
|||
cxxopts::value<bool>()->default_value("true"))
|
||||
("enable_gelu_approximation", "Specify whether to enable GELU approximation.",
|
||||
cxxopts::value<bool>()->default_value("true"))
|
||||
("attn_dropout_checkpoint", "Enable checkpointing of attention dropout to save memory.",
|
||||
cxxopts::value<bool>()->default_value("false"))
|
||||
("gelu_checkpoint", "Enable checkpointing of Gelu activation output to save memory.",
|
||||
cxxopts::value<bool>()->default_value("false"))
|
||||
("use_invertible_layernorm_grad", "Specify whether to use invertible laynorm(dropping the input activation)",
|
||||
cxxopts::value<bool>()->default_value("false"));
|
||||
options
|
||||
|
|
@ -453,6 +457,8 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet
|
|||
}
|
||||
|
||||
params.enable_gelu_approximation = flags["enable_gelu_approximation"].as<bool>();
|
||||
params.attn_dropout_checkpoint = flags["attn_dropout_checkpoint"].as<bool>();
|
||||
params.gelu_checkpoint = flags["gelu_checkpoint"].as<bool>();
|
||||
|
||||
ort_params.log_severity = static_cast<logging::Severity>(flags["ort_log_severity"].as<int>());
|
||||
ORT_RETURN_IF_NOT(
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ Status TrainingRunner::Initialize() {
|
|||
config.immutable_weights = params_.immutable_weights;
|
||||
|
||||
config.gradient_graph_config.use_invertible_layernorm_grad = params_.use_invertible_layernorm_grad;
|
||||
config.set_gradients_as_graph_outputs = false;
|
||||
config.gradient_graph_config.set_gradients_as_graph_outputs = false;
|
||||
|
||||
config.gradient_accumulation_steps = params_.gradient_accumulation_steps;
|
||||
|
||||
|
|
@ -164,7 +164,15 @@ Status TrainingRunner::Initialize() {
|
|||
config.pipeline_config = pipe;
|
||||
}
|
||||
|
||||
config.enable_gelu_approximation = params_.enable_gelu_approximation;
|
||||
// always configure the graph transformer
|
||||
{
|
||||
TrainingSession::TrainingConfiguration::GraphTransformerConfiguration gt_config{};
|
||||
gt_config.enable_gelu_approximation = params_.enable_gelu_approximation;
|
||||
gt_config.attn_dropout_checkpoint = params_.attn_dropout_checkpoint;
|
||||
gt_config.gelu_checkpoint = params_.gelu_checkpoint;
|
||||
|
||||
config.graph_transformer_config = gt_config;
|
||||
}
|
||||
|
||||
TrainingSession::TrainingConfigurationResult config_result{};
|
||||
|
||||
|
|
|
|||
|
|
@ -169,7 +169,10 @@ class TrainingRunner {
|
|||
|
||||
// Enable GELU approximation
|
||||
bool enable_gelu_approximation = false;
|
||||
|
||||
// Enable checkpointing of attention dropout to save memory
|
||||
bool attn_dropout_checkpoint = false;
|
||||
// Enable checkpointing of Gelu activation output to save memory
|
||||
bool gelu_checkpoint = false;
|
||||
// Use invertible layernorm grad
|
||||
bool use_invertible_layernorm_grad = false;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@
|
|||
#include "orttraining/core/framework/mpi_setup.h"
|
||||
#include "python/onnxruntime_pybind_mlvalue.h"
|
||||
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace python {
|
||||
namespace py = pybind11;
|
||||
|
|
@ -91,8 +90,6 @@ TrainingConfigurationResult ConfigureSessionForTraining(
|
|||
config.weight_names_to_not_train = parameters.weights_not_to_train;
|
||||
config.immutable_weights = parameters.immutable_weights;
|
||||
|
||||
config.set_gradients_as_graph_outputs = parameters.set_gradients_as_graph_outputs;
|
||||
|
||||
config.gradient_accumulation_steps = parameters.gradient_accumulation_steps;
|
||||
|
||||
config.distributed_config.world_rank = parameters.world_rank;
|
||||
|
|
@ -146,6 +143,7 @@ TrainingConfigurationResult ConfigureSessionForTraining(
|
|||
}
|
||||
|
||||
config.gradient_graph_config.use_invertible_layernorm_grad = parameters.use_invertible_layernorm_grad;
|
||||
config.gradient_graph_config.set_gradients_as_graph_outputs = parameters.set_gradients_as_graph_outputs;
|
||||
|
||||
training::TrainingSession::TrainingConfigurationResult config_result{};
|
||||
|
||||
|
|
@ -194,9 +192,9 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
|
||||
py::class_<onnxruntime::training::TrainingSession, InferenceSession> training_session(m, "TrainingSession");
|
||||
training_session.def(py::init([](const SessionOptions& so) {
|
||||
Environment& env = get_env();
|
||||
return onnxruntime::make_unique<onnxruntime::training::TrainingSession>(so, env);
|
||||
}))
|
||||
Environment& env = get_env();
|
||||
return onnxruntime::make_unique<onnxruntime::training::TrainingSession>(so, env);
|
||||
}))
|
||||
.def(py::init([]() {
|
||||
Environment& env = get_env();
|
||||
return onnxruntime::make_unique<onnxruntime::training::TrainingSession>(GetDefaultCPUSessionOptions(), env);
|
||||
|
|
@ -274,7 +272,6 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
.def("is_output_fp32_node", [](onnxruntime::training::TrainingSession* sess, const std::string& output_name) {
|
||||
return sess->IsGraphOutputFp32Node(output_name);
|
||||
});
|
||||
|
||||
}
|
||||
} // namespace python
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -307,12 +307,13 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::InitOpTesterWithGradGraph(
|
|||
}
|
||||
|
||||
training::GradientGraphConfiguration gradient_graph_config;
|
||||
gradient_graph_config.set_gradients_as_graph_outputs = true;
|
||||
training::GradientGraphBuilder grad_graph_builder(&graph,
|
||||
dy_values,
|
||||
weights_to_train,
|
||||
"",
|
||||
gradient_graph_config,
|
||||
true);
|
||||
logging::LoggingManager::DefaultLogger());
|
||||
Status status = grad_graph_builder.Build();
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
||||
|
|
|
|||
|
|
@ -71,12 +71,13 @@ void GradientOpTester::Run(
|
|||
}
|
||||
|
||||
training::GradientGraphConfiguration gradient_graph_config;
|
||||
gradient_graph_config.set_gradients_as_graph_outputs = true;
|
||||
training::GradientGraphBuilder grad_graph_builder(&graph,
|
||||
dy_values,
|
||||
weights_to_train,
|
||||
"",
|
||||
gradient_graph_config,
|
||||
true);
|
||||
logging::LoggingManager::DefaultLogger());
|
||||
status = grad_graph_builder.Build();
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue