diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc index 40fa2fc5cc..ac619bdc39 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc @@ -31,7 +31,6 @@ constexpr bool IsForwardPassOperator(ptrdiff_t op_order_in_topological_sort, // Reset seed attribute for the dropout node if the seed is not set. bool SetSeedForDropoutNode(Node& node) { // ONNX Dropout 1, 6, 7, 10 do not have seed attribute, so we remove them from the recompute support. - // TODO(pengwa): add the opset check in GetAllowedRecomputeOps. if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Dropout", {12, 13}, kOnnxDomain) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "BitmaskDropout", {1}, kMSDomain) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "BiasDropout", {1}, kMSDomain) || diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.cc b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.cc index 88e93b26e0..d511743c4b 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.cc +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.cc @@ -60,9 +60,10 @@ std::vector custom_function_backward_runner(const char* func_name_cha tensor = torch::utils::tensor_fromDLPack(args[arg_index]); } else { TORCH_CHECK(args[arg_index] == Py_None, "Only None is supported for non-tensor input."); - PyObject* fw_kernel_invoke_id = PyObject_GetAttrString(ctx.ptr(), "fw_kernel_invoke_id"); + py::object fw_kernel_invoke_id = PyObject_FastGetAttrString(ctx.ptr(), "fw_kernel_invoke_id"); + TORCH_CHECK(fw_kernel_invoke_id.ptr() != nullptr, "fw_kernel_invoke_id is not found in the context."); std::string fw_kernel_invoke_id_str = - py::cast(py::reinterpret_borrow(fw_kernel_invoke_id)); + py::cast(fw_kernel_invoke_id); CustomFuncOpKernelInfo& fw_kernel_info = KernelInfoStore::GetInstance().GetKernelInfoMap().at(fw_kernel_invoke_id_str); if (fw_kernel_info.materialize_grads) { diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.cc b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.cc index 599bdf8139..3bb5151265 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.cc +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.cc @@ -255,7 +255,7 @@ static py::object get_mockup_context_class() { throw std::runtime_error("Fails to import the module."); } - auto python_class = py::reinterpret_steal(PyObject_GetAttrString(module.ptr(), "FakeContext")); + auto python_class = PyObject_FastGetAttrString(module.ptr(), "FakeContext"); if (!PyCallable_Check(python_class.ptr())) { throw std::runtime_error("Cannot instantiate the Python class"); }