mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Fix reference count for autograd.Function (#15121)
### Fix reference count for autograd
When PythonOp kernel initialized, `AddPointerScalarArgs` creates
`const_args_` which put all non-tensor references (including
ProcessGroup, string, or other user types) in it.
In kernel's destructor, all ref cnt got decreased for `const_args_`.
```
void PythonOpBase::Clear() {
for (auto ptr : const_args_) {
auto obj = reinterpret_cast<PyObject*>(ptr);
Py_DECREF(obj);
}
}
```
It means, we did not increase cnt, but just decrease cnt. Running the
unit, segmentation fault will be thrown. The simple fix is to remove the
Py_DECREF for those pointer-type constant inputs triggered by kernel
destructor.
NONTENSOR_OBJECT_POINTER_STORE is the place we increase the reference
during export, then the reference will remain until the python program
terminates.
Additionally tunings:
1. Move some logs into verbose instead of warning in case of flooding
training logs.
2. Move pointer type ref holding from python side
(NONTENSOR_OBJECT_POINTER_STORE) to
orttraining/orttraining/core/framework/torch/custom_function_register.h.
Then we use a consistent approach to manage all PythonOp related python
object/methonds ref count increasing and decreasing.
This commit is contained in:
parent
f972d21e81
commit
7bec80d92a
11 changed files with 209 additions and 72 deletions
|
|
@ -6,6 +6,8 @@
|
|||
#include "orttraining/core/framework/torch/refcount_tracker.h"
|
||||
#include "core/platform/env.h"
|
||||
#include <cstdio>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace language_interop_ops {
|
||||
|
|
@ -151,6 +153,15 @@ PyObject* OrtTorchFunctionPool::GetBackwardCore(const std::string& key) {
|
|||
return iter->second.get();
|
||||
}
|
||||
|
||||
void OrtTorchFunctionPool::RegisterMiscellaneousConstInput(PyObject* obj) {
|
||||
ORT_ENFORCE(obj, "Cannot register NULL reference input.");
|
||||
const void* address = static_cast<const void*>(obj);
|
||||
std::stringstream ss;
|
||||
ss << address;
|
||||
std::string key = ss.str();
|
||||
RegisterEntry(mutex_, key, obj, miscellaneous_const_input_pool_);
|
||||
}
|
||||
|
||||
int64_t OrtTorchFunctionPool::RegisterContext(PyObject* autograd_context) {
|
||||
static int64_t index_ = 0x1000000;
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
|
@ -185,12 +196,21 @@ PyObject* OrtTorchFunctionPool::GetContext(int64_t context_index) {
|
|||
return iter->second.get();
|
||||
}
|
||||
|
||||
void OrtTorchFunctionPool::UnRegisterFunctions() {
|
||||
void OrtTorchFunctionPool::UnRegisterGlobalFunctions() {
|
||||
forward_runner_.reset();
|
||||
backward_runner_.reset();
|
||||
func_context_pool_.clear();
|
||||
}
|
||||
|
||||
void OrtTorchFunctionPool::UnRegisterModelSpecificFunctions() {
|
||||
forward_core_pool_.clear();
|
||||
backward_core_pool_.clear();
|
||||
func_context_pool_.clear();
|
||||
miscellaneous_const_input_pool_.clear();
|
||||
}
|
||||
|
||||
void OrtTorchFunctionPool::UnRegisterFunctions() {
|
||||
UnRegisterGlobalFunctions();
|
||||
UnRegisterModelSpecificFunctions();
|
||||
}
|
||||
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -34,6 +34,14 @@ class OrtTorchFunctionPool final {
|
|||
// 2. Caller of GetBackwardCore should not decrease the reference count of the returned object.
|
||||
PyObject* GetBackwardCore(const std::string& key); // The "key" is the "name" attribute in PythonOpGrad.
|
||||
|
||||
// Autograd function may take input of "non-tensor && non int/float && non int/float tuple" types.
|
||||
// While PythonOp running requires those inputs be there otherwise kernel execution will fail.
|
||||
// So during model exporting, we need register those input with this API, then a ref cnt is increased by 1,
|
||||
// they will not be released until OrtTorchFunctionPool is destroyed.
|
||||
// We also trying to release those registration in 'UnRegisterFunctions' to avoid the issues of python program
|
||||
// exits before we de-crease ref cnt for the already release python object.
|
||||
void RegisterMiscellaneousConstInput(PyObject* obj);
|
||||
|
||||
// Context is torch backward gradient function pointer, and
|
||||
// it is a property of forward run outputs (tensors), its lifecycle
|
||||
// is along with forward run outputs in PyTorch design.
|
||||
|
|
@ -76,11 +84,15 @@ class OrtTorchFunctionPool final {
|
|||
OrtTorchFunctionPool(){};
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OrtTorchFunctionPool);
|
||||
|
||||
void UnRegisterGlobalFunctions();
|
||||
void UnRegisterModelSpecificFunctions();
|
||||
|
||||
PythonObjectPtr forward_runner_;
|
||||
PythonObjectPtr backward_runner_;
|
||||
|
||||
std::unordered_map<std::string, PythonObjectPtr> forward_core_pool_;
|
||||
std::unordered_map<std::string, PythonObjectPtr> backward_core_pool_;
|
||||
std::unordered_map<std::string, PythonObjectPtr> miscellaneous_const_input_pool_;
|
||||
std::unordered_map<int64_t, PythonObjectPtr> func_context_pool_;
|
||||
|
||||
std::mutex mutex_;
|
||||
|
|
|
|||
|
|
@ -27,7 +27,8 @@ void RefCountTracker::TrackPyObject(RefCountTracker::ObjCategory category, PyObj
|
|||
} else {
|
||||
addrs[addr].push_back(log_tag);
|
||||
}
|
||||
LOGS_DEFAULT(WARNING) << "Track" << ObjCategoryToString(category) << "\tAddress: [" << addr << "]\tRefCnt: " << Py_REFCNT(addr) << "\tLogTag: " << log_tag;
|
||||
LOGS_DEFAULT(VERBOSE) << "Track" << ObjCategoryToString(category) << "\tAddress: [" << addr << "]\tRefCnt: "
|
||||
<< Py_REFCNT(addr) << "\tLogTag: " << log_tag;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -48,7 +49,7 @@ void RefCountTracker::DumpDetails(const std::string& phase_name) const {
|
|||
}
|
||||
}
|
||||
oss << "==========================================================" << std::endl;
|
||||
LOGS_DEFAULT(WARNING) << oss.str();
|
||||
LOGS_DEFAULT(VERBOSE) << oss.str();
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -71,7 +71,8 @@ void CheckArguments(
|
|||
ORT_ENFORCE(obj_args.size() == obj_indices.size());
|
||||
|
||||
for (const auto i : requires_grads) {
|
||||
ORT_ENFORCE(i == 0 || i == 1, "Flag of requiring gradient must be either 0 (not required) or 1 (required) but got ", i);
|
||||
ORT_ENFORCE(i == 0 || i == 1,
|
||||
"Flag of requiring gradient must be either 0 (not required) or 1 (required) but got ", i);
|
||||
}
|
||||
|
||||
std::vector<int64_t> counts(len, 0);
|
||||
|
|
@ -150,15 +151,19 @@ void InvokeRunner(
|
|||
// from Pytorch.
|
||||
PyObject* py_obj = PyTuple_GetItem(result_ptr.get(), 0);
|
||||
if (is_training_mode) {
|
||||
const auto& refcnt = Py_REFCNT(py_obj);
|
||||
// We don't need do ref increase here because, python returns tensor.grad_fn as part of
|
||||
// tuple, who increased the refcnt already (and tensor persist until the backward kernels completed).
|
||||
// Pytorch also increases refcnt before apply() return, so we should expect refcount >= 2.
|
||||
// We say "at least" 2 because user could increase the context refcnt as well in their autograd forward()
|
||||
// and backward() functions.
|
||||
ORT_ENFORCE(refcnt >= 2, "Ref count of context should be 2, but actually it's ", refcnt, ".");
|
||||
if (refcnt > 2) {
|
||||
LOGS_DEFAULT(WARNING) << "Autograd context refcnt > 2";
|
||||
if (py_obj == Py_None) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Under training mode, autograd context found to be Py_None.";
|
||||
} else {
|
||||
const auto refcnt = Py_REFCNT(py_obj);
|
||||
// We don't need do ref increase here because, python returns tensor.grad_fn as part of
|
||||
// tuple, who increased the refcnt already (and tensor persist until the backward kernels completed).
|
||||
// Pytorch also increases refcnt before apply() return, so we should expect refcount >= 2.
|
||||
// We say "at least" 2 because user could increase the context refcnt as well in their autograd forward()
|
||||
// and backward() functions.
|
||||
ORT_ENFORCE(refcnt >= 2, "Ref count of context should be 2, but actually it's ", refcnt, ".");
|
||||
if (refcnt > 2) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Autograd context refcnt > 2, refcnt: " << refcnt;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ORT_ENFORCE(py_obj == Py_None, "Under inference mode, autograd context should be Py_None.");
|
||||
|
|
|
|||
|
|
@ -169,7 +169,6 @@ struct PyOptimizer {
|
|||
PyOptimizer(const std::string optimizer_model_uri,
|
||||
onnxruntime::training::api::Module* model, std::vector<std::shared_ptr<IExecutionProvider>> provider)
|
||||
: optimizer_() {
|
||||
|
||||
auto env = GetTrainingEnv().GetORTEnv();
|
||||
// XXX: We hope that env will be around when optimizer needs it.
|
||||
optimizer_ = std::make_shared<onnxruntime::training::api::Optimizer>(optimizer_model_uri,
|
||||
|
|
@ -524,6 +523,14 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
|
|||
#else
|
||||
ORT_UNUSED_PARAMETER(key);
|
||||
ORT_UNUSED_PARAMETER(obj);
|
||||
#endif
|
||||
});
|
||||
m.def("register_miscellaneous_const_input", [](py::object obj) -> void {
|
||||
#ifdef ENABLE_TRAINING_TORCH_INTEROP
|
||||
auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance();
|
||||
pool.RegisterMiscellaneousConstInput(obj.ptr());
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(obj);
|
||||
#endif
|
||||
});
|
||||
m.def("unregister_python_functions", []() -> void {
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ def enable_custom_autograd_support(to_enable=True):
|
|||
)
|
||||
from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_interop_utils
|
||||
|
||||
from ._custom_autograd_function_exporter import _clear_nontensor_object_references, _export
|
||||
from ._custom_autograd_function_exporter import _export
|
||||
|
||||
if to_enable is True and custom_autograd_function_enabler.state is False:
|
||||
if custom_autograd_function_enabler.already_enabled is False:
|
||||
|
|
@ -59,8 +59,6 @@ def enable_custom_autograd_support(to_enable=True):
|
|||
# Clear all gradient functions, to avoid a deadlock issue.
|
||||
# Check the called function for more detailed comments.
|
||||
atexit.register(torch_interop_utils.clear_all_grad_fns)
|
||||
# Clear all non-tensor object reference (for example, ProcessGroup passed to PythonOp).
|
||||
atexit.register(_clear_nontensor_object_references)
|
||||
|
||||
try:
|
||||
# This is for the latest Pytorch nightly after this commit:
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import torch.utils.checkpoint
|
|||
from packaging import version
|
||||
from torch.onnx import symbolic_helper
|
||||
|
||||
from onnxruntime.capi._pybind_state import register_torch_autograd_function
|
||||
from onnxruntime.capi._pybind_state import register_torch_autograd_function, register_miscellaneous_const_input
|
||||
from onnxruntime.training import ortmodule
|
||||
|
||||
from . import _logger
|
||||
|
|
@ -58,15 +58,6 @@ def pytorch_type_to_onnx(scalar_type: str) -> torch.onnx.TensorProtoDataType:
|
|||
return _CAST_PYTORCH_TO_ONNX[scalar_type]
|
||||
|
||||
|
||||
# For pointer needed for PythonOp execution, we firstly append it into a global store to hold a
|
||||
# reference (in case it is released after module exported).
|
||||
NONTENSOR_OBJECT_POINTER_STORE = {}
|
||||
|
||||
|
||||
def _clear_nontensor_object_references():
|
||||
NONTENSOR_OBJECT_POINTER_STORE.clear()
|
||||
|
||||
|
||||
def _export_pt_1_10(g, n, *args, **kwargs):
|
||||
"""
|
||||
This function exports PythonOp (input: "n") into a graph
|
||||
|
|
@ -168,7 +159,9 @@ def _export_pt_1_10(g, n, *args, **kwargs):
|
|||
input_pointer_scalar_positions.append(i)
|
||||
input_pointer_scalars.append(id(arg))
|
||||
|
||||
NONTENSOR_OBJECT_POINTER_STORE[id(arg)] = arg
|
||||
# For pointer (for example, ProcessGroup passed to PythonOp) needed for PythonOp execution,
|
||||
# we append it into a global store to hold a reference (in case it is released after module exported).
|
||||
register_miscellaneous_const_input(arg)
|
||||
else:
|
||||
raise wrap_exception(
|
||||
ORTModuleONNXModelException,
|
||||
|
|
|
|||
|
|
@ -244,7 +244,7 @@ def check_for_name_collisions_and_bind_methods_to_ortmodule(ortmodule: torch.nn.
|
|||
if attribute_name == "forward":
|
||||
continue
|
||||
|
||||
# This is a user defined/overriden method. Check for collisions.
|
||||
# This is a user defined/overridden method. Check for collisions.
|
||||
if attribute_name in ortmodule_attributes:
|
||||
# This is a user defined method, issue a warning.
|
||||
warnings.warn(
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ def test_gelu():
|
|||
run_training_test_and_compare(model_builder, input_generator, label_input)
|
||||
|
||||
|
||||
def test_GeLU_custom_func_rets_not_as_module_output():
|
||||
def test_gelu_custom_func_rets_not_as_module_output():
|
||||
@torch.jit.script
|
||||
def bias_gelu(bias, y):
|
||||
x = bias + y
|
||||
|
|
@ -147,7 +147,7 @@ def test_GeLU_custom_func_rets_not_as_module_output():
|
|||
run_training_test_and_compare(model_builder, input_generator, label_input)
|
||||
|
||||
|
||||
def test_GeLU_multiple_forward_runs():
|
||||
def test_gelu_multiple_forward_runs():
|
||||
@torch.jit.script
|
||||
def bias_gelu(bias, y):
|
||||
x = bias + y
|
||||
|
|
@ -199,7 +199,7 @@ def test_GeLU_multiple_forward_runs():
|
|||
run_training_test_and_compare(model_builder, input_generator, label_input, run_forward_twice=True)
|
||||
|
||||
|
||||
def test_MegatronF():
|
||||
def test_megatronf():
|
||||
# MegatronGFunction is tested in distributed test files.
|
||||
class MegatronFFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
|
@ -239,7 +239,7 @@ def test_MegatronF():
|
|||
run_training_test_and_compare(model_builder, input_generator, label_input)
|
||||
|
||||
|
||||
def test_ScalarAndTuple():
|
||||
def test_scalar_and_tuple():
|
||||
class ScalarAndTupleFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, alpha, beta, gamma):
|
||||
|
|
@ -286,7 +286,7 @@ def test_ScalarAndTuple():
|
|||
run_training_test_and_compare(model_builder, input_generator, label_input)
|
||||
|
||||
|
||||
def test_ScalarAndTupleReordered():
|
||||
def test_scalar_and_tuple_reordered():
|
||||
class ScalarAndTupleReorderedFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, alpha, beta, input, gamma):
|
||||
|
|
@ -333,6 +333,41 @@ def test_ScalarAndTupleReordered():
|
|||
run_training_test_and_compare(model_builder, input_generator, label_input)
|
||||
|
||||
|
||||
def test_pointer_type():
|
||||
class StringInputFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, name: str):
|
||||
ctx.save_for_backward(input)
|
||||
ctx.name = name
|
||||
return input.detach()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output, None
|
||||
|
||||
class StringInputFunctionTestModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.func = StringInputFunction.apply
|
||||
|
||||
def forward(self, x):
|
||||
h = self.func(x, "temp_name")
|
||||
return h
|
||||
|
||||
output_size = 2
|
||||
|
||||
def model_builder():
|
||||
return StringInputFunctionTestModel()
|
||||
|
||||
def input_generator():
|
||||
return torch.randn(output_size, dtype=torch.float).requires_grad_()
|
||||
|
||||
# generate a label that have same shape as forward output.
|
||||
label_input = torch.ones([output_size])
|
||||
|
||||
run_training_test_and_compare(model_builder, input_generator, label_input)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="This test is not correct. All tensors modified by in-place operattions should be mark_dirty(...)."
|
||||
)
|
||||
|
|
@ -633,7 +668,7 @@ def test_InplaceUpdateInputAsOutputRequireGradWithMarkDirty():
|
|||
run_training_test_and_compare(model_builder, input_generator, label_input)
|
||||
|
||||
|
||||
def test_EvalTest():
|
||||
def test_evaluation():
|
||||
class EvalTestFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# bias is an optional argument
|
||||
|
|
@ -679,7 +714,7 @@ def test_EvalTest():
|
|||
torch_version_lower_than("1.10.0"),
|
||||
reason="PyTorch older than 1.10.0 has bugs for exporting multiple output custom function",
|
||||
)
|
||||
def test_TwoOutputFunction():
|
||||
def test_two_outputs_function():
|
||||
class TwoOutputFunction1(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# bias is an optional argument
|
||||
|
|
@ -739,7 +774,7 @@ def test_TwoOutputFunction():
|
|||
run_training_test_and_compare(model_builder, input_generator, label_input)
|
||||
|
||||
|
||||
def test_InnerModuleCall():
|
||||
def test_inner_module_call():
|
||||
class InnerModel(torch.nn.Module):
|
||||
def __init__(self, dim, device):
|
||||
super(InnerModel, self).__init__()
|
||||
|
|
@ -814,7 +849,7 @@ def test_InnerModuleCall():
|
|||
torch_version_lower_than("1.10.0"),
|
||||
reason="PyTorch older than 1.10.0 has bugs for exporting multiple output custom function",
|
||||
)
|
||||
def test_Share_Input():
|
||||
def test_share_input():
|
||||
class TwoOutputFunction2(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# bias is an optional argument
|
||||
|
|
@ -865,7 +900,7 @@ def test_Share_Input():
|
|||
run_training_test_and_compare(model_builder, input_generator_with_requires_grad, label_input)
|
||||
|
||||
|
||||
def test_MultipleStream_InForwardFunction():
|
||||
def test_multiple_stream_in_forward_function():
|
||||
class MultipleStreamFunction1(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
|
|
@ -913,7 +948,7 @@ def test_MultipleStream_InForwardFunction():
|
|||
)
|
||||
|
||||
|
||||
def test_NonDefaultStream_InForwardFunction1():
|
||||
def test_nondefault_stream_in_forward_function1():
|
||||
class MultipleStreamFunction2(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
|
|
@ -961,7 +996,7 @@ def test_NonDefaultStream_InForwardFunction1():
|
|||
)
|
||||
|
||||
|
||||
def test_NonDefaultStream_InForwardFunction2():
|
||||
def test_nondefault_stream_in_forward_function2():
|
||||
class MultipleStreamFunction3(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
|
|
@ -1008,7 +1043,7 @@ def test_NonDefaultStream_InForwardFunction2():
|
|||
)
|
||||
|
||||
|
||||
def test_NonDefaultStreamInplaceUpdate_InForwardFunction():
|
||||
def test_nondefault_stream_inplace_update_in_forward_function():
|
||||
class MultipleStreamFunction4(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
|
|
|
|||
|
|
@ -85,9 +85,12 @@ void PythonOpBase::Init(const OpKernelInfo& info) {
|
|||
}
|
||||
|
||||
void PythonOpBase::Clear() {
|
||||
for (auto ptr : const_args_) {
|
||||
auto obj = reinterpret_cast<PyObject*>(ptr);
|
||||
Py_DECREF(obj);
|
||||
for (const auto& arg : const_arg_set_.GetArgs()) {
|
||||
// Only release owned PyObject.
|
||||
if (arg.is_owned) {
|
||||
auto obj = reinterpret_cast<PyObject*>(arg.data_ptr);
|
||||
Py_DECREF(obj);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -103,8 +106,8 @@ void PythonOpBase::RunForward(OpKernelContext* context,
|
|||
input_requires_grads_,
|
||||
args,
|
||||
arg_positions_,
|
||||
const_args_,
|
||||
const_arg_positions_,
|
||||
const_arg_set_.GetDataPtrs(),
|
||||
const_arg_set_.GetPositions(),
|
||||
diff_ctx,
|
||||
returned_ortvalues,
|
||||
is_training_mode_,
|
||||
|
|
@ -120,70 +123,73 @@ void PythonOpBase::SetOutputs(OpKernelContext* context, void* diff_ctx, std::vec
|
|||
}
|
||||
|
||||
void PythonOpBase::AddIntScalarArgs() {
|
||||
ORT_ENFORCE(const_args_.size() == const_arg_positions_.size());
|
||||
for (size_t i = 0; i < input_int_scalars_.size(); ++i) {
|
||||
const_arg_positions_.emplace_back(input_int_scalar_positions_.at(i));
|
||||
const_args_.emplace_back(Py_BuildValue("L", static_cast<long long>(input_int_scalars_.at(i))));
|
||||
const_arg_set_.Add(input_int_scalar_positions_.at(i),
|
||||
Py_BuildValue("L", static_cast<long long>(input_int_scalars_.at(i))),
|
||||
true /*owned*/);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < input_float_scalars_.size(); ++i) {
|
||||
const_arg_positions_.emplace_back(input_float_scalar_positions_.at(i));
|
||||
const_args_.emplace_back(Py_BuildValue("f", input_float_scalars_.at(i)));
|
||||
const_arg_set_.Add(input_float_scalar_positions_.at(i), Py_BuildValue("f", input_float_scalars_.at(i)),
|
||||
true /*owned*/);
|
||||
}
|
||||
}
|
||||
|
||||
void PythonOpBase::AddInputTupleArgs() {
|
||||
ORT_ENFORCE(const_args_.size() == const_arg_positions_.size());
|
||||
for (size_t i = 0; i < input_int_tuple_begins_.size(); ++i) {
|
||||
// Process i-th tuple.
|
||||
// Starting index of i-th tuple in the concatenation buffer.
|
||||
const size_t begin = input_int_tuple_begins_.at(i);
|
||||
// Endding (exclusive) index of i-th tuple in the concatenation buffer.
|
||||
const size_t end = (i + 1 == input_int_tuple_begins_.size()) ? input_int_tuples_.size() : input_int_tuple_begins_.at(i + 1);
|
||||
const size_t end =
|
||||
(i + 1 == input_int_tuple_begins_.size()) ? input_int_tuples_.size() : input_int_tuple_begins_.at(i + 1);
|
||||
PyObject* tuple = PyTuple_New(end - begin);
|
||||
for (size_t j = begin; j < end; ++j) {
|
||||
PyObject* item = Py_BuildValue("L", input_int_tuples_.at(j));
|
||||
PyTuple_SetItem(tuple, j - begin, item);
|
||||
}
|
||||
const_arg_positions_.emplace_back(input_int_tuple_positions_.at(i));
|
||||
const_args_.emplace_back(tuple);
|
||||
|
||||
const_arg_set_.Add(input_int_tuple_positions_.at(i), tuple, true /*owned*/);
|
||||
}
|
||||
}
|
||||
|
||||
void PythonOpBase::AddFloatTupleArgs() {
|
||||
ORT_ENFORCE(const_args_.size() == const_arg_positions_.size());
|
||||
for (size_t i = 0; i < input_float_tuple_begins_.size(); ++i) {
|
||||
// Process i-th tuple.
|
||||
// Starting index of i-th tuple in the concatenation buffer.
|
||||
const size_t begin = input_float_tuple_begins_.at(i);
|
||||
// Endding (exclusive) index of i-th tuple in the concatenation buffer.
|
||||
const size_t end = (i + 1 == input_float_tuple_begins_.size()) ? input_float_tuples_.size() : input_float_tuple_begins_.at(i + 1);
|
||||
const size_t end =
|
||||
(i + 1 == input_float_tuple_begins_.size()) ? input_float_tuples_.size() : input_float_tuple_begins_.at(i + 1);
|
||||
PyObject* tuple = PyTuple_New(end - begin);
|
||||
for (size_t j = begin; j < end; ++j) {
|
||||
PyObject* item = Py_BuildValue("f", input_float_tuples_.at(j));
|
||||
PyTuple_SetItem(tuple, j - begin, item);
|
||||
}
|
||||
const_arg_positions_.emplace_back(input_float_tuple_positions_.at(i));
|
||||
const_args_.emplace_back(tuple);
|
||||
|
||||
const_arg_set_.Add(input_float_tuple_positions_.at(i), tuple, true /*owned*/);
|
||||
}
|
||||
}
|
||||
|
||||
void PythonOpBase::AddPointerScalarArgs() {
|
||||
ORT_ENFORCE(const_args_.size() == const_arg_positions_.size());
|
||||
for (size_t i = 0; i < input_pointer_scalars_.size(); ++i) {
|
||||
const_arg_positions_.emplace_back(input_pointer_scalar_positions_.at(i));
|
||||
PyObject* ptr = reinterpret_cast<PyObject*>(input_pointer_scalars_.at(i));
|
||||
const_args_.emplace_back(ptr);
|
||||
// We don't want to own the Python object from C++ side because once C++ destructor called through pybind,
|
||||
// it may trigger python side object destroying, potentially requires GILs, resulting in a hang.
|
||||
// Instead, we have mechanism during exporting we increase the reference count already.
|
||||
const_arg_set_.Add(input_pointer_scalar_positions_.at(i), ptr, false /*owned*/);
|
||||
}
|
||||
}
|
||||
|
||||
void PythonOpBase::CreateConstArgs() {
|
||||
ORT_ENFORCE(const_args_.size() == 0);
|
||||
ORT_ENFORCE(const_arg_positions_.size() == 0);
|
||||
ORT_ENFORCE(const_arg_set_.Size() == 0);
|
||||
AddIntScalarArgs();
|
||||
AddInputTupleArgs();
|
||||
AddFloatTupleArgs();
|
||||
AddPointerScalarArgs();
|
||||
|
||||
// Freeze the constant arg.
|
||||
const_arg_set_.Finalize();
|
||||
}
|
||||
|
||||
void PythonOpBase::CreateArgPositions() {
|
||||
|
|
@ -191,11 +197,11 @@ void PythonOpBase::CreateArgPositions() {
|
|||
|
||||
// occupied[i] being true means the i-th input argument
|
||||
// to Python function has been set.
|
||||
std::vector<bool> occupied(input_tensor_types_.size() + const_args_.size(), false);
|
||||
std::vector<bool> occupied(input_tensor_types_.size() + const_arg_set_.Size(), false);
|
||||
|
||||
// We know all non-tensors were set above, so let's catch up.
|
||||
for (const auto pos : const_arg_positions_) {
|
||||
occupied.at(pos) = true;
|
||||
for (auto& arg : const_arg_set_.GetArgs()) {
|
||||
occupied.at(arg.position) = true;
|
||||
}
|
||||
|
||||
// Search for empty slots for tensors.
|
||||
|
|
@ -234,7 +240,8 @@ void PythonOpGradBase::Init(const OpKernelInfo& info) {
|
|||
ORT_THROW_IF_ERROR(info.GetAttr("output_convention", &output_convention_));
|
||||
ORT_THROW_IF_ERROR(info.GetAttrs("output_tensor_types", output_tensor_types_));
|
||||
output_tensor_requires_grads_ = info.GetAttrsOrDefault("output_tensor_requires_grads", std::vector<int64_t>());
|
||||
ORT_ENFORCE(output_tensor_types_.size() == output_tensor_requires_grads_.size(), "backward tensor output count mismatch");
|
||||
ORT_ENFORCE(output_tensor_types_.size() == output_tensor_requires_grads_.size(),
|
||||
"backward tensor output count mismatch");
|
||||
|
||||
SetPositions();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -45,8 +45,67 @@ class PythonOpBase {
|
|||
void Clear();
|
||||
|
||||
protected:
|
||||
std::vector<int64_t> const_arg_positions_;
|
||||
std::vector<void*> const_args_;
|
||||
class ConstantArgSet {
|
||||
private:
|
||||
class ConstantArg {
|
||||
public:
|
||||
ConstantArg(int64_t position, void* data_ptr, bool is_owned)
|
||||
: position(position), data_ptr(data_ptr), is_owned(is_owned) {}
|
||||
int64_t position; // input offset in the input lists
|
||||
void* data_ptr; // pointer to the data
|
||||
bool is_owned; // whether the data is owned by this PythonOp kernel.
|
||||
};
|
||||
|
||||
public:
|
||||
// Append new constant argument. Fail when called after Finalize() got called.
|
||||
void Add(int64_t position, void* data_ptr, bool owned) {
|
||||
ORT_ENFORCE(positions_.empty() && data_ptrs_.empty(),
|
||||
"Cannot add constant arg after Finalize()");
|
||||
args_.emplace_back(ConstantArg(position, data_ptr, owned));
|
||||
}
|
||||
|
||||
// Finalize the constant arg set. This is called after all constant args are added.
|
||||
// Fail when called more than once.
|
||||
void Finalize() {
|
||||
ORT_ENFORCE(positions_.empty() && data_ptrs_.empty());
|
||||
positions_.reserve(args_.size());
|
||||
for (auto& arg : args_) {
|
||||
positions_.push_back(arg.position);
|
||||
}
|
||||
|
||||
data_ptrs_.reserve(args_.size());
|
||||
for (auto& arg : args_) {
|
||||
data_ptrs_.push_back(arg.data_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
size_t Size() const {
|
||||
return args_.size();
|
||||
}
|
||||
|
||||
const std::vector<ConstantArg>& GetArgs() const {
|
||||
return args_;
|
||||
}
|
||||
|
||||
const std::vector<int64_t>& GetPositions() const {
|
||||
return positions_;
|
||||
}
|
||||
|
||||
const std::vector<void*>& GetDataPtrs() const {
|
||||
return data_ptrs_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<ConstantArg> args_;
|
||||
std::vector<int64_t> positions_;
|
||||
std::vector<void*> data_ptrs_;
|
||||
};
|
||||
|
||||
// A collection for all non-tensor input arguments, we treated them all as constants, including primitive types and
|
||||
// tuples, and also string or other user defined data types (represented in pointer in the attribute
|
||||
// "input_pointer_scalars").
|
||||
ConstantArgSet const_arg_set_;
|
||||
|
||||
std::vector<int64_t> arg_positions_;
|
||||
|
||||
// Name of containing class. For example, MyReLU.
|
||||
|
|
|
|||
Loading…
Reference in a new issue