mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
Prevent PySparseTensor form being garbage collected if we have an outstanding OrtValue (#9540)
Prevent PySparseTensor form being garbage collected if we have an outstanding OrtValue Improve comments.
This commit is contained in:
parent
aa76520e60
commit
4e76360261
4 changed files with 47 additions and 20 deletions
|
|
@ -110,7 +110,7 @@ void addOrtValueMethods(pybind11::module& m) {
|
|||
Tensor::InitOrtValue(ml_type, shape, std::move(allocator), *ml_value);
|
||||
return ml_value;
|
||||
})
|
||||
// This will create a copy of OrtValue (cheap) and will return as a separate OrtValue object
|
||||
|
||||
.def_static("ort_value_from_sparse_tensor", [](const PySparseTensor* py_sparse_tensor) -> std::unique_ptr<OrtValue> {
|
||||
return py_sparse_tensor->AsOrtValue();
|
||||
})
|
||||
|
|
|
|||
|
|
@ -208,6 +208,8 @@ py::object GetPyObjectFromSparseTensor(size_t pos, const OrtValue& ort_value, co
|
|||
OrtPybindThrowIfError(status);
|
||||
py_sparse_tensor.reset(new PySparseTensor(std::move(dst_sparse_tensor)));
|
||||
}
|
||||
} else {
|
||||
py_sparse_tensor.reset(new PySparseTensor(ort_value));
|
||||
}
|
||||
|
||||
py::object result = py::cast(py_sparse_tensor.get(), py::return_value_policy::take_ownership);
|
||||
|
|
|
|||
|
|
@ -80,11 +80,18 @@ OrtValue FromDlpack(PyObject* dlpack_tensor, const bool is_bool_tensor) {
|
|||
|
||||
#endif
|
||||
|
||||
void PySparseTensor::Init(std::unique_ptr<SparseTensor>&& instance) {
|
||||
auto sparse_tensor(std::move(instance));
|
||||
auto ml_type = DataTypeImpl::GetType<SparseTensor>();
|
||||
ort_value_.Init(sparse_tensor.get(), ml_type, ml_type->GetDeleteFunc());
|
||||
sparse_tensor.release();
|
||||
std::unique_ptr<OrtValue> PySparseTensor::AsOrtValue() const {
|
||||
if (instance_) {
|
||||
auto ort_value = std::make_unique<OrtValue>();
|
||||
auto ml_type = DataTypeImpl::GetType<SparseTensor>();
|
||||
py::object this_object = py::cast(*this);
|
||||
// Create an std::function deleter that captures and ref-counts this PySparseTensor
|
||||
ort_value->Init(instance_.get(), ml_type, [object = std::move(this_object)](void*) {});
|
||||
return ort_value;
|
||||
}
|
||||
|
||||
assert(ort_value_.IsAllocated());
|
||||
return std::make_unique<OrtValue>(ort_value_);
|
||||
}
|
||||
|
||||
PySparseTensor::~PySparseTensor() {
|
||||
|
|
|
|||
|
|
@ -307,8 +307,15 @@ inline AllocatorPtr& GetAllocator() {
|
|||
// This class exposes SparseTensor to Python
|
||||
// The class serves two major purposes
|
||||
// - to be able to map numpy arrays memory and use it on input, this serves as a reference holder
|
||||
// so incoming arrays do not disappear
|
||||
// - to be able to expose SparseTensor returned from run method
|
||||
// so incoming arrays do not disappear. To this end we create an instance of SparseTensor
|
||||
// on top of the user provided numpy arrays and create a duplicate of py::objects for those
|
||||
// numpy array for ref-counting purposes and store it here.
|
||||
//
|
||||
// - to be able to expose SparseTensor returned from run method. We get an OrtValue from run()
|
||||
// and store a copy of it in ort_value_. The OrtValue shared_ptr ref-counting will make sure
|
||||
// the memory stays around.
|
||||
//
|
||||
// An object of the class must never have both instance_ and ort_value_ have data at the same time.
|
||||
class PySparseTensor {
|
||||
public:
|
||||
/// <summary>
|
||||
|
|
@ -320,8 +327,7 @@ class PySparseTensor {
|
|||
/// <param name="storage">a collection reference guards</param>
|
||||
PySparseTensor(std::unique_ptr<SparseTensor>&& instance,
|
||||
std::vector<pybind11::object>&& storage)
|
||||
: backing_storage_(std::move(storage)), ort_value_() {
|
||||
Init(std::move(instance));
|
||||
: instance_(std::move(instance)), backing_storage_(std::move(storage)), ort_value_() {
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
@ -329,12 +335,16 @@ class PySparseTensor {
|
|||
/// </summary>
|
||||
/// <param name="instance"></param>
|
||||
explicit PySparseTensor(std::unique_ptr<SparseTensor>&& instance)
|
||||
: backing_storage_(), ort_value_() {
|
||||
Init(std::move(instance));
|
||||
: instance_(std::move(instance)), backing_storage_(), ort_value_() {
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Edge case when we can not copy memory on GPU and therefore
|
||||
/// can not own it.
|
||||
/// </summary>
|
||||
/// <param name="ort_value"></param>
|
||||
explicit PySparseTensor(const OrtValue& ort_value)
|
||||
: backing_storage_(), ort_value_(ort_value) {}
|
||||
: instance_(), backing_storage_(), ort_value_(ort_value) {}
|
||||
|
||||
PySparseTensor(const PySparseTensor&) = delete;
|
||||
PySparseTensor& operator=(const PySparseTensor&) = delete;
|
||||
|
|
@ -344,27 +354,35 @@ class PySparseTensor {
|
|||
}
|
||||
|
||||
PySparseTensor& operator=(PySparseTensor&& o) noexcept {
|
||||
ort_value_ = std::move(o.ort_value_);
|
||||
instance_ = std::move(o.instance_);
|
||||
backing_storage_ = std::move(o.backing_storage_);
|
||||
ort_value_ = std::move(o.ort_value_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
~PySparseTensor();
|
||||
|
||||
const SparseTensor& Instance() const {
|
||||
if (instance_) {
|
||||
return *instance_;
|
||||
}
|
||||
return ort_value_.Get<SparseTensor>();
|
||||
}
|
||||
|
||||
std::unique_ptr<OrtValue> AsOrtValue() const {
|
||||
return std::make_unique<OrtValue>(ort_value_);
|
||||
}
|
||||
std::unique_ptr<OrtValue> AsOrtValue() const;
|
||||
|
||||
private:
|
||||
void Init(std::unique_ptr<SparseTensor>&& instance);
|
||||
|
||||
// These will hold references to underpinning python array objects
|
||||
// when they serve as a backing storage for a feeding SparseTensor
|
||||
// instance_ represents data that comes as input. Thus we depend on numpy
|
||||
// arrays that own the underlying memory to stay around. We store copies
|
||||
// of py::objects for those arrays in backing_storage_ as an extra ref-count.
|
||||
|
||||
// If we have and are able to copy from the OrtValue returned by run() to CPU, then this owns the data
|
||||
// and backing_storage_ is empty.
|
||||
std::unique_ptr<SparseTensor> instance_;
|
||||
std::vector<pybind11::object> backing_storage_;
|
||||
|
||||
// We create a copy of OrtValue when we obtain it from a run method.
|
||||
OrtValue ort_value_;
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue