From 2c0d607882e5520d86aef27fe510e289f27a351e Mon Sep 17 00:00:00 2001 From: Sergei Vorobev Date: Thu, 18 May 2023 20:29:03 +0000 Subject: [PATCH] [bazel] add build for functorch (#101475) Fixes #101469 Pull Request resolved: https://github.com/pytorch/pytorch/pull/101475 Approved by: https://github.com/ezyang --- BUILD.bazel | 36 ++++++++++++++++++++++++++++++++++-- test/_test_bazel.py | 30 ++++++++++++++++++++++++++++++ test/test_bazel.py | 15 --------------- torch/version.py.tpl | 2 ++ 4 files changed, 66 insertions(+), 17 deletions(-) create mode 100644 test/_test_bazel.py delete mode 100644 test/test_bazel.py diff --git a/BUILD.bazel b/BUILD.bazel index c7b89b233c9..992353d5258 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1697,6 +1697,36 @@ pybind_extension( ], ) +cc_library( + name = "functorch", + hdrs = glob([ + "functorch/csrc/dim/*.h", + ]), + srcs = glob([ + "functorch/csrc/dim/*.cpp", + ]), + deps = [ + ":aten_nvrtc", + ":torch_python", + "@pybind11", + ], +) + +pybind_extension( + name = "functorch/_C", + copts=[ + "-DTORCH_EXTENSION_NAME=_C" + ], + srcs = [ + "functorch/csrc/init_dim_only.cpp", + ], + deps = [ + ":functorch", + ":torch_python", + ":aten_nvrtc", + ], +) + cc_binary( name = "torch/bin/torch_shm_manager", srcs = [ @@ -1725,7 +1755,7 @@ template_rule( rules.py_library( name = "pytorch_py", visibility = ["//visibility:public"], - srcs = glob(["torch/**/*.py"], exclude = ["torch/version.py"]) + [":torch/version.py"], + srcs = glob(["torch/**/*.py"], exclude = ["torch/version.py"]) + [":torch/version.py"] + glob(["functorch/**/*.py"]), deps = [ rules.requirement("future"), rules.requirement("numpy"), @@ -1738,6 +1768,7 @@ rules.py_library( ], data = [ ":torch/_C.so", + ":functorch/_C.so", ":torch/bin/torch_shm_manager", ], ) @@ -1904,7 +1935,8 @@ cc_test( py_test( name = "test_bazel", - srcs = ["test/test_bazel.py"], + srcs = ["test/_test_bazel.py"], + main = "test/_test_bazel.py", deps = [":pytorch_py"], ) diff --git a/test/_test_bazel.py b/test/_test_bazel.py new file mode 100644 index 00000000000..9c3bb6f87b5 --- /dev/null +++ b/test/_test_bazel.py @@ -0,0 +1,30 @@ +# Owner(s): ["module: bazel"] + +""" +This test module contains a minimalistic "smoke tests" for the bazel build. + +Currently it doesn't use any testing framework (i.e. pytest) +TODO: integrate this into the existing pytorch testing framework. + +The name uses underscore `_test_bazel.py` to avoid globbing into other non-bazel configurations. +""" + +import torch + +def test_sum() -> None: + assert torch.eq(torch.tensor([[1, 2, 3]]) + torch.tensor([[4, 5, 6]]), torch.tensor([[5, 7, 9]])).all() + +def test_simple_compile_eager() -> None: + + def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + a = torch.sin(x) + b = torch.cos(y) + return a + b + + opt_foo1 = torch.compile(foo, backend="eager") + # just check that we can run without raising an Exception + assert opt_foo1(torch.randn(10, 10), torch.randn(10, 10)) is not None + + +test_sum() +test_simple_compile_eager() diff --git a/test/test_bazel.py b/test/test_bazel.py deleted file mode 100644 index 645f7f98351..00000000000 --- a/test/test_bazel.py +++ /dev/null @@ -1,15 +0,0 @@ -# Owner(s): ["module: bazel"] - -""" -This test module contains a minimalistic "smoke tests" for the bazel build. - -Currently it doesn't use any testing framework (i.e. pytest) -TODO: integrate this into the existing pytorch testing framework. -""" - -import torch - -def test_sum(): - assert torch.eq(torch.tensor([[1, 2, 3]]) + torch.tensor([[4, 5, 6]]), torch.tensor([[5, 7, 9]])).all() - -test_sum() diff --git a/torch/version.py.tpl b/torch/version.py.tpl index 26b80ac323b..1b7eab07ac9 100644 --- a/torch/version.py.tpl +++ b/torch/version.py.tpl @@ -1,6 +1,8 @@ __version__ = '{{VERSION}}' debug = False cuda = '{{CUDA_VERSION}}' +# TODO: use workspace status to stamp the correct version +git_version = "" hip = None # This is a gross monkey-patch hack that depends on the order of imports