diff --git a/orttraining/orttraining/core/framework/torch/custom_function_register.cc b/orttraining/orttraining/core/framework/torch/custom_function_register.cc index 6681b32d66..2bf0be1d71 100644 --- a/orttraining/orttraining/core/framework/torch/custom_function_register.cc +++ b/orttraining/orttraining/core/framework/torch/custom_function_register.cc @@ -6,6 +6,8 @@ #include "orttraining/core/framework/torch/refcount_tracker.h" #include "core/platform/env.h" #include +#include +#include namespace onnxruntime { namespace language_interop_ops { @@ -151,6 +153,15 @@ PyObject* OrtTorchFunctionPool::GetBackwardCore(const std::string& key) { return iter->second.get(); } +void OrtTorchFunctionPool::RegisterMiscellaneousConstInput(PyObject* obj) { + ORT_ENFORCE(obj, "Cannot register NULL reference input."); + const void* address = static_cast(obj); + std::stringstream ss; + ss << address; + std::string key = ss.str(); + RegisterEntry(mutex_, key, obj, miscellaneous_const_input_pool_); +} + int64_t OrtTorchFunctionPool::RegisterContext(PyObject* autograd_context) { static int64_t index_ = 0x1000000; std::lock_guard lock(mutex_); @@ -185,12 +196,21 @@ PyObject* OrtTorchFunctionPool::GetContext(int64_t context_index) { return iter->second.get(); } -void OrtTorchFunctionPool::UnRegisterFunctions() { +void OrtTorchFunctionPool::UnRegisterGlobalFunctions() { forward_runner_.reset(); backward_runner_.reset(); + func_context_pool_.clear(); +} + +void OrtTorchFunctionPool::UnRegisterModelSpecificFunctions() { forward_core_pool_.clear(); backward_core_pool_.clear(); - func_context_pool_.clear(); + miscellaneous_const_input_pool_.clear(); +} + +void OrtTorchFunctionPool::UnRegisterFunctions() { + UnRegisterGlobalFunctions(); + UnRegisterModelSpecificFunctions(); } } // namespace torch diff --git a/orttraining/orttraining/core/framework/torch/custom_function_register.h b/orttraining/orttraining/core/framework/torch/custom_function_register.h index 48164548c9..0dea6d036a 100644 --- a/orttraining/orttraining/core/framework/torch/custom_function_register.h +++ b/orttraining/orttraining/core/framework/torch/custom_function_register.h @@ -34,6 +34,14 @@ class OrtTorchFunctionPool final { // 2. Caller of GetBackwardCore should not decrease the reference count of the returned object. PyObject* GetBackwardCore(const std::string& key); // The "key" is the "name" attribute in PythonOpGrad. + // Autograd function may take input of "non-tensor && non int/float && non int/float tuple" types. + // While PythonOp running requires those inputs be there otherwise kernel execution will fail. + // So during model exporting, we need register those input with this API, then a ref cnt is increased by 1, + // they will not be released until OrtTorchFunctionPool is destroyed. + // We also trying to release those registration in 'UnRegisterFunctions' to avoid the issues of python program + // exits before we de-crease ref cnt for the already release python object. + void RegisterMiscellaneousConstInput(PyObject* obj); + // Context is torch backward gradient function pointer, and // it is a property of forward run outputs (tensors), its lifecycle // is along with forward run outputs in PyTorch design. @@ -76,11 +84,15 @@ class OrtTorchFunctionPool final { OrtTorchFunctionPool(){}; ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OrtTorchFunctionPool); + void UnRegisterGlobalFunctions(); + void UnRegisterModelSpecificFunctions(); + PythonObjectPtr forward_runner_; PythonObjectPtr backward_runner_; std::unordered_map forward_core_pool_; std::unordered_map backward_core_pool_; + std::unordered_map miscellaneous_const_input_pool_; std::unordered_map func_context_pool_; std::mutex mutex_; diff --git a/orttraining/orttraining/core/framework/torch/refcount_tracker.cc b/orttraining/orttraining/core/framework/torch/refcount_tracker.cc index 9feea6bdf0..27eaa90d07 100644 --- a/orttraining/orttraining/core/framework/torch/refcount_tracker.cc +++ b/orttraining/orttraining/core/framework/torch/refcount_tracker.cc @@ -27,7 +27,8 @@ void RefCountTracker::TrackPyObject(RefCountTracker::ObjCategory category, PyObj } else { addrs[addr].push_back(log_tag); } - LOGS_DEFAULT(WARNING) << "Track" << ObjCategoryToString(category) << "\tAddress: [" << addr << "]\tRefCnt: " << Py_REFCNT(addr) << "\tLogTag: " << log_tag; + LOGS_DEFAULT(VERBOSE) << "Track" << ObjCategoryToString(category) << "\tAddress: [" << addr << "]\tRefCnt: " + << Py_REFCNT(addr) << "\tLogTag: " << log_tag; #endif } @@ -48,7 +49,7 @@ void RefCountTracker::DumpDetails(const std::string& phase_name) const { } } oss << "==========================================================" << std::endl; - LOGS_DEFAULT(WARNING) << oss.str(); + LOGS_DEFAULT(VERBOSE) << oss.str(); #endif } diff --git a/orttraining/orttraining/core/framework/torch/torch_proxy.cc b/orttraining/orttraining/core/framework/torch/torch_proxy.cc index ad72227a02..3bdcc85fb0 100644 --- a/orttraining/orttraining/core/framework/torch/torch_proxy.cc +++ b/orttraining/orttraining/core/framework/torch/torch_proxy.cc @@ -71,7 +71,8 @@ void CheckArguments( ORT_ENFORCE(obj_args.size() == obj_indices.size()); for (const auto i : requires_grads) { - ORT_ENFORCE(i == 0 || i == 1, "Flag of requiring gradient must be either 0 (not required) or 1 (required) but got ", i); + ORT_ENFORCE(i == 0 || i == 1, + "Flag of requiring gradient must be either 0 (not required) or 1 (required) but got ", i); } std::vector counts(len, 0); @@ -150,15 +151,19 @@ void InvokeRunner( // from Pytorch. PyObject* py_obj = PyTuple_GetItem(result_ptr.get(), 0); if (is_training_mode) { - const auto& refcnt = Py_REFCNT(py_obj); - // We don't need do ref increase here because, python returns tensor.grad_fn as part of - // tuple, who increased the refcnt already (and tensor persist until the backward kernels completed). - // Pytorch also increases refcnt before apply() return, so we should expect refcount >= 2. - // We say "at least" 2 because user could increase the context refcnt as well in their autograd forward() - // and backward() functions. - ORT_ENFORCE(refcnt >= 2, "Ref count of context should be 2, but actually it's ", refcnt, "."); - if (refcnt > 2) { - LOGS_DEFAULT(WARNING) << "Autograd context refcnt > 2"; + if (py_obj == Py_None) { + LOGS_DEFAULT(VERBOSE) << "Under training mode, autograd context found to be Py_None."; + } else { + const auto refcnt = Py_REFCNT(py_obj); + // We don't need do ref increase here because, python returns tensor.grad_fn as part of + // tuple, who increased the refcnt already (and tensor persist until the backward kernels completed). + // Pytorch also increases refcnt before apply() return, so we should expect refcount >= 2. + // We say "at least" 2 because user could increase the context refcnt as well in their autograd forward() + // and backward() functions. + ORT_ENFORCE(refcnt >= 2, "Ref count of context should be 2, but actually it's ", refcnt, "."); + if (refcnt > 2) { + LOGS_DEFAULT(VERBOSE) << "Autograd context refcnt > 2, refcnt: " << refcnt; + } } } else { ORT_ENFORCE(py_obj == Py_None, "Under inference mode, autograd context should be Py_None."); diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 71c0fe5654..415dd7dd97 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -169,7 +169,6 @@ struct PyOptimizer { PyOptimizer(const std::string optimizer_model_uri, onnxruntime::training::api::Module* model, std::vector> provider) : optimizer_() { - auto env = GetTrainingEnv().GetORTEnv(); // XXX: We hope that env will be around when optimizer needs it. optimizer_ = std::make_shared(optimizer_model_uri, @@ -524,6 +523,14 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn #else ORT_UNUSED_PARAMETER(key); ORT_UNUSED_PARAMETER(obj); +#endif + }); + m.def("register_miscellaneous_const_input", [](py::object obj) -> void { +#ifdef ENABLE_TRAINING_TORCH_INTEROP + auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance(); + pool.RegisterMiscellaneousConstInput(obj.ptr()); +#else + ORT_UNUSED_PARAMETER(obj); #endif }); m.def("unregister_python_functions", []() -> void { diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py index 1c2fce2b1a..0466195744 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py @@ -44,7 +44,7 @@ def enable_custom_autograd_support(to_enable=True): ) from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_interop_utils - from ._custom_autograd_function_exporter import _clear_nontensor_object_references, _export + from ._custom_autograd_function_exporter import _export if to_enable is True and custom_autograd_function_enabler.state is False: if custom_autograd_function_enabler.already_enabled is False: @@ -59,8 +59,6 @@ def enable_custom_autograd_support(to_enable=True): # Clear all gradient functions, to avoid a deadlock issue. # Check the called function for more detailed comments. atexit.register(torch_interop_utils.clear_all_grad_fns) - # Clear all non-tensor object reference (for example, ProcessGroup passed to PythonOp). - atexit.register(_clear_nontensor_object_references) try: # This is for the latest Pytorch nightly after this commit: diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 96f9ccd4ae..6787c8603f 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -11,7 +11,7 @@ import torch.utils.checkpoint from packaging import version from torch.onnx import symbolic_helper -from onnxruntime.capi._pybind_state import register_torch_autograd_function +from onnxruntime.capi._pybind_state import register_torch_autograd_function, register_miscellaneous_const_input from onnxruntime.training import ortmodule from . import _logger @@ -58,15 +58,6 @@ def pytorch_type_to_onnx(scalar_type: str) -> torch.onnx.TensorProtoDataType: return _CAST_PYTORCH_TO_ONNX[scalar_type] -# For pointer needed for PythonOp execution, we firstly append it into a global store to hold a -# reference (in case it is released after module exported). -NONTENSOR_OBJECT_POINTER_STORE = {} - - -def _clear_nontensor_object_references(): - NONTENSOR_OBJECT_POINTER_STORE.clear() - - def _export_pt_1_10(g, n, *args, **kwargs): """ This function exports PythonOp (input: "n") into a graph @@ -168,7 +159,9 @@ def _export_pt_1_10(g, n, *args, **kwargs): input_pointer_scalar_positions.append(i) input_pointer_scalars.append(id(arg)) - NONTENSOR_OBJECT_POINTER_STORE[id(arg)] = arg + # For pointer (for example, ProcessGroup passed to PythonOp) needed for PythonOp execution, + # we append it into a global store to hold a reference (in case it is released after module exported). + register_miscellaneous_const_input(arg) else: raise wrap_exception( ORTModuleONNXModelException, diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index a43f0c3e66..d256c91810 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -244,7 +244,7 @@ def check_for_name_collisions_and_bind_methods_to_ortmodule(ortmodule: torch.nn. if attribute_name == "forward": continue - # This is a user defined/overriden method. Check for collisions. + # This is a user defined/overridden method. Check for collisions. if attribute_name in ortmodule_attributes: # This is a user defined method, issue a warning. warnings.warn( diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py index a625735c8c..52b838df7b 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py @@ -89,7 +89,7 @@ def test_gelu(): run_training_test_and_compare(model_builder, input_generator, label_input) -def test_GeLU_custom_func_rets_not_as_module_output(): +def test_gelu_custom_func_rets_not_as_module_output(): @torch.jit.script def bias_gelu(bias, y): x = bias + y @@ -147,7 +147,7 @@ def test_GeLU_custom_func_rets_not_as_module_output(): run_training_test_and_compare(model_builder, input_generator, label_input) -def test_GeLU_multiple_forward_runs(): +def test_gelu_multiple_forward_runs(): @torch.jit.script def bias_gelu(bias, y): x = bias + y @@ -199,7 +199,7 @@ def test_GeLU_multiple_forward_runs(): run_training_test_and_compare(model_builder, input_generator, label_input, run_forward_twice=True) -def test_MegatronF(): +def test_megatronf(): # MegatronGFunction is tested in distributed test files. class MegatronFFunction(torch.autograd.Function): @staticmethod @@ -239,7 +239,7 @@ def test_MegatronF(): run_training_test_and_compare(model_builder, input_generator, label_input) -def test_ScalarAndTuple(): +def test_scalar_and_tuple(): class ScalarAndTupleFunction(torch.autograd.Function): @staticmethod def forward(ctx, input, alpha, beta, gamma): @@ -286,7 +286,7 @@ def test_ScalarAndTuple(): run_training_test_and_compare(model_builder, input_generator, label_input) -def test_ScalarAndTupleReordered(): +def test_scalar_and_tuple_reordered(): class ScalarAndTupleReorderedFunction(torch.autograd.Function): @staticmethod def forward(ctx, alpha, beta, input, gamma): @@ -333,6 +333,41 @@ def test_ScalarAndTupleReordered(): run_training_test_and_compare(model_builder, input_generator, label_input) +def test_pointer_type(): + class StringInputFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, name: str): + ctx.save_for_backward(input) + ctx.name = name + return input.detach() + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + class StringInputFunctionTestModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.func = StringInputFunction.apply + + def forward(self, x): + h = self.func(x, "temp_name") + return h + + output_size = 2 + + def model_builder(): + return StringInputFunctionTestModel() + + def input_generator(): + return torch.randn(output_size, dtype=torch.float).requires_grad_() + + # generate a label that have same shape as forward output. + label_input = torch.ones([output_size]) + + run_training_test_and_compare(model_builder, input_generator, label_input) + + @pytest.mark.skip( reason="This test is not correct. All tensors modified by in-place operattions should be mark_dirty(...)." ) @@ -633,7 +668,7 @@ def test_InplaceUpdateInputAsOutputRequireGradWithMarkDirty(): run_training_test_and_compare(model_builder, input_generator, label_input) -def test_EvalTest(): +def test_evaluation(): class EvalTestFunction(torch.autograd.Function): @staticmethod # bias is an optional argument @@ -679,7 +714,7 @@ def test_EvalTest(): torch_version_lower_than("1.10.0"), reason="PyTorch older than 1.10.0 has bugs for exporting multiple output custom function", ) -def test_TwoOutputFunction(): +def test_two_outputs_function(): class TwoOutputFunction1(torch.autograd.Function): @staticmethod # bias is an optional argument @@ -739,7 +774,7 @@ def test_TwoOutputFunction(): run_training_test_and_compare(model_builder, input_generator, label_input) -def test_InnerModuleCall(): +def test_inner_module_call(): class InnerModel(torch.nn.Module): def __init__(self, dim, device): super(InnerModel, self).__init__() @@ -814,7 +849,7 @@ def test_InnerModuleCall(): torch_version_lower_than("1.10.0"), reason="PyTorch older than 1.10.0 has bugs for exporting multiple output custom function", ) -def test_Share_Input(): +def test_share_input(): class TwoOutputFunction2(torch.autograd.Function): @staticmethod # bias is an optional argument @@ -865,7 +900,7 @@ def test_Share_Input(): run_training_test_and_compare(model_builder, input_generator_with_requires_grad, label_input) -def test_MultipleStream_InForwardFunction(): +def test_multiple_stream_in_forward_function(): class MultipleStreamFunction1(torch.autograd.Function): @staticmethod def forward(ctx, input): @@ -913,7 +948,7 @@ def test_MultipleStream_InForwardFunction(): ) -def test_NonDefaultStream_InForwardFunction1(): +def test_nondefault_stream_in_forward_function1(): class MultipleStreamFunction2(torch.autograd.Function): @staticmethod def forward(ctx, input): @@ -961,7 +996,7 @@ def test_NonDefaultStream_InForwardFunction1(): ) -def test_NonDefaultStream_InForwardFunction2(): +def test_nondefault_stream_in_forward_function2(): class MultipleStreamFunction3(torch.autograd.Function): @staticmethod def forward(ctx, input): @@ -1008,7 +1043,7 @@ def test_NonDefaultStream_InForwardFunction2(): ) -def test_NonDefaultStreamInplaceUpdate_InForwardFunction(): +def test_nondefault_stream_inplace_update_in_forward_function(): class MultipleStreamFunction4(torch.autograd.Function): @staticmethod def forward(ctx, input): diff --git a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc index d943e27c25..faaee0cb3a 100644 --- a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc +++ b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc @@ -85,9 +85,12 @@ void PythonOpBase::Init(const OpKernelInfo& info) { } void PythonOpBase::Clear() { - for (auto ptr : const_args_) { - auto obj = reinterpret_cast(ptr); - Py_DECREF(obj); + for (const auto& arg : const_arg_set_.GetArgs()) { + // Only release owned PyObject. + if (arg.is_owned) { + auto obj = reinterpret_cast(arg.data_ptr); + Py_DECREF(obj); + } } } @@ -103,8 +106,8 @@ void PythonOpBase::RunForward(OpKernelContext* context, input_requires_grads_, args, arg_positions_, - const_args_, - const_arg_positions_, + const_arg_set_.GetDataPtrs(), + const_arg_set_.GetPositions(), diff_ctx, returned_ortvalues, is_training_mode_, @@ -120,70 +123,73 @@ void PythonOpBase::SetOutputs(OpKernelContext* context, void* diff_ctx, std::vec } void PythonOpBase::AddIntScalarArgs() { - ORT_ENFORCE(const_args_.size() == const_arg_positions_.size()); for (size_t i = 0; i < input_int_scalars_.size(); ++i) { - const_arg_positions_.emplace_back(input_int_scalar_positions_.at(i)); - const_args_.emplace_back(Py_BuildValue("L", static_cast(input_int_scalars_.at(i)))); + const_arg_set_.Add(input_int_scalar_positions_.at(i), + Py_BuildValue("L", static_cast(input_int_scalars_.at(i))), + true /*owned*/); } for (size_t i = 0; i < input_float_scalars_.size(); ++i) { - const_arg_positions_.emplace_back(input_float_scalar_positions_.at(i)); - const_args_.emplace_back(Py_BuildValue("f", input_float_scalars_.at(i))); + const_arg_set_.Add(input_float_scalar_positions_.at(i), Py_BuildValue("f", input_float_scalars_.at(i)), + true /*owned*/); } } void PythonOpBase::AddInputTupleArgs() { - ORT_ENFORCE(const_args_.size() == const_arg_positions_.size()); for (size_t i = 0; i < input_int_tuple_begins_.size(); ++i) { // Process i-th tuple. // Starting index of i-th tuple in the concatenation buffer. const size_t begin = input_int_tuple_begins_.at(i); // Endding (exclusive) index of i-th tuple in the concatenation buffer. - const size_t end = (i + 1 == input_int_tuple_begins_.size()) ? input_int_tuples_.size() : input_int_tuple_begins_.at(i + 1); + const size_t end = + (i + 1 == input_int_tuple_begins_.size()) ? input_int_tuples_.size() : input_int_tuple_begins_.at(i + 1); PyObject* tuple = PyTuple_New(end - begin); for (size_t j = begin; j < end; ++j) { PyObject* item = Py_BuildValue("L", input_int_tuples_.at(j)); PyTuple_SetItem(tuple, j - begin, item); } - const_arg_positions_.emplace_back(input_int_tuple_positions_.at(i)); - const_args_.emplace_back(tuple); + + const_arg_set_.Add(input_int_tuple_positions_.at(i), tuple, true /*owned*/); } } void PythonOpBase::AddFloatTupleArgs() { - ORT_ENFORCE(const_args_.size() == const_arg_positions_.size()); for (size_t i = 0; i < input_float_tuple_begins_.size(); ++i) { // Process i-th tuple. // Starting index of i-th tuple in the concatenation buffer. const size_t begin = input_float_tuple_begins_.at(i); // Endding (exclusive) index of i-th tuple in the concatenation buffer. - const size_t end = (i + 1 == input_float_tuple_begins_.size()) ? input_float_tuples_.size() : input_float_tuple_begins_.at(i + 1); + const size_t end = + (i + 1 == input_float_tuple_begins_.size()) ? input_float_tuples_.size() : input_float_tuple_begins_.at(i + 1); PyObject* tuple = PyTuple_New(end - begin); for (size_t j = begin; j < end; ++j) { PyObject* item = Py_BuildValue("f", input_float_tuples_.at(j)); PyTuple_SetItem(tuple, j - begin, item); } - const_arg_positions_.emplace_back(input_float_tuple_positions_.at(i)); - const_args_.emplace_back(tuple); + + const_arg_set_.Add(input_float_tuple_positions_.at(i), tuple, true /*owned*/); } } void PythonOpBase::AddPointerScalarArgs() { - ORT_ENFORCE(const_args_.size() == const_arg_positions_.size()); for (size_t i = 0; i < input_pointer_scalars_.size(); ++i) { - const_arg_positions_.emplace_back(input_pointer_scalar_positions_.at(i)); PyObject* ptr = reinterpret_cast(input_pointer_scalars_.at(i)); - const_args_.emplace_back(ptr); + // We don't want to own the Python object from C++ side because once C++ destructor called through pybind, + // it may trigger python side object destroying, potentially requires GILs, resulting in a hang. + // Instead, we have mechanism during exporting we increase the reference count already. + const_arg_set_.Add(input_pointer_scalar_positions_.at(i), ptr, false /*owned*/); } } void PythonOpBase::CreateConstArgs() { - ORT_ENFORCE(const_args_.size() == 0); - ORT_ENFORCE(const_arg_positions_.size() == 0); + ORT_ENFORCE(const_arg_set_.Size() == 0); AddIntScalarArgs(); AddInputTupleArgs(); AddFloatTupleArgs(); AddPointerScalarArgs(); + + // Freeze the constant arg. + const_arg_set_.Finalize(); } void PythonOpBase::CreateArgPositions() { @@ -191,11 +197,11 @@ void PythonOpBase::CreateArgPositions() { // occupied[i] being true means the i-th input argument // to Python function has been set. - std::vector occupied(input_tensor_types_.size() + const_args_.size(), false); + std::vector occupied(input_tensor_types_.size() + const_arg_set_.Size(), false); // We know all non-tensors were set above, so let's catch up. - for (const auto pos : const_arg_positions_) { - occupied.at(pos) = true; + for (auto& arg : const_arg_set_.GetArgs()) { + occupied.at(arg.position) = true; } // Search for empty slots for tensors. @@ -234,7 +240,8 @@ void PythonOpGradBase::Init(const OpKernelInfo& info) { ORT_THROW_IF_ERROR(info.GetAttr("output_convention", &output_convention_)); ORT_THROW_IF_ERROR(info.GetAttrs("output_tensor_types", output_tensor_types_)); output_tensor_requires_grads_ = info.GetAttrsOrDefault("output_tensor_requires_grads", std::vector()); - ORT_ENFORCE(output_tensor_types_.size() == output_tensor_requires_grads_.size(), "backward tensor output count mismatch"); + ORT_ENFORCE(output_tensor_types_.size() == output_tensor_requires_grads_.size(), + "backward tensor output count mismatch"); SetPositions(); } diff --git a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.h b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.h index 4347ba86f1..a9cb4960b9 100644 --- a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.h +++ b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.h @@ -45,8 +45,67 @@ class PythonOpBase { void Clear(); protected: - std::vector const_arg_positions_; - std::vector const_args_; + class ConstantArgSet { + private: + class ConstantArg { + public: + ConstantArg(int64_t position, void* data_ptr, bool is_owned) + : position(position), data_ptr(data_ptr), is_owned(is_owned) {} + int64_t position; // input offset in the input lists + void* data_ptr; // pointer to the data + bool is_owned; // whether the data is owned by this PythonOp kernel. + }; + + public: + // Append new constant argument. Fail when called after Finalize() got called. + void Add(int64_t position, void* data_ptr, bool owned) { + ORT_ENFORCE(positions_.empty() && data_ptrs_.empty(), + "Cannot add constant arg after Finalize()"); + args_.emplace_back(ConstantArg(position, data_ptr, owned)); + } + + // Finalize the constant arg set. This is called after all constant args are added. + // Fail when called more than once. + void Finalize() { + ORT_ENFORCE(positions_.empty() && data_ptrs_.empty()); + positions_.reserve(args_.size()); + for (auto& arg : args_) { + positions_.push_back(arg.position); + } + + data_ptrs_.reserve(args_.size()); + for (auto& arg : args_) { + data_ptrs_.push_back(arg.data_ptr); + } + } + + size_t Size() const { + return args_.size(); + } + + const std::vector& GetArgs() const { + return args_; + } + + const std::vector& GetPositions() const { + return positions_; + } + + const std::vector& GetDataPtrs() const { + return data_ptrs_; + } + + private: + std::vector args_; + std::vector positions_; + std::vector data_ptrs_; + }; + + // A collection for all non-tensor input arguments, we treated them all as constants, including primitive types and + // tuples, and also string or other user defined data types (represented in pointer in the attribute + // "input_pointer_scalars"). + ConstantArgSet const_arg_set_; + std::vector arg_positions_; // Name of containing class. For example, MyReLU.