diff --git a/test/test_utils.py b/test/test_utils.py index e119faf9573..5c38dda0748 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -5,16 +5,18 @@ import shutil import random import tempfile import unittest -import torch -import torch.cuda import sys import traceback +import torch +import torch.cuda from torch.autograd import Variable from torch.utils.trainer import Trainer from torch.utils.trainer.plugins import * from torch.utils.trainer.plugins.plugin import Plugin from torch.utils.data import * +HAS_CUDA = torch.cuda.is_available() + from common import TestCase try: @@ -326,7 +328,7 @@ class TestFFI(TestCase): self.assertRaises(torch.FatalError, lambda: cpulib.bad_func(tensor, 2, 1.5)) - @unittest.skipIf(not HAS_CFFI, "ffi tests require cffi package") + @unittest.skipIf(not HAS_CFFI or not HAS_CUDA, "ffi tests require cffi package") def test_gpu(self): compile_extension( name='gpulib',