mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Revert "Disable flaky test TestCppExtensionAOT.test_cuda_extension in… (#33404)
Summary:
… Windows CI (https://github.com/pytorch/pytorch/issues/33282)"
This reverts commit 5b922918d0.
Fixes https://github.com/pytorch/pytorch/issues/33270.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33404
Differential Revision: D19972594
Pulled By: ezyang
fbshipit-source-id: c8f67536fd6e4b7135171d621ad671b1b2a21fd4
This commit is contained in:
parent
05fb160048
commit
ffe327f7d9
2 changed files with 6 additions and 10 deletions
|
|
@ -55,7 +55,6 @@ class TestCppExtensionAOT(common.TestCase):
|
|||
self.assertEqual(tensor.grad, expected_tensor_grad)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
@unittest.skipIf(IS_WINDOWS, "Flaky on Windows, see issue #33270")
|
||||
def test_cuda_extension(self):
|
||||
import torch_test_cpp_extension.cuda as cuda_extension
|
||||
|
||||
|
|
|
|||
|
|
@ -454,13 +454,13 @@ class BuildExtension(build_ext, object):
|
|||
include_dirs, sources,
|
||||
depends, extra_postargs)
|
||||
common_cflags = extra_preargs or []
|
||||
common_cflags.append('/c')
|
||||
cflags = []
|
||||
if debug:
|
||||
common_cflags.extend(self.compiler.compile_options_debug)
|
||||
cflags.extend(self.compiler.compile_options_debug)
|
||||
else:
|
||||
common_cflags.extend(self.compiler.compile_options)
|
||||
cflags.extend(self.compiler.compile_options)
|
||||
common_cflags.extend(COMMON_MSVC_FLAGS)
|
||||
cflags = common_cflags + pp_opts
|
||||
cflags = cflags + common_cflags + pp_opts
|
||||
with_cuda = any(map(_is_cuda_file, sources))
|
||||
|
||||
# extra_postargs can be either:
|
||||
|
|
@ -477,11 +477,8 @@ class BuildExtension(build_ext, object):
|
|||
if with_cuda:
|
||||
cuda_cflags = []
|
||||
for common_cflag in common_cflags:
|
||||
if common_cflag == '/c':
|
||||
cuda_cflags.append('-c')
|
||||
else:
|
||||
cuda_cflags.append('-Xcompiler')
|
||||
cuda_cflags.append(common_cflag)
|
||||
cuda_cflags.append('-Xcompiler')
|
||||
cuda_cflags.append(common_cflag)
|
||||
cuda_cflags.extend(pp_opts)
|
||||
if isinstance(extra_postargs, dict):
|
||||
cuda_post_cflags = extra_postargs['nvcc']
|
||||
|
|
|
|||
Loading…
Reference in a new issue