[ROCm] Enable/fix unit tests test_stream_args and test_event_args (#82346)

### Description
Removed some stubbed out code that was necessary for ROCm builds to support JIT compilation of Event and Stream classes. Original motivation for the code to be stubbed out in the ROCm case was likely due to this pull request:
https://github.com/pytorch/pytorch/pull/48020
In this PR, the include statement at the at the top of cuda.h was incorrectly pointed to aten/src/ATen/cuda/CUDAEvent.h when it should have been set to ATen/cuda/CUDAEvent.h. This error caused the hipification process of build_amd.py to not hipify this include statement correctly, causing errors. The include statement in question was subsequently fixed in the following commit:
acd072967a

This PR re-introduces the stubbed out code to the ROCm build and "unskips" the associated unit tests.

### Testing
Note: bullets prepended by ROCm were tested on systems with AMD GPUs while the others were tested with NVIDIA GPUs.
- apply commit
- (ROCm)`python tools/amd_build/build_amd.py`
- `python setup.py develop`
- (ROCm)`PYTORCH_TEST_WITH_ROCM=1 python test/test_jit.py TestCUDA.test_event_args`
- (ROCm)`PYTORCH_TEST_WITH_ROCM=1 python test/test_jit.py TestCUDA.test_stream_args`
- `python test/test_jit.py TestCUDA.test_event_args`
- `python test/test_jit.py TestCUDA.test_stream_args`
- Confirm tests pass in all scenarios

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82346
Approved by: https://github.com/malfet
This commit is contained in:
Andres Lugo-Reyes 2022-08-01 22:55:15 +00:00 committed by PyTorch MergeBot
parent 6ca95547ac
commit f1a1356907
4 changed files with 0 additions and 10 deletions

View file

@ -89,7 +89,6 @@ class TestCUDA(JitTestCase):
FileCheck().check("cuda::synchronize(") \
.run(test_multi_device_synchronize.graph)
@skipIfRocm
def test_stream_args(self):
# Test stream creation with default arguments
@torch.jit.script
@ -119,7 +118,6 @@ class TestCUDA(JitTestCase):
self.assertTrue(stream_default_args_for_priority)
self.assertTrue(stream_args_all)
@skipIfRocm
def test_event_args(self):
# Test Event creation with default arguments
@torch.jit.script

View file

@ -219,7 +219,6 @@ std::shared_ptr<SugaredValue> PythonModuleValue::attr(
return toSugaredValue(member, m, loc, /*is_constant=*/true);
}
#if !defined(USE_ROCM)
std::shared_ptr<SugaredValue> CUDAPythonModuleValue::attr(
const SourceRange& loc,
GraphFunction& m,
@ -259,7 +258,6 @@ std::shared_ptr<SugaredValue> CUDAPythonModuleValue::attr(
// even though it is possible, though rare, for someone to mutate them
return toSugaredValue(member, m, loc, /*is_constant=*/true);
}
#endif
Value* ModuleValue::asValue(const SourceRange& loc, GraphFunction& m) {
return self_;
@ -1199,12 +1197,10 @@ std::shared_ptr<SugaredValue> toSugaredValue(
if (auto callee = as_function(obj)) {
return std::make_shared<FunctionValue>(callee->function_);
} else if (py::isinstance<py::module>(obj)) {
#ifndef USE_ROCM
std::string obj_name = py::cast<py::str>(py::getattr(obj, "__name__"));
if (obj_name.compare("torch.cuda") == 0) {
return std::make_shared<CUDAPythonModuleValue>(obj);
}
#endif
return std::make_shared<PythonModuleValue>(obj);
} else if (
obj.ptr() == py::module::import("torch.jit").attr("_fork").ptr() ||

View file

@ -96,7 +96,6 @@ struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue {
// Used for desugaring uses of the torch.cuda module. All the CUDA APIs with
// torch.cuda.* are resolved using CUDAPythonModuleValue.
#if !defined(USE_ROCM)
struct VISIBILITY_HIDDEN CUDAPythonModuleValue : public PythonValue {
explicit CUDAPythonModuleValue(py::object mod)
: PythonValue(std::move(mod)) {}
@ -106,7 +105,6 @@ struct VISIBILITY_HIDDEN CUDAPythonModuleValue : public PythonValue {
GraphFunction& m,
const std::string& field) override;
};
#endif
// Represents all the parameters of a module as a List[Tensor]
struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {

View file

@ -1,6 +1,5 @@
// This file registers special JIT operators used to implement the PyTorch CUDA
// API in TorchScript.
#if !defined(USE_ROCM)
#include <torch/csrc/api/include/torch/utils.h>
#include <torch/csrc/jit/cuda/cuda.h>
#include <torch/csrc/jit/ir/ir.h>
@ -167,4 +166,3 @@ RegisterOperators const reg({
} // namespace
} // namespace jit
} // namespace torch
#endif