Fix reference count for autograd.Function (#15121)

### Fix reference count for autograd

When PythonOp kernel initialized, `AddPointerScalarArgs` creates
`const_args_` which put all non-tensor references (including
ProcessGroup, string, or other user types) in it.

In kernel's destructor, all ref cnt got decreased for `const_args_`. 


```
void PythonOpBase::Clear() {
  for (auto ptr : const_args_) {
    auto obj = reinterpret_cast<PyObject*>(ptr);
    Py_DECREF(obj);
  }
}
```

It means, we did not increase cnt, but just decrease cnt. Running the
unit, segmentation fault will be thrown. The simple fix is to remove the
Py_DECREF for those pointer-type constant inputs triggered by kernel
destructor.

NONTENSOR_OBJECT_POINTER_STORE is the place we increase the reference
during export, then the reference will remain until the python program
terminates.


Additionally tunings:
1. Move some logs into verbose instead of warning in case of flooding
training logs.
2. Move pointer type ref holding from python side
(NONTENSOR_OBJECT_POINTER_STORE) to
orttraining/orttraining/core/framework/torch/custom_function_register.h.
Then we use a consistent approach to manage all PythonOp related python
object/methonds ref count increasing and decreasing.
This commit is contained in:
pengwa 2023-03-23 12:51:50 +08:00 committed by GitHub
parent f972d21e81
commit 7bec80d92a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 209 additions and 72 deletions

View file

@ -6,6 +6,8 @@
#include "orttraining/core/framework/torch/refcount_tracker.h"
#include "core/platform/env.h"
#include <cstdio>
#include <sstream>
#include <string>
namespace onnxruntime {
namespace language_interop_ops {
@ -151,6 +153,15 @@ PyObject* OrtTorchFunctionPool::GetBackwardCore(const std::string& key) {
return iter->second.get();
}
void OrtTorchFunctionPool::RegisterMiscellaneousConstInput(PyObject* obj) {
ORT_ENFORCE(obj, "Cannot register NULL reference input.");
const void* address = static_cast<const void*>(obj);
std::stringstream ss;
ss << address;
std::string key = ss.str();
RegisterEntry(mutex_, key, obj, miscellaneous_const_input_pool_);
}
int64_t OrtTorchFunctionPool::RegisterContext(PyObject* autograd_context) {
static int64_t index_ = 0x1000000;
std::lock_guard<std::mutex> lock(mutex_);
@ -185,12 +196,21 @@ PyObject* OrtTorchFunctionPool::GetContext(int64_t context_index) {
return iter->second.get();
}
void OrtTorchFunctionPool::UnRegisterFunctions() {
void OrtTorchFunctionPool::UnRegisterGlobalFunctions() {
forward_runner_.reset();
backward_runner_.reset();
func_context_pool_.clear();
}
void OrtTorchFunctionPool::UnRegisterModelSpecificFunctions() {
forward_core_pool_.clear();
backward_core_pool_.clear();
func_context_pool_.clear();
miscellaneous_const_input_pool_.clear();
}
void OrtTorchFunctionPool::UnRegisterFunctions() {
UnRegisterGlobalFunctions();
UnRegisterModelSpecificFunctions();
}
} // namespace torch

View file

@ -34,6 +34,14 @@ class OrtTorchFunctionPool final {
// 2. Caller of GetBackwardCore should not decrease the reference count of the returned object.
PyObject* GetBackwardCore(const std::string& key); // The "key" is the "name" attribute in PythonOpGrad.
// Autograd function may take input of "non-tensor && non int/float && non int/float tuple" types.
// While PythonOp running requires those inputs be there otherwise kernel execution will fail.
// So during model exporting, we need register those input with this API, then a ref cnt is increased by 1,
// they will not be released until OrtTorchFunctionPool is destroyed.
// We also trying to release those registration in 'UnRegisterFunctions' to avoid the issues of python program
// exits before we de-crease ref cnt for the already release python object.
void RegisterMiscellaneousConstInput(PyObject* obj);
// Context is torch backward gradient function pointer, and
// it is a property of forward run outputs (tensors), its lifecycle
// is along with forward run outputs in PyTorch design.
@ -76,11 +84,15 @@ class OrtTorchFunctionPool final {
OrtTorchFunctionPool(){};
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OrtTorchFunctionPool);
void UnRegisterGlobalFunctions();
void UnRegisterModelSpecificFunctions();
PythonObjectPtr forward_runner_;
PythonObjectPtr backward_runner_;
std::unordered_map<std::string, PythonObjectPtr> forward_core_pool_;
std::unordered_map<std::string, PythonObjectPtr> backward_core_pool_;
std::unordered_map<std::string, PythonObjectPtr> miscellaneous_const_input_pool_;
std::unordered_map<int64_t, PythonObjectPtr> func_context_pool_;
std::mutex mutex_;

View file

@ -27,7 +27,8 @@ void RefCountTracker::TrackPyObject(RefCountTracker::ObjCategory category, PyObj
} else {
addrs[addr].push_back(log_tag);
}
LOGS_DEFAULT(WARNING) << "Track" << ObjCategoryToString(category) << "\tAddress: [" << addr << "]\tRefCnt: " << Py_REFCNT(addr) << "\tLogTag: " << log_tag;
LOGS_DEFAULT(VERBOSE) << "Track" << ObjCategoryToString(category) << "\tAddress: [" << addr << "]\tRefCnt: "
<< Py_REFCNT(addr) << "\tLogTag: " << log_tag;
#endif
}
@ -48,7 +49,7 @@ void RefCountTracker::DumpDetails(const std::string& phase_name) const {
}
}
oss << "==========================================================" << std::endl;
LOGS_DEFAULT(WARNING) << oss.str();
LOGS_DEFAULT(VERBOSE) << oss.str();
#endif
}

View file

@ -71,7 +71,8 @@ void CheckArguments(
ORT_ENFORCE(obj_args.size() == obj_indices.size());
for (const auto i : requires_grads) {
ORT_ENFORCE(i == 0 || i == 1, "Flag of requiring gradient must be either 0 (not required) or 1 (required) but got ", i);
ORT_ENFORCE(i == 0 || i == 1,
"Flag of requiring gradient must be either 0 (not required) or 1 (required) but got ", i);
}
std::vector<int64_t> counts(len, 0);
@ -150,15 +151,19 @@ void InvokeRunner(
// from Pytorch.
PyObject* py_obj = PyTuple_GetItem(result_ptr.get(), 0);
if (is_training_mode) {
const auto& refcnt = Py_REFCNT(py_obj);
// We don't need do ref increase here because, python returns tensor.grad_fn as part of
// tuple, who increased the refcnt already (and tensor persist until the backward kernels completed).
// Pytorch also increases refcnt before apply() return, so we should expect refcount >= 2.
// We say "at least" 2 because user could increase the context refcnt as well in their autograd forward()
// and backward() functions.
ORT_ENFORCE(refcnt >= 2, "Ref count of context should be 2, but actually it's ", refcnt, ".");
if (refcnt > 2) {
LOGS_DEFAULT(WARNING) << "Autograd context refcnt > 2";
if (py_obj == Py_None) {
LOGS_DEFAULT(VERBOSE) << "Under training mode, autograd context found to be Py_None.";
} else {
const auto refcnt = Py_REFCNT(py_obj);
// We don't need do ref increase here because, python returns tensor.grad_fn as part of
// tuple, who increased the refcnt already (and tensor persist until the backward kernels completed).
// Pytorch also increases refcnt before apply() return, so we should expect refcount >= 2.
// We say "at least" 2 because user could increase the context refcnt as well in their autograd forward()
// and backward() functions.
ORT_ENFORCE(refcnt >= 2, "Ref count of context should be 2, but actually it's ", refcnt, ".");
if (refcnt > 2) {
LOGS_DEFAULT(VERBOSE) << "Autograd context refcnt > 2, refcnt: " << refcnt;
}
}
} else {
ORT_ENFORCE(py_obj == Py_None, "Under inference mode, autograd context should be Py_None.");

View file

@ -169,7 +169,6 @@ struct PyOptimizer {
PyOptimizer(const std::string optimizer_model_uri,
onnxruntime::training::api::Module* model, std::vector<std::shared_ptr<IExecutionProvider>> provider)
: optimizer_() {
auto env = GetTrainingEnv().GetORTEnv();
// XXX: We hope that env will be around when optimizer needs it.
optimizer_ = std::make_shared<onnxruntime::training::api::Optimizer>(optimizer_model_uri,
@ -524,6 +523,14 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
#else
ORT_UNUSED_PARAMETER(key);
ORT_UNUSED_PARAMETER(obj);
#endif
});
m.def("register_miscellaneous_const_input", [](py::object obj) -> void {
#ifdef ENABLE_TRAINING_TORCH_INTEROP
auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance();
pool.RegisterMiscellaneousConstInput(obj.ptr());
#else
ORT_UNUSED_PARAMETER(obj);
#endif
});
m.def("unregister_python_functions", []() -> void {

View file

@ -44,7 +44,7 @@ def enable_custom_autograd_support(to_enable=True):
)
from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_interop_utils
from ._custom_autograd_function_exporter import _clear_nontensor_object_references, _export
from ._custom_autograd_function_exporter import _export
if to_enable is True and custom_autograd_function_enabler.state is False:
if custom_autograd_function_enabler.already_enabled is False:
@ -59,8 +59,6 @@ def enable_custom_autograd_support(to_enable=True):
# Clear all gradient functions, to avoid a deadlock issue.
# Check the called function for more detailed comments.
atexit.register(torch_interop_utils.clear_all_grad_fns)
# Clear all non-tensor object reference (for example, ProcessGroup passed to PythonOp).
atexit.register(_clear_nontensor_object_references)
try:
# This is for the latest Pytorch nightly after this commit:

View file

@ -11,7 +11,7 @@ import torch.utils.checkpoint
from packaging import version
from torch.onnx import symbolic_helper
from onnxruntime.capi._pybind_state import register_torch_autograd_function
from onnxruntime.capi._pybind_state import register_torch_autograd_function, register_miscellaneous_const_input
from onnxruntime.training import ortmodule
from . import _logger
@ -58,15 +58,6 @@ def pytorch_type_to_onnx(scalar_type: str) -> torch.onnx.TensorProtoDataType:
return _CAST_PYTORCH_TO_ONNX[scalar_type]
# For pointer needed for PythonOp execution, we firstly append it into a global store to hold a
# reference (in case it is released after module exported).
NONTENSOR_OBJECT_POINTER_STORE = {}
def _clear_nontensor_object_references():
NONTENSOR_OBJECT_POINTER_STORE.clear()
def _export_pt_1_10(g, n, *args, **kwargs):
"""
This function exports PythonOp (input: "n") into a graph
@ -168,7 +159,9 @@ def _export_pt_1_10(g, n, *args, **kwargs):
input_pointer_scalar_positions.append(i)
input_pointer_scalars.append(id(arg))
NONTENSOR_OBJECT_POINTER_STORE[id(arg)] = arg
# For pointer (for example, ProcessGroup passed to PythonOp) needed for PythonOp execution,
# we append it into a global store to hold a reference (in case it is released after module exported).
register_miscellaneous_const_input(arg)
else:
raise wrap_exception(
ORTModuleONNXModelException,

View file

@ -244,7 +244,7 @@ def check_for_name_collisions_and_bind_methods_to_ortmodule(ortmodule: torch.nn.
if attribute_name == "forward":
continue
# This is a user defined/overriden method. Check for collisions.
# This is a user defined/overridden method. Check for collisions.
if attribute_name in ortmodule_attributes:
# This is a user defined method, issue a warning.
warnings.warn(

View file

@ -89,7 +89,7 @@ def test_gelu():
run_training_test_and_compare(model_builder, input_generator, label_input)
def test_GeLU_custom_func_rets_not_as_module_output():
def test_gelu_custom_func_rets_not_as_module_output():
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
@ -147,7 +147,7 @@ def test_GeLU_custom_func_rets_not_as_module_output():
run_training_test_and_compare(model_builder, input_generator, label_input)
def test_GeLU_multiple_forward_runs():
def test_gelu_multiple_forward_runs():
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
@ -199,7 +199,7 @@ def test_GeLU_multiple_forward_runs():
run_training_test_and_compare(model_builder, input_generator, label_input, run_forward_twice=True)
def test_MegatronF():
def test_megatronf():
# MegatronGFunction is tested in distributed test files.
class MegatronFFunction(torch.autograd.Function):
@staticmethod
@ -239,7 +239,7 @@ def test_MegatronF():
run_training_test_and_compare(model_builder, input_generator, label_input)
def test_ScalarAndTuple():
def test_scalar_and_tuple():
class ScalarAndTupleFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, alpha, beta, gamma):
@ -286,7 +286,7 @@ def test_ScalarAndTuple():
run_training_test_and_compare(model_builder, input_generator, label_input)
def test_ScalarAndTupleReordered():
def test_scalar_and_tuple_reordered():
class ScalarAndTupleReorderedFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, alpha, beta, input, gamma):
@ -333,6 +333,41 @@ def test_ScalarAndTupleReordered():
run_training_test_and_compare(model_builder, input_generator, label_input)
def test_pointer_type():
class StringInputFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, name: str):
ctx.save_for_backward(input)
ctx.name = name
return input.detach()
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
class StringInputFunctionTestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.func = StringInputFunction.apply
def forward(self, x):
h = self.func(x, "temp_name")
return h
output_size = 2
def model_builder():
return StringInputFunctionTestModel()
def input_generator():
return torch.randn(output_size, dtype=torch.float).requires_grad_()
# generate a label that have same shape as forward output.
label_input = torch.ones([output_size])
run_training_test_and_compare(model_builder, input_generator, label_input)
@pytest.mark.skip(
reason="This test is not correct. All tensors modified by in-place operattions should be mark_dirty(...)."
)
@ -633,7 +668,7 @@ def test_InplaceUpdateInputAsOutputRequireGradWithMarkDirty():
run_training_test_and_compare(model_builder, input_generator, label_input)
def test_EvalTest():
def test_evaluation():
class EvalTestFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
@ -679,7 +714,7 @@ def test_EvalTest():
torch_version_lower_than("1.10.0"),
reason="PyTorch older than 1.10.0 has bugs for exporting multiple output custom function",
)
def test_TwoOutputFunction():
def test_two_outputs_function():
class TwoOutputFunction1(torch.autograd.Function):
@staticmethod
# bias is an optional argument
@ -739,7 +774,7 @@ def test_TwoOutputFunction():
run_training_test_and_compare(model_builder, input_generator, label_input)
def test_InnerModuleCall():
def test_inner_module_call():
class InnerModel(torch.nn.Module):
def __init__(self, dim, device):
super(InnerModel, self).__init__()
@ -814,7 +849,7 @@ def test_InnerModuleCall():
torch_version_lower_than("1.10.0"),
reason="PyTorch older than 1.10.0 has bugs for exporting multiple output custom function",
)
def test_Share_Input():
def test_share_input():
class TwoOutputFunction2(torch.autograd.Function):
@staticmethod
# bias is an optional argument
@ -865,7 +900,7 @@ def test_Share_Input():
run_training_test_and_compare(model_builder, input_generator_with_requires_grad, label_input)
def test_MultipleStream_InForwardFunction():
def test_multiple_stream_in_forward_function():
class MultipleStreamFunction1(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
@ -913,7 +948,7 @@ def test_MultipleStream_InForwardFunction():
)
def test_NonDefaultStream_InForwardFunction1():
def test_nondefault_stream_in_forward_function1():
class MultipleStreamFunction2(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
@ -961,7 +996,7 @@ def test_NonDefaultStream_InForwardFunction1():
)
def test_NonDefaultStream_InForwardFunction2():
def test_nondefault_stream_in_forward_function2():
class MultipleStreamFunction3(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
@ -1008,7 +1043,7 @@ def test_NonDefaultStream_InForwardFunction2():
)
def test_NonDefaultStreamInplaceUpdate_InForwardFunction():
def test_nondefault_stream_inplace_update_in_forward_function():
class MultipleStreamFunction4(torch.autograd.Function):
@staticmethod
def forward(ctx, input):

View file

@ -85,9 +85,12 @@ void PythonOpBase::Init(const OpKernelInfo& info) {
}
void PythonOpBase::Clear() {
for (auto ptr : const_args_) {
auto obj = reinterpret_cast<PyObject*>(ptr);
Py_DECREF(obj);
for (const auto& arg : const_arg_set_.GetArgs()) {
// Only release owned PyObject.
if (arg.is_owned) {
auto obj = reinterpret_cast<PyObject*>(arg.data_ptr);
Py_DECREF(obj);
}
}
}
@ -103,8 +106,8 @@ void PythonOpBase::RunForward(OpKernelContext* context,
input_requires_grads_,
args,
arg_positions_,
const_args_,
const_arg_positions_,
const_arg_set_.GetDataPtrs(),
const_arg_set_.GetPositions(),
diff_ctx,
returned_ortvalues,
is_training_mode_,
@ -120,70 +123,73 @@ void PythonOpBase::SetOutputs(OpKernelContext* context, void* diff_ctx, std::vec
}
void PythonOpBase::AddIntScalarArgs() {
ORT_ENFORCE(const_args_.size() == const_arg_positions_.size());
for (size_t i = 0; i < input_int_scalars_.size(); ++i) {
const_arg_positions_.emplace_back(input_int_scalar_positions_.at(i));
const_args_.emplace_back(Py_BuildValue("L", static_cast<long long>(input_int_scalars_.at(i))));
const_arg_set_.Add(input_int_scalar_positions_.at(i),
Py_BuildValue("L", static_cast<long long>(input_int_scalars_.at(i))),
true /*owned*/);
}
for (size_t i = 0; i < input_float_scalars_.size(); ++i) {
const_arg_positions_.emplace_back(input_float_scalar_positions_.at(i));
const_args_.emplace_back(Py_BuildValue("f", input_float_scalars_.at(i)));
const_arg_set_.Add(input_float_scalar_positions_.at(i), Py_BuildValue("f", input_float_scalars_.at(i)),
true /*owned*/);
}
}
void PythonOpBase::AddInputTupleArgs() {
ORT_ENFORCE(const_args_.size() == const_arg_positions_.size());
for (size_t i = 0; i < input_int_tuple_begins_.size(); ++i) {
// Process i-th tuple.
// Starting index of i-th tuple in the concatenation buffer.
const size_t begin = input_int_tuple_begins_.at(i);
// Endding (exclusive) index of i-th tuple in the concatenation buffer.
const size_t end = (i + 1 == input_int_tuple_begins_.size()) ? input_int_tuples_.size() : input_int_tuple_begins_.at(i + 1);
const size_t end =
(i + 1 == input_int_tuple_begins_.size()) ? input_int_tuples_.size() : input_int_tuple_begins_.at(i + 1);
PyObject* tuple = PyTuple_New(end - begin);
for (size_t j = begin; j < end; ++j) {
PyObject* item = Py_BuildValue("L", input_int_tuples_.at(j));
PyTuple_SetItem(tuple, j - begin, item);
}
const_arg_positions_.emplace_back(input_int_tuple_positions_.at(i));
const_args_.emplace_back(tuple);
const_arg_set_.Add(input_int_tuple_positions_.at(i), tuple, true /*owned*/);
}
}
void PythonOpBase::AddFloatTupleArgs() {
ORT_ENFORCE(const_args_.size() == const_arg_positions_.size());
for (size_t i = 0; i < input_float_tuple_begins_.size(); ++i) {
// Process i-th tuple.
// Starting index of i-th tuple in the concatenation buffer.
const size_t begin = input_float_tuple_begins_.at(i);
// Endding (exclusive) index of i-th tuple in the concatenation buffer.
const size_t end = (i + 1 == input_float_tuple_begins_.size()) ? input_float_tuples_.size() : input_float_tuple_begins_.at(i + 1);
const size_t end =
(i + 1 == input_float_tuple_begins_.size()) ? input_float_tuples_.size() : input_float_tuple_begins_.at(i + 1);
PyObject* tuple = PyTuple_New(end - begin);
for (size_t j = begin; j < end; ++j) {
PyObject* item = Py_BuildValue("f", input_float_tuples_.at(j));
PyTuple_SetItem(tuple, j - begin, item);
}
const_arg_positions_.emplace_back(input_float_tuple_positions_.at(i));
const_args_.emplace_back(tuple);
const_arg_set_.Add(input_float_tuple_positions_.at(i), tuple, true /*owned*/);
}
}
void PythonOpBase::AddPointerScalarArgs() {
ORT_ENFORCE(const_args_.size() == const_arg_positions_.size());
for (size_t i = 0; i < input_pointer_scalars_.size(); ++i) {
const_arg_positions_.emplace_back(input_pointer_scalar_positions_.at(i));
PyObject* ptr = reinterpret_cast<PyObject*>(input_pointer_scalars_.at(i));
const_args_.emplace_back(ptr);
// We don't want to own the Python object from C++ side because once C++ destructor called through pybind,
// it may trigger python side object destroying, potentially requires GILs, resulting in a hang.
// Instead, we have mechanism during exporting we increase the reference count already.
const_arg_set_.Add(input_pointer_scalar_positions_.at(i), ptr, false /*owned*/);
}
}
void PythonOpBase::CreateConstArgs() {
ORT_ENFORCE(const_args_.size() == 0);
ORT_ENFORCE(const_arg_positions_.size() == 0);
ORT_ENFORCE(const_arg_set_.Size() == 0);
AddIntScalarArgs();
AddInputTupleArgs();
AddFloatTupleArgs();
AddPointerScalarArgs();
// Freeze the constant arg.
const_arg_set_.Finalize();
}
void PythonOpBase::CreateArgPositions() {
@ -191,11 +197,11 @@ void PythonOpBase::CreateArgPositions() {
// occupied[i] being true means the i-th input argument
// to Python function has been set.
std::vector<bool> occupied(input_tensor_types_.size() + const_args_.size(), false);
std::vector<bool> occupied(input_tensor_types_.size() + const_arg_set_.Size(), false);
// We know all non-tensors were set above, so let's catch up.
for (const auto pos : const_arg_positions_) {
occupied.at(pos) = true;
for (auto& arg : const_arg_set_.GetArgs()) {
occupied.at(arg.position) = true;
}
// Search for empty slots for tensors.
@ -234,7 +240,8 @@ void PythonOpGradBase::Init(const OpKernelInfo& info) {
ORT_THROW_IF_ERROR(info.GetAttr("output_convention", &output_convention_));
ORT_THROW_IF_ERROR(info.GetAttrs("output_tensor_types", output_tensor_types_));
output_tensor_requires_grads_ = info.GetAttrsOrDefault("output_tensor_requires_grads", std::vector<int64_t>());
ORT_ENFORCE(output_tensor_types_.size() == output_tensor_requires_grads_.size(), "backward tensor output count mismatch");
ORT_ENFORCE(output_tensor_types_.size() == output_tensor_requires_grads_.size(),
"backward tensor output count mismatch");
SetPositions();
}

View file

@ -45,8 +45,67 @@ class PythonOpBase {
void Clear();
protected:
std::vector<int64_t> const_arg_positions_;
std::vector<void*> const_args_;
class ConstantArgSet {
private:
class ConstantArg {
public:
ConstantArg(int64_t position, void* data_ptr, bool is_owned)
: position(position), data_ptr(data_ptr), is_owned(is_owned) {}
int64_t position; // input offset in the input lists
void* data_ptr; // pointer to the data
bool is_owned; // whether the data is owned by this PythonOp kernel.
};
public:
// Append new constant argument. Fail when called after Finalize() got called.
void Add(int64_t position, void* data_ptr, bool owned) {
ORT_ENFORCE(positions_.empty() && data_ptrs_.empty(),
"Cannot add constant arg after Finalize()");
args_.emplace_back(ConstantArg(position, data_ptr, owned));
}
// Finalize the constant arg set. This is called after all constant args are added.
// Fail when called more than once.
void Finalize() {
ORT_ENFORCE(positions_.empty() && data_ptrs_.empty());
positions_.reserve(args_.size());
for (auto& arg : args_) {
positions_.push_back(arg.position);
}
data_ptrs_.reserve(args_.size());
for (auto& arg : args_) {
data_ptrs_.push_back(arg.data_ptr);
}
}
size_t Size() const {
return args_.size();
}
const std::vector<ConstantArg>& GetArgs() const {
return args_;
}
const std::vector<int64_t>& GetPositions() const {
return positions_;
}
const std::vector<void*>& GetDataPtrs() const {
return data_ptrs_;
}
private:
std::vector<ConstantArg> args_;
std::vector<int64_t> positions_;
std::vector<void*> data_ptrs_;
};
// A collection for all non-tensor input arguments, we treated them all as constants, including primitive types and
// tuples, and also string or other user defined data types (represented in pointer in the attribute
// "input_pointer_scalars").
ConstantArgSet const_arg_set_;
std::vector<int64_t> arg_positions_;
// Name of containing class. For example, MyReLU.