pytorch/torch/csrc/tensor/python_tensor.h
garfield1997 3a5bf0bc36 expose extra torch_python apis (#144746)
Fixes #144302
After checking the code of my third-party devices, I think these APIs are also relied on by us, so I exposed them according to the discussion in the issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144746
Approved by: https://github.com/albanD
2025-01-16 20:50:31 +00:00

35 lines
1.1 KiB
C++

#pragma once
#include <c10/core/Device.h>
#include <c10/core/DispatchKey.h>
#include <c10/core/ScalarType.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/python_headers.h>
namespace at {
class Tensor;
} // namespace at
namespace torch::tensors {
// Initializes the Python tensor type objects: torch.FloatTensor,
// torch.DoubleTensor, etc. and binds them in their containing modules.
TORCH_PYTHON_API void initialize_python_bindings();
// Same as set_default_tensor_type() but takes a PyObject*
TORCH_PYTHON_API void py_set_default_tensor_type(PyObject* type_obj);
// Same as py_set_default_tensor_type, but only changes the dtype (ScalarType).
TORCH_PYTHON_API void py_set_default_dtype(PyObject* dtype_obj);
// Gets the DispatchKey for the default tensor type.
//
// TODO: This is nuts! There is no reason to let the default tensor type id
// change. Probably only store ScalarType, as that's the only flex point
// we support.
TORCH_API c10::DispatchKey get_default_dispatch_key();
TORCH_PYTHON_API at::Device get_default_device();
// Gets the ScalarType for the default tensor type.
at::ScalarType get_default_scalar_type();
} // namespace torch::tensors