Prevent PySparseTensor form being garbage collected if we have an outstanding OrtValue (#9540)

Prevent PySparseTensor form being garbage collected if we have an outstanding OrtValue
  Improve  comments.
This commit is contained in:
Dmitri Smirnov 2021-10-27 11:28:37 -07:00 committed by GitHub
parent aa76520e60
commit 4e76360261
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 47 additions and 20 deletions

View file

@ -110,7 +110,7 @@ void addOrtValueMethods(pybind11::module& m) {
Tensor::InitOrtValue(ml_type, shape, std::move(allocator), *ml_value);
return ml_value;
})
// This will create a copy of OrtValue (cheap) and will return as a separate OrtValue object
.def_static("ort_value_from_sparse_tensor", [](const PySparseTensor* py_sparse_tensor) -> std::unique_ptr<OrtValue> {
return py_sparse_tensor->AsOrtValue();
})

View file

@ -208,6 +208,8 @@ py::object GetPyObjectFromSparseTensor(size_t pos, const OrtValue& ort_value, co
OrtPybindThrowIfError(status);
py_sparse_tensor.reset(new PySparseTensor(std::move(dst_sparse_tensor)));
}
} else {
py_sparse_tensor.reset(new PySparseTensor(ort_value));
}
py::object result = py::cast(py_sparse_tensor.get(), py::return_value_policy::take_ownership);

View file

@ -80,11 +80,18 @@ OrtValue FromDlpack(PyObject* dlpack_tensor, const bool is_bool_tensor) {
#endif
void PySparseTensor::Init(std::unique_ptr<SparseTensor>&& instance) {
auto sparse_tensor(std::move(instance));
auto ml_type = DataTypeImpl::GetType<SparseTensor>();
ort_value_.Init(sparse_tensor.get(), ml_type, ml_type->GetDeleteFunc());
sparse_tensor.release();
std::unique_ptr<OrtValue> PySparseTensor::AsOrtValue() const {
if (instance_) {
auto ort_value = std::make_unique<OrtValue>();
auto ml_type = DataTypeImpl::GetType<SparseTensor>();
py::object this_object = py::cast(*this);
// Create an std::function deleter that captures and ref-counts this PySparseTensor
ort_value->Init(instance_.get(), ml_type, [object = std::move(this_object)](void*) {});
return ort_value;
}
assert(ort_value_.IsAllocated());
return std::make_unique<OrtValue>(ort_value_);
}
PySparseTensor::~PySparseTensor() {

View file

@ -307,8 +307,15 @@ inline AllocatorPtr& GetAllocator() {
// This class exposes SparseTensor to Python
// The class serves two major purposes
// - to be able to map numpy arrays memory and use it on input, this serves as a reference holder
// so incoming arrays do not disappear
// - to be able to expose SparseTensor returned from run method
// so incoming arrays do not disappear. To this end we create an instance of SparseTensor
// on top of the user provided numpy arrays and create a duplicate of py::objects for those
// numpy array for ref-counting purposes and store it here.
//
// - to be able to expose SparseTensor returned from run method. We get an OrtValue from run()
// and store a copy of it in ort_value_. The OrtValue shared_ptr ref-counting will make sure
// the memory stays around.
//
// An object of the class must never have both instance_ and ort_value_ have data at the same time.
class PySparseTensor {
public:
/// <summary>
@ -320,8 +327,7 @@ class PySparseTensor {
/// <param name="storage">a collection reference guards</param>
PySparseTensor(std::unique_ptr<SparseTensor>&& instance,
std::vector<pybind11::object>&& storage)
: backing_storage_(std::move(storage)), ort_value_() {
Init(std::move(instance));
: instance_(std::move(instance)), backing_storage_(std::move(storage)), ort_value_() {
}
/// <summary>
@ -329,12 +335,16 @@ class PySparseTensor {
/// </summary>
/// <param name="instance"></param>
explicit PySparseTensor(std::unique_ptr<SparseTensor>&& instance)
: backing_storage_(), ort_value_() {
Init(std::move(instance));
: instance_(std::move(instance)), backing_storage_(), ort_value_() {
}
/// <summary>
/// Edge case when we can not copy memory on GPU and therefore
/// can not own it.
/// </summary>
/// <param name="ort_value"></param>
explicit PySparseTensor(const OrtValue& ort_value)
: backing_storage_(), ort_value_(ort_value) {}
: instance_(), backing_storage_(), ort_value_(ort_value) {}
PySparseTensor(const PySparseTensor&) = delete;
PySparseTensor& operator=(const PySparseTensor&) = delete;
@ -344,27 +354,35 @@ class PySparseTensor {
}
PySparseTensor& operator=(PySparseTensor&& o) noexcept {
ort_value_ = std::move(o.ort_value_);
instance_ = std::move(o.instance_);
backing_storage_ = std::move(o.backing_storage_);
ort_value_ = std::move(o.ort_value_);
return *this;
}
~PySparseTensor();
const SparseTensor& Instance() const {
if (instance_) {
return *instance_;
}
return ort_value_.Get<SparseTensor>();
}
std::unique_ptr<OrtValue> AsOrtValue() const {
return std::make_unique<OrtValue>(ort_value_);
}
std::unique_ptr<OrtValue> AsOrtValue() const;
private:
void Init(std::unique_ptr<SparseTensor>&& instance);
// These will hold references to underpinning python array objects
// when they serve as a backing storage for a feeding SparseTensor
// instance_ represents data that comes as input. Thus we depend on numpy
// arrays that own the underlying memory to stay around. We store copies
// of py::objects for those arrays in backing_storage_ as an extra ref-count.
// If we have and are able to copy from the OrtValue returned by run() to CPU, then this owns the data
// and backing_storage_ is empty.
std::unique_ptr<SparseTensor> instance_;
std::vector<pybind11::object> backing_storage_;
// We create a copy of OrtValue when we obtain it from a run method.
OrtValue ort_value_;
};