diff --git a/functorch/experimental/_cond.py b/functorch/experimental/_cond.py index 8a75300e435..f0cfe5b0e2f 100644 --- a/functorch/experimental/_cond.py +++ b/functorch/experimental/_cond.py @@ -101,16 +101,18 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) +@cond.py_impl(DispatchKey.CUDA) @cond.py_impl(DispatchKey.CPU) def cond_dense(pred, true_fn, false_fn, operands): mode = _get_current_dispatch_mode() - assert (mode is None), "Mode should never be enabled for CPU key" + assert (mode is None), "Mode should never be enabled for CPU/CUDA key" if pred: return true_fn(*operands) else: return false_fn(*operands) +@cond.py_impl(DispatchKey.AutogradCUDA) @cond.py_impl(DispatchKey.AutogradCPU) def cond_autograd(pred, true_fn, false_fn, *operands): # TODO: support autograd diff --git a/functorch/experimental/_map.py b/functorch/experimental/_map.py index 568b2de3884..0eb228f0e65 100644 --- a/functorch/experimental/_map.py +++ b/functorch/experimental/_map.py @@ -57,13 +57,15 @@ def trace_map(proxy_mode, func_overload, f, xs, *args): return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) +@map.py_impl(DispatchKey.CUDA) @map.py_impl(DispatchKey.CPU) def map_cpu(f, xs, *args): mode = _get_current_dispatch_mode() - assert (mode is None), "Mode should never be enabled for CPU key" + assert (mode is None), "Mode should never be enabled for CPU/CUDA key" return torch.stack([f(x, *args) for x in xs]) +@map.py_impl(DispatchKey.AutogradCUDA) @map.py_impl(DispatchKey.AutogradCPU) def map_autograd(f, xs, *args): # TODO: support autograd diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 13bafaaf36a..2b270797b91 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -1,4 +1,6 @@ # Owner(s): ["module: functorch"] +import unittest + import torch from functorch.experimental import control_flow from functorch.experimental.control_flow import cond @@ -20,6 +22,30 @@ class TestControlFlow(TestCase): result = cond(False, true_fn, false_fn, [x]) self.assertEqual(result, torch.cos(x)) + @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") + def test_cond_gpu(self): + def true_fn(x): + return x.sin() + + def false_fn(x): + return x.cos() + + x = torch.randn(4, device="cuda") + pred = torch.tensor(False, device="cuda") + result = cond(False, true_fn, false_fn, [x]) + self.assertEqual(result, torch.cos(x)) + + @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") + def test_map_gpu(self): + def f(x, y): + return x + y + + xs = torch.ones(3, 2, 2, device="cuda") + y = torch.ones(2, device="cuda") + res = control_flow.map(f, xs, y) + + self.assertEqual(res, control_flow.map(f, torch.ones(3, 2, 2), torch.ones(2))) + class TestControlFlowTraced(TestCase): def test_cond_traced_not_nested(self):