From 4e76360261259fff3fc316b3889766ff4239c32f Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Wed, 27 Oct 2021 11:28:37 -0700 Subject: [PATCH] Prevent PySparseTensor form being garbage collected if we have an outstanding OrtValue (#9540) Prevent PySparseTensor form being garbage collected if we have an outstanding OrtValue Improve comments. --- .../python/onnxruntime_pybind_ortvalue.cc | 2 +- .../python/onnxruntime_pybind_state.cc | 2 + .../python/onnxruntime_pybind_state_common.cc | 17 +++++-- .../python/onnxruntime_pybind_state_common.h | 46 +++++++++++++------ 4 files changed, 47 insertions(+), 20 deletions(-) diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index ade210bc33..3911689379 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -110,7 +110,7 @@ void addOrtValueMethods(pybind11::module& m) { Tensor::InitOrtValue(ml_type, shape, std::move(allocator), *ml_value); return ml_value; }) - // This will create a copy of OrtValue (cheap) and will return as a separate OrtValue object + .def_static("ort_value_from_sparse_tensor", [](const PySparseTensor* py_sparse_tensor) -> std::unique_ptr { return py_sparse_tensor->AsOrtValue(); }) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 9124696a2a..2081047fe5 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -208,6 +208,8 @@ py::object GetPyObjectFromSparseTensor(size_t pos, const OrtValue& ort_value, co OrtPybindThrowIfError(status); py_sparse_tensor.reset(new PySparseTensor(std::move(dst_sparse_tensor))); } + } else { + py_sparse_tensor.reset(new PySparseTensor(ort_value)); } py::object result = py::cast(py_sparse_tensor.get(), py::return_value_policy::take_ownership); diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.cc b/onnxruntime/python/onnxruntime_pybind_state_common.cc index 637206bec9..e9a81be84c 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.cc +++ b/onnxruntime/python/onnxruntime_pybind_state_common.cc @@ -80,11 +80,18 @@ OrtValue FromDlpack(PyObject* dlpack_tensor, const bool is_bool_tensor) { #endif -void PySparseTensor::Init(std::unique_ptr&& instance) { - auto sparse_tensor(std::move(instance)); - auto ml_type = DataTypeImpl::GetType(); - ort_value_.Init(sparse_tensor.get(), ml_type, ml_type->GetDeleteFunc()); - sparse_tensor.release(); +std::unique_ptr PySparseTensor::AsOrtValue() const { + if (instance_) { + auto ort_value = std::make_unique(); + auto ml_type = DataTypeImpl::GetType(); + py::object this_object = py::cast(*this); + // Create an std::function deleter that captures and ref-counts this PySparseTensor + ort_value->Init(instance_.get(), ml_type, [object = std::move(this_object)](void*) {}); + return ort_value; + } + + assert(ort_value_.IsAllocated()); + return std::make_unique(ort_value_); } PySparseTensor::~PySparseTensor() { diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index f244a2c8f6..9d9df82318 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -307,8 +307,15 @@ inline AllocatorPtr& GetAllocator() { // This class exposes SparseTensor to Python // The class serves two major purposes // - to be able to map numpy arrays memory and use it on input, this serves as a reference holder -// so incoming arrays do not disappear -// - to be able to expose SparseTensor returned from run method +// so incoming arrays do not disappear. To this end we create an instance of SparseTensor +// on top of the user provided numpy arrays and create a duplicate of py::objects for those +// numpy array for ref-counting purposes and store it here. +// +// - to be able to expose SparseTensor returned from run method. We get an OrtValue from run() +// and store a copy of it in ort_value_. The OrtValue shared_ptr ref-counting will make sure +// the memory stays around. +// +// An object of the class must never have both instance_ and ort_value_ have data at the same time. class PySparseTensor { public: /// @@ -320,8 +327,7 @@ class PySparseTensor { /// a collection reference guards PySparseTensor(std::unique_ptr&& instance, std::vector&& storage) - : backing_storage_(std::move(storage)), ort_value_() { - Init(std::move(instance)); + : instance_(std::move(instance)), backing_storage_(std::move(storage)), ort_value_() { } /// @@ -329,12 +335,16 @@ class PySparseTensor { /// /// explicit PySparseTensor(std::unique_ptr&& instance) - : backing_storage_(), ort_value_() { - Init(std::move(instance)); + : instance_(std::move(instance)), backing_storage_(), ort_value_() { } + /// + /// Edge case when we can not copy memory on GPU and therefore + /// can not own it. + /// + /// explicit PySparseTensor(const OrtValue& ort_value) - : backing_storage_(), ort_value_(ort_value) {} + : instance_(), backing_storage_(), ort_value_(ort_value) {} PySparseTensor(const PySparseTensor&) = delete; PySparseTensor& operator=(const PySparseTensor&) = delete; @@ -344,27 +354,35 @@ class PySparseTensor { } PySparseTensor& operator=(PySparseTensor&& o) noexcept { - ort_value_ = std::move(o.ort_value_); + instance_ = std::move(o.instance_); backing_storage_ = std::move(o.backing_storage_); + ort_value_ = std::move(o.ort_value_); return *this; } ~PySparseTensor(); const SparseTensor& Instance() const { + if (instance_) { + return *instance_; + } return ort_value_.Get(); } - std::unique_ptr AsOrtValue() const { - return std::make_unique(ort_value_); - } + std::unique_ptr AsOrtValue() const; private: - void Init(std::unique_ptr&& instance); - // These will hold references to underpinning python array objects - // when they serve as a backing storage for a feeding SparseTensor + // instance_ represents data that comes as input. Thus we depend on numpy + // arrays that own the underlying memory to stay around. We store copies + // of py::objects for those arrays in backing_storage_ as an extra ref-count. + + // If we have and are able to copy from the OrtValue returned by run() to CPU, then this owns the data + // and backing_storage_ is empty. + std::unique_ptr instance_; std::vector backing_storage_; + + // We create a copy of OrtValue when we obtain it from a run method. OrtValue ort_value_; };