diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index d891731903..0cf18a36df 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -223,8 +223,10 @@ if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS) endif() if (NOT (UNIX AND onnxruntime_ENABLE_PYTHON AND onnxruntime_ENABLE_TRAINING AND (NOT onnxruntime_BUILD_SHARED_LIB))) + if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) + message(WARNING "onnxruntime_ENABLE_TRAINING_TORCH_INTEROP is turned OFF due to incompatible build combinations.") + endif() set(onnxruntime_ENABLE_TRAINING_TORCH_INTEROP OFF) - message(WARNING "onnxruntime_ENABLE_TRAINING_TORCH_INTEROP is turned OFF due to incompatible build combinations.") endif() set(onnxruntime_REQUIRE_PYTHON_EMBED_LIB OFF) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 8333928eb0..7d95f2a60d 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -10,13 +10,9 @@ set(onnxruntime_pybind_srcs_pattern ) if (onnxruntime_ENABLE_TRAINING) - # todo: move dlpack/dlpack_python.* to ${ONNXRUNTIME_ROOT}/python folder. list(APPEND onnxruntime_pybind_srcs_pattern "${ORTTRAINING_ROOT}/orttraining/python/*.cc" "${ORTTRAINING_ROOT}/orttraining/python/*.h" - "${ONNXRUNTIME_ROOT}/core/dlpack/dlpack_python.cc" - "${ONNXRUNTIME_ROOT}/core/dlpack/dlpack_python.h" - "${ONNXRUNTIME_ROOT}/core/dlpack/python_common.h" ) endif() diff --git a/onnxruntime/core/dlpack/dlpack_python.cc b/onnxruntime/core/dlpack/dlpack_python.cc deleted file mode 100644 index 781c1c7c64..0000000000 --- a/onnxruntime/core/dlpack/dlpack_python.cc +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/dlpack/dlpack_python.h" - -namespace onnxruntime { -namespace dlpack { - -static void DlpackCapsuleDestructor(PyObject* data) { - DLManagedTensor* dlmanged_tensor = reinterpret_cast( - PyCapsule_GetPointer(data, "dltensor")); - if (dlmanged_tensor) { - // The dlmanged_tensor has not been consumed, call deleter ourselves. - dlmanged_tensor->deleter(const_cast(dlmanged_tensor)); - } else { - // The dlmanged_tensor has been consumed, - // PyCapsule_GetPointer has set an error indicator. - PyErr_Clear(); - } -} - -PyObject* ToDlpack(OrtValue ort_value) { - DLManagedTensor* dlmanaged_tensor = dlpack::OrtValueToDlpack(ort_value); - return PyCapsule_New(dlmanaged_tensor, "dltensor", DlpackCapsuleDestructor); -} - -OrtValue FromDlpack(PyObject* dlpack_tensor, const bool is_bool_tensor) { - // Extract DLPack tensor pointer from the capsule carrier. - DLManagedTensor* dlmanaged_tensor = (DLManagedTensor*)PyCapsule_GetPointer(dlpack_tensor, "dltensor"); - OrtValue ort_value = dlpack::DlpackToOrtValue(dlmanaged_tensor, is_bool_tensor); - // Make sure this capsule will never be used again. - PyCapsule_SetName(dlpack_tensor, "used_dltensor"); - return ort_value; -} - -} // namespace dlpack -} // namespace onnxruntime diff --git a/onnxruntime/core/dlpack/dlpack_python.h b/onnxruntime/core/dlpack/dlpack_python.h deleted file mode 100644 index 082e886e21..0000000000 --- a/onnxruntime/core/dlpack/dlpack_python.h +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -/// Python level interface for DLPack conversion. - -/// todo(pengwa) move this file back to pybind projects. -#pragma once - -#include "core/dlpack/dlpack_converter.h" -#include "core/dlpack/python_common.h" - -namespace onnxruntime { - -namespace dlpack { - -// Allocate a new Capsule object, which takes the ownership of OrtValue. -// Caller is responsible for releasing. -// This function calls OrtValueToDlpack(...). -PyObject* ToDlpack(OrtValue ort_value); - -// Consume a Capsule object and claims the ownership of its underlying tensor to -// create a OrtValue. This function calls DlpackToOrtValue(...) to do the conversion. -OrtValue FromDlpack(PyObject* dlpack_tensor, const bool is_bool_tensor); - -} // namespace dlpack -} // namespace onnxruntime diff --git a/onnxruntime/core/dlpack/python_common.h b/onnxruntime/core/dlpack/python_common.h deleted file mode 100644 index 0908e65819..0000000000 --- a/onnxruntime/core/dlpack/python_common.h +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -// Avoid linking to pythonX_d.lib on Windows in debug build -#ifdef _WIN32 -#pragma warning(push) -#pragma warning(disable : 4510 4610 4512 4005) -#ifdef _DEBUG -#define ORT_DISABLE_INCLUDE_DEBUG_PYTHON_LIB -#undef _DEBUG -#endif -#endif - -#include - -#ifdef _WIN32 -#ifdef ORT_DISABLE_INCLUDE_DEBUG_PYTHON_LIB -#define _DEBUG -#undef ORT_DISABLE_INCLUDE_DEBUG_PYTHON_LIB -#endif -#pragma warning(pop) -#endif diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index 2301db3e36..21ced780af 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -10,10 +10,6 @@ #define PY_ARRAY_UNIQUE_SYMBOL onnxruntime_python_ARRAY_API #include -#ifdef ENABLE_TRAINING -#include "core/dlpack/dlpack_python.h" -#endif - #include "core/framework/ml_value.h" #include "core/framework/tensor.h" #include "core/framework/sparse_tensor.h" @@ -207,10 +203,10 @@ void addOrtValueMethods(pybind11::module& m) { }) #ifdef ENABLE_TRAINING .def("to_dlpack", [](OrtValue* ort_value) -> py::object { - return py::reinterpret_steal(dlpack::ToDlpack(*ort_value)); + return py::reinterpret_steal(ToDlpack(*ort_value)); }) .def_static("from_dlpack", [](py::object data, bool is_bool_tensor = false) { - return dlpack::FromDlpack(data.ptr(), is_bool_tensor); + return FromDlpack(data.ptr(), is_bool_tensor); }) #endif ; diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.cc b/onnxruntime/python/onnxruntime_pybind_state_common.cc index fb6b35aabe..cd9dc1e29f 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.cc +++ b/onnxruntime/python/onnxruntime_pybind_state_common.cc @@ -19,7 +19,6 @@ std::string openvino_device_type; std::string nuphar_settings; #endif - // TODO remove deprecated global config OrtDevice::DeviceId cuda_device_id = 0; // TODO remove deprecated global config @@ -44,5 +43,41 @@ onnxruntime::ROCMExecutionProviderExternalAllocatorInfo external_allocator_info{ onnxruntime::ArenaExtendStrategy arena_extend_strategy = onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo; #endif +#ifdef ENABLE_TRAINING + +static void DlpackCapsuleDestructor(PyObject* data) { + DLManagedTensor* dlmanged_tensor = reinterpret_cast( + PyCapsule_GetPointer(data, "dltensor")); + if (dlmanged_tensor) { + // The dlmanged_tensor has not been consumed, call deleter ourselves. + dlmanged_tensor->deleter(const_cast(dlmanged_tensor)); + } else { + // The dlmanged_tensor has been consumed, + // PyCapsule_GetPointer has set an error indicator. + PyErr_Clear(); + } +} + +// Allocate a new Capsule object, which takes the ownership of OrtValue. +// Caller is responsible for releasing. +// This function calls OrtValueToDlpack(...). +PyObject* ToDlpack(OrtValue ort_value) { + DLManagedTensor* dlmanaged_tensor = dlpack::OrtValueToDlpack(ort_value); + return PyCapsule_New(dlmanaged_tensor, "dltensor", DlpackCapsuleDestructor); +} + +// Consume a Capsule object and claims the ownership of its underlying tensor to +// create a OrtValue. This function calls DlpackToOrtValue(...) to do the conversion. +OrtValue FromDlpack(PyObject* dlpack_tensor, const bool is_bool_tensor) { + // Extract DLPack tensor pointer from the capsule carrier. + DLManagedTensor* dlmanaged_tensor = (DLManagedTensor*)PyCapsule_GetPointer(dlpack_tensor, "dltensor"); + OrtValue ort_value = dlpack::DlpackToOrtValue(dlmanaged_tensor, is_bool_tensor); + // Make sure this capsule will never be used again. + PyCapsule_SetName(dlpack_tensor, "used_dltensor"); + return ort_value; +} + +#endif + } // namespace python } // namespace onnxruntime diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index f6fe658f8f..3e3b7d46a1 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -341,5 +341,18 @@ bool CheckIfTensor(const std::vector& def_list, const std::string& name, /*out*/ ONNX_NAMESPACE::TypeProto& type_proto); +#ifdef ENABLE_TRAINING + +// Allocate a new Capsule object, which takes the ownership of OrtValue. +// Caller is responsible for releasing. +// This function calls OrtValueToDlpack(...). +PyObject* ToDlpack(OrtValue ort_value); + +// Consume a Capsule object and claims the ownership of its underlying tensor to +// create a OrtValue. This function calls DlpackToOrtValue(...) to do the conversion. +OrtValue FromDlpack(PyObject* dlpack_tensor, const bool is_bool_tensor); + +#endif + } // namespace python } // namespace onnxruntime diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index a7eabe9b4b..c1cbff56da 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -8,7 +8,6 @@ #include #include -#include "core/dlpack/dlpack_python.h" #include "core/session/environment.h" #include "orttraining/core/session/training_session.h" #include "orttraining/core/agent/training_agent.h" @@ -311,21 +310,20 @@ void addObjectMethodsForTraining(py::module& m) { v->push_back(ortvalue); }) .def("push_back", [](std::vector* v, py::object dlpack_tensor, const bool is_bool_tensor) { - v->push_back(dlpack::FromDlpack(dlpack_tensor.ptr(), is_bool_tensor)); + v->push_back(FromDlpack(dlpack_tensor.ptr(), is_bool_tensor)); }) .def("reserve", [](std::vector* v, const size_t len) { v->reserve(len); }) .def("shrink_to_fit", [](std::vector* v) { v->shrink_to_fit(); }) .def("__len__", [](const std::vector& v) { return v.size(); }) - .def( - "__iter__", [](const std::vector& v) { - return py::make_iterator(v.cbegin(), v.cend()); - }, - py::keep_alive<0, 1>()) + .def("__iter__", [](const std::vector& v) { + return py::make_iterator(v.cbegin(), v.cend()); + }, + py::keep_alive<0, 1>()) .def("__getitem__", [](const std::vector& v, const size_t idx) { return v.at(idx); }) .def("dlpack_at", [](std::vector* v, const size_t idx) { - return py::reinterpret_steal(dlpack::ToDlpack(v->at(idx))); + return py::reinterpret_steal(ToDlpack(v->at(idx))); }); py::class_ parameters(m, "TrainingParameters", R"pbdoc(Configuration information for training.)pbdoc");