update type() calling to not use unneeded device (#110163)

Previous code path was doing an unnecessary cuda init as well as causing an unnecessary "device" to occur in the jit trace.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110163
Approved by: https://github.com/henryhu6, https://github.com/albanD
This commit is contained in:
eellison 2023-09-27 16:17:30 +00:00 committed by PyTorch MergeBot
parent 7f5fd92372
commit 2a246c5259
2 changed files with 22 additions and 7 deletions

View file

@ -1010,6 +1010,21 @@ class TestTracer(JitTestCase):
self.assertEqual(out, out_state)
self.assertNotEqual(out, out_ones)
@unittest.skipIf(not RUN_CUDA, "uses cuda")
def test_type_same_device(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.dtype = torch.float16
def forward(self, x=None):
h = x.type(self.dtype)
return h
a = Model()
b = torch.jit.trace(a, example_inputs=(torch.ones([1], device=torch.device("cuda")),))
FileCheck().check_not("device").run(b.code)
def test_export_no_reorder(self):
def func(a, b):
return a * b / (a - 2 * b) + b

View file

@ -1067,13 +1067,13 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa
Device device = self_.device();
if (is_dtype) {
scalar_type = r.scalartype(0);
} else {
at::TensorOptions options = torch::utils::options_from_string(type_name);
scalar_type = at::typeMetaToScalarType(options.dtype());
auto device_type = options.device().type();
if (device_type != device.type()) {
device = at::Device(device_type);
}
return THPVariable_Wrap(dispatch_to(self_, scalar_type, /*non_blocking=*/ r.toBool(1), /*copy=*/ false, opt_memory_format));
}
at::TensorOptions options = torch::utils::options_from_string(type_name);
scalar_type = at::typeMetaToScalarType(options.dtype());
auto device_type = options.device().type();
if (device_type != device.type()) {
device = at::Device(device_type);
}
if (device.is_cuda()) {
torch::utils::cuda_lazy_init();