mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add torch.promote_types function
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26655 Test Plan: Imported from OSS Differential Revision: D17556196 Pulled By: nairbv fbshipit-source-id: eeebce8968bfb2ffd25c066595bc19e5dee6ea6f
This commit is contained in:
parent
024a422f41
commit
0c6a18de8d
7 changed files with 35 additions and 17 deletions
|
|
@ -1226,6 +1226,7 @@ bool aten_op_is_not_moved_to_c10_yet(const c10::OperatorName& opName) {
|
|||
{"aten::result_type", "Scalar_Tensor"},
|
||||
{"aten::result_type", "Scalar_Scalar"},
|
||||
{"aten::can_cast", ""},
|
||||
{"aten::promote_types", ""},
|
||||
{"aten::_thnn_fused_lstm_cell", ""},
|
||||
{"aten::_thnn_fused_lstm_cell_backward", ""},
|
||||
{"aten::_thnn_differentiable_lstm_cell_backward", ""},
|
||||
|
|
|
|||
|
|
@ -117,4 +117,8 @@ bool can_cast(const at::ScalarType from, const at::ScalarType to) {
|
|||
return at::canCast(from, to);
|
||||
}
|
||||
|
||||
ScalarType promote_types(ScalarType type1, ScalarType type2) {
|
||||
return promoteTypes(type1, type2);
|
||||
}
|
||||
|
||||
}} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -3734,6 +3734,9 @@
|
|||
- func: can_cast(ScalarType from, ScalarType to) -> bool
|
||||
variants: function
|
||||
|
||||
- func: promote_types(ScalarType type1, ScalarType type2) -> ScalarType
|
||||
variants: function
|
||||
|
||||
# NB: Does NOT check precondition that numel == 1
|
||||
- func: _local_scalar_dense(Tensor self) -> Scalar
|
||||
use_c10_dispatcher: full
|
||||
|
|
|
|||
|
|
@ -361,3 +361,4 @@ Utilities
|
|||
.. autofunction:: compiled_with_cxx11_abi
|
||||
.. autofunction:: result_type
|
||||
.. autofunction:: can_cast
|
||||
.. autofunction:: promote_types
|
||||
|
|
|
|||
|
|
@ -312,6 +312,11 @@ class TestTypePromotion(TestCase):
|
|||
self.assertTrue(actual, expected)
|
||||
self.assertTrue(actual.dtype == torch.bool)
|
||||
|
||||
def test_promote_types(self):
|
||||
self.assertEqual(torch.promote_types(torch.float, torch.int), torch.float)
|
||||
self.assertEqual(torch.promote_types(torch.float, torch.double), torch.double)
|
||||
self.assertEqual(torch.promote_types(torch.int, torch.uint8), torch.int)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "no cuda")
|
||||
class TestTypePromotionCuda(TestTypePromotion):
|
||||
def setUp(self):
|
||||
|
|
|
|||
|
|
@ -355,22 +355,6 @@ static PyObject * THPVariable_from_numpy(PyObject* module, PyObject* arg)
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject * THPVariable__promote_types(PyObject* self, PyObject* args, PyObject* kwargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
static PythonArgParser parser({
|
||||
"_promote_types(ScalarType type1, ScalarType type2)",
|
||||
});
|
||||
ParsedArgs<2> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
if (r.idx == 0) {
|
||||
ScalarType promoted = at::promoteTypes(r.scalartype(0), r.scalartype(1));
|
||||
return torch::autograd::utils::wrap(torch::getDtype(promoted));
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static Tensor dispatch_nonzero(const Tensor & self) {
|
||||
AutoNoGIL no_gil;
|
||||
OptionalDeviceGuard device_guard(device_of(self));
|
||||
|
|
@ -456,7 +440,6 @@ static PyMethodDef torch_functions[] = {
|
|||
{"dsmm", (PyCFunction)(void(*)(void))THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
|
||||
{"from_numpy", (PyCFunction)THPVariable_from_numpy, METH_STATIC | METH_O, NULL},
|
||||
{"hsmm", (PyCFunction)(void(*)(void))THPVariable_hspmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
|
||||
{"_promote_types", (PyCFunction)(void(*)(void))THPVariable__promote_types, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
|
||||
{"nonzero", (PyCFunction)(void(*)(void))THPVariable_nonzero, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
|
||||
{"randint", (PyCFunction)(void(*)(void))THPVariable_randint, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
|
||||
{"range", (PyCFunction)(void(*)(void))THPVariable_range, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
|
||||
|
|
|
|||
|
|
@ -3995,6 +3995,27 @@ Example::
|
|||
tensor([-0.2018, -0.2962, -0.0821, -1.1831])
|
||||
""".format(**single_dim_common))
|
||||
|
||||
add_docstr(torch.promote_types,
|
||||
r"""
|
||||
promote_types(type1, type2) -> dtype
|
||||
|
||||
Returns the :class:`torch.dtype` with the smallest size and scalar kind that is
|
||||
not smaller nor of lower kind than either `type1` or `type2`. See type promotion
|
||||
:ref:`documentation <type-promotion-doc>` for more information on the type
|
||||
promotion logic.
|
||||
|
||||
Args:
|
||||
type1 (:class:`torch.dtype`)
|
||||
type2 (:class:`torch.dtype`)
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.promote_types(torch.int32, torch.float32))
|
||||
torch.float32
|
||||
>>> torch.promote_types(torch.uint8, torch.long)
|
||||
torch.long
|
||||
""")
|
||||
|
||||
add_docstr(torch.qr,
|
||||
r"""
|
||||
qr(input, some=True, out=None) -> (Tensor, Tensor)
|
||||
|
|
|
|||
Loading…
Reference in a new issue