Add torch.promote_types function

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26655

Test Plan: Imported from OSS

Differential Revision: D17556196

Pulled By: nairbv

fbshipit-source-id: eeebce8968bfb2ffd25c066595bc19e5dee6ea6f
This commit is contained in:
Brian Vaughan 2019-09-27 16:46:43 -07:00 committed by Facebook Github Bot
parent 024a422f41
commit 0c6a18de8d
7 changed files with 35 additions and 17 deletions

View file

@ -1226,6 +1226,7 @@ bool aten_op_is_not_moved_to_c10_yet(const c10::OperatorName& opName) {
{"aten::result_type", "Scalar_Tensor"},
{"aten::result_type", "Scalar_Scalar"},
{"aten::can_cast", ""},
{"aten::promote_types", ""},
{"aten::_thnn_fused_lstm_cell", ""},
{"aten::_thnn_fused_lstm_cell_backward", ""},
{"aten::_thnn_differentiable_lstm_cell_backward", ""},

View file

@ -117,4 +117,8 @@ bool can_cast(const at::ScalarType from, const at::ScalarType to) {
return at::canCast(from, to);
}
ScalarType promote_types(ScalarType type1, ScalarType type2) {
return promoteTypes(type1, type2);
}
}} // namespace at::native

View file

@ -3734,6 +3734,9 @@
- func: can_cast(ScalarType from, ScalarType to) -> bool
variants: function
- func: promote_types(ScalarType type1, ScalarType type2) -> ScalarType
variants: function
# NB: Does NOT check precondition that numel == 1
- func: _local_scalar_dense(Tensor self) -> Scalar
use_c10_dispatcher: full

View file

@ -361,3 +361,4 @@ Utilities
.. autofunction:: compiled_with_cxx11_abi
.. autofunction:: result_type
.. autofunction:: can_cast
.. autofunction:: promote_types

View file

@ -312,6 +312,11 @@ class TestTypePromotion(TestCase):
self.assertTrue(actual, expected)
self.assertTrue(actual.dtype == torch.bool)
def test_promote_types(self):
self.assertEqual(torch.promote_types(torch.float, torch.int), torch.float)
self.assertEqual(torch.promote_types(torch.float, torch.double), torch.double)
self.assertEqual(torch.promote_types(torch.int, torch.uint8), torch.int)
@unittest.skipIf(not torch.cuda.is_available(), "no cuda")
class TestTypePromotionCuda(TestTypePromotion):
def setUp(self):

View file

@ -355,22 +355,6 @@ static PyObject * THPVariable_from_numpy(PyObject* module, PyObject* arg)
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable__promote_types(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"_promote_types(ScalarType type1, ScalarType type2)",
});
ParsedArgs<2> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
ScalarType promoted = at::promoteTypes(r.scalartype(0), r.scalartype(1));
return torch::autograd::utils::wrap(torch::getDtype(promoted));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static Tensor dispatch_nonzero(const Tensor & self) {
AutoNoGIL no_gil;
OptionalDeviceGuard device_guard(device_of(self));
@ -456,7 +440,6 @@ static PyMethodDef torch_functions[] = {
{"dsmm", (PyCFunction)(void(*)(void))THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"from_numpy", (PyCFunction)THPVariable_from_numpy, METH_STATIC | METH_O, NULL},
{"hsmm", (PyCFunction)(void(*)(void))THPVariable_hspmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"_promote_types", (PyCFunction)(void(*)(void))THPVariable__promote_types, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"nonzero", (PyCFunction)(void(*)(void))THPVariable_nonzero, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"randint", (PyCFunction)(void(*)(void))THPVariable_randint, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"range", (PyCFunction)(void(*)(void))THPVariable_range, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},

View file

@ -3995,6 +3995,27 @@ Example::
tensor([-0.2018, -0.2962, -0.0821, -1.1831])
""".format(**single_dim_common))
add_docstr(torch.promote_types,
r"""
promote_types(type1, type2) -> dtype
Returns the :class:`torch.dtype` with the smallest size and scalar kind that is
not smaller nor of lower kind than either `type1` or `type2`. See type promotion
:ref:`documentation <type-promotion-doc>` for more information on the type
promotion logic.
Args:
type1 (:class:`torch.dtype`)
type2 (:class:`torch.dtype`)
Example::
>>> torch.promote_types(torch.int32, torch.float32))
torch.float32
>>> torch.promote_types(torch.uint8, torch.long)
torch.long
""")
add_docstr(torch.qr,
r"""
qr(input, some=True, out=None) -> (Tensor, Tensor)