mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
update type() calling to not use unneeded device (#110163)
Previous code path was doing an unnecessary cuda init as well as causing an unnecessary "device" to occur in the jit trace. Pull Request resolved: https://github.com/pytorch/pytorch/pull/110163 Approved by: https://github.com/henryhu6, https://github.com/albanD
This commit is contained in:
parent
7f5fd92372
commit
2a246c5259
2 changed files with 22 additions and 7 deletions
|
|
@ -1010,6 +1010,21 @@ class TestTracer(JitTestCase):
|
|||
self.assertEqual(out, out_state)
|
||||
self.assertNotEqual(out, out_ones)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "uses cuda")
|
||||
def test_type_same_device(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dtype = torch.float16
|
||||
|
||||
def forward(self, x=None):
|
||||
h = x.type(self.dtype)
|
||||
return h
|
||||
|
||||
a = Model()
|
||||
b = torch.jit.trace(a, example_inputs=(torch.ones([1], device=torch.device("cuda")),))
|
||||
FileCheck().check_not("device").run(b.code)
|
||||
|
||||
def test_export_no_reorder(self):
|
||||
def func(a, b):
|
||||
return a * b / (a - 2 * b) + b
|
||||
|
|
|
|||
|
|
@ -1067,13 +1067,13 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa
|
|||
Device device = self_.device();
|
||||
if (is_dtype) {
|
||||
scalar_type = r.scalartype(0);
|
||||
} else {
|
||||
at::TensorOptions options = torch::utils::options_from_string(type_name);
|
||||
scalar_type = at::typeMetaToScalarType(options.dtype());
|
||||
auto device_type = options.device().type();
|
||||
if (device_type != device.type()) {
|
||||
device = at::Device(device_type);
|
||||
}
|
||||
return THPVariable_Wrap(dispatch_to(self_, scalar_type, /*non_blocking=*/ r.toBool(1), /*copy=*/ false, opt_memory_format));
|
||||
}
|
||||
at::TensorOptions options = torch::utils::options_from_string(type_name);
|
||||
scalar_type = at::typeMetaToScalarType(options.dtype());
|
||||
auto device_type = options.device().type();
|
||||
if (device_type != device.type()) {
|
||||
device = at::Device(device_type);
|
||||
}
|
||||
if (device.is_cuda()) {
|
||||
torch::utils::cuda_lazy_init();
|
||||
|
|
|
|||
Loading…
Reference in a new issue