mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
Fix convergence for dolly+stage3 training (#17685)
### Fix convergence for dolly+stage3 training In [ZeROOffloadSubscriber](216214b7d3/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py (L359C7-L359C28)), we defined some PythonOp, taking input and returning it inplace, for example:216214b7d3/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py (L223C20-L223C20). While it is possible, when ORT runs such a PythonOp, once it completes, it will release the input OrtValue, triggered the data erasing or overridden. But the PythonOp's returned value OrtValue are still pointing to that address, reading or writting on that may introduce a wrong result or even undefined behaviors. ``` /bert_ort/pengwa/py38/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_custom_autograd_function_runner.py:28: UserWarning: .rank-0: onnxruntime.training.utils.hooks._zero_offload_subscriber.ORTZeROOffloadPreForwardFunction->Backward: ONNX Op attribute 'tensor_reuse_map' doesn't indicate 8-th output is reusing any input, but detected inplace_map indicates it is reusing some input index. A clone will be done before returning to ORT, to align with ORT's NO Buffer reuse plan. Please update inplace_map explicitly to avoid such a copy. warnings.warn(f".rank-{get_rank()}: {message}") 0%|▏ | 1/1000 [00:04<1:15:08, 4.51s/it][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,023 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 14.1406, 'learning_rate': 0, 'epoch': 0.0} 0%|▏ | 1/1000 [00:04<1:15:08, 4.51s/it]Invalidate trace cache @ step 5: expected module 6, but got module 7 0%|▍ | 2/1000 [00:04<31:53, 1.92s/it][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,124 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 0%|▋ | 3/1000 [00:04<18:05, 1.09s/it][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,227 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 0%|▋ | 3/1000 [00:04<18:05, 1.09s/it][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,326 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 0%|█▏ | 5/1000 [00:04<08:44, 1.90it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,419 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 0%|█▏ | 5/1000 [00:04<08:44, 1.90it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,505 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 1%|█▋ | 7/1000 [00:05<05:28, 3.02it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,597 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 1%|█▋ | 7/1000 [00:05<05:28, 3.02it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,690 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 1%|██▏ | 9/1000 [00:05<03:57, 4.17it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,791 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 1%|██▏ | 9/1000 [00:05<03:57, 4.17it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,889 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 1%|██▋ | 11/1000 [00:05<03:06, 5.32it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,981 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 1%|██▋ | 11/1000 [00:05<03:06, 5.32it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,073 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01} 1%|███▏ | 13/1000 [00:05<02:33, 6.42it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,166 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01} 1%|███▏ | 13/1000 [00:05<02:33, 6.42it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,256 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01} 2%|███▌ | 15/1000 [00:05<02:12, 7.43it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,348 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01} 2%|███▌ | 15/1000 [00:05<02:12, 7.43it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,439 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01} 2%|████ | 17/1000 [00:06<01:59, 8.22it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,535 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01} 2%|████ | 17/1000 [00:06<01:59, 8.22it/s]Traceback (most recent call last): File "examples/onnxruntime/training/language-modeling/run_clm.py", line 600, in <module> main() File "examples/onnxruntime/training/language-modeling/run_clm.py", line 548, in main train_result = trainer.train(resume_from_checkpoint=checkpoint) File "/bert_ort/pengwa/optimum/optimum/onnxruntime/trainer.py", line 457, in train return inner_training_loop( File "/bert_ort/pengwa/optimum/optimum/onnxruntime/trainer.py", line 781, in _inner_training_loop self.deepspeed.step() File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/engine.py", line 2084, in step self._take_model_step(lr_kwargs) File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/engine.py", line 1990, in _take_model_step self.optimizer.step() File "/bert_ort/pengwa/deepspeed/deepspeed/utils/nvtx.py", line 15, in wrapped_fn ret_val = func(*args, **kwargs) File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/zero/stage3.py", line 1854, in step if self._overflow_check_and_loss_scale_update(): File "/bert_ort/pengwa/deepspeed/deepspeed/utils/nvtx.py", line 15, in wrapped_fn ret_val = func(*args, **kwargs) File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/zero/stage3.py", line 1788, in _overflow_check_and_loss_scale_update self._update_scale(self.overflow) File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/zero/stage3.py", line 2132, in _update_scale self.loss_scaler.update_scale(has_overflow) File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/fp16/loss_scaler.py", line 175, in update_scale raise Exception( Exception: Current loss scale already at minimum - cannot decrease scale anymore. Exiting run. 2%|████ | 17/1000 [00:06<06:07, 2.67it/s] [2023-09-25 08:30:51,075] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 1065120) of binary: /bert_ort/pengwa/py38/bin/python Traceback (most recent call last): File "/bert_ort/pengwa/py38/bin/torchrun", line 8, in <module> sys.exit(main()) File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper return f(*args, **kwargs) File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/run.py", line 806, in main run(args) File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/run.py", line 797, in run elastic_launch( File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 134, in __call__ return launch_agent(self._config, self._entrypoint, list(args)) File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError: ============================================================ examples/onnxruntime/training/language-modeling/run_clm.py FAILED ------------------------------------------------------------ Failures: <NO_OTHER_FAILURES> ------------------------------------------------------------ Root Cause (first observed failure): [0]: time : 2023-09-25_08:30:51 host : orttrainingdev10.internal.cloudapp.net rank : 0 (local_rank: 0) exitcode : 1 (pid: 1065120) error_file: <N/A> traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html ============================================================ (/bert_ort/pengwa/py38) pengwa@microsoft.com@orttrainingdev10:/bert_ort/pengwa/optim ``` ## The Fix For those output that are reusing input, but ORT is not aware of, we detected on the fly (the first iteration, by checking the output tensor addresses with input tensor addresses) , then do implicit copy before set it as PythonOp's output tensors. With this fix: (left: PyTorch, right: ORT) 
This commit is contained in:
parent
891b50cc68
commit
7201def4ec
11 changed files with 659 additions and 157 deletions
|
|
@ -10,9 +10,7 @@
|
|||
#include "orttraining/core/framework/torch/gil.h"
|
||||
#include "core/platform/env.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace language_interop_ops {
|
||||
namespace torch {
|
||||
namespace onnxruntime::language_interop_ops::torch {
|
||||
|
||||
void PythonObjectDeleter(PyObject* ptr) { Py_XDECREF(ptr); };
|
||||
|
||||
|
|
@ -130,6 +128,18 @@ PyObject* CreateRequiresGradFlags(
|
|||
return flags;
|
||||
}
|
||||
|
||||
PyObject* CreateInplaceMap(
|
||||
const std::vector<int64_t>& inplace_map) {
|
||||
PyObject* inplace_map_obj = Ort_PyList_New(inplace_map.size(), "inplace_map");
|
||||
|
||||
for (size_t output_index = 0; output_index < inplace_map.size(); ++output_index) {
|
||||
PyObject* input_index = PyLong_FromLong(inplace_map[output_index]);
|
||||
Ort_PyList_SetItem_NoIncref(inplace_map_obj, output_index, input_index, std::to_string(__LINE__));
|
||||
}
|
||||
|
||||
return inplace_map_obj;
|
||||
}
|
||||
|
||||
void InvokeRunner(
|
||||
PyObject* callback_runner,
|
||||
PyObject* args,
|
||||
|
|
@ -197,14 +207,15 @@ PythonObjectPtr CreatePythonCallArguments(
|
|||
const std::vector<void*>& obj_args,
|
||||
const std::vector<int64_t>& obj_indices,
|
||||
const bool is_training_mode,
|
||||
const bool is_inplace,
|
||||
const std::string& invoke_id) {
|
||||
const std::vector<int64_t>& inplace_map,
|
||||
const std::string& invoke_id,
|
||||
const std::string& func_name) {
|
||||
ORT_ENFORCE(PyCallable_Check(callback), "Forward callback is not callable.");
|
||||
// The number of variables before those of
|
||||
// autograd.Function.apply and autograd.Function.backward.
|
||||
// The extra variables are used to configure the launch
|
||||
// forward and backward runners.
|
||||
constexpr int64_t num_control_args = 6;
|
||||
constexpr int64_t num_control_args = 7;
|
||||
|
||||
// All arguments created for Python call will be destroyed along with PythonObjectPtr.
|
||||
PythonObjectPtr args(Ort_PyTuple_New(num_control_args + len, "forward_arguments_tuple"), PythonObjectDeleter);
|
||||
|
|
@ -216,11 +227,16 @@ PythonObjectPtr CreatePythonCallArguments(
|
|||
Ort_PyTuple_SetItem_NoIncref(args.get(), 2, tensor_flags, "tensor_flags");
|
||||
PyObject* is_training_mode_arg = is_training_mode ? Py_True : Py_False;
|
||||
Ort_PyTuple_SetItem_Incref(args.get(), 3, is_training_mode_arg, "is_training_mode");
|
||||
PyObject* is_inplace_arg = is_inplace ? Py_True : Py_False;
|
||||
Ort_PyTuple_SetItem_Incref(args.get(), 4, is_inplace_arg, "is_inplace_mode");
|
||||
|
||||
PyObject* inplace_map_arg = CreateInplaceMap(inplace_map);
|
||||
Ort_PyTuple_SetItem_NoIncref(args.get(), 4, inplace_map_arg, "inplace_map");
|
||||
|
||||
PyObject* kernel_invoke_id_arg = PyBytes_FromStringAndSize(invoke_id.c_str(), invoke_id.size());
|
||||
Ort_PyTuple_SetItem_NoIncref(args.get(), 5, kernel_invoke_id_arg, "kernel_invoke_id_arg");
|
||||
|
||||
PyObject* func_name_arg = PyBytes_FromStringAndSize(func_name.c_str(), func_name.size());
|
||||
Ort_PyTuple_SetItem_NoIncref(args.get(), 6, func_name_arg, "func_name_arg");
|
||||
|
||||
// Tensor inputs to call autograd.Function.apply or autograd.Function.backward.
|
||||
for (size_t i = 0; i < tensor_args.size(); ++i) {
|
||||
if (!tensor_args[i].has_value()) {
|
||||
|
|
@ -246,6 +262,7 @@ PythonObjectPtr CreatePythonCallArguments(
|
|||
}
|
||||
|
||||
void Invoke(
|
||||
const std::string& func_name,
|
||||
PyObject* runner,
|
||||
PyObject* callback,
|
||||
const std::vector<int64_t>& requires_grads,
|
||||
|
|
@ -253,11 +270,11 @@ void Invoke(
|
|||
const std::vector<int64_t>& tensor_indices,
|
||||
const std::vector<void*>& obj_args,
|
||||
const std::vector<int64_t>& obj_indices,
|
||||
void** diff_ctx,
|
||||
std::vector<OrtValue>& returned_ortvalues,
|
||||
const bool is_training_mode,
|
||||
const bool is_inplace,
|
||||
const std::string& invoke_id) {
|
||||
const std::vector<int64_t>& inplace_map,
|
||||
const std::string& invoke_id,
|
||||
void** diff_ctx,
|
||||
std::vector<OrtValue>& returned_ortvalues) {
|
||||
const auto len = tensor_args.size() + obj_args.size();
|
||||
CheckArguments(len, requires_grads, tensor_args, tensor_indices, obj_args, obj_indices);
|
||||
RefCountTracker::GetInstance().Reset();
|
||||
|
|
@ -271,8 +288,9 @@ void Invoke(
|
|||
obj_args,
|
||||
obj_indices,
|
||||
is_training_mode,
|
||||
is_inplace,
|
||||
invoke_id);
|
||||
inplace_map,
|
||||
invoke_id,
|
||||
func_name);
|
||||
|
||||
RefCountTracker::GetInstance().DumpDetails("Before Invoke Python Call");
|
||||
InvokeRunner(runner, args.get(), is_training_mode, diff_ctx, returned_ortvalues);
|
||||
|
|
@ -282,17 +300,18 @@ void Invoke(
|
|||
}
|
||||
|
||||
void TorchProxy::Forward(
|
||||
const std::string& func_name,
|
||||
void* callback,
|
||||
const std::vector<int64_t>& requires_grads,
|
||||
const std::vector<std::optional<OrtValue>>& tensor_args,
|
||||
const std::vector<int64_t>& tensor_indices,
|
||||
const std::vector<void*>& obj_args,
|
||||
const std::vector<int64_t>& obj_indices,
|
||||
void** diff_ctx,
|
||||
std::vector<OrtValue>& returned_ortvalues,
|
||||
const bool is_training_mode,
|
||||
const bool is_inplace,
|
||||
const std::string& invoke_id) {
|
||||
const std::vector<int64_t>& inplace_map,
|
||||
const std::string& invoke_id,
|
||||
void** diff_ctx,
|
||||
std::vector<OrtValue>& returned_ortvalues) {
|
||||
// Semantically, this lock uniquely takes the ownership of TorchProxy
|
||||
// so that there will be only one of TorchProxy::Forward TorchProxy::Backward
|
||||
// can be run at one time.
|
||||
|
|
@ -301,6 +320,7 @@ void TorchProxy::Forward(
|
|||
GilGuard guard;
|
||||
auto runner = OrtTorchFunctionPool::GetInstance().GetForwardRunner();
|
||||
Invoke(
|
||||
func_name,
|
||||
runner,
|
||||
reinterpret_cast<PyObject*>(callback),
|
||||
requires_grads,
|
||||
|
|
@ -308,22 +328,23 @@ void TorchProxy::Forward(
|
|||
tensor_indices,
|
||||
obj_args,
|
||||
obj_indices,
|
||||
diff_ctx,
|
||||
returned_ortvalues,
|
||||
is_training_mode,
|
||||
is_inplace,
|
||||
invoke_id);
|
||||
inplace_map,
|
||||
invoke_id,
|
||||
diff_ctx,
|
||||
returned_ortvalues);
|
||||
}
|
||||
|
||||
void TorchProxy::Backward(
|
||||
const std::string& func_name,
|
||||
void* callback,
|
||||
const std::vector<std::optional<OrtValue>>& tensor_args,
|
||||
const std::vector<int64_t>& tensor_indices,
|
||||
const std::vector<void*>& obj_args,
|
||||
const std::vector<int64_t>& obj_indices,
|
||||
std::vector<OrtValue>& returned_ortvalues,
|
||||
const bool is_inplace,
|
||||
const std::string& invoke_id) {
|
||||
const std::vector<int64_t>& inplace_map,
|
||||
const std::string& invoke_id,
|
||||
std::vector<OrtValue>& returned_ortvalues) {
|
||||
// Semantically, this lock uniquely takes the ownership of TorchProxy
|
||||
// so that there will be only one of TorchProxy::Forward TorchProxy::Backward
|
||||
// can be run at one time.
|
||||
|
|
@ -336,6 +357,7 @@ void TorchProxy::Backward(
|
|||
const auto all_input_count = tensor_args.size() + obj_args.size();
|
||||
const std::vector<int64_t> requires_grads(all_input_count, 0);
|
||||
Invoke(
|
||||
func_name,
|
||||
runner,
|
||||
reinterpret_cast<PyObject*>(callback),
|
||||
requires_grads,
|
||||
|
|
@ -343,12 +365,11 @@ void TorchProxy::Backward(
|
|||
tensor_indices,
|
||||
obj_args,
|
||||
obj_indices,
|
||||
nullptr /* context to store */,
|
||||
returned_ortvalues,
|
||||
true /* is_training_mode */,
|
||||
is_inplace,
|
||||
invoke_id);
|
||||
inplace_map,
|
||||
invoke_id,
|
||||
nullptr /* context to store */,
|
||||
returned_ortvalues);
|
||||
}
|
||||
} // namespace torch
|
||||
} // namespace language_interop_ops
|
||||
} // namespace onnxruntime
|
||||
|
||||
} // namespace onnxruntime::language_interop_ops::torch
|
||||
|
|
|
|||
|
|
@ -37,27 +37,29 @@ class TorchProxy {
|
|||
};
|
||||
|
||||
void Forward(
|
||||
const std::string& func_name,
|
||||
void* callback,
|
||||
const std::vector<int64_t>& requires_grads,
|
||||
const std::vector<std::optional<OrtValue>>& tensor_args,
|
||||
const std::vector<int64_t>& tensor_indices,
|
||||
const std::vector<void*>& obj_args,
|
||||
const std::vector<int64_t>& obj_indices,
|
||||
void** diff_ctx,
|
||||
std::vector<OrtValue>& returned_ortvalues,
|
||||
const bool is_training_mode,
|
||||
const bool is_inplace,
|
||||
const std::string& invoke_id);
|
||||
const std::vector<int64_t>& inplace_map,
|
||||
const std::string& invoke_id,
|
||||
void** diff_ctx,
|
||||
std::vector<OrtValue>& returned_ortvalues);
|
||||
|
||||
void Backward(
|
||||
const std::string& func_name,
|
||||
void* callback,
|
||||
const std::vector<std::optional<OrtValue>>& tensor_args,
|
||||
const std::vector<int64_t>& tensor_indices,
|
||||
const std::vector<void*>& obj_args,
|
||||
const std::vector<int64_t>& obj_indices,
|
||||
std::vector<OrtValue>& return_args,
|
||||
const bool is_inplace,
|
||||
const std::string& invoke_id);
|
||||
const std::vector<int64_t>& inplace_map,
|
||||
const std::string& invoke_id,
|
||||
std::vector<OrtValue>& return_args);
|
||||
|
||||
private:
|
||||
TorchProxy(){};
|
||||
|
|
|
|||
|
|
@ -1765,7 +1765,6 @@ IMPLEMENT_GRADIENT_BUILDER(GetPythonOpGradient) {
|
|||
ORT_ENFORCE(utils::HasString(src_attrs.at("func_name")));
|
||||
attrs.push_back(MakeAttribute("func_name", src_attrs.at("func_name").s()));
|
||||
attrs.push_back(MakeAttribute("output_convention", src_attrs.at("input_convention").s()));
|
||||
attrs.push_back(MakeAttribute("inplace", src_attrs.at("inplace").i()));
|
||||
|
||||
// input_tensor_types[i] store the type of autograd.Function.apply's ith output.
|
||||
// Note that PythonOpGrad's 0-th input is the Python context generated by PythonOp.
|
||||
|
|
|
|||
|
|
@ -3908,10 +3908,16 @@ Return true if all elements are true and false otherwise.
|
|||
AttributeProto::INTS)
|
||||
// Other attributes.
|
||||
.Attr(
|
||||
"inplace",
|
||||
"Indicate if the output should reuse input memory.",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(0))
|
||||
"tensor_reuse_map",
|
||||
"A int array indicating whether output at each index is reusing specific input or not."
|
||||
"If the given index is -1, it means the output is not reusing any input."
|
||||
"For example, there are 2 tensor inputs and 3 tensor outputs (including ctx), "
|
||||
"tensor_reuse_map = [-1, 1, 0] means"
|
||||
"- the output 0 (ctx) don't reuse any input buffer."
|
||||
"- the output 1 reuses the input 1."
|
||||
"- the output 2 reuses the input 0.",
|
||||
AttributeProto::INTS,
|
||||
false)
|
||||
.Attr(
|
||||
"training_mode",
|
||||
"Indicate if the model is exported in training_mode, by default, False.",
|
||||
|
|
@ -4033,11 +4039,6 @@ Return true if all elements are true and false otherwise.
|
|||
"func_name",
|
||||
"Name of custom class.",
|
||||
AttributeProto::STRING)
|
||||
.Attr(
|
||||
"inplace",
|
||||
"Indicate if the output should reuse input memory. Todo(pengwa): do we need it?",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(0))
|
||||
.Attr(
|
||||
"input_tensor_types",
|
||||
"Input types of autograd.Function.backward (including only tensor inputs)."
|
||||
|
|
@ -4069,6 +4070,16 @@ Return true if all elements are true and false otherwise.
|
|||
"A string inidicating autograd.Function.backward outputs's type."
|
||||
"value 'c' - non-tensor output; value 'd' - tensor output.",
|
||||
AttributeProto::STRING)
|
||||
.Attr(
|
||||
"tensor_reuse_map",
|
||||
"A int array indicating whether output at each index is reusing specific input or not."
|
||||
"If the given index is -1, it means the output is not reusing any input."
|
||||
"For example, there are 3 inputs (including ctx) and 2 outputs, tensor_reuse_map = [2, 1] means"
|
||||
"- the output 0 reuses the input 2."
|
||||
"- the output 1 reuses the input 1."
|
||||
"Be noted: the input 0 is ctx.",
|
||||
AttributeProto::INTS,
|
||||
false)
|
||||
.Attr(
|
||||
"comment",
|
||||
"comment only for debugging purposes.",
|
||||
|
|
|
|||
|
|
@ -118,7 +118,6 @@ def _export_pt_1_10(g, n, *args, **kwargs):
|
|||
"wrap exportable sub-nn.Module's as ORTModule."
|
||||
)
|
||||
|
||||
inplace = kwargs["inplace"]
|
||||
# TODO move to public API once the exporter team exposes that
|
||||
training_mode = None
|
||||
if get_runtime_pytorch_version() >= version.parse("1.12"):
|
||||
|
|
@ -260,7 +259,6 @@ def _export_pt_1_10(g, n, *args, **kwargs):
|
|||
|
||||
attrs = {
|
||||
"func_name_s": func_full_qual_name,
|
||||
"inplace_i": inplace,
|
||||
"input_convention_s": cconv,
|
||||
"outputs": n.outputsSize(),
|
||||
"input_tensor_types_i": input_tensor_types,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
|
|
@ -14,10 +15,21 @@ from torch.utils.dlpack import from_dlpack, to_dlpack
|
|||
from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_interop_utils
|
||||
|
||||
from ._fallback import ORTModuleFallbackException, ORTModuleIOError, _FallbackManager, wrap_exception # noqa: F401
|
||||
from ._utils import get_rank
|
||||
|
||||
|
||||
def _log_warning(message: str):
|
||||
"""Configure the logger for PythonOp runner according to following rules.
|
||||
1. If multiple processes are used, the rank will be appended
|
||||
to the logger name.
|
||||
2. The logger will be disabled for non-zero ranks.
|
||||
"""
|
||||
if get_rank() == 0:
|
||||
warnings.warn(f"[rank-{get_rank()}] {message}")
|
||||
|
||||
|
||||
class CustomFuncOpKernelInfo:
|
||||
"""Store the kernel specific information retrieved with the first-time run."""
|
||||
"""Store the kernel-specific information retrieved with the first-time run."""
|
||||
|
||||
def __init__(self, kernel_invoke_id: str):
|
||||
# kernel_invoke_id is a string contains session thread id, op kernel creation time stamp in ms, a random int,
|
||||
|
|
@ -31,9 +43,9 @@ class CustomFuncOpKernelInfo:
|
|||
# reference, may release the content of the tensor before it is needed in backward). Once
|
||||
# `autograd.Function.apply` completes, by checking the existence of the tensor in the saved_tensors,
|
||||
# `_GlobalOpKernelInfoMap` is updated to save the input indices that are saved in context.
|
||||
# 2. For the subsequent runs, if the input index is in `input_indices_to_save_in_ctx`, the tensor
|
||||
# 2. For the subsequent runs, if the input index is in `tensor_input_indices_to_save_in_ctx`, the tensor
|
||||
# will be cloned before fed into `autograd.Function.apply` as input.
|
||||
self.input_indices_to_save_in_ctx: List[int] = []
|
||||
self.tensor_input_indices_to_save_in_ctx: Optional[List[int]] = None
|
||||
|
||||
# To align with PyTorch `ctx.set_materialize_grads(False|True)``
|
||||
# materialize_grads_config is a map from output index to (device, dtype, shape) of the output tensor, used
|
||||
|
|
@ -41,27 +53,211 @@ class CustomFuncOpKernelInfo:
|
|||
self.materialize_grads: bool = False
|
||||
self.materialize_grads_config: Optional[Dict[int, Tuple[torch.device, torch.dtype, torch.shape]]] = None
|
||||
|
||||
# For the tensors generated from ORT backend, there is special handling here:
|
||||
# 1. For the first time run for the kernel (the uniqueness of the kernel is defined by kernel_invoke_id),
|
||||
# all such tensors will be cloned (with gradient) in case they are marked as dirty (if not cloned, but marked
|
||||
# as dirty, PyTorch will complain the tensor is a leaf, should not be used for inplace update). Once
|
||||
# `autograd.Function.apply` completes, by checking the existence of the tensor in the dirty_tensors,
|
||||
# `_GlobalOpKernelInfoMap` is updated to save the input indices that are marked as dirty.
|
||||
# 2. For the subsequent runs, if the input index is in `tensor_input_indices_for_mark_dirty`, the tensor
|
||||
# will be cloned (with gradient) before fed into `autograd.Function.apply` as input.
|
||||
self.tensor_input_indices_for_mark_dirty: Optional[List[int]] = None
|
||||
|
||||
# Store the kernel specific information that cannot be retrieved and saved by PyTorch exporter.
|
||||
# For those infos that can only be retrieved with real run, we try to collect them in the first time run.
|
||||
# A list of output indices that needs to be clone before returned, due to inplace update analysis.
|
||||
self.output_indices_for_clone: Optional[List[int]] = None
|
||||
|
||||
|
||||
# Store the kernel-specific information that cannot be retrieved and saved by PyTorch exporter.
|
||||
# For the infos that can only be retrieved with real run, we try to collect them in the first time run.
|
||||
# key: kernel_invoke_id, value: CustomFuncOpKernelInfo.
|
||||
_GlobalOpKernelInfoMap: Dict[str, CustomFuncOpKernelInfo] = {}
|
||||
|
||||
|
||||
def _process_inplace_outputs(
|
||||
kernel_info: CustomFuncOpKernelInfo,
|
||||
func_name: str,
|
||||
input_tensors_of_kernel_run: List[torch.Tensor],
|
||||
all_outputs_of_kernel_run: List[Union[torch.Tensor, any]],
|
||||
all_outputs_to_tensor_inputs_reuse_map: List[int],
|
||||
raw_input_tensors_used_inplace: Dict[int, torch.Tensor],
|
||||
is_backward=False,
|
||||
):
|
||||
"""Special handling for in-place reusing in forward or backward.
|
||||
|
||||
Args:
|
||||
kernel_info: kernel-specific information.
|
||||
func_name: name of the autograd.Function.
|
||||
input_tensors_of_kernel_run: input tensors used to run the autograd.Function forward/backward.
|
||||
all_outputs_of_kernel_run: all outputs of the autograd.Function forward/backward.
|
||||
all_outputs_to_tensor_inputs_reuse_map: a list of the same length of kernel outputs, each element representing
|
||||
which input index it is reusing. If there is no reuse, the value is -1.
|
||||
raw_input_tensors_used_inplace: a dict of raw input tensors marked as inplace in
|
||||
`all_outputs_to_tensor_inputs_reuse_map`, the key is the input index, value is the raw input tensor.
|
||||
is_backward: indicates if this is backward or forward.
|
||||
|
||||
Procedures:
|
||||
1. Detect all outputs to tensor inputs reuse mapping.
|
||||
2. Validate the detected inplace_map with the registered inplace_map in ORT. For the output tensor,
|
||||
2.0 If the reuse mapping value is the same in both inplace_map and detected inplace_map:
|
||||
2.0.1 Most likely, we don't need to do anything, except 2.0.2.
|
||||
2.0.2 Conditions:
|
||||
> During forward run,
|
||||
> The output tensor is reusing one of input tensors,
|
||||
> The raw input tensor to be reused given from ORT is copied to run the forward kernels
|
||||
(for two possible reasons:
|
||||
a. the first time forward run, all inputs will be copied to detect
|
||||
`tensor_input_indices_to_save_in_ctx`;
|
||||
b. for every iteration, the input needs to be cloned because it is in
|
||||
`tensor_input_indices_to_save_in_ctx`).
|
||||
|
||||
In this case, need to copy the output tensor back to the raw input tensor, to make it compatible with
|
||||
ORT statistically planned buffer reuse.
|
||||
2.1 If the reuse mapping value is NOT equal in both inplace_map and detected inplace_map:
|
||||
2.1.1 If the detected reuse input index is -1 (e.g. there is NO buffer reuse for this output),
|
||||
while user specified reuse input index is NOT -1 (ORT planned the reuse), we raise an error.
|
||||
2.1.2 If the detected reuse input index is NOT -1 (e.g. there is buffer reuse for this output),
|
||||
while user specified reuse input index is -1 (ORT did not plan the reuse). We will try to clone the
|
||||
output tensor before returning to ORT, to align with ORT's NO Buffer reuse plan; otherwise, once the
|
||||
input buffer is released by ORT memory planner, the output tensor read/write will be corrupted.
|
||||
Raise a warning to notify users to update inplace_map explicitly for performance consideration.
|
||||
2.1.3 Other cases (for example user gives a wrong mapping index compared with detected ones), raise an
|
||||
error.
|
||||
3. Do copies for 2.1.2 cases.
|
||||
4. Do copies for 2.0.2 cases.
|
||||
"""
|
||||
|
||||
log_prefix = f"{func_name}->{'Backward' if is_backward else 'Forward'}: "
|
||||
input_tensor_address_list = [t.data_ptr() for t in input_tensors_of_kernel_run]
|
||||
if is_backward:
|
||||
input_tensor_address_list = [-1, *input_tensor_address_list] # skip the context input
|
||||
|
||||
is_first_time_init = kernel_info.output_indices_for_clone is None
|
||||
# If this is the first time run, collect runtime tensor reuse mapping.
|
||||
if is_first_time_init:
|
||||
# Procedure 1: Detect all outputs to tensor inputs reuse mapping, according to `all_outputs_of_kernel_run` and
|
||||
# `input_tensors_of_kernel_run`.
|
||||
assert len(all_outputs_to_tensor_inputs_reuse_map) == len(all_outputs_of_kernel_run), (
|
||||
f"{log_prefix}all_outputs_to_tensor_inputs_reuse_map and kernel run outputs should have the same length."
|
||||
f"all_outputs_to_tensor_inputs_reuse_map: {all_outputs_to_tensor_inputs_reuse_map}, "
|
||||
f"kernel run outputs: {all_outputs_of_kernel_run}"
|
||||
)
|
||||
|
||||
# Detect all outputs to tensor inputs reuse mapping.
|
||||
detected_reuse_map = [-1] * (len(all_outputs_of_kernel_run))
|
||||
for output_index, arg in enumerate(all_outputs_of_kernel_run):
|
||||
if not isinstance(arg, torch.Tensor):
|
||||
continue
|
||||
if arg.data_ptr() in input_tensor_address_list:
|
||||
input_index = input_tensor_address_list.index(arg.data_ptr())
|
||||
detected_reuse_map[output_index] = input_index
|
||||
|
||||
# Procedure 2: Validate the detected inplace_map with the registered inplace_map in ORT.
|
||||
output_indices_for_clone = (
|
||||
[]
|
||||
) # collect the output indices that need to be cloned before returned in case 2.1.2.
|
||||
for output_index, (detected_inplace_index, inplace_index) in enumerate(
|
||||
zip(detected_reuse_map, all_outputs_to_tensor_inputs_reuse_map)
|
||||
):
|
||||
if inplace_index == detected_inplace_index:
|
||||
continue
|
||||
|
||||
# If users register inplace_map (alloc planner will do buffer reuse),
|
||||
# but detected inplace_map indicates it is NO inplace reusing, we raise an error.
|
||||
if inplace_index != -1 and detected_inplace_index == -1:
|
||||
raise RuntimeError(
|
||||
f"{log_prefix}Fatal: "
|
||||
f"ONNX Op attribute 'tensor_reuse_map' indicates {output_index}-th output is reusing input "
|
||||
f"{inplace_index}, but detected inplace_map indicates it is NOT reusing any input. "
|
||||
"Please update inplace_map explicitly to make it consistent "
|
||||
f"to avoid undefined behavior due to ORT's memory reuse plan. "
|
||||
f"inplace_map: {all_outputs_to_tensor_inputs_reuse_map}, "
|
||||
f"detected inplace_map: {detected_reuse_map}"
|
||||
)
|
||||
|
||||
if inplace_index == -1 and detected_inplace_index != -1:
|
||||
output_indices_for_clone.append(output_index)
|
||||
continue
|
||||
|
||||
raise RuntimeError(
|
||||
f"{log_prefix}Fatal: "
|
||||
f"ONNX Op attribute 'inplace_map' indicates {inplace_index}-th output is reusing "
|
||||
f"input index {detected_inplace_index}, but detected inplace_map indicates it is reusing "
|
||||
f"input index {inplace_index}. Please update inplace_map explicitly to avoid undefined behavior "
|
||||
f"due to memory reuse. inplace_map: {all_outputs_to_tensor_inputs_reuse_map}, "
|
||||
f"detected inplace_map: {detected_reuse_map}"
|
||||
)
|
||||
|
||||
kernel_info.output_indices_for_clone = output_indices_for_clone
|
||||
|
||||
assert kernel_info.output_indices_for_clone is not None
|
||||
|
||||
# Procedure 3: Do copies for 2.1.2 cases.
|
||||
for output_index in kernel_info.output_indices_for_clone:
|
||||
_log_warning(
|
||||
f"{log_prefix}ONNX Op attribute "
|
||||
f"'tensor_reuse_map' doesn't indicate {output_index}-th output is reusing any input, "
|
||||
f"but detected inplace_map indicates it is reusing some input index. "
|
||||
"A clone will be done before returning to ORT, to align with ORT's NO Buffer reuse plan. "
|
||||
"Please update inplace_map explicitly to avoid such a copy."
|
||||
)
|
||||
all_outputs_of_kernel_run[output_index] = all_outputs_of_kernel_run[output_index].detach().clone()
|
||||
|
||||
# Procedure 4: Do copies for 2.0.2 cases.
|
||||
if is_backward is False and (
|
||||
is_first_time_init
|
||||
or kernel_info.tensor_input_indices_to_save_in_ctx
|
||||
or kernel_info.tensor_input_indices_for_mark_dirty
|
||||
):
|
||||
for raw_tensor_input_index, raw_input_tensor in raw_input_tensors_used_inplace.items():
|
||||
# raw_input_tensor can be None for backward run, but backward won't go here.
|
||||
assert isinstance(raw_input_tensor, torch.Tensor)
|
||||
|
||||
# We did not do the check with tensor_input_indices_to_save_in_ctx/tensor_input_indices_for_mark_dirty
|
||||
# because even for those tensor indices not in
|
||||
# tensor_input_indices_to_save_in_ctx/tensor_input_indices_for_mark_dirty, we still need to do the
|
||||
# copy for the first-time run.
|
||||
if raw_input_tensor.data_ptr() == input_tensor_address_list[raw_tensor_input_index]:
|
||||
# If the raw input tensor is not copied, we don't need this handling.
|
||||
continue
|
||||
|
||||
copied = False # for each tensor, we don't do the copy once.
|
||||
output_indices_reusing_current_raw_input = [
|
||||
output_index
|
||||
for output_index, input_index in enumerate(all_outputs_to_tensor_inputs_reuse_map)
|
||||
if input_index == raw_tensor_input_index
|
||||
]
|
||||
output_tensor_address = all_outputs_of_kernel_run[output_indices_reusing_current_raw_input[0]].data_ptr()
|
||||
for output_index in output_indices_reusing_current_raw_input:
|
||||
assert (
|
||||
output_tensor_address == all_outputs_of_kernel_run[output_index].data_ptr()
|
||||
), "Outputs reusing the same input tensor should have the same address."
|
||||
|
||||
if not copied:
|
||||
# Only need a copy once.
|
||||
raw_input_tensor.copy_(all_outputs_of_kernel_run[output_index])
|
||||
_log_warning(
|
||||
f"{log_prefix}Copy output tensor {output_index} to raw input tensor {raw_tensor_input_index}."
|
||||
"Provide output to input reuse mapping to avoid the copy overhead."
|
||||
)
|
||||
copied = True
|
||||
|
||||
all_outputs_of_kernel_run[output_index] = raw_input_tensor
|
||||
|
||||
|
||||
def _get_context(forward_tensor_outputs: List[torch.Tensor]) -> Tuple[any, Optional[torch.Tensor]]:
|
||||
"""Search for context among all outputs.
|
||||
|
||||
Note1: All forward outputs of torch.autograd.Function shared the same gradient function pointer,
|
||||
Note 1: All forward outputs of torch.autograd.Function shared the same gradient function pointer,
|
||||
so here we just get the first tensor having grad_fn attribute.
|
||||
(https://github.com/PyTorch/PyTorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/custom_function.cpp#L267)
|
||||
|
||||
Note2: Context can be None because NOT all torch.autograd.Function's are differentiable. The function
|
||||
Note 2: Context can be None because NOT all torch.autograd.Function's are differentiable. The function
|
||||
https://github.com/PyTorch/PyTorch/blob/d701357d921ef167d42c125e65b6f7da6be3ad0f/torch/csrc/autograd/custom_function.cpp#L209?
|
||||
means if all output of forward function is not differentiable, then grad_fn will be None (not be set).
|
||||
means if all output of the forward function is not differentiable, then grad_fn will be None (not be set).
|
||||
|
||||
For example,
|
||||
class Bar(torch.autograd.Function):
|
||||
# A non-differentiable autograd Function whose forard output
|
||||
# A non-differentiable autograd Function whose forward output
|
||||
# doesn't have grad_fn attribute.
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
|
|
@ -85,7 +281,7 @@ def _get_context(forward_tensor_outputs: List[torch.Tensor]) -> Tuple[any, Optio
|
|||
continue
|
||||
|
||||
if arg.grad_fn is None:
|
||||
# For following case, it is possible grad_fn exist, but its value is None,
|
||||
# For the following case, it is possible grad_fn exists, but its value is None,
|
||||
# so we need to continue to search for the first tensor having a non-None grad_fn.
|
||||
#
|
||||
# >>> w = torch.randn(5, 6)
|
||||
|
|
@ -106,9 +302,10 @@ def _get_context(forward_tensor_outputs: List[torch.Tensor]) -> Tuple[any, Optio
|
|||
return (ctx, first_tensor_output)
|
||||
|
||||
|
||||
def _finalize_traing_mode_forward(
|
||||
def _finalize_training_mode_forward(
|
||||
kernel_invoke_id: str,
|
||||
input_tensors_from_ort: Dict[int, torch.Tensor],
|
||||
func_name: str,
|
||||
input_tensors_used_for_fw_run: Dict[int, torch.Tensor],
|
||||
forward_output_tensors: List[Union[torch.Tensor, None]],
|
||||
):
|
||||
"""Complete the epilogue of forward runner for training mode.
|
||||
|
|
@ -120,16 +317,25 @@ def _finalize_traing_mode_forward(
|
|||
|
||||
Things to do:
|
||||
1. Try to get context from forward output tensors.
|
||||
2. Remove the gradient functions between current autograd.Function and its input's gradient function, because
|
||||
2. Remove the gradient functions between the current autograd.Function and its input's gradient function, because
|
||||
in ORT we don't depend on PyTorch's autograd engine.
|
||||
3. Register the current autograd.Function's gradient function into our PyNodeSharedPointerPool.
|
||||
4. Save kernel specific information into _GlobalOpKernelInfoMap in the first-time kernel run.
|
||||
4. Save kernel-specific information into _GlobalOpKernelInfoMap in the first-time kernel run.
|
||||
"""
|
||||
|
||||
ctx, tensor_owning_ctx = _get_context(forward_output_tensors)
|
||||
|
||||
kernel_info = _GlobalOpKernelInfoMap[kernel_invoke_id]
|
||||
|
||||
# ctx being None in training mode means the forward function is not differentiable, so backward is not needed.
|
||||
if ctx is None:
|
||||
# If this is the first time run, collect kernel-specific information.
|
||||
if kernel_info.tensor_input_indices_to_save_in_ctx is None:
|
||||
kernel_info.tensor_input_indices_to_save_in_ctx = []
|
||||
|
||||
if kernel_info.tensor_input_indices_for_mark_dirty is None:
|
||||
kernel_info.tensor_input_indices_for_mark_dirty = []
|
||||
|
||||
return None
|
||||
|
||||
# Filter out the None in the saved_tensors.
|
||||
|
|
@ -137,19 +343,20 @@ def _finalize_traing_mode_forward(
|
|||
|
||||
ctx.fw_kernel_invoke_id = kernel_invoke_id
|
||||
|
||||
# If this is the first time run, collect kernel specific information.
|
||||
if kernel_invoke_id not in _GlobalOpKernelInfoMap:
|
||||
kernel_info = CustomFuncOpKernelInfo(kernel_invoke_id)
|
||||
_GlobalOpKernelInfoMap[kernel_invoke_id] = kernel_info
|
||||
# If this is the first time run, collect kernel-specific information.
|
||||
if kernel_info.tensor_input_indices_to_save_in_ctx is None:
|
||||
kernel_info.tensor_input_indices_to_save_in_ctx = []
|
||||
if len(saved_tensors):
|
||||
# Check tensors generated by ORT is in the saved_tensors or not.
|
||||
# Check tensors generated by ORT are in the saved_tensors or not.
|
||||
# If yes, save the input index of the tensor in the _GlobalOpKernelInfoMap.
|
||||
kernel_info.input_indices_to_save_in_ctx = [
|
||||
arg_index
|
||||
for arg_index, tensor in input_tensors_from_ort.items()
|
||||
kernel_info.tensor_input_indices_to_save_in_ctx = [
|
||||
tensor_input_index
|
||||
for tensor_input_index, tensor in input_tensors_used_for_fw_run.items()
|
||||
if any(tensor is saved_tensor for saved_tensor in saved_tensors)
|
||||
]
|
||||
warnings.warn("Add input index to _GlobalOpKernelInfoMap, to avoid extra copy in every iteration.")
|
||||
_log_warning(
|
||||
f"{func_name}: Add input index to _GlobalOpKernelInfoMap, to avoid extra copy in every iteration."
|
||||
)
|
||||
kernel_info.materialize_grads = torch_interop_utils.get_materialize_grads(tensor_owning_ctx)
|
||||
kernel_info.materialize_grads_config = OrderedDict()
|
||||
if kernel_info.materialize_grads:
|
||||
|
|
@ -161,6 +368,22 @@ def _finalize_traing_mode_forward(
|
|||
tensor.shape,
|
||||
)
|
||||
|
||||
if kernel_info.tensor_input_indices_for_mark_dirty is None:
|
||||
kernel_info.tensor_input_indices_for_mark_dirty = []
|
||||
# Check tensors generated by ORT are marked as dirty(for inplace update) or not.
|
||||
# If yes, save the input index of the tensor in the _GlobalOpKernelInfoMap.
|
||||
are_tensors_marked_as_dirty = torch_interop_utils.are_tensors_marked_as_dirty(
|
||||
tensor_owning_ctx, [t for t in input_tensors_used_for_fw_run.values()]
|
||||
)
|
||||
kernel_info.tensor_input_indices_for_mark_dirty = [
|
||||
tensor_input_index
|
||||
for is_dirty, (tensor_input_index, tensor) in zip(
|
||||
are_tensors_marked_as_dirty, input_tensors_used_for_fw_run.items()
|
||||
)
|
||||
if is_dirty is True
|
||||
]
|
||||
_log_warning(f"{func_name}: Add input index to _GlobalOpKernelInfoMap, to support leaf node do inplace update.")
|
||||
|
||||
# FORWARD BACKWARD FUNCTION CONNECTIONS
|
||||
# input_1 (leaf, constructed by from_dlpack) <----reference---- AccumulateGrad gradient function
|
||||
# ↓ ↑
|
||||
|
|
@ -188,8 +411,9 @@ def call_python_forward_function(
|
|||
requires_grad_flags: List[bool],
|
||||
tensor_type_flags: List[int],
|
||||
is_training_mode: bool,
|
||||
inplace: bool,
|
||||
inplace_map: List[int],
|
||||
kernel_invoke_id: str,
|
||||
func_name: Union[bytes, str],
|
||||
*args,
|
||||
):
|
||||
"""
|
||||
|
|
@ -206,93 +430,119 @@ def call_python_forward_function(
|
|||
requires_grad_flags: requires_grad_flags[i] indicates if the i-th arg needs gradient.
|
||||
tensor_type_flags: tensor_type_flags[i] indicates the type of the i-th arg, 0 - non-tensor, 1 - tensor.
|
||||
is_training_mode: indicates if this model is running under training mode.
|
||||
inplace: indicates if args can be modified inside the custom function.
|
||||
inplace_map: a list of the same length of kernel outputs, each element represents which input index
|
||||
it is reusing. If there is no reuse, the value is -1.
|
||||
args: inputs to "backward_function".
|
||||
"""
|
||||
|
||||
def generate_non_leaf_or_not(grad_flag, tensor_flag, arg, is_training_mode, is_inplace):
|
||||
if is_training_mode and tensor_flag and grad_flag and is_inplace:
|
||||
# "multiply one" helps change the torch tensor's is_leaf to False.
|
||||
# This is required when the torch tensor is updated in-place during forward pass.
|
||||
# We cannot use view here, because PyTorch handles grad_fn for view differently.
|
||||
non_leaf_arg = arg * 1
|
||||
return non_leaf_arg
|
||||
else:
|
||||
return arg
|
||||
|
||||
try:
|
||||
func_name = func_name.decode("utf-8") if isinstance(func_name, bytes) else func_name
|
||||
# If this is the first time run, collect runtime tensor reuse mapping.
|
||||
if kernel_invoke_id not in _GlobalOpKernelInfoMap:
|
||||
kernel_info = CustomFuncOpKernelInfo(kernel_invoke_id)
|
||||
_GlobalOpKernelInfoMap[kernel_invoke_id] = kernel_info
|
||||
|
||||
kernel_info = _GlobalOpKernelInfoMap[kernel_invoke_id]
|
||||
|
||||
tensor_input_indices_to_save_in_ctx = kernel_info.tensor_input_indices_to_save_in_ctx
|
||||
tensor_input_indices_for_mark_dirty = kernel_info.tensor_input_indices_for_mark_dirty
|
||||
|
||||
# Collect the tensor address for all inputs used for run forward, used for reuse detection.
|
||||
tensor_input_index = 0
|
||||
# If the input is reused, we need to save the raw input tensor for special handling.
|
||||
raw_input_tensors_used_inplace = OrderedDict() # Orders matter here.
|
||||
input_tensors_used_for_fw_run = OrderedDict() # Orders matter here.
|
||||
|
||||
wrapped_args = []
|
||||
tensor_input_args_map = OrderedDict()
|
||||
|
||||
# Be noted: in inference mode, we won't insert any information into _GlobalOpKernelInfoMap, because ctx
|
||||
# will always be None in the first time run.
|
||||
input_indices_to_save_in_ctx = None # Uninitialized
|
||||
if kernel_invoke_id in _GlobalOpKernelInfoMap:
|
||||
input_indices_to_save_in_ctx = _GlobalOpKernelInfoMap[kernel_invoke_id].input_indices_to_save_in_ctx
|
||||
|
||||
for arg_index, (grad_flag, tensor_flag, arg) in enumerate(zip(requires_grad_flags, tensor_type_flags, args)):
|
||||
for _, (grad_flag, tensor_flag, arg) in enumerate(zip(requires_grad_flags, tensor_type_flags, args)):
|
||||
if tensor_flag:
|
||||
# Assume it's a DLPack tensor and convert it to PyTorch tensor.
|
||||
wrapped_arg = from_dlpack(arg)
|
||||
|
||||
if tensor_input_index in inplace_map:
|
||||
raw_input_tensors_used_inplace[tensor_input_index] = wrapped_arg
|
||||
|
||||
# Note1:
|
||||
# If it's first-time kernel invocation, input_indices_to_save_in_ctx is None, we do the
|
||||
# copy for all tensor. Otherwise, we only copy the tensors whose indices are in
|
||||
# input_indices_to_save_in_ctx.
|
||||
#
|
||||
# If it's first-time kernel invocation, tensor_input_indices_to_save_in_ctx is None, we do the
|
||||
# copy for all tensors. Otherwise, we only copy the tensors whose indices are in
|
||||
# tensor_input_indices_to_save_in_ctx.
|
||||
# Note2:
|
||||
# For inference mode, we don't need do the copy because ctx will be None,
|
||||
# For inference mode, we don't need to do the copy because ctx will be None,
|
||||
# so nothing will be saved for ctx.
|
||||
if is_training_mode and (
|
||||
input_indices_to_save_in_ctx is None or arg_index in input_indices_to_save_in_ctx
|
||||
tensor_input_indices_to_save_in_ctx is None
|
||||
or tensor_input_index in tensor_input_indices_to_save_in_ctx
|
||||
):
|
||||
wrapped_arg = from_dlpack(arg).detach().clone()
|
||||
else:
|
||||
wrapped_arg = from_dlpack(arg)
|
||||
wrapped_arg = wrapped_arg.detach().clone()
|
||||
|
||||
# Only requires gradient when running under training mode
|
||||
# and the associated tensor has grad_flag=True (i.e.,
|
||||
# "requires_grad=True" in the original PyTorch script).
|
||||
wrapped_arg.requires_grad = is_training_mode and grad_flag
|
||||
wrapped_args.append(wrapped_arg)
|
||||
tensor_input_args_map[arg_index] = wrapped_arg
|
||||
|
||||
# Note3:
|
||||
# If it's not first-time kernel invocation, tensor_input_indices_for_mark_dirty is None, we do the
|
||||
# mul for all tensors. Otherwise, we only mul by one for the tensors whose indices are in
|
||||
# tensor_input_indices_for_mark_dirty.
|
||||
if is_training_mode and (
|
||||
tensor_input_indices_for_mark_dirty is None
|
||||
or tensor_input_index in tensor_input_indices_for_mark_dirty
|
||||
):
|
||||
# To fix this issue:
|
||||
# "a leaf Variable that requires grad has been used in an in-place operation."
|
||||
with torch.set_grad_enabled(True):
|
||||
wrapped_arg = wrapped_arg.clone()
|
||||
|
||||
wrapped_args.append(wrapped_arg)
|
||||
input_tensors_used_for_fw_run[tensor_input_index] = wrapped_arg
|
||||
|
||||
tensor_input_index += 1
|
||||
else:
|
||||
# Use non-tensor as is. It's a PyObject*.
|
||||
wrapped_args.append(arg)
|
||||
|
||||
with torch.set_grad_enabled(is_training_mode):
|
||||
# Another level of wrap to avoid requires_grad=True for leaf variables.
|
||||
new_wrapped_args = list(
|
||||
generate_non_leaf_or_not(grad_flag, tensor_flag, arg, is_training_mode, inplace)
|
||||
for grad_flag, tensor_flag, arg in zip(requires_grad_flags, tensor_type_flags, wrapped_args)
|
||||
)
|
||||
|
||||
# Run autograd.Function.apply(...).
|
||||
# TODO(pengwa): looks we are assuming all outputs will be either Tensor or None.
|
||||
# TODO(pengwa): looks like we are assuming all outputs will be either Tensor or None.
|
||||
# We should revisit if it is possible to support other types of output, for example int, or, etc.
|
||||
# But that might also requires some work in backend.
|
||||
result = forward_function(*new_wrapped_args)
|
||||
# But that might also require some work in backend.
|
||||
result = forward_function(*wrapped_args)
|
||||
|
||||
# Extract results as DLPack tensors plus autograd context. Also skips all None values.
|
||||
results = []
|
||||
if isinstance(result, torch.Tensor):
|
||||
ctx = None
|
||||
if is_training_mode:
|
||||
ctx = _finalize_traing_mode_forward(kernel_invoke_id, tensor_input_args_map, [result])
|
||||
unwrapped_values = [ctx, to_dlpack(result)]
|
||||
results = [result]
|
||||
elif isinstance(result, (tuple, list)):
|
||||
ctx = None
|
||||
if is_training_mode:
|
||||
ctx = _finalize_traing_mode_forward(kernel_invoke_id, tensor_input_args_map, result)
|
||||
wrapped = [ctx]
|
||||
wrapped.extend(list(to_dlpack(value) if value is not None else None for value in result))
|
||||
# Inside the returned list, first element is context and the rest
|
||||
# are DLPack tensors.
|
||||
unwrapped_values = wrapped
|
||||
results = [r for r in result]
|
||||
else:
|
||||
raise wrap_exception(
|
||||
ORTModuleIOError,
|
||||
TypeError(f"ORTModule does not support the following model output type {type(result)}."),
|
||||
)
|
||||
return tuple(unwrapped_values)
|
||||
|
||||
ctx = None
|
||||
if is_training_mode:
|
||||
ctx = _finalize_training_mode_forward(
|
||||
kernel_invoke_id, func_name, input_tensors_used_for_fw_run, results
|
||||
)
|
||||
|
||||
final_rets = [ctx]
|
||||
final_rets.extend(results)
|
||||
|
||||
_process_inplace_outputs(
|
||||
kernel_info,
|
||||
func_name,
|
||||
input_tensors_used_for_fw_run.values(),
|
||||
final_rets,
|
||||
inplace_map,
|
||||
raw_input_tensors_used_inplace,
|
||||
)
|
||||
|
||||
dlpacks = [final_rets[0]]
|
||||
dlpacks.extend(list(to_dlpack(value) if value is not None else None for value in final_rets[1:]))
|
||||
|
||||
# Inside the returned list, the first element is context and the rest
|
||||
# are DLPack tensors.
|
||||
return tuple(dlpacks)
|
||||
except Exception as e:
|
||||
# Flush buffers. Otherwise, calling this from C++ may lose them.
|
||||
print("Exception happens when running ", forward_function)
|
||||
|
|
@ -306,8 +556,9 @@ def call_python_backward_function(
|
|||
requires_grad_flags: List[bool],
|
||||
tensor_type_flags: List[int],
|
||||
is_training_mode: bool,
|
||||
inplace: bool,
|
||||
inplace_map: List[int],
|
||||
kernel_invoke_id: str,
|
||||
func_name: Union[bytes, str],
|
||||
*args,
|
||||
):
|
||||
"""
|
||||
|
|
@ -319,11 +570,13 @@ def call_python_backward_function(
|
|||
Args:
|
||||
backward_function: pointer to autograd.Function.backward (e.g., MyReLU.backward).
|
||||
requires_grad_flags: requires_grad_flags[i] indicates if the i-th arg needs gradient.
|
||||
tensor_type_flags: tensor_type_flagsi] indicates the type of the i-th arg.
|
||||
tensor_type_flags: tensor_type_flags[i] indicates the type of the i-th arg.
|
||||
is_training_mode: indicates if this model is running under training mode.
|
||||
inplace: indicates if args can be modified inside the custom function.
|
||||
inplace_map: a list of the same length of kernel outputs, each element represents which input index
|
||||
it is reusing. If there is no reuse, the value is -1.
|
||||
args: inputs to "backward_function".
|
||||
"""
|
||||
func_name = func_name.decode("utf-8") if isinstance(func_name, bytes) else func_name
|
||||
with torch.no_grad():
|
||||
|
||||
def wrap_all_outputs(result):
|
||||
|
|
@ -338,6 +591,13 @@ def call_python_backward_function(
|
|||
)
|
||||
|
||||
try:
|
||||
# If this is the first time run, collect runtime tensor reuse mapping.
|
||||
if kernel_invoke_id not in _GlobalOpKernelInfoMap:
|
||||
kernel_info = CustomFuncOpKernelInfo(kernel_invoke_id)
|
||||
_GlobalOpKernelInfoMap[kernel_invoke_id] = kernel_info
|
||||
|
||||
kernel_info = _GlobalOpKernelInfoMap[kernel_invoke_id]
|
||||
|
||||
# Backward inputs should not require gradients.
|
||||
assert all(grad_flag == 0 for grad_flag in requires_grad_flags)
|
||||
|
||||
|
|
@ -345,6 +605,12 @@ def call_python_backward_function(
|
|||
ctx = args[0]
|
||||
fw_kernel_invoke_id = ctx.fw_kernel_invoke_id
|
||||
wrapped_args = []
|
||||
|
||||
# Collect the tensor address for all inputs used for run backward, used for reuse detection.
|
||||
tensor_input_index = 1 # skip the context input
|
||||
# If input is reused, we need to save the raw input tensor for special handling.
|
||||
raw_input_tensors_used_inplace = OrderedDict() # Orders matter here.
|
||||
input_tensors_used_for_bw_run = OrderedDict() # Orders matter here.
|
||||
for grad_input_index, (grad_flag, tensor_flag, arg) in enumerate(
|
||||
zip(requires_grad_flags, tensor_type_flags, args)
|
||||
):
|
||||
|
|
@ -362,12 +628,19 @@ def call_python_backward_function(
|
|||
# Assume it's a DLPack tensor# and convert it to PyTorch tensor.
|
||||
wrapped_arg = from_dlpack(arg)
|
||||
|
||||
if grad_input_index in inplace_map:
|
||||
raw_input_tensors_used_inplace[tensor_input_index] = wrapped_arg
|
||||
|
||||
input_tensors_used_for_bw_run[tensor_input_index] = wrapped_arg
|
||||
|
||||
if wrapped_arg is not None:
|
||||
# Only requires gradient when running under training mode
|
||||
# and the associated tensor has grad_flag=True (i.e.,
|
||||
# "requires_grad=True" in the original PyTorch script).
|
||||
wrapped_arg.requires_grad = is_training_mode and grad_flag
|
||||
|
||||
wrapped_args.append(wrapped_arg)
|
||||
tensor_input_index += 1
|
||||
else:
|
||||
# Use non-tensor as is. It's a PyObject*.
|
||||
wrapped_args.append(arg)
|
||||
|
|
@ -386,6 +659,16 @@ def call_python_backward_function(
|
|||
TypeError(f"ORTModule does not support the following model output type {type(result)}."),
|
||||
)
|
||||
|
||||
_process_inplace_outputs(
|
||||
kernel_info,
|
||||
func_name,
|
||||
input_tensors_used_for_bw_run.values(),
|
||||
result,
|
||||
inplace_map,
|
||||
raw_input_tensors_used_inplace,
|
||||
is_backward=True,
|
||||
)
|
||||
|
||||
wrapped_returned_args = wrap_all_outputs(result)
|
||||
|
||||
torch_interop_utils.unregister_grad_fn(id(ctx))
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
|
@ -234,16 +235,16 @@ def _create_weight_retrieval_pythonop(
|
|||
func_full_qual_name: str,
|
||||
input_name: str,
|
||||
output_names: List[str],
|
||||
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE,
|
||||
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE: List[int],
|
||||
pull_weight_trigger_output_dtype: int,
|
||||
pull_weight_trigger_output_shape: List[int],
|
||||
) -> Tuple[ValueInfoProto, NodeProto]:
|
||||
"""This function is used to create a weight retrieving PythonOp."""
|
||||
offload_param_count = 0 if zero_stage3_named_params is None else len(zero_stage3_named_params)
|
||||
new_input = helper.make_tensor_value_info(
|
||||
input_name, STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE
|
||||
input_name, pull_weight_trigger_output_dtype, pull_weight_trigger_output_shape
|
||||
)
|
||||
output_rank_for_pull_weight_trigger = len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE)
|
||||
output_dtype_for_pull_weight_trigger = STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE
|
||||
output_rank_for_pull_weight_trigger = len(pull_weight_trigger_output_shape)
|
||||
output_dtype_for_pull_weight_trigger = pull_weight_trigger_output_dtype
|
||||
output_tensor_ranks = [
|
||||
output_rank_for_pull_weight_trigger,
|
||||
] * offload_param_count
|
||||
|
|
@ -253,10 +254,9 @@ def _create_weight_retrieval_pythonop(
|
|||
|
||||
node_attributes = {
|
||||
"comment": "",
|
||||
"inplace": 0,
|
||||
"input_convention": "d",
|
||||
"input_tensor_ranks": [len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE)],
|
||||
"input_tensor_types": [STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE],
|
||||
"input_tensor_ranks": [len(pull_weight_trigger_output_shape)],
|
||||
"input_tensor_types": [pull_weight_trigger_output_dtype],
|
||||
"output_tensor_ranks": output_tensor_ranks,
|
||||
"output_tensor_types": output_tensor_types,
|
||||
"training_mode": 1,
|
||||
|
|
|
|||
|
|
@ -150,6 +150,34 @@ bool get_materialize_grads(at::Tensor target) {
|
|||
return py_fn->materialize_grads;
|
||||
}
|
||||
|
||||
std::vector<bool> are_tensors_marked_as_dirty(at::Tensor target, std::vector<at::Tensor> tensors_to_check) {
|
||||
torch::autograd::AutogradMeta* autograd_meta = torch::autograd::impl::get_autograd_meta(target);
|
||||
const auto& grad_fn = autograd_meta->grad_fn_;
|
||||
auto py_node_fn = dynamic_cast<torch::autograd::PyNode*>(grad_fn.get());
|
||||
TORCH_CHECK(py_node_fn != nullptr, "grad_fn is not PyNode type.");
|
||||
THPFunction* py_fn = (THPFunction*)py_node_fn->obj;
|
||||
std::vector<bool> are_tensors_marked_dirty(tensors_to_check.size(), false);
|
||||
if (!py_fn->dirty_tensors)
|
||||
return are_tensors_marked_dirty;
|
||||
|
||||
Py_ssize_t num_dirty = PyTuple_GET_SIZE(py_fn->dirty_tensors);
|
||||
for (const auto j : c10::irange(tensors_to_check.size())) {
|
||||
bool is_tensor_marked_dirty = false;
|
||||
for (const auto i : c10::irange(num_dirty)) {
|
||||
PyObject* obj = PyTuple_GET_ITEM(py_fn->dirty_tensors, i);
|
||||
const auto& tensor = THPVariable_Unpack(obj);
|
||||
if (tensor.is_same(tensors_to_check[j])) {
|
||||
is_tensor_marked_dirty = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
are_tensors_marked_dirty[j] = is_tensor_marked_dirty;
|
||||
}
|
||||
|
||||
return are_tensors_marked_dirty;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("register_grad_fn_and_remove_from_autograd", ®ister_grad_fn_and_remove_from_autograd,
|
||||
"Increase grad_fn shared pointer reference.");
|
||||
|
|
@ -158,4 +186,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
m.def("clear_grad_fns_for_next_edges", &clear_grad_fns_for_next_edges,
|
||||
"Remove reference on next edges' gradient functions.");
|
||||
m.def("get_materialize_grads", &get_materialize_grads, "Return whether materialize_grads is enabled or not.");
|
||||
m.def("are_tensors_marked_as_dirty", &are_tensors_marked_as_dirty, "Return whether the tensors are marked dirty or not.");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1549,7 +1549,7 @@ def test_python_op_save_input_for_backward():
|
|||
count += 1
|
||||
|
||||
if index == 0:
|
||||
assert count == 1
|
||||
assert count == 2
|
||||
else:
|
||||
assert count == 0
|
||||
|
||||
|
|
@ -1717,3 +1717,97 @@ def test_customized_shape_inference():
|
|||
).train()
|
||||
_ = ortmodule(torch.randn(output_size, dtype=torch.float))
|
||||
_check_pythonop_shape(ortmodule)
|
||||
|
||||
|
||||
def test_python_op_return_persistent_param_as_value():
|
||||
"""Some PythonOp return values that are still used by PyTorch computation. This test makes sure that ORTModule
|
||||
will not release/erase the storage of those return values during tear down OrtValue of the corresponding PythonOp
|
||||
return values.
|
||||
"""
|
||||
|
||||
class SimplePassThrough(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x.detach()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output
|
||||
|
||||
class GeluWithExternalOutput(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, bias_param):
|
||||
ctx.save_for_backward(x)
|
||||
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))), bias_param.detach()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outputs):
|
||||
(x,) = ctx.saved_tensors
|
||||
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
||||
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
|
||||
g = ff * grad_outputs[0]
|
||||
return g, grad_outputs[1]
|
||||
|
||||
class TestLayer(torch.nn.Module):
|
||||
def __init__(self, output_size):
|
||||
super().__init__()
|
||||
self.relu = GeluWithExternalOutput.apply
|
||||
self._output_size = output_size
|
||||
self.bias = Parameter(torch.empty(output_size, device=torch.cuda.current_device(), dtype=torch.float))
|
||||
self.w = Parameter(
|
||||
torch.empty(output_size, output_size, device=torch.cuda.current_device(), dtype=torch.float)
|
||||
)
|
||||
with torch.no_grad():
|
||||
self.bias.uniform_()
|
||||
self.w.uniform_()
|
||||
|
||||
def forward(self, model_input):
|
||||
activation0 = torch.add(model_input, 0.4)
|
||||
activation1 = activation0.view(self._output_size, -1)
|
||||
|
||||
# Returned detached_bias_param Tensor shares the same storage with self.bias
|
||||
# We are testing to make sure ORT will not erase the storage of self.bias during tear down OrtValue as
|
||||
# the returned value of the SimplePassThrough PythonOp.
|
||||
detached_bias_param = SimplePassThrough.apply(self.bias)
|
||||
relu_out, detached_bias_param = self.relu(activation1, detached_bias_param)
|
||||
activation2 = torch.add(relu_out, self.bias)
|
||||
activation3 = torch.add(activation2, detached_bias_param)
|
||||
activation3 = torch.matmul(self.w, activation3)
|
||||
activation4 = torch.div(activation3, 1000)
|
||||
return activation4
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self, output_size) -> None:
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList([TestLayer(output_size) for i in range(6)])
|
||||
|
||||
def forward(self, x):
|
||||
# ModuleList can act as an iterable, or be indexed using ints
|
||||
for layer in self.layers:
|
||||
x = x.view(-1)
|
||||
x = torch.nn.functional.relu(layer(x))
|
||||
return x
|
||||
|
||||
device = "cuda"
|
||||
output_size = 1024
|
||||
pt_model = TestModule(output_size).to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
def _run_step(model, input):
|
||||
loss = model(input).sum()
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
for _ in range(5):
|
||||
input = torch.randn(output_size, device=device, dtype=torch.float)
|
||||
_run_step(pt_model, input)
|
||||
_run_step(ort_model, input)
|
||||
|
||||
pt_params = {n: p for n, p in pt_model.named_parameters()}
|
||||
for name, param in ort_model.named_parameters():
|
||||
assert_values_are_close(param, pt_params[name], rtol=1e-04, atol=1e-3)
|
||||
if param.grad is not None:
|
||||
assert pt_params[name].grad is not None, f"pt param.grad is None for {name}"
|
||||
assert_values_are_close(param.grad, pt_params[name].grad, rtol=1e-04, atol=1e-3)
|
||||
else:
|
||||
assert pt_params[name].grad is None
|
||||
|
|
|
|||
|
|
@ -49,7 +49,6 @@ std::vector<std::optional<OrtValue>> CreateOrtValueArgs(OpKernelContext* context
|
|||
|
||||
void PythonOpBase::Init(const OpKernelInfo& info) {
|
||||
ORT_THROW_IF_ERROR(info.GetAttr("func_name", &name_));
|
||||
ORT_THROW_IF_ERROR(info.GetAttr("inplace", &inplace_));
|
||||
|
||||
is_training_mode_ = static_cast<bool>(info.GetAttrOrDefault("training_mode", static_cast<int64_t>(0)));
|
||||
ORT_THROW_IF_ERROR(info.GetAttr("input_convention", &input_convention_));
|
||||
|
|
@ -117,6 +116,9 @@ void PythonOpBase::Init(const OpKernelInfo& info) {
|
|||
// Output tensors.
|
||||
ORT_THROW_IF_ERROR(info.GetAttrs("output_tensor_types", output_tensor_types_));
|
||||
|
||||
all_output_to_tensor_input_reuse_map_ =
|
||||
info.GetAttrsOrDefault("tensor_reuse_map", std::vector<int64_t>((info.node().OutputDefs().size()), -1));
|
||||
|
||||
CreateConstArgs();
|
||||
CreateArgPositions();
|
||||
|
||||
|
|
@ -141,17 +143,18 @@ void PythonOpBase::RunForward(OpKernelContext* context,
|
|||
std::vector<std::optional<OrtValue>> args = CreateOrtValueArgs(context, 0, context->InputCount());
|
||||
// Invoke Python calls.
|
||||
TorchProxy::GetInstance().Forward(
|
||||
name_,
|
||||
OrtTorchFunctionPool::GetInstance().GetForwardCore(name_),
|
||||
input_requires_grads_,
|
||||
args,
|
||||
arg_positions_,
|
||||
const_arg_set_.GetDataPtrs(),
|
||||
const_arg_set_.GetPositions(),
|
||||
diff_ctx,
|
||||
returned_ortvalues,
|
||||
is_training_mode_,
|
||||
inplace_ != 0,
|
||||
kernel_invoke_id_);
|
||||
all_output_to_tensor_input_reuse_map_,
|
||||
kernel_invoke_id_,
|
||||
diff_ctx,
|
||||
returned_ortvalues);
|
||||
|
||||
const size_t returned_output_count = 1 + returned_ortvalues.size();
|
||||
const size_t kernel_output_count = static_cast<size_t>(context->OutputCount());
|
||||
|
|
@ -291,14 +294,32 @@ void PythonOpBase::SetContextOutput(OpKernelContext* context, void* diff_ctx) co
|
|||
|
||||
void PythonOpBase::SetOtherOutputs(OpKernelContext* context, std::vector<OrtValue>& returned_ortvalues) const {
|
||||
auto* ctx_internal = reinterpret_cast<onnxruntime::OpKernelContextInternal*>(context);
|
||||
ORT_ENFORCE(returned_ortvalues.size() == all_output_to_tensor_input_reuse_map_.size() - 1, "PythonOp output count mismatch inplace map count.",
|
||||
returned_ortvalues.size(), " != ", all_output_to_tensor_input_reuse_map_.size() - 1);
|
||||
for (size_t i = 0; i < returned_ortvalues.size(); ++i) {
|
||||
size_t output_index = i + 1;
|
||||
if (all_output_to_tensor_input_reuse_map_[output_index] != -1) {
|
||||
const void* tensor_address = returned_ortvalues[i].Get<Tensor>().DataRaw();
|
||||
const void* input_tensor_address = context->Input<Tensor>(all_output_to_tensor_input_reuse_map_[output_index])->DataRaw();
|
||||
ORT_ENFORCE(tensor_address == input_tensor_address,
|
||||
"PythonOp inplace tensor address mismatch, output index: ", output_index, ", input index: ",
|
||||
all_output_to_tensor_input_reuse_map_[output_index]);
|
||||
}
|
||||
|
||||
// Notes: if the buffer is created, managed by PyTorch, converted to OrtValue through dlpack here,
|
||||
// but also be used outside ORT later, we don't need to be concerned about
|
||||
// "when the buffer of returned_ortvalues[i] is erased by ORT during releasing that OrtValue causing
|
||||
// the PyTorch code still using that buffer will be failed".
|
||||
// In this case, the created OrtValue's destructor will not release the buffer,
|
||||
// instead it will release a tensor pointing to that buffer, where PyTorch will decide whether to release
|
||||
// the buffer or not, if the tensor storage is not used by any other tensors
|
||||
// (https://github.com/PyTorch/PyTorch/blob/ac603bc2f8ffac8fc061cfb99e77537464da4b18/aten/src/ATen/DLConvertor.cpp#L257C25-L257C29).
|
||||
ORT_THROW_IF_ERROR(ctx_internal->SetOutputMLValue(static_cast<int>(i + 1), returned_ortvalues[i]));
|
||||
}
|
||||
}
|
||||
|
||||
void PythonOpGradBase::Init(const OpKernelInfo& info) {
|
||||
ORT_THROW_IF_ERROR(info.GetAttr("func_name", &name_));
|
||||
ORT_THROW_IF_ERROR(info.GetAttr("inplace", &inplace_));
|
||||
ORT_THROW_IF_ERROR(info.GetAttrs("input_tensor_types", input_tensor_types_));
|
||||
ORT_THROW_IF_ERROR(info.GetAttr("output_convention", &output_convention_));
|
||||
ORT_THROW_IF_ERROR(info.GetAttrs("output_tensor_types", output_tensor_types_));
|
||||
|
|
@ -306,6 +327,24 @@ void PythonOpGradBase::Init(const OpKernelInfo& info) {
|
|||
ORT_ENFORCE(output_tensor_types_.size() == output_tensor_requires_grads_.size(),
|
||||
"backward tensor output count mismatch");
|
||||
|
||||
std::vector<int64_t> tensor_output_to_tensor_input_alias_map =
|
||||
info.GetAttrsOrDefault("tensor_reuse_map",
|
||||
std::vector<int64_t>((info.node().OutputDefs().size()), -1));
|
||||
all_output_to_tensor_input_reuse_map_.clear();
|
||||
all_output_to_tensor_input_reuse_map_.reserve(output_convention_.size());
|
||||
size_t tensor_output_index = 0;
|
||||
for (size_t i = 0; i < output_convention_.size(); ++i) {
|
||||
if (output_convention_[i] == 'd') {
|
||||
all_output_to_tensor_input_reuse_map_.push_back(
|
||||
tensor_output_to_tensor_input_alias_map[tensor_output_index] == -1
|
||||
? -1
|
||||
: tensor_output_to_tensor_input_alias_map[tensor_output_index]);
|
||||
++tensor_output_index;
|
||||
} else {
|
||||
all_output_to_tensor_input_reuse_map_.push_back(-1);
|
||||
}
|
||||
}
|
||||
|
||||
SetPositions();
|
||||
|
||||
kernel_invoke_id_ = GetInvokeIdString(this);
|
||||
|
|
@ -314,7 +353,7 @@ void PythonOpGradBase::Init(const OpKernelInfo& info) {
|
|||
void PythonOpGradBase::RunBackward(OpKernelContext* context,
|
||||
std::vector<OrtValue>& returned_ortvalues) const {
|
||||
std::vector<std::optional<OrtValue>> args = CreateOrtValueArgs(context, 1, context->InputCount() - 1);
|
||||
// This is called "const" because that's how Pytorch calls all non-tensor inputs.
|
||||
// This is called "const" because that's how PyTorch calls all non-tensor inputs.
|
||||
const Tensor* context_id_tensor = context->Input<Tensor>(0);
|
||||
ORT_ENFORCE(context_id_tensor, "Context ID (first input) should not be null.");
|
||||
const int64_t* context_index_ptr = context_id_tensor->template Data<int64_t>();
|
||||
|
|
@ -323,15 +362,15 @@ void PythonOpGradBase::RunBackward(OpKernelContext* context,
|
|||
|
||||
std::string err;
|
||||
TorchProxy::GetInstance().Backward(
|
||||
OrtTorchFunctionPool::GetInstance()
|
||||
.GetBackwardCore(name_),
|
||||
name_,
|
||||
OrtTorchFunctionPool::GetInstance().GetBackwardCore(name_),
|
||||
args,
|
||||
arg_positions_,
|
||||
const_args,
|
||||
const_arg_positions_,
|
||||
returned_ortvalues,
|
||||
inplace_ != 0,
|
||||
kernel_invoke_id_);
|
||||
all_output_to_tensor_input_reuse_map_,
|
||||
kernel_invoke_id_,
|
||||
returned_ortvalues);
|
||||
|
||||
OrtTorchFunctionPool::GetInstance().UnregisterContext(*context_index_ptr);
|
||||
}
|
||||
|
|
@ -343,6 +382,29 @@ void PythonOpGradBase::SetOutputs(OpKernelContext* context, std::vector<OrtValue
|
|||
for (size_t i = 0; i < returned_ortvalues.size(); ++i) {
|
||||
if (output_convention_[i] == 'd') {
|
||||
if (output_tensor_requires_grads_[tensor_output_index]) {
|
||||
if (all_output_to_tensor_input_reuse_map_[i] != -1) {
|
||||
const Tensor* input_tensor = context->Input<Tensor>(all_output_to_tensor_input_reuse_map_[i]);
|
||||
if (input_tensor) {
|
||||
ORT_ENFORCE(input_tensor, "PythonOpGrad input tensor should not be null. input index: ", all_output_to_tensor_input_reuse_map_[i]);
|
||||
|
||||
// Be noted: PythonOpGrad's input won't be non-tensor.
|
||||
ORT_ENFORCE(all_output_to_tensor_input_reuse_map_[i] < context->InputCount(), "PythonOpGrad inplace tensor index out of bound.");
|
||||
const void* tensor_address = returned_ortvalues[i].Get<Tensor>().DataRaw();
|
||||
|
||||
const void* input_tensor_address = input_tensor->DataRaw();
|
||||
ORT_ENFORCE(tensor_address == input_tensor_address,
|
||||
"PythonOpGrad inplace tensor address mismatch, output index: ", i, ", input index: ", all_output_to_tensor_input_reuse_map_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Notes: if the buffer is created, managed by PyTorch, converted to OrtValue through dlpack here,
|
||||
// but also be used outside ORT later, we don't need to be concerned about
|
||||
// "when the buffer of returned_ortvalues[i] is erased by ORT during releasing that OrtValue causing
|
||||
// the PyTorch code still using that buffer will be failed".
|
||||
// In this case, the created OrtValue's destructor will not release the buffer,
|
||||
// instead it will release a tensor pointing to that buffer, where PyTorch will decide whether to release
|
||||
// the buffer or not, if the tensor storage is not used by any other tensors
|
||||
// (https://github.com/PyTorch/PyTorch/blob/ac603bc2f8ffac8fc061cfb99e77537464da4b18/aten/src/ATen/DLConvertor.cpp#L257C25-L257C29).
|
||||
ORT_THROW_IF_ERROR(ctx_internal->SetOutputMLValue(tensor_output_index, returned_ortvalues.at(i)));
|
||||
}
|
||||
++tensor_output_index;
|
||||
|
|
@ -356,11 +418,11 @@ void PythonOpGradBase::SetPositions() {
|
|||
ORT_ENFORCE(const_arg_positions_.size() == 0);
|
||||
ORT_ENFORCE(arg_positions_.size() == 0);
|
||||
|
||||
// Pytorch's autograd context is the first (indexed by 0) input of the called Python function.
|
||||
// PyTorch's autograd context is the first (indexed by 0) input of the called Python function.
|
||||
// Note that here we will call autograd.Function.backward(ctx, tensor0, tensor1, ...).
|
||||
const_arg_positions_ = {0};
|
||||
|
||||
// The rest inputs are just Pytorch tensors.
|
||||
// The rest inputs are just PyTorch tensors.
|
||||
arg_positions_.resize(input_tensor_types_.size());
|
||||
for (size_t i = 0; i < arg_positions_.size(); ++i) {
|
||||
// i-th tensor is the (i+1)-th input of autograd.Function.backward.
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ class PythonOpBase {
|
|||
|
||||
// Name of containing class. For example, MyReLU.
|
||||
std::string name_;
|
||||
int64_t inplace_;
|
||||
std::vector<int64_t> all_output_to_tensor_input_reuse_map_;
|
||||
std::string input_convention_;
|
||||
bool is_training_mode_;
|
||||
// input_requires_grads_[i] indicates if the i-th inputs of apply() should have gradient.
|
||||
|
|
@ -179,7 +179,7 @@ class PythonOpGradBase {
|
|||
protected:
|
||||
// Name of containing class. For example, MyReLU.
|
||||
std::string name_;
|
||||
int64_t inplace_;
|
||||
|
||||
// Input types of MyReLU.backward(...).
|
||||
std::vector<int64_t> input_tensor_types_;
|
||||
|
||||
|
|
@ -190,6 +190,9 @@ class PythonOpGradBase {
|
|||
std::vector<int64_t> arg_positions_;
|
||||
std::vector<int64_t> const_arg_positions_;
|
||||
|
||||
// Memory reuse map for all outputs.
|
||||
std::vector<int64_t> all_output_to_tensor_input_reuse_map_;
|
||||
|
||||
private:
|
||||
void SetPositions();
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue