mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
decouple the shared python dependency (#8294)
* remove warnining message for non-training build * move to/from dlpack for onnxruntime_python back into python project
This commit is contained in:
parent
067759b387
commit
5454af4b95
9 changed files with 60 additions and 107 deletions
|
|
@ -223,8 +223,10 @@ if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS)
|
|||
endif()
|
||||
|
||||
if (NOT (UNIX AND onnxruntime_ENABLE_PYTHON AND onnxruntime_ENABLE_TRAINING AND (NOT onnxruntime_BUILD_SHARED_LIB)))
|
||||
if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
|
||||
message(WARNING "onnxruntime_ENABLE_TRAINING_TORCH_INTEROP is turned OFF due to incompatible build combinations.")
|
||||
endif()
|
||||
set(onnxruntime_ENABLE_TRAINING_TORCH_INTEROP OFF)
|
||||
message(WARNING "onnxruntime_ENABLE_TRAINING_TORCH_INTEROP is turned OFF due to incompatible build combinations.")
|
||||
endif()
|
||||
|
||||
set(onnxruntime_REQUIRE_PYTHON_EMBED_LIB OFF)
|
||||
|
|
|
|||
|
|
@ -10,13 +10,9 @@ set(onnxruntime_pybind_srcs_pattern
|
|||
)
|
||||
|
||||
if (onnxruntime_ENABLE_TRAINING)
|
||||
# todo: move dlpack/dlpack_python.* to ${ONNXRUNTIME_ROOT}/python folder.
|
||||
list(APPEND onnxruntime_pybind_srcs_pattern
|
||||
"${ORTTRAINING_ROOT}/orttraining/python/*.cc"
|
||||
"${ORTTRAINING_ROOT}/orttraining/python/*.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/dlpack/dlpack_python.cc"
|
||||
"${ONNXRUNTIME_ROOT}/core/dlpack/dlpack_python.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/dlpack/python_common.h"
|
||||
)
|
||||
endif()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,37 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/dlpack/dlpack_python.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace dlpack {
|
||||
|
||||
static void DlpackCapsuleDestructor(PyObject* data) {
|
||||
DLManagedTensor* dlmanged_tensor = reinterpret_cast<DLManagedTensor*>(
|
||||
PyCapsule_GetPointer(data, "dltensor"));
|
||||
if (dlmanged_tensor) {
|
||||
// The dlmanged_tensor has not been consumed, call deleter ourselves.
|
||||
dlmanged_tensor->deleter(const_cast<DLManagedTensor*>(dlmanged_tensor));
|
||||
} else {
|
||||
// The dlmanged_tensor has been consumed,
|
||||
// PyCapsule_GetPointer has set an error indicator.
|
||||
PyErr_Clear();
|
||||
}
|
||||
}
|
||||
|
||||
PyObject* ToDlpack(OrtValue ort_value) {
|
||||
DLManagedTensor* dlmanaged_tensor = dlpack::OrtValueToDlpack(ort_value);
|
||||
return PyCapsule_New(dlmanaged_tensor, "dltensor", DlpackCapsuleDestructor);
|
||||
}
|
||||
|
||||
OrtValue FromDlpack(PyObject* dlpack_tensor, const bool is_bool_tensor) {
|
||||
// Extract DLPack tensor pointer from the capsule carrier.
|
||||
DLManagedTensor* dlmanaged_tensor = (DLManagedTensor*)PyCapsule_GetPointer(dlpack_tensor, "dltensor");
|
||||
OrtValue ort_value = dlpack::DlpackToOrtValue(dlmanaged_tensor, is_bool_tensor);
|
||||
// Make sure this capsule will never be used again.
|
||||
PyCapsule_SetName(dlpack_tensor, "used_dltensor");
|
||||
return ort_value;
|
||||
}
|
||||
|
||||
} // namespace dlpack
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
/// Python level interface for DLPack conversion.
|
||||
|
||||
/// todo(pengwa) move this file back to pybind projects.
|
||||
#pragma once
|
||||
|
||||
#include "core/dlpack/dlpack_converter.h"
|
||||
#include "core/dlpack/python_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace dlpack {
|
||||
|
||||
// Allocate a new Capsule object, which takes the ownership of OrtValue.
|
||||
// Caller is responsible for releasing.
|
||||
// This function calls OrtValueToDlpack(...).
|
||||
PyObject* ToDlpack(OrtValue ort_value);
|
||||
|
||||
// Consume a Capsule object and claims the ownership of its underlying tensor to
|
||||
// create a OrtValue. This function calls DlpackToOrtValue(...) to do the conversion.
|
||||
OrtValue FromDlpack(PyObject* dlpack_tensor, const bool is_bool_tensor);
|
||||
|
||||
} // namespace dlpack
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
// Avoid linking to pythonX_d.lib on Windows in debug build
|
||||
#ifdef _WIN32
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 4510 4610 4512 4005)
|
||||
#ifdef _DEBUG
|
||||
#define ORT_DISABLE_INCLUDE_DEBUG_PYTHON_LIB
|
||||
#undef _DEBUG
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#ifdef ORT_DISABLE_INCLUDE_DEBUG_PYTHON_LIB
|
||||
#define _DEBUG
|
||||
#undef ORT_DISABLE_INCLUDE_DEBUG_PYTHON_LIB
|
||||
#endif
|
||||
#pragma warning(pop)
|
||||
#endif
|
||||
|
|
@ -10,10 +10,6 @@
|
|||
#define PY_ARRAY_UNIQUE_SYMBOL onnxruntime_python_ARRAY_API
|
||||
#include <numpy/arrayobject.h>
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
#include "core/dlpack/dlpack_python.h"
|
||||
#endif
|
||||
|
||||
#include "core/framework/ml_value.h"
|
||||
#include "core/framework/tensor.h"
|
||||
#include "core/framework/sparse_tensor.h"
|
||||
|
|
@ -207,10 +203,10 @@ void addOrtValueMethods(pybind11::module& m) {
|
|||
})
|
||||
#ifdef ENABLE_TRAINING
|
||||
.def("to_dlpack", [](OrtValue* ort_value) -> py::object {
|
||||
return py::reinterpret_steal<py::object>(dlpack::ToDlpack(*ort_value));
|
||||
return py::reinterpret_steal<py::object>(ToDlpack(*ort_value));
|
||||
})
|
||||
.def_static("from_dlpack", [](py::object data, bool is_bool_tensor = false) {
|
||||
return dlpack::FromDlpack(data.ptr(), is_bool_tensor);
|
||||
return FromDlpack(data.ptr(), is_bool_tensor);
|
||||
})
|
||||
#endif
|
||||
;
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ std::string openvino_device_type;
|
|||
std::string nuphar_settings;
|
||||
#endif
|
||||
|
||||
|
||||
// TODO remove deprecated global config
|
||||
OrtDevice::DeviceId cuda_device_id = 0;
|
||||
// TODO remove deprecated global config
|
||||
|
|
@ -44,5 +43,41 @@ onnxruntime::ROCMExecutionProviderExternalAllocatorInfo external_allocator_info{
|
|||
onnxruntime::ArenaExtendStrategy arena_extend_strategy = onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo;
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
|
||||
static void DlpackCapsuleDestructor(PyObject* data) {
|
||||
DLManagedTensor* dlmanged_tensor = reinterpret_cast<DLManagedTensor*>(
|
||||
PyCapsule_GetPointer(data, "dltensor"));
|
||||
if (dlmanged_tensor) {
|
||||
// The dlmanged_tensor has not been consumed, call deleter ourselves.
|
||||
dlmanged_tensor->deleter(const_cast<DLManagedTensor*>(dlmanged_tensor));
|
||||
} else {
|
||||
// The dlmanged_tensor has been consumed,
|
||||
// PyCapsule_GetPointer has set an error indicator.
|
||||
PyErr_Clear();
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate a new Capsule object, which takes the ownership of OrtValue.
|
||||
// Caller is responsible for releasing.
|
||||
// This function calls OrtValueToDlpack(...).
|
||||
PyObject* ToDlpack(OrtValue ort_value) {
|
||||
DLManagedTensor* dlmanaged_tensor = dlpack::OrtValueToDlpack(ort_value);
|
||||
return PyCapsule_New(dlmanaged_tensor, "dltensor", DlpackCapsuleDestructor);
|
||||
}
|
||||
|
||||
// Consume a Capsule object and claims the ownership of its underlying tensor to
|
||||
// create a OrtValue. This function calls DlpackToOrtValue(...) to do the conversion.
|
||||
OrtValue FromDlpack(PyObject* dlpack_tensor, const bool is_bool_tensor) {
|
||||
// Extract DLPack tensor pointer from the capsule carrier.
|
||||
DLManagedTensor* dlmanaged_tensor = (DLManagedTensor*)PyCapsule_GetPointer(dlpack_tensor, "dltensor");
|
||||
OrtValue ort_value = dlpack::DlpackToOrtValue(dlmanaged_tensor, is_bool_tensor);
|
||||
// Make sure this capsule will never be used again.
|
||||
PyCapsule_SetName(dlpack_tensor, "used_dltensor");
|
||||
return ort_value;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace python
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -341,5 +341,18 @@ bool CheckIfTensor(const std::vector<const NodeArg*>& def_list,
|
|||
const std::string& name,
|
||||
/*out*/ ONNX_NAMESPACE::TypeProto& type_proto);
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
|
||||
// Allocate a new Capsule object, which takes the ownership of OrtValue.
|
||||
// Caller is responsible for releasing.
|
||||
// This function calls OrtValueToDlpack(...).
|
||||
PyObject* ToDlpack(OrtValue ort_value);
|
||||
|
||||
// Consume a Capsule object and claims the ownership of its underlying tensor to
|
||||
// create a OrtValue. This function calls DlpackToOrtValue(...) to do the conversion.
|
||||
OrtValue FromDlpack(PyObject* dlpack_tensor, const bool is_bool_tensor);
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace python
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
#include <pybind11/stl.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "core/dlpack/dlpack_python.h"
|
||||
#include "core/session/environment.h"
|
||||
#include "orttraining/core/session/training_session.h"
|
||||
#include "orttraining/core/agent/training_agent.h"
|
||||
|
|
@ -311,21 +310,20 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
v->push_back(ortvalue);
|
||||
})
|
||||
.def("push_back", [](std::vector<OrtValue>* v, py::object dlpack_tensor, const bool is_bool_tensor) {
|
||||
v->push_back(dlpack::FromDlpack(dlpack_tensor.ptr(), is_bool_tensor));
|
||||
v->push_back(FromDlpack(dlpack_tensor.ptr(), is_bool_tensor));
|
||||
})
|
||||
.def("reserve", [](std::vector<OrtValue>* v, const size_t len) { v->reserve(len); })
|
||||
.def("shrink_to_fit", [](std::vector<OrtValue>* v) { v->shrink_to_fit(); })
|
||||
.def("__len__", [](const std::vector<OrtValue>& v) { return v.size(); })
|
||||
.def(
|
||||
"__iter__", [](const std::vector<OrtValue>& v) {
|
||||
return py::make_iterator(v.cbegin(), v.cend());
|
||||
},
|
||||
py::keep_alive<0, 1>())
|
||||
.def("__iter__", [](const std::vector<OrtValue>& v) {
|
||||
return py::make_iterator(v.cbegin(), v.cend());
|
||||
},
|
||||
py::keep_alive<0, 1>())
|
||||
.def("__getitem__", [](const std::vector<OrtValue>& v, const size_t idx) {
|
||||
return v.at(idx);
|
||||
})
|
||||
.def("dlpack_at", [](std::vector<OrtValue>* v, const size_t idx) {
|
||||
return py::reinterpret_steal<py::object>(dlpack::ToDlpack(v->at(idx)));
|
||||
return py::reinterpret_steal<py::object>(ToDlpack(v->at(idx)));
|
||||
});
|
||||
|
||||
py::class_<TrainingParameters> parameters(m, "TrainingParameters", R"pbdoc(Configuration information for training.)pbdoc");
|
||||
|
|
|
|||
Loading…
Reference in a new issue