2023-06-21 20:54:17 +00:00
Thank you for your interest in contributing to PyTorch!
If you're a new contributor, please first take a read through our
2023-09-22 02:01:20 +00:00
[Contributing Guide ](https://github.com/pytorch/pytorch/wiki/The-Ultimate-Guide-to-PyTorch-Contributions ), specifically the [Submitting a Change ](https://github.com/pytorch/pytorch/wiki/The-Ultimate-Guide-to-PyTorch-Contributions#submitting-a-change ) section
2023-06-21 20:54:17 +00:00
that walks through the process of contributing a change to PyTorch.
The rest of this document (CONTRIBUTING.md) covers some of the more technical
aspects of contributing to PyTorch.
2020-07-08 02:59:12 +00:00
# Table of Contents
2020-04-24 17:53:55 +00:00
2021-02-01 22:25:44 +00:00
<!-- toc -->
2021-03-31 17:34:38 +00:00
2021-02-01 22:25:44 +00:00
- [Developing PyTorch ](#developing-pytorch )
2024-07-19 08:38:29 +00:00
- [Setup the development environment ](#setup-the-development-environment )
2021-02-01 22:25:44 +00:00
- [Tips and Debugging ](#tips-and-debugging )
- [Nightly Checkout & Pull ](#nightly-checkout--pull )
- [Codebase structure ](#codebase-structure )
- [Unit testing ](#unit-testing )
2021-04-01 17:15:06 +00:00
- [Python Unit Testing ](#python-unit-testing )
- [Better local unit tests with `pytest` ](#better-local-unit-tests-with-pytest )
2021-04-20 19:14:37 +00:00
- [Local linting ](#local-linting )
2022-06-02 19:31:48 +00:00
- [Running `mypy` ](#running-mypy )
2021-04-01 17:15:06 +00:00
- [C++ Unit Testing ](#c-unit-testing )
2021-05-11 18:57:19 +00:00
- [Run Specific CI Jobs ](#run-specific-ci-jobs )
2022-09-23 18:23:34 +00:00
- [Merging your Change ](#merging-your-change )
2021-02-01 22:25:44 +00:00
- [Writing documentation ](#writing-documentation )
2022-08-17 14:53:02 +00:00
- [Docstring type formatting ](#docstring-type-formatting )
2021-02-01 22:25:44 +00:00
- [Building documentation ](#building-documentation )
- [Tips ](#tips )
- [Building C++ Documentation ](#building-c-documentation )
2021-07-01 19:16:24 +00:00
- [Previewing changes locally ](#previewing-changes-locally )
- [Previewing documentation on PRs ](#previewing-documentation-on-prs )
2021-02-01 22:25:44 +00:00
- [Adding documentation tests ](#adding-documentation-tests )
- [Profiling with `py-spy` ](#profiling-with-py-spy )
- [Managing multiple build trees ](#managing-multiple-build-trees )
- [C++ development tips ](#c-development-tips )
- [Build only what you need ](#build-only-what-you-need )
- [Code completion and IDE support ](#code-completion-and-ide-support )
- [Make no-op build fast ](#make-no-op-build-fast )
- [Use Ninja ](#use-ninja )
- [Use CCache ](#use-ccache )
- [Use a faster linker ](#use-a-faster-linker )
2021-08-17 17:11:05 +00:00
- [Use pre-compiled headers ](#use-pre-compiled-headers )
2021-10-11 16:04:07 +00:00
- [Workaround for header dependency bug in nvcc ](#workaround-for-header-dependency-bug-in-nvcc )
[DevX] Add tool and doc on partial debug builds (#116521)
Turned command sequence mentioned in https://dev-discuss.pytorch.org/t/how-to-get-a-fast-debug-build/1597 and in various discussions into a tool that I use almost daily to debug crashes or correctness issues in the codebase
Essentially it allows one to turn this:
```
Process 87729 stopped
* thread #1, queue = 'com.apple.main-thread', stop reason = breakpoint 1.1
frame #0: 0x00000001023d55a8 libtorch_python.dylib`at::indexing::impl::applySelect(at::Tensor const&, long long, c10::SymInt, long long, c10::Device const&, std::__1::optional<c10::ArrayRef<c10::SymInt>> const&)
libtorch_python.dylib`at::indexing::impl::applySelect:
-> 0x1023d55a8 <+0>: sub sp, sp, #0xd0
0x1023d55ac <+4>: stp x24, x23, [sp, #0x90]
0x1023d55b0 <+8>: stp x22, x21, [sp, #0xa0]
0x1023d55b4 <+12>: stp x20, x19, [sp, #0xb0]
```
into this
```
Process 87741 stopped
* thread #1, queue = 'com.apple.main-thread', stop reason = breakpoint 1.1
frame #0: 0x00000001024e2628 libtorch_python.dylib`at::indexing::impl::applySelect(self=0x00000001004ee8a8, dim=0, index=(data_ = 3), real_dim=0, (null)=0x000000016fdfe535, self_sizes= Has Value=true ) at TensorIndexing.h:239:7
236 const at::Device& /*self_device*/,
237 const c10::optional<SymIntArrayRef>& self_sizes) {
238 // See NOTE [nested tensor size for indexing]
-> 239 if (self_sizes.has_value()) {
240 auto maybe_index = index.maybe_as_int();
241 if (maybe_index.has_value()) {
242 TORCH_CHECK_INDEX(
```
while retaining good performance for the rest of the codebase
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116521
Approved by: https://github.com/atalman
2023-12-29 05:15:35 +00:00
- [Rebuild few files with debug information ](#rebuild-few-files-with-debug-information )
2021-02-01 22:25:44 +00:00
- [C++ frontend development tips ](#c-frontend-development-tips )
gdb special command to print tensors (#54339)
Summary:
This is something which I wrote because it was useful during my debugging sessions, but I think it might be generally useful to other people as well so I took the liberty of proposing an official `pytorch-gdb` extension.
`pytorch-gdb` is a gdb script written in python. Currently, it contains only one command: `torch-tensor-repr`, which prints a human-readable repr of an `at::Tensor` object. Example:
```
Breakpoint 1, at::native::neg (self=...) at [...]/pytorch/aten/src/ATen/native/UnaryOps.cpp:520
520 Tensor neg(const Tensor& self) { return unary_op_impl(self, at::neg_out); }
(gdb) # the default repr of 'self' is not very useful
(gdb) p self
$1 = (const at::Tensor &) 0x7ffff72ed780: {impl_ = {target_ = 0x5555559df6e0}}
(gdb) torch-tensor-repr self
Python-level repr of self:
tensor([1., 2., 3., 4.], dtype=torch.float64)
```
The idea is that by having an official place where to put these things, `pytorch-gdb` will slowly grow other useful features and make the pytorch debugging experience nicer and faster.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54339
Reviewed By: bdhirsh
Differential Revision: D27253674
Pulled By: ezyang
fbshipit-source-id: dba219e126cc2fe66b2d26740f3a8e3b886e56f5
2021-03-23 19:27:51 +00:00
- [GDB integration ](#gdb-integration )
2021-09-23 05:51:44 +00:00
- [C++ stacktraces ](#c-stacktraces )
2021-02-01 22:25:44 +00:00
- [CUDA development tips ](#cuda-development-tips )
- [Windows development tips ](#windows-development-tips )
- [Known MSVC (and MSVC with NVCC) bugs ](#known-msvc-and-msvc-with-nvcc-bugs )
2021-05-07 00:34:48 +00:00
- [Building on legacy code and CUDA ](#building-on-legacy-code-and-cuda )
2021-02-01 22:25:44 +00:00
- [Pre-commit tidy/linting hook ](#pre-commit-tidylinting-hook )
- [Building PyTorch with ASAN ](#building-pytorch-with-asan )
- [Getting `ccache` to work ](#getting-ccache-to-work )
- [Why this stuff with `LD_PRELOAD` and `LIBASAN_RT`? ](#why-this-stuff-with-ld_preload-and-libasan_rt )
- [Why LD_PRELOAD in the build function? ](#why-ld_preload-in-the-build-function )
- [Why no leak detection? ](#why-no-leak-detection )
- [Caffe2 notes ](#caffe2-notes )
- [CI failure tips ](#ci-failure-tips )
2021-03-30 18:46:10 +00:00
- [Which commit is used in CI? ](#which-commit-is-used-in-ci )
2022-05-20 19:44:38 +00:00
- [Dev Infra Office Hours ](#dev-infra-office-hours )
2021-03-31 17:34:38 +00:00
<!-- tocstop -->
2019-12-03 00:42:41 +00:00
2018-08-14 03:45:16 +00:00
## Developing PyTorch
2024-07-19 08:38:29 +00:00
2023-01-19 22:14:28 +00:00
Follow the instructions for [installing PyTorch from source ](https://github.com/pytorch/pytorch#from-source ). If you get stuck when developing PyTorch on your machine, check out the [tips and debugging ](#tips-and-debugging ) section below for common solutions.
2017-03-17 11:59:37 +00:00
2024-07-19 08:38:29 +00:00
### Setup the development environment
First, you need to [fork the PyTorch project on GitHub ](https://github.com/pytorch/pytorch/fork ) and follow the instructions at [Connecting to GitHub with SSH ](https://docs.github.com/en/authentication/connecting-to-github-with-ssh ) to setup your SSH authentication credentials.
Then clone the PyTorch project and setup the development environment:
```bash
git clone git@github.com:< USERNAME > /pytorch.git
cd pytorch
2024-07-25 05:38:40 +00:00
git remote add upstream git@github.com:pytorch/pytorch.git
2024-07-19 08:38:29 +00:00
2024-12-26 08:48:42 +00:00
make setup-env
# Or run `make setup-env-cuda` for pre-built CUDA binaries
# Or run `make setup-env-rocm` for pre-built ROCm binaries
Refactor nightly pull tool to use `venv` and `pip` (#141281)
Resolves #141238
- #141238
Example output:
```console
$ python3.12 tools/nightly.py checkout -b my-nightly-branch -p my-env --python python3.10
log file: /Users/PanXuehai/Projects/pytorch/nightly/log/2024-11-22_04h15m45s_63f8b29e-a845-11ef-bbf9-32c784498a7b/nightly.log
Creating virtual environment
Creating venv (Python 3.10.15): /Users/PanXuehai/Projects/pytorch/my-env
Installing packages
Upgrading package(s) (https://download.pytorch.org/whl/nightly/cpu): pip, setuptools, wheel
Installing packages took 5.576 [s]
Creating virtual environment took 9.505 [s]
Downloading packages
Downloading package(s) (https://download.pytorch.org/whl/nightly/cpu): torch
Downloaded 9 file(s) to /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/pip-download-lty5dvz4:
- mpmath-1.3.0-py3-none-any.whl
- torch-2.6.0.dev20241121-cp310-none-macosx_11_0_arm64.whl
- jinja2-3.1.4-py3-none-any.whl
- sympy-1.13.1-py3-none-any.whl
- MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl
- networkx-3.4.2-py3-none-any.whl
- fsspec-2024.10.0-py3-none-any.whl
- filelock-3.16.1-py3-none-any.whl
- typing_extensions-4.12.2-py3-none-any.whl
Downloading packages took 7.628 [s]
Installing dependencies
Installing packages
Installing package(s) (https://download.pytorch.org/whl/nightly/cpu): numpy, cmake, ninja, packaging, ruff, mypy, pytest, hypothesis, ipython, rich, clang-format, clang-tidy, sphinx, mpmath-1.3.0-py3-none-any.whl, jinja2-3.1.4-py3-none-any.whl, sympy-1.13.1-py3-none-any.whl, MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl, networkx-3.4.2-py3-none-any.whl, fsspec-2024.10.0-py3-none-any.whl, filelock-3.16.1-py3-none-any.whl, typing_extensions-4.12.2-py3-none-any.whl
Installing packages took 42.514 [s]
Installing dependencies took 42.515 [s]
Unpacking wheel file
Unpacking wheel file took 3.223 [s]
Checking out nightly PyTorch
Found released git version ac47a2d9714278889923ddd40e4210d242d8d4ee
Found nightly release version e0482fdf95eb3ce679fa442b50871d113ceb673b
Switched to a new branch 'my-nightly-branch'
Checking out nightly PyTorch took 0.198 [s]
Moving nightly files into repo
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/_C.cpython-310-darwin.so -> /Users/PanXuehai/Projects/pytorch/torch/_C.cpython-310-darwin.so
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/lib/libtorch_python.dylib -> /Users/PanXuehai/Projects/pytorch/torch/lib/libtorch_python.dylib
...
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/include/c10/macros/Macros.h -> /Users/PanXuehai/Projects/pytorch/torch/include/c10/macros/Macros.h
Moving nightly files into repo took 11.426 [s]
Writing pytorch-nightly.pth
Writing pytorch-nightly.pth took 0.036 [s]
-------
PyTorch Development Environment set up!
Please activate to enable this environment:
$ source /Users/PanXuehai/Projects/pytorch/my-env/bin/activate
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141281
Approved by: https://github.com/seemethere
2024-11-22 17:46:42 +00:00
source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
2024-07-19 08:38:29 +00:00
```
2023-01-19 22:14:28 +00:00
### Tips and Debugging
2019-03-29 23:02:02 +00:00
2023-01-19 22:14:28 +00:00
* If you want to have no-op incremental rebuilds (which are fast), see [Make no-op build fast ](#make-no-op-build-fast ) below.
2017-03-17 11:59:37 +00:00
2023-05-26 14:11:08 +00:00
* When installing with `python setup.py develop` (in contrast to `python setup.py install` ) Python runtime will use
the current local source-tree when importing `torch` package. (This is done by creating [`.egg-link` ](https://wiki.python.org/moin/PythonPackagingTerminology#egg-link ) file in `site-packages` folder)
2023-01-19 22:14:28 +00:00
This way you do not need to repeatedly install after modifying Python files (`.py`).
However, you would need to reinstall if you modify Python interface (`.pyi`, `.pyi.in` ) or
non-Python files (`.cpp`, `.cc` , `.cu` , `.h` , ...).
2017-03-17 11:59:37 +00:00
2023-05-26 14:11:08 +00:00
One way to avoid running `python setup.py develop` every time one makes a change to C++/CUDA/ObjectiveC files on Linux/Mac,
is to create a symbolic link from `build` folder to `torch/lib` , for example, by issuing following:
```bash
pushd torch/lib; sh -c "ln -sf ../../build/lib/libtorch_cpu.* ."; popd
```
Afterwards rebuilding a library (for example to rebuild `libtorch_cpu.so` issue `ninja torch_cpu` from `build` folder),
would be sufficient to make change visible in `torch` package.
2023-01-19 22:14:28 +00:00
To reinstall, first uninstall all existing PyTorch installs. You may need to run `pip
uninstall torch` multiple times. You'll know `torch` is fully
uninstalled when you see `WARNING: Skipping torch as it is not
installed`. (You should only have to `pip uninstall` a few times, but
you can always `uninstall` with `timeout` or in a loop if you're feeling
lazy.)
2017-03-17 11:59:37 +00:00
2023-01-19 22:14:28 +00:00
```bash
conda uninstall pytorch -y
yes | pip uninstall torch
```
2017-03-17 11:59:37 +00:00
2023-01-19 22:14:28 +00:00
Next run `python setup.py clean` . After that, you can install in `develop` mode again.
2018-07-18 12:10:38 +00:00
2020-10-06 17:38:06 +00:00
* If you run into errors when running `python setup.py develop` , here are some debugging steps:
1. Run `printf '#include <stdio.h>\nint main() { printf("Hello World");}'|clang -x c -; ./a.out` to make sure
your CMake works and can compile this simple Hello World program without errors.
2. Nuke your `build` directory. The `setup.py` script compiles binaries into the `build` folder and caches many
details along the way, which saves time the next time you build. If you're running into issues, you can always
`rm -rf build` from the toplevel `pytorch` directory and start over.
3. If you have made edits to the PyTorch repo, commit any change you'd like to keep and clean the repo with the
following commands (note that clean _really_ removes all untracked files and changes.):
2023-01-19 22:14:28 +00:00
```bash
git submodule deinit -f .
git clean -xdf
python setup.py clean
git submodule update --init --recursive # very important to sync the submodules
python setup.py develop # then try running the command again
```
2020-10-06 17:38:06 +00:00
4. The main step within `python setup.py develop` is running `make` from the `build` directory. If you want to
2023-01-19 22:14:28 +00:00
experiment with some environment variables, you can pass them into the command:
```bash
ENV_KEY1=ENV_VAL1[, ENV_KEY2=ENV_VAL2]* python setup.py develop
```
2022-12-20 02:17:02 +00:00
* If you run into issue running `git submodule update --init --recursive` . Please try the following:
2022-06-02 19:31:48 +00:00
- If you encounter an error such as
2021-02-01 22:25:44 +00:00
```
error: Submodule 'third_party/pybind11' could not be updated
```
check whether your Git local or global config file contains any `submodule.*` settings. If yes, remove them and try again.
(please reference [this doc ](https://git-scm.com/docs/git-config#Documentation/git-config.txt-submoduleltnamegturl ) for more info).
2022-06-02 19:31:48 +00:00
- If you encounter an error such as
2021-02-01 22:25:44 +00:00
```
fatal: unable to access 'https://github.com/pybind11/pybind11.git': could not load PEM client certificate ...
```
this is likely that you are using HTTP proxying and the certificate expired. To check if the certificate is valid, run
`git config --global --list` and search for config like `http.proxysslcert=<cert_file>` . Then check certificate valid date by running
2021-09-08 02:00:18 +00:00
```bash
2021-02-01 22:25:44 +00:00
openssl x509 -noout -in < cert_file > -dates
```
2022-06-02 19:31:48 +00:00
- If you encounter an error that some third_party modules are not checked out correctly, such as
2021-02-01 22:25:44 +00:00
```
Could not find .../pytorch/third_party/pybind11/CMakeLists.txt
```
remove any `submodule.*` settings in your local git config (`.git/config` of your pytorch repo) and try again.
2021-08-24 17:50:57 +00:00
* If you're a Windows contributor, please check out [Best Practices ](https://github.com/pytorch/pytorch/wiki/Best-Practices-to-Edit-and-Compile-Pytorch-Source-Code-On-Windows ).
2022-05-20 19:44:38 +00:00
* For help with any part of the contributing process, please don’ t hesitate to utilize our Zoom office hours! See details [here ](https://github.com/pytorch/pytorch/wiki/Dev-Infra-Office-Hours )
2020-10-06 17:38:06 +00:00
2020-08-20 15:32:35 +00:00
## Nightly Checkout & Pull
2020-08-14 03:04:37 +00:00
2020-08-25 19:03:47 +00:00
The `tools/nightly.py` script is provided to ease pure Python development of
Refactor nightly pull tool to use `venv` and `pip` (#141281)
Resolves #141238
- #141238
Example output:
```console
$ python3.12 tools/nightly.py checkout -b my-nightly-branch -p my-env --python python3.10
log file: /Users/PanXuehai/Projects/pytorch/nightly/log/2024-11-22_04h15m45s_63f8b29e-a845-11ef-bbf9-32c784498a7b/nightly.log
Creating virtual environment
Creating venv (Python 3.10.15): /Users/PanXuehai/Projects/pytorch/my-env
Installing packages
Upgrading package(s) (https://download.pytorch.org/whl/nightly/cpu): pip, setuptools, wheel
Installing packages took 5.576 [s]
Creating virtual environment took 9.505 [s]
Downloading packages
Downloading package(s) (https://download.pytorch.org/whl/nightly/cpu): torch
Downloaded 9 file(s) to /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/pip-download-lty5dvz4:
- mpmath-1.3.0-py3-none-any.whl
- torch-2.6.0.dev20241121-cp310-none-macosx_11_0_arm64.whl
- jinja2-3.1.4-py3-none-any.whl
- sympy-1.13.1-py3-none-any.whl
- MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl
- networkx-3.4.2-py3-none-any.whl
- fsspec-2024.10.0-py3-none-any.whl
- filelock-3.16.1-py3-none-any.whl
- typing_extensions-4.12.2-py3-none-any.whl
Downloading packages took 7.628 [s]
Installing dependencies
Installing packages
Installing package(s) (https://download.pytorch.org/whl/nightly/cpu): numpy, cmake, ninja, packaging, ruff, mypy, pytest, hypothesis, ipython, rich, clang-format, clang-tidy, sphinx, mpmath-1.3.0-py3-none-any.whl, jinja2-3.1.4-py3-none-any.whl, sympy-1.13.1-py3-none-any.whl, MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl, networkx-3.4.2-py3-none-any.whl, fsspec-2024.10.0-py3-none-any.whl, filelock-3.16.1-py3-none-any.whl, typing_extensions-4.12.2-py3-none-any.whl
Installing packages took 42.514 [s]
Installing dependencies took 42.515 [s]
Unpacking wheel file
Unpacking wheel file took 3.223 [s]
Checking out nightly PyTorch
Found released git version ac47a2d9714278889923ddd40e4210d242d8d4ee
Found nightly release version e0482fdf95eb3ce679fa442b50871d113ceb673b
Switched to a new branch 'my-nightly-branch'
Checking out nightly PyTorch took 0.198 [s]
Moving nightly files into repo
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/_C.cpython-310-darwin.so -> /Users/PanXuehai/Projects/pytorch/torch/_C.cpython-310-darwin.so
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/lib/libtorch_python.dylib -> /Users/PanXuehai/Projects/pytorch/torch/lib/libtorch_python.dylib
...
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/include/c10/macros/Macros.h -> /Users/PanXuehai/Projects/pytorch/torch/include/c10/macros/Macros.h
Moving nightly files into repo took 11.426 [s]
Writing pytorch-nightly.pth
Writing pytorch-nightly.pth took 0.036 [s]
-------
PyTorch Development Environment set up!
Please activate to enable this environment:
$ source /Users/PanXuehai/Projects/pytorch/my-env/bin/activate
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141281
Approved by: https://github.com/seemethere
2024-11-22 17:46:42 +00:00
PyTorch. This uses `venv` and `git` to check out the nightly development
2020-08-25 19:03:47 +00:00
version of PyTorch and installs pre-built binaries into the current repository.
This is like a development or editable install, but without needing the ability
to compile any C++ code.
2020-08-14 03:04:37 +00:00
2020-08-25 19:03:47 +00:00
You can use this script to check out a new nightly branch with the following:
2020-08-14 03:04:37 +00:00
2020-08-25 19:03:47 +00:00
```bash
./tools/nightly.py checkout -b my-nightly-branch
Refactor nightly pull tool to use `venv` and `pip` (#141281)
Resolves #141238
- #141238
Example output:
```console
$ python3.12 tools/nightly.py checkout -b my-nightly-branch -p my-env --python python3.10
log file: /Users/PanXuehai/Projects/pytorch/nightly/log/2024-11-22_04h15m45s_63f8b29e-a845-11ef-bbf9-32c784498a7b/nightly.log
Creating virtual environment
Creating venv (Python 3.10.15): /Users/PanXuehai/Projects/pytorch/my-env
Installing packages
Upgrading package(s) (https://download.pytorch.org/whl/nightly/cpu): pip, setuptools, wheel
Installing packages took 5.576 [s]
Creating virtual environment took 9.505 [s]
Downloading packages
Downloading package(s) (https://download.pytorch.org/whl/nightly/cpu): torch
Downloaded 9 file(s) to /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/pip-download-lty5dvz4:
- mpmath-1.3.0-py3-none-any.whl
- torch-2.6.0.dev20241121-cp310-none-macosx_11_0_arm64.whl
- jinja2-3.1.4-py3-none-any.whl
- sympy-1.13.1-py3-none-any.whl
- MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl
- networkx-3.4.2-py3-none-any.whl
- fsspec-2024.10.0-py3-none-any.whl
- filelock-3.16.1-py3-none-any.whl
- typing_extensions-4.12.2-py3-none-any.whl
Downloading packages took 7.628 [s]
Installing dependencies
Installing packages
Installing package(s) (https://download.pytorch.org/whl/nightly/cpu): numpy, cmake, ninja, packaging, ruff, mypy, pytest, hypothesis, ipython, rich, clang-format, clang-tidy, sphinx, mpmath-1.3.0-py3-none-any.whl, jinja2-3.1.4-py3-none-any.whl, sympy-1.13.1-py3-none-any.whl, MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl, networkx-3.4.2-py3-none-any.whl, fsspec-2024.10.0-py3-none-any.whl, filelock-3.16.1-py3-none-any.whl, typing_extensions-4.12.2-py3-none-any.whl
Installing packages took 42.514 [s]
Installing dependencies took 42.515 [s]
Unpacking wheel file
Unpacking wheel file took 3.223 [s]
Checking out nightly PyTorch
Found released git version ac47a2d9714278889923ddd40e4210d242d8d4ee
Found nightly release version e0482fdf95eb3ce679fa442b50871d113ceb673b
Switched to a new branch 'my-nightly-branch'
Checking out nightly PyTorch took 0.198 [s]
Moving nightly files into repo
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/_C.cpython-310-darwin.so -> /Users/PanXuehai/Projects/pytorch/torch/_C.cpython-310-darwin.so
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/lib/libtorch_python.dylib -> /Users/PanXuehai/Projects/pytorch/torch/lib/libtorch_python.dylib
...
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/include/c10/macros/Macros.h -> /Users/PanXuehai/Projects/pytorch/torch/include/c10/macros/Macros.h
Moving nightly files into repo took 11.426 [s]
Writing pytorch-nightly.pth
Writing pytorch-nightly.pth took 0.036 [s]
-------
PyTorch Development Environment set up!
Please activate to enable this environment:
$ source /Users/PanXuehai/Projects/pytorch/my-env/bin/activate
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141281
Approved by: https://github.com/seemethere
2024-11-22 17:46:42 +00:00
source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
2020-08-14 03:04:37 +00:00
```
Or if you would like to re-use an existing conda environment, you can pass in
Refactor nightly pull tool to use `venv` and `pip` (#141281)
Resolves #141238
- #141238
Example output:
```console
$ python3.12 tools/nightly.py checkout -b my-nightly-branch -p my-env --python python3.10
log file: /Users/PanXuehai/Projects/pytorch/nightly/log/2024-11-22_04h15m45s_63f8b29e-a845-11ef-bbf9-32c784498a7b/nightly.log
Creating virtual environment
Creating venv (Python 3.10.15): /Users/PanXuehai/Projects/pytorch/my-env
Installing packages
Upgrading package(s) (https://download.pytorch.org/whl/nightly/cpu): pip, setuptools, wheel
Installing packages took 5.576 [s]
Creating virtual environment took 9.505 [s]
Downloading packages
Downloading package(s) (https://download.pytorch.org/whl/nightly/cpu): torch
Downloaded 9 file(s) to /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/pip-download-lty5dvz4:
- mpmath-1.3.0-py3-none-any.whl
- torch-2.6.0.dev20241121-cp310-none-macosx_11_0_arm64.whl
- jinja2-3.1.4-py3-none-any.whl
- sympy-1.13.1-py3-none-any.whl
- MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl
- networkx-3.4.2-py3-none-any.whl
- fsspec-2024.10.0-py3-none-any.whl
- filelock-3.16.1-py3-none-any.whl
- typing_extensions-4.12.2-py3-none-any.whl
Downloading packages took 7.628 [s]
Installing dependencies
Installing packages
Installing package(s) (https://download.pytorch.org/whl/nightly/cpu): numpy, cmake, ninja, packaging, ruff, mypy, pytest, hypothesis, ipython, rich, clang-format, clang-tidy, sphinx, mpmath-1.3.0-py3-none-any.whl, jinja2-3.1.4-py3-none-any.whl, sympy-1.13.1-py3-none-any.whl, MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl, networkx-3.4.2-py3-none-any.whl, fsspec-2024.10.0-py3-none-any.whl, filelock-3.16.1-py3-none-any.whl, typing_extensions-4.12.2-py3-none-any.whl
Installing packages took 42.514 [s]
Installing dependencies took 42.515 [s]
Unpacking wheel file
Unpacking wheel file took 3.223 [s]
Checking out nightly PyTorch
Found released git version ac47a2d9714278889923ddd40e4210d242d8d4ee
Found nightly release version e0482fdf95eb3ce679fa442b50871d113ceb673b
Switched to a new branch 'my-nightly-branch'
Checking out nightly PyTorch took 0.198 [s]
Moving nightly files into repo
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/_C.cpython-310-darwin.so -> /Users/PanXuehai/Projects/pytorch/torch/_C.cpython-310-darwin.so
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/lib/libtorch_python.dylib -> /Users/PanXuehai/Projects/pytorch/torch/lib/libtorch_python.dylib
...
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/include/c10/macros/Macros.h -> /Users/PanXuehai/Projects/pytorch/torch/include/c10/macros/Macros.h
Moving nightly files into repo took 11.426 [s]
Writing pytorch-nightly.pth
Writing pytorch-nightly.pth took 0.036 [s]
-------
PyTorch Development Environment set up!
Please activate to enable this environment:
$ source /Users/PanXuehai/Projects/pytorch/my-env/bin/activate
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141281
Approved by: https://github.com/seemethere
2024-11-22 17:46:42 +00:00
the prefix argument (`--prefix`):
2020-08-14 03:04:37 +00:00
2020-08-25 19:03:47 +00:00
```bash
Refactor nightly pull tool to use `venv` and `pip` (#141281)
Resolves #141238
- #141238
Example output:
```console
$ python3.12 tools/nightly.py checkout -b my-nightly-branch -p my-env --python python3.10
log file: /Users/PanXuehai/Projects/pytorch/nightly/log/2024-11-22_04h15m45s_63f8b29e-a845-11ef-bbf9-32c784498a7b/nightly.log
Creating virtual environment
Creating venv (Python 3.10.15): /Users/PanXuehai/Projects/pytorch/my-env
Installing packages
Upgrading package(s) (https://download.pytorch.org/whl/nightly/cpu): pip, setuptools, wheel
Installing packages took 5.576 [s]
Creating virtual environment took 9.505 [s]
Downloading packages
Downloading package(s) (https://download.pytorch.org/whl/nightly/cpu): torch
Downloaded 9 file(s) to /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/pip-download-lty5dvz4:
- mpmath-1.3.0-py3-none-any.whl
- torch-2.6.0.dev20241121-cp310-none-macosx_11_0_arm64.whl
- jinja2-3.1.4-py3-none-any.whl
- sympy-1.13.1-py3-none-any.whl
- MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl
- networkx-3.4.2-py3-none-any.whl
- fsspec-2024.10.0-py3-none-any.whl
- filelock-3.16.1-py3-none-any.whl
- typing_extensions-4.12.2-py3-none-any.whl
Downloading packages took 7.628 [s]
Installing dependencies
Installing packages
Installing package(s) (https://download.pytorch.org/whl/nightly/cpu): numpy, cmake, ninja, packaging, ruff, mypy, pytest, hypothesis, ipython, rich, clang-format, clang-tidy, sphinx, mpmath-1.3.0-py3-none-any.whl, jinja2-3.1.4-py3-none-any.whl, sympy-1.13.1-py3-none-any.whl, MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl, networkx-3.4.2-py3-none-any.whl, fsspec-2024.10.0-py3-none-any.whl, filelock-3.16.1-py3-none-any.whl, typing_extensions-4.12.2-py3-none-any.whl
Installing packages took 42.514 [s]
Installing dependencies took 42.515 [s]
Unpacking wheel file
Unpacking wheel file took 3.223 [s]
Checking out nightly PyTorch
Found released git version ac47a2d9714278889923ddd40e4210d242d8d4ee
Found nightly release version e0482fdf95eb3ce679fa442b50871d113ceb673b
Switched to a new branch 'my-nightly-branch'
Checking out nightly PyTorch took 0.198 [s]
Moving nightly files into repo
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/_C.cpython-310-darwin.so -> /Users/PanXuehai/Projects/pytorch/torch/_C.cpython-310-darwin.so
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/lib/libtorch_python.dylib -> /Users/PanXuehai/Projects/pytorch/torch/lib/libtorch_python.dylib
...
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/include/c10/macros/Macros.h -> /Users/PanXuehai/Projects/pytorch/torch/include/c10/macros/Macros.h
Moving nightly files into repo took 11.426 [s]
Writing pytorch-nightly.pth
Writing pytorch-nightly.pth took 0.036 [s]
-------
PyTorch Development Environment set up!
Please activate to enable this environment:
$ source /Users/PanXuehai/Projects/pytorch/my-env/bin/activate
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141281
Approved by: https://github.com/seemethere
2024-11-22 17:46:42 +00:00
./tools/nightly.py checkout -b my-nightly-branch -p my-env
source my-env/bin/activate # or `& .\my-env\Scripts\Activate.ps1` on Windows
2020-08-14 03:04:37 +00:00
```
2024-07-19 08:38:29 +00:00
To install the nightly binaries built with CUDA, you can pass in the flag `--cuda` :
```bash
./tools/nightly.py checkout -b my-nightly-branch --cuda
Refactor nightly pull tool to use `venv` and `pip` (#141281)
Resolves #141238
- #141238
Example output:
```console
$ python3.12 tools/nightly.py checkout -b my-nightly-branch -p my-env --python python3.10
log file: /Users/PanXuehai/Projects/pytorch/nightly/log/2024-11-22_04h15m45s_63f8b29e-a845-11ef-bbf9-32c784498a7b/nightly.log
Creating virtual environment
Creating venv (Python 3.10.15): /Users/PanXuehai/Projects/pytorch/my-env
Installing packages
Upgrading package(s) (https://download.pytorch.org/whl/nightly/cpu): pip, setuptools, wheel
Installing packages took 5.576 [s]
Creating virtual environment took 9.505 [s]
Downloading packages
Downloading package(s) (https://download.pytorch.org/whl/nightly/cpu): torch
Downloaded 9 file(s) to /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/pip-download-lty5dvz4:
- mpmath-1.3.0-py3-none-any.whl
- torch-2.6.0.dev20241121-cp310-none-macosx_11_0_arm64.whl
- jinja2-3.1.4-py3-none-any.whl
- sympy-1.13.1-py3-none-any.whl
- MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl
- networkx-3.4.2-py3-none-any.whl
- fsspec-2024.10.0-py3-none-any.whl
- filelock-3.16.1-py3-none-any.whl
- typing_extensions-4.12.2-py3-none-any.whl
Downloading packages took 7.628 [s]
Installing dependencies
Installing packages
Installing package(s) (https://download.pytorch.org/whl/nightly/cpu): numpy, cmake, ninja, packaging, ruff, mypy, pytest, hypothesis, ipython, rich, clang-format, clang-tidy, sphinx, mpmath-1.3.0-py3-none-any.whl, jinja2-3.1.4-py3-none-any.whl, sympy-1.13.1-py3-none-any.whl, MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl, networkx-3.4.2-py3-none-any.whl, fsspec-2024.10.0-py3-none-any.whl, filelock-3.16.1-py3-none-any.whl, typing_extensions-4.12.2-py3-none-any.whl
Installing packages took 42.514 [s]
Installing dependencies took 42.515 [s]
Unpacking wheel file
Unpacking wheel file took 3.223 [s]
Checking out nightly PyTorch
Found released git version ac47a2d9714278889923ddd40e4210d242d8d4ee
Found nightly release version e0482fdf95eb3ce679fa442b50871d113ceb673b
Switched to a new branch 'my-nightly-branch'
Checking out nightly PyTorch took 0.198 [s]
Moving nightly files into repo
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/_C.cpython-310-darwin.so -> /Users/PanXuehai/Projects/pytorch/torch/_C.cpython-310-darwin.so
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/lib/libtorch_python.dylib -> /Users/PanXuehai/Projects/pytorch/torch/lib/libtorch_python.dylib
...
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/include/c10/macros/Macros.h -> /Users/PanXuehai/Projects/pytorch/torch/include/c10/macros/Macros.h
Moving nightly files into repo took 11.426 [s]
Writing pytorch-nightly.pth
Writing pytorch-nightly.pth took 0.036 [s]
-------
PyTorch Development Environment set up!
Please activate to enable this environment:
$ source /Users/PanXuehai/Projects/pytorch/my-env/bin/activate
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141281
Approved by: https://github.com/seemethere
2024-11-22 17:46:42 +00:00
source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
2024-07-19 08:38:29 +00:00
```
2024-12-26 08:48:42 +00:00
To install the nightly binaries built with ROCm, you can pass in the flag `--rocm` :
```bash
./tools/nightly.py checkout -b my-nightly-branch --rocm
source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
```
2020-08-25 19:03:47 +00:00
You can also use this tool to pull the nightly commits into the current branch:
2020-08-20 15:32:35 +00:00
2020-08-25 19:03:47 +00:00
```bash
Refactor nightly pull tool to use `venv` and `pip` (#141281)
Resolves #141238
- #141238
Example output:
```console
$ python3.12 tools/nightly.py checkout -b my-nightly-branch -p my-env --python python3.10
log file: /Users/PanXuehai/Projects/pytorch/nightly/log/2024-11-22_04h15m45s_63f8b29e-a845-11ef-bbf9-32c784498a7b/nightly.log
Creating virtual environment
Creating venv (Python 3.10.15): /Users/PanXuehai/Projects/pytorch/my-env
Installing packages
Upgrading package(s) (https://download.pytorch.org/whl/nightly/cpu): pip, setuptools, wheel
Installing packages took 5.576 [s]
Creating virtual environment took 9.505 [s]
Downloading packages
Downloading package(s) (https://download.pytorch.org/whl/nightly/cpu): torch
Downloaded 9 file(s) to /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/pip-download-lty5dvz4:
- mpmath-1.3.0-py3-none-any.whl
- torch-2.6.0.dev20241121-cp310-none-macosx_11_0_arm64.whl
- jinja2-3.1.4-py3-none-any.whl
- sympy-1.13.1-py3-none-any.whl
- MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl
- networkx-3.4.2-py3-none-any.whl
- fsspec-2024.10.0-py3-none-any.whl
- filelock-3.16.1-py3-none-any.whl
- typing_extensions-4.12.2-py3-none-any.whl
Downloading packages took 7.628 [s]
Installing dependencies
Installing packages
Installing package(s) (https://download.pytorch.org/whl/nightly/cpu): numpy, cmake, ninja, packaging, ruff, mypy, pytest, hypothesis, ipython, rich, clang-format, clang-tidy, sphinx, mpmath-1.3.0-py3-none-any.whl, jinja2-3.1.4-py3-none-any.whl, sympy-1.13.1-py3-none-any.whl, MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl, networkx-3.4.2-py3-none-any.whl, fsspec-2024.10.0-py3-none-any.whl, filelock-3.16.1-py3-none-any.whl, typing_extensions-4.12.2-py3-none-any.whl
Installing packages took 42.514 [s]
Installing dependencies took 42.515 [s]
Unpacking wheel file
Unpacking wheel file took 3.223 [s]
Checking out nightly PyTorch
Found released git version ac47a2d9714278889923ddd40e4210d242d8d4ee
Found nightly release version e0482fdf95eb3ce679fa442b50871d113ceb673b
Switched to a new branch 'my-nightly-branch'
Checking out nightly PyTorch took 0.198 [s]
Moving nightly files into repo
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/_C.cpython-310-darwin.so -> /Users/PanXuehai/Projects/pytorch/torch/_C.cpython-310-darwin.so
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/lib/libtorch_python.dylib -> /Users/PanXuehai/Projects/pytorch/torch/lib/libtorch_python.dylib
...
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/include/c10/macros/Macros.h -> /Users/PanXuehai/Projects/pytorch/torch/include/c10/macros/Macros.h
Moving nightly files into repo took 11.426 [s]
Writing pytorch-nightly.pth
Writing pytorch-nightly.pth took 0.036 [s]
-------
PyTorch Development Environment set up!
Please activate to enable this environment:
$ source /Users/PanXuehai/Projects/pytorch/my-env/bin/activate
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141281
Approved by: https://github.com/seemethere
2024-11-22 17:46:42 +00:00
./tools/nightly.py pull -p my-env
source my-env/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
2020-08-20 15:32:35 +00:00
```
Refactor nightly pull tool to use `venv` and `pip` (#141281)
Resolves #141238
- #141238
Example output:
```console
$ python3.12 tools/nightly.py checkout -b my-nightly-branch -p my-env --python python3.10
log file: /Users/PanXuehai/Projects/pytorch/nightly/log/2024-11-22_04h15m45s_63f8b29e-a845-11ef-bbf9-32c784498a7b/nightly.log
Creating virtual environment
Creating venv (Python 3.10.15): /Users/PanXuehai/Projects/pytorch/my-env
Installing packages
Upgrading package(s) (https://download.pytorch.org/whl/nightly/cpu): pip, setuptools, wheel
Installing packages took 5.576 [s]
Creating virtual environment took 9.505 [s]
Downloading packages
Downloading package(s) (https://download.pytorch.org/whl/nightly/cpu): torch
Downloaded 9 file(s) to /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/pip-download-lty5dvz4:
- mpmath-1.3.0-py3-none-any.whl
- torch-2.6.0.dev20241121-cp310-none-macosx_11_0_arm64.whl
- jinja2-3.1.4-py3-none-any.whl
- sympy-1.13.1-py3-none-any.whl
- MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl
- networkx-3.4.2-py3-none-any.whl
- fsspec-2024.10.0-py3-none-any.whl
- filelock-3.16.1-py3-none-any.whl
- typing_extensions-4.12.2-py3-none-any.whl
Downloading packages took 7.628 [s]
Installing dependencies
Installing packages
Installing package(s) (https://download.pytorch.org/whl/nightly/cpu): numpy, cmake, ninja, packaging, ruff, mypy, pytest, hypothesis, ipython, rich, clang-format, clang-tidy, sphinx, mpmath-1.3.0-py3-none-any.whl, jinja2-3.1.4-py3-none-any.whl, sympy-1.13.1-py3-none-any.whl, MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl, networkx-3.4.2-py3-none-any.whl, fsspec-2024.10.0-py3-none-any.whl, filelock-3.16.1-py3-none-any.whl, typing_extensions-4.12.2-py3-none-any.whl
Installing packages took 42.514 [s]
Installing dependencies took 42.515 [s]
Unpacking wheel file
Unpacking wheel file took 3.223 [s]
Checking out nightly PyTorch
Found released git version ac47a2d9714278889923ddd40e4210d242d8d4ee
Found nightly release version e0482fdf95eb3ce679fa442b50871d113ceb673b
Switched to a new branch 'my-nightly-branch'
Checking out nightly PyTorch took 0.198 [s]
Moving nightly files into repo
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/_C.cpython-310-darwin.so -> /Users/PanXuehai/Projects/pytorch/torch/_C.cpython-310-darwin.so
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/lib/libtorch_python.dylib -> /Users/PanXuehai/Projects/pytorch/torch/lib/libtorch_python.dylib
...
Linking /var/folders/sq/7sf73d5s2qnb3w6jjsmhsw3h0000gn/T/wheel-dljxil5i/torch-2.6.0.dev20241121/torch/include/c10/macros/Macros.h -> /Users/PanXuehai/Projects/pytorch/torch/include/c10/macros/Macros.h
Moving nightly files into repo took 11.426 [s]
Writing pytorch-nightly.pth
Writing pytorch-nightly.pth took 0.036 [s]
-------
PyTorch Development Environment set up!
Please activate to enable this environment:
$ source /Users/PanXuehai/Projects/pytorch/my-env/bin/activate
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141281
Approved by: https://github.com/seemethere
2024-11-22 17:46:42 +00:00
Pulling will recreate a fresh virtual environment and reinstall the development
dependencies as well as the nightly binaries into the repo directory.
2020-08-20 15:32:35 +00:00
2018-10-17 03:00:25 +00:00
## Codebase structure
* [c10 ](c10 ) - Core library files that work everywhere, both server
2018-12-20 20:20:42 +00:00
and mobile. We are slowly moving pieces from [ATen/core ](aten/src/ATen/core )
here. This library is intended only to contain essential functionality,
and appropriate to use in settings where binary size matters. (But
2018-10-17 03:00:25 +00:00
you'll have a lot of missing functionality if you try to use it
directly.)
* [aten ](aten ) - C++ tensor library for PyTorch (no autograd support)
2020-05-08 02:31:37 +00:00
* [src ](aten/src ) - [README ](aten/src/README.md )
2018-12-20 20:20:42 +00:00
* [ATen ](aten/src/ATen )
* [core ](aten/src/ATen/core ) - Core functionality of ATen. This
2018-10-17 03:00:25 +00:00
is migrating to top-level c10 folder.
* [native ](aten/src/ATen/native ) - Modern implementations of
2018-12-20 20:20:42 +00:00
operators. If you want to write a new operator, here is where
it should go. Most CPU operators go in the top level directory,
2018-10-17 03:00:25 +00:00
except for operators which need to be compiled specially; see
cpu below.
* [cpu ](aten/src/ATen/native/cpu ) - Not actually CPU
implementations of operators, but specifically implementations
which are compiled with processor-specific instructions, like
2018-12-20 20:20:42 +00:00
AVX. See the [README ](aten/src/ATen/native/cpu/README.md ) for more
details.
2018-10-17 03:00:25 +00:00
* [cuda ](aten/src/ATen/native/cuda ) - CUDA implementations of
operators.
* [sparse ](aten/src/ATen/native/sparse ) - CPU and CUDA
implementations of COO sparse tensor operations
* [mkl ](aten/src/ATen/native/mkl ) [mkldnn ](aten/src/ATen/native/mkldnn )
[miopen ](aten/src/ATen/native/miopen ) [cudnn ](aten/src/ATen/native/cudnn )
- implementations of operators which simply bind to some
backend library.
2020-05-08 02:31:37 +00:00
* [quantized ](aten/src/ATen/native/quantized/ ) - Quantized tensor (i.e. QTensor) operation implementations. [README ](aten/src/ATen/native/quantized/README.md ) contains details including how to implement native quantized operations.
2018-12-20 20:20:42 +00:00
* [torch ](torch ) - The actual PyTorch library. Everything that is not
in [csrc ](torch/csrc ) is a Python module, following the PyTorch Python
frontend module structure.
* [csrc ](torch/csrc ) - C++ files composing the PyTorch library. Files
2018-10-17 03:00:25 +00:00
in this directory tree are a mix of Python binding code, and C++
2018-12-20 20:20:42 +00:00
heavy lifting. Consult `setup.py` for the canonical list of Python
2018-10-17 03:00:25 +00:00
binding files; conventionally, they are often prefixed with
2020-05-08 02:31:37 +00:00
`python_` . [README ](torch/csrc/README.md )
2018-10-17 03:00:25 +00:00
* [jit ](torch/csrc/jit ) - Compiler and frontend for TorchScript JIT
2020-05-08 02:31:37 +00:00
frontend. [README ](torch/csrc/jit/README.md )
* [autograd ](torch/csrc/autograd ) - Implementation of reverse-mode automatic differentiation. [README ](torch/csrc/autograd/README.md )
2018-10-17 03:00:25 +00:00
* [api ](torch/csrc/api ) - The PyTorch C++ frontend.
* [distributed ](torch/csrc/distributed ) - Distributed training
support for PyTorch.
2018-11-06 19:18:48 +00:00
* [tools ](tools ) - Code generation scripts for the PyTorch library.
2018-12-20 20:20:42 +00:00
See [README ](tools/README.md ) of this directory for more details.
2020-07-08 02:59:12 +00:00
* [test ](test ) - Python unit tests for PyTorch Python frontend.
2018-10-17 03:00:25 +00:00
* [test_torch.py ](test/test_torch.py ) - Basic tests for PyTorch
2018-12-20 20:20:42 +00:00
functionality.
2018-10-17 03:00:25 +00:00
* [test_autograd.py ](test/test_autograd.py ) - Tests for non-NN
2018-12-20 20:20:42 +00:00
automatic differentiation support.
2018-10-17 03:00:25 +00:00
* [test_nn.py ](test/test_nn.py ) - Tests for NN operators and
2018-12-20 20:20:42 +00:00
their automatic differentiation.
2018-10-17 03:00:25 +00:00
* [test_jit.py ](test/test_jit.py ) - Tests for the JIT compiler
2018-12-20 20:20:42 +00:00
and TorchScript.
2018-10-17 03:00:25 +00:00
* ...
2018-12-20 20:20:42 +00:00
* [cpp ](test/cpp ) - C++ unit tests for PyTorch C++ frontend.
2020-05-08 02:31:37 +00:00
* [api ](test/cpp/api ) - [README ](test/cpp/api/README.md )
* [jit ](test/cpp/jit ) - [README ](test/cpp/jit/README.md )
* [tensorexpr ](test/cpp/tensorexpr ) - [README ](test/cpp/tensorexpr/README.md )
2018-10-17 03:00:25 +00:00
* [expect ](test/expect ) - Automatically generated "expect" files
which are used to compare against expected output.
* [onnx ](test/onnx ) - Tests for ONNX export functionality,
using both PyTorch and Caffe2.
* [caffe2 ](caffe2 ) - The Caffe2 library.
* [core ](caffe2/core ) - Core files of Caffe2, e.g., tensor, workspace,
blobs, etc.
2018-12-20 20:20:42 +00:00
* [operators ](caffe2/operators ) - Operators of Caffe2.
* [python ](caffe2/python ) - Python bindings to Caffe2.
2018-10-17 03:00:25 +00:00
* ...
2020-05-08 02:31:37 +00:00
* [.circleci ](.circleci ) - CircleCI configuration management. [README ](.circleci/README.md )
2018-10-17 03:00:25 +00:00
2018-03-15 16:06:20 +00:00
## Unit testing
2021-04-01 17:15:06 +00:00
### Python Unit Testing
2020-07-21 18:25:57 +00:00
2022-06-02 19:31:48 +00:00
**Prerequisites**:
The following packages should be installed with either `conda` or `pip` :
- `expecttest` and `hypothesis` - required to run tests
- `mypy` - recommended for linting
- `pytest` - recommended to run tests more selectively
2024-10-25 18:47:44 +00:00
Running
```
2024-12-08 19:35:31 +00:00
pip install -r requirements.txt
2024-10-25 18:47:44 +00:00
```
will install these dependencies for you.
2022-06-02 19:31:48 +00:00
2021-04-01 17:15:06 +00:00
All PyTorch test suites are located in the `test` folder and start with
`test_` . Run the entire test
suite with
2018-03-15 16:06:20 +00:00
2018-12-20 20:20:42 +00:00
```bash
2018-03-15 16:06:20 +00:00
python test/run_test.py
```
2021-04-01 17:15:06 +00:00
or run individual test suites using the command `python test/FILENAME.py` ,
where `FILENAME` represents the file containing the test suite you wish
to run.
2018-03-15 16:06:20 +00:00
2021-04-01 17:15:06 +00:00
For example, to run all the TorchScript JIT tests (located at
`test/test_jit.py` ), you would run:
2018-03-15 16:06:20 +00:00
2021-04-01 17:15:06 +00:00
```bash
python test/test_jit.py
```
You can narrow down what you're testing even further by specifying the
name of an individual test with `TESTCLASSNAME.TESTNAME` . Here,
`TESTNAME` is the name of the test you want to run, and `TESTCLASSNAME`
is the name of the class in which it is defined.
Going off the above example, let's say you want to run
`test_Sequential` , which is defined as part of the `TestJit` class
in `test/test_jit.py` . Your command would be:
```bash
python test/test_jit.py TestJit.test_Sequential
```
2022-04-12 00:54:18 +00:00
**Weird note:** In our CI (Continuous Integration) jobs, we actually run the tests from the `test` folder and **not** the root of the repo, since there are various dependencies we set up for CI that expects the tests to be run from the test folder. As such, there may be some inconsistencies between local testing and CI testing--if you observe an inconsistency, please [file an issue ](https://github.com/pytorch/pytorch/issues/new/choose ).
2021-04-01 17:15:06 +00:00
### Better local unit tests with `pytest`
2021-09-08 02:00:18 +00:00
2021-04-01 17:15:06 +00:00
We don't officially support `pytest` , but it works well with our
`unittest` tests and offers a number of useful features for local
developing. Install it via `pip install pytest` .
If you want to just run tests that contain a specific substring, you can
use the `-k` flag:
2018-03-15 16:06:20 +00:00
2018-12-20 20:20:42 +00:00
```bash
2018-03-15 16:06:20 +00:00
pytest test/test_nn.py -k Loss -v
```
2021-04-01 17:15:06 +00:00
The above is an example of testing a change to all Loss functions: this
command runs tests such as `TestNN.test_BCELoss` and
`TestNN.test_MSELoss` and can be useful to save keystrokes.
2017-07-10 14:24:54 +00:00
2021-04-20 19:14:37 +00:00
### Local linting
2022-06-02 19:31:48 +00:00
Install all prerequisites by running
2021-04-20 19:14:37 +00:00
```bash
2024-07-19 08:38:29 +00:00
make setup-lint
2021-04-20 19:14:37 +00:00
```
2022-06-02 19:31:48 +00:00
You can now run the same linting steps that are used in CI locally via `make` :
2021-04-20 19:14:37 +00:00
```bash
2022-06-02 19:31:48 +00:00
make lint
2021-07-26 20:37:54 +00:00
```
2022-06-02 19:31:48 +00:00
Learn more about the linter on the [lintrunner wiki page ](https://github.com/pytorch/pytorch/wiki/lintrunner )
2021-04-20 19:14:37 +00:00
2022-06-02 19:31:48 +00:00
#### Running `mypy`
2021-01-14 18:04:10 +00:00
2021-04-01 17:15:06 +00:00
`mypy` is an optional static type checker for Python. We have multiple `mypy`
2022-06-02 19:31:48 +00:00
configs for the PyTorch codebase that are automatically validated against whenever the linter is run.
2021-04-01 17:15:06 +00:00
2021-01-14 18:04:10 +00:00
See [Guide for adding type annotations to
PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch)
for more information on how to set up `mypy` and tackle type annotation
2021-04-01 17:15:06 +00:00
tasks.
### C++ Unit Testing
PyTorch offers a series of tests located in the `test/cpp` folder.
These tests are written in C++ and use the Google Test testing framework.
After compiling PyTorch from source, the test runner binaries will be
written to the `build/bin` folder. The command to run one of these tests
is `./build/bin/FILENAME --gtest_filter=TESTSUITE.TESTNAME` , where
`TESTNAME` is the name of the test you'd like to run and `TESTSUITE` is
the suite that test is defined in.
2021-09-03 13:10:37 +00:00
For example, if you wanted to run the test `MayContainAlias` , which
2021-04-01 17:15:06 +00:00
is part of the test suite `ContainerAliasingTest` in the file
`test/cpp/jit/test_alias_analysis.cpp` , the command would be:
```bash
2021-09-03 13:10:37 +00:00
./build/bin/test_jit --gtest_filter=ContainerAliasingTest.MayContainAlias
2021-04-01 17:15:06 +00:00
```
2021-01-14 18:04:10 +00:00
2021-05-11 18:57:19 +00:00
### Run Specific CI Jobs
You can generate a commit that limits the CI to only run a specific job by using
2021-06-24 17:12:37 +00:00
`tools/testing/explicit_ci_jobs.py` like so:
2021-05-11 18:57:19 +00:00
```bash
# --job: specify one or more times to filter to a specific job + its dependencies
2021-11-04 21:36:52 +00:00
# --filter-gha: specify github actions workflows to keep
2021-05-11 18:57:19 +00:00
# --make-commit: commit CI changes to git with a message explaining the change
2021-09-07 22:14:05 +00:00
python tools/testing/explicit_ci_jobs.py --job binary_linux_manywheel_3_6m_cpu_devtoolset7_nightly_test --filter-gha '*generated*gcc5.4*' --make-commit
2021-05-11 18:57:19 +00:00
# Make your changes
ghstack submit
```
**NB**: It is not recommended to use this workflow unless you are also using
[`ghstack` ](https://github.com/ezyang/ghstack ). It creates a large commit that is
of very low signal to reviewers.
2022-09-23 18:23:34 +00:00
## Merging your Change
2023-03-04 00:14:02 +00:00
If you know the right people or team that should approve your PR (and you have the required permissions to do so), add them to the Reviewers list.
2022-09-23 18:23:34 +00:00
If not, leave the Reviewers section empty. Our triage squad will review your PR, add a module label, and assign it to the appropriate reviewer in a couple business days. The reviewer will then look at your PR and respond.
Occasionally, things might fall through the cracks (sorry!). In case your PR either doesn't get assigned to a reviewer or doesn't get any response from the reviewer for 4 business days, please leave comment on the PR (mentioning the reviewer if one has been assigned). That'll get it nudged back onto people's radar.
2023-06-28 17:07:34 +00:00
If that still doesn't help, come see us during [our office hours ](https://github.com/pytorch/pytorch/wiki/Dev-Infra-Office-Hours )
2022-09-23 18:23:34 +00:00
Once your PR is approved, you can merge it in by entering a comment with the content `@pytorchmergebot merge` ([what's this bot?](https://github.com/pytorch/pytorch/wiki/Bot-commands))
2019-12-03 00:42:41 +00:00
## Writing documentation
2017-07-10 14:24:54 +00:00
2021-08-16 22:35:05 +00:00
So you want to write some documentation and don't know where to start?
PyTorch has two main types of documentation:
2022-06-02 19:31:48 +00:00
- **User facing documentation**:
2021-11-30 19:49:06 +00:00
These are the docs that you see over at [our docs website ](https://pytorch.org/docs ).
2022-06-02 19:31:48 +00:00
- **Developer facing documentation**:
2021-08-16 22:35:05 +00:00
Developer facing documentation is spread around our READMEs in our codebase and in
2021-11-30 19:49:06 +00:00
the [PyTorch Developer Wiki ](https://pytorch.org/wiki ).
2022-07-23 19:19:38 +00:00
If you're interested in adding new developer docs, please read this [page on the wiki ](https://github.com/pytorch/pytorch/wiki/Where-or-how-should-I-add-documentation ) on our best practices for where to put it.
2021-08-16 22:35:05 +00:00
The rest of this section is about user-facing documentation.
2022-08-17 14:53:02 +00:00
PyTorch uses [Google style ](https://www.sphinx-doc.org/en/master/usage/extensions/example_google.html )
2022-06-02 19:31:48 +00:00
for formatting docstrings. Each line inside a docstrings block must be limited to 80 characters so that it fits into Jupyter documentation popups.
2017-07-10 14:24:54 +00:00
2022-08-17 14:53:02 +00:00
### Docstring type formatting
In addition to the standard Google Style docstring formatting rules, the following guidelines should be followed for docstring types (docstring types are the type information contained in the round brackets after the variable name):
2022-08-24 23:41:09 +00:00
* The "`Callable`", "`Any`", "`Iterable`", "`Iterator`", "`Generator`" types should have their first letter capitalized.
* The "`list`" and "`tuple`" types should be completely lowercase.
2022-08-17 14:53:02 +00:00
* Types should not be made plural. For example: `tuple of int` should be used instead of `tuple of ints` .
2022-08-24 23:41:09 +00:00
* The only acceptable delimiter words for types are `or` and `of` . No other non-type words should be used other than `optional` .
2022-08-17 14:53:02 +00:00
* The word `optional` should only be used after the types, and it is only used if the user does not have to specify a value for the variable. Default values are listed after the variable description. Example:
```
my_var (int, optional): Variable description. Default: 1
```
2022-08-24 23:41:09 +00:00
* Basic Python types should match their type name so that the [Intersphinx ](https://www.sphinx-doc.org/en/master/usage/extensions/intersphinx.html ) extension can correctly identify them. For example:
2022-08-17 14:53:02 +00:00
* Use `str` instead of `string` .
* Use `bool` instead of `boolean` .
* Use `dict` instead of `dictionary` .
2022-08-24 23:41:09 +00:00
* Square brackets should be used for the dictionary type. For example:
2022-08-17 14:53:02 +00:00
```
my_var (dict[str, int]): Variable description.
```
* If a variable has two different possible types, then the word `or` should be used without a comma. Otherwise variables with 3 or more types should use commas to separate the types. Example:
```
x (type1 or type2): Variable description.
y (type1, type2, or type3): Variable description.
```
2019-12-03 00:42:41 +00:00
### Building documentation
2019-09-24 17:35:53 +00:00
To build the documentation:
1. Build and install PyTorch
2020-07-20 17:14:11 +00:00
2. Install the prerequisites
2019-09-24 17:35:53 +00:00
```bash
cd docs
pip install -r requirements.txt
# `katex` must also be available in your PATH.
2019-10-12 00:18:29 +00:00
# You can either install katex globally if you have properly configured npm:
2019-10-05 05:12:25 +00:00
# npm install -g katex
2019-10-12 00:18:29 +00:00
# Or if you prefer an uncontaminated global executable environment or do not want to go through the node configuration:
# npm install katex && export PATH="$PATH:$(pwd)/node_modules/.bin"
2021-04-14 15:46:04 +00:00
```
2022-07-15 09:21:14 +00:00
> Note: if you installed `nodejs` with a different package manager (e.g.,
`conda` ) then `npm` will probably install a version of `katex` that is not
compatible with your version of `nodejs` and doc builds will fail.
A combination of versions that is known to work is `node@6.13.1` and
`katex@0.13.18` . To install the latter with `npm` you can run
```npm install -g katex@0.13.18```
2021-04-14 15:46:04 +00:00
> Note that if you are a Facebook employee using a devserver, yarn may be more convenient to install katex:
2021-09-08 02:00:18 +00:00
```bash
2021-04-14 15:46:04 +00:00
yarn global add katex
2019-09-24 17:35:53 +00:00
```
2022-07-15 09:21:14 +00:00
> If a specific version is required you can use for example `yarn global add katex@0.13.18`.
2019-09-24 17:35:53 +00:00
3. Generate the documentation HTML files. The generated files will be in `docs/build/html` .
```bash
make html
```
#### Tips
The `.rst` source files live in [docs/source ](docs/source ). Some of the `.rst`
files pull in docstrings from PyTorch Python code (for example, via
the `autofunction` or `autoclass` directives). To vastly shorten doc build times,
it is helpful to remove the files you are not working on, only keeping the base
`index.rst` file and the files you are editing. The Sphinx build will produce
missing file warnings but will still complete. For example, to work on `jit.rst` :
```bash
cd docs/source
2022-03-07 14:50:09 +00:00
find . -type f | grep rst | grep -v index | grep -v jit | xargs rm
2019-09-24 17:35:53 +00:00
# Make your changes, build the docs, etc.
# Don't commit the deletions!
2019-12-03 00:42:41 +00:00
git add index.rst jit.rst
2019-09-24 17:35:53 +00:00
...
```
2020-01-08 23:35:47 +00:00
#### Building C++ Documentation
2021-09-08 02:00:18 +00:00
2020-01-08 23:35:47 +00:00
For C++ documentation (https://pytorch.org/cppdocs), we use
[Doxygen ](http://www.doxygen.nl/ ) and then convert it to
[Sphinx ](http://www.sphinx-doc.org/ ) via
[Breathe ](https://github.com/michaeljones/breathe ) and
[Exhale ](https://github.com/svenevs/exhale ). Check the [Doxygen
2023-10-06 05:23:27 +00:00
reference](https://www.doxygen.nl/manual/) for more
2020-01-08 23:35:47 +00:00
information on the documentation syntax.
We run Doxygen in CI (Travis) to verify that you do not use invalid Doxygen
commands. To run this check locally, run `./check-doxygen.sh` from inside
2020-11-10 19:00:58 +00:00
`docs/cpp/source` .
2020-01-08 23:35:47 +00:00
To build the documentation, follow the same steps as above, but run them from
`docs/cpp` instead of `docs` .
2021-07-01 19:16:24 +00:00
### Previewing changes locally
2020-01-08 23:35:47 +00:00
To view HTML files locally, you can open the files in your web browser. For example,
navigate to `file:///your_pytorch_folder/docs/build/html/index.html` in a web
browser.
If you are developing on a remote machine, you can set up an SSH tunnel so that
you can access the HTTP server on the remote machine from your local machine. To map
remote port 8000 to local port 8000, use either of the following commands.
```bash
# For SSH
ssh my_machine -L 8000:my_machine:8000
# For Eternal Terminal
et my_machine -t="8000:8000"
```
Then navigate to `localhost:8000` in your web browser.
2021-04-14 15:46:04 +00:00
**Tip:**
You can start a lightweight HTTP server on the remote machine with:
2021-09-08 02:00:18 +00:00
```bash
2021-04-14 15:46:04 +00:00
python -m http.server 8000 < path_to_html_output >
```
2020-11-10 19:00:58 +00:00
Alternatively, you can run `rsync` on your local machine to copy the files from
your remote machine:
2021-09-08 02:00:18 +00:00
2020-11-10 19:00:58 +00:00
```bash
mkdir -p build cpp/build
rsync -az me@my_machine:/path/to/pytorch/docs/build/html build
rsync -az me@my_machine:/path/to/pytorch/docs/cpp/build/html cpp/build
```
2021-07-01 19:16:24 +00:00
### Previewing documentation on PRs
2020-01-08 23:35:47 +00:00
2023-09-05 20:17:51 +00:00
PyTorch will host documentation previews at `https://docs-preview.pytorch.org/pytorch/pytorch/<pr number>/index.html` once the
2021-07-01 19:16:24 +00:00
`pytorch_python_doc_build` GitHub Actions job has completed on your PR. You can visit that page directly
or find its link in the automated Dr. CI comment on your PR.
2019-09-24 17:35:53 +00:00
2019-12-03 00:42:41 +00:00
### Adding documentation tests
2019-09-24 17:35:53 +00:00
It is easy for code snippets in docstrings and `.rst` files to get out of date. The docs
build includes the [Sphinx Doctest Extension ](https://www.sphinx-doc.org/en/master/usage/extensions/doctest.html ),
which can run code in documentation as a unit test. To use the extension, use
the `.. testcode::` directive in your `.rst` and docstrings.
To manually run these tests, follow steps 1 and 2 above, then run:
```bash
cd docs
make doctest
```
2019-11-20 23:04:37 +00:00
## Profiling with `py-spy`
Evaluating the performance impact of code changes in PyTorch can be complicated,
particularly if code changes happen in compiled code. One simple way to profile
both Python and C++ code in PyTorch is to use
[`py-spy` ](https://github.com/benfred/py-spy ), a sampling profiler for Python
that has the ability to profile native code and Python code in the same session.
`py-spy` can be installed via `pip` :
```bash
2021-09-08 02:00:18 +00:00
pip install py-spy
2019-11-20 23:04:37 +00:00
```
To use `py-spy` , first write a Python test script that exercises the
functionality you would like to profile. For example, this script profiles
`torch.add` :
```python
import torch
t1 = torch.tensor([[1, 1], [1, 1.]])
t2 = torch.tensor([[0, 0], [0, 0.]])
for _ in range(1000000):
torch.add(t1, t2)
```
Since the `torch.add` operation happens in microseconds, we repeat it a large
number of times to get good statistics. The most straightforward way to use
`py-spy` with such a script is to generate a [flame
graph](http://www.brendangregg.com/flamegraphs.html):
```bash
2021-09-08 02:00:18 +00:00
py-spy record -o profile.svg --native -- python test_tensor_tensor_add.py
2019-11-20 23:04:37 +00:00
```
This will output a file named `profile.svg` containing a flame graph you can
view in a web browser or SVG viewer. Individual stack frame entries in the graph
can be selected interactively with your mouse to zoom in on a particular part of
the program execution timeline. The `--native` command-line option tells
`py-spy` to record stack frame entries for PyTorch C++ code. To get line numbers
for C++ code it may be necessary to compile PyTorch in debug mode by prepending
your `setup.py develop` call to compile PyTorch with `DEBUG=1` . Depending on
your operating system it may also be necessary to run `py-spy` with root
privileges.
`py-spy` can also work in an `htop` -like "live profiling" mode and can be
tweaked to adjust the stack sampling rate, see the `py-spy` readme for more
details.
2019-12-03 00:42:41 +00:00
## Managing multiple build trees
2017-05-25 15:21:52 +00:00
One downside to using `python setup.py develop` is that your development
2018-12-20 20:20:42 +00:00
version of PyTorch will be installed globally on your account (e.g., if
2017-05-25 15:21:52 +00:00
you run `import torch` anywhere else, the development version will be
used.
If you want to manage multiple builds of PyTorch, you can make use of
[conda environments ](https://conda.io/docs/using/envs.html ) to maintain
separate Python package environments, each of which can be tied to a
2018-12-20 20:20:42 +00:00
specific build of PyTorch. To set one up:
2017-05-25 15:21:52 +00:00
2018-12-20 20:20:42 +00:00
```bash
2017-05-25 15:21:52 +00:00
conda create -n pytorch-myfeature
source activate pytorch-myfeature
# if you run python now, torch will NOT be installed
2019-01-29 04:43:59 +00:00
python setup.py develop
2017-05-25 15:21:52 +00:00
```
2019-12-03 00:42:41 +00:00
## C++ development tips
2017-06-13 21:23:39 +00:00
If you are working on the C++ code, there are a few important things that you
will want to keep in mind:
2018-12-20 20:20:42 +00:00
1. How to rebuild only the code you are working on.
2017-06-13 21:23:39 +00:00
2. How to make rebuilds in the absence of changes go faster.
2019-12-03 00:42:41 +00:00
### Build only what you need
2017-04-18 19:39:58 +00:00
2019-03-20 14:33:51 +00:00
`python setup.py build` will build everything by default, but sometimes you are
only interested in a specific component.
2017-06-13 21:23:39 +00:00
2018-12-20 20:20:42 +00:00
- Working on a test binary? Run `(cd build && ninja bin/test_binary_name)` to
rebuild only that test binary (without rerunning cmake). (Replace `ninja` with
2018-07-25 05:18:43 +00:00
`make` if you don't have ninja installed).
2017-06-13 21:23:39 +00:00
On the initial build, you can also speed things up with the environment
faster build instructions CONTRIBUTING.md (#109900)
Discovered this as I was building pytorch on a fresh g5.4x instance on aws, building flash attnetion was bricking my machine
```
Building wheel torch-2.2.0a0+gitd0c8e82
-- Building version 2.2.0a0+gitd0c8e82
cmake --build . --target install --config Release
[1/748] Building CUDA object caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu.o
FAILED: caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu.o
/opt/conda/envs/torchbench/bin/ccache /usr/local/cuda/bin/nvcc -forward-unknown-to-host-compiler -DAT_PER_OPERATOR_HEADERS -DHAVE_MALLOC_USABLE_SIZE=1 -DHAVE_MMAP=1 -DHAVE_SHM_OPEN=1 -DHAVE_SHM_UNLINK=1 -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DTORCH_CUDA_BUILD_MAIN_LIB -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_CUDA -DUSE_DISTRIBUTED -DUSE_EXPERIMENTAL_CUDNN_V8_API -DUSE_EXTERNAL_MZCRC -DUSE_FLASH_ATTENTION -DUSE_MEM_EFF_ATTENTION -DUSE_NCCL -DUSE_RPC -DUSE_TENSORPIPE -D_FILE_OFFSET_BITS=64 -Dtorch_cuda_EXPORTS -I/home/ubuntu/pytorch/build/aten/src -I/home/ubuntu/pytorch/aten/src -I/home/ubuntu/pytorch/build -I/home/ubuntu/pytorch -I/home/ubuntu/pytorch/cmake/../third_party/benchmark/include -I/home/ubuntu/pytorch/third_party/onnx -I/home/ubuntu/pytorch/build/third_party/onnx -I/home/ubuntu/pytorch/third_party/foxi -I/home/ubuntu/pytorch/build/third_party/foxi -I/home/ubuntu/pytorch/aten/src/THC -I/home/ubuntu/pytorch/aten/src/ATen/cuda -I/home/ubuntu/pytorch/aten/src/ATen/../../../third_party/cutlass/include -I/home/ubuntu/pytorch/build/caffe2/aten/src -I/home/ubuntu/pytorch/aten/src/ATen/.. -I/home/ubuntu/pytorch/build/nccl/include -I/home/ubuntu/pytorch/c10/cuda/../.. -I/home/ubuntu/pytorch/c10/.. -I/home/ubuntu/pytorch/third_party/tensorpipe -I/home/ubuntu/pytorch/build/third_party/tensorpipe -I/home/ubuntu/pytorch/third_party/tensorpipe/third_party/libnop/include -I/home/ubuntu/pytorch/torch/csrc/api -I/home/ubuntu/pytorch/torch/csrc/api/include -isystem /home/ubuntu/pytorch/build/third_party/gloo -isystem /home/ubuntu/pytorch/cmake/../third_party/gloo -isystem /home/ubuntu/pytorch/cmake/../third_party/tensorpipe/third_party/libuv/include -isystem /home/ubuntu/pytorch/cmake/../third_party/googletest/googlemock/include -isystem /home/ubuntu/pytorch/cmake/../third_party/googletest/googletest/include -isystem /home/ubuntu/pytorch/third_party/protobuf/src -isystem /home/ubuntu/pytorch/third_party/gemmlowp -isystem /home/ubuntu/pytorch/third_party/neon2sse -isystem /home/ubuntu/pytorch/third_party/XNNPACK/include -isystem /home/ubuntu/pytorch/third_party/ittapi/include -isystem /home/ubuntu/pytorch/cmake/../third_party/eigen -isystem /usr/local/cuda/include -isystem /home/ubuntu/pytorch/third_party/ideep/mkl-dnn/include/oneapi/dnnl -isystem /home/ubuntu/pytorch/third_party/ideep/include -isystem /home/ubuntu/pytorch/cmake/../third_party/cudnn_frontend/include -D_GLIBCXX_USE_CXX11_ABI=1 -Xfatbin -compress-all -DONNX_NAMESPACE=onnx_torch -gencode arch=compute_86,code=sm_86 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=set_but_not_used,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -Wno-deprecated-gpu-targets --expt-extended-lambda -DCUB_WRAPPED_NAMESPACE=at_cuda_detail -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -O3 -DNDEBUG -std=c++17 -Xcompiler=-fPIC -DTORCH_USE_LIBUV -DCAFFE2_USE_GLOO -Xcompiler=-Wall,-Wextra,-Wno-unused-parameter,-Wno-unused-function,-Wno-unused-result,-Wno-missing-field-initializers,-Wno-unknown-pragmas,-Wno-type-limits,-Wno-array-bounds,-Wno-unknown-pragmas,-Wno-strict-overflow,-Wno-strict-aliasing,-Wno-missing-braces,-Wno-maybe-uninitialized -MD -MT caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu.o -MF caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu.o.d -x cu -c /home/ubuntu/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu -o caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu.o
Killed
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109900
Approved by: https://github.com/drisspg
2023-09-22 22:39:47 +00:00
variables `DEBUG` , `USE_DISTRIBUTED` , `USE_MKLDNN` , `USE_CUDA` , `USE_FLASH_ATTENTION` , `USE_MEM_EFF_ATTENTION` , `BUILD_TEST` , `USE_FBGEMM` , `USE_NNPACK` and `USE_QNNPACK` .
2017-04-18 19:39:26 +00:00
- `DEBUG=1` will enable debug builds (-g -O0)
2018-12-12 06:15:20 +00:00
- `REL_WITH_DEB_INFO=1` will enable debug symbols with optimizations (-g -O3)
2019-07-11 11:12:08 +00:00
- `USE_DISTRIBUTED=0` will disable distributed (c10d, gloo, mpi, etc.) build.
- `USE_MKLDNN=0` will disable using MKL-DNN.
2019-07-08 15:13:46 +00:00
- `USE_CUDA=0` will disable compiling CUDA (in case you are developing on something not CUDA related), to save compile time.
2019-07-11 11:12:08 +00:00
- `BUILD_TEST=0` will disable building C++ test binaries.
- `USE_FBGEMM=0` will disable using FBGEMM (quantized 8-bit server operators).
- `USE_NNPACK=0` will disable compiling with NNPACK.
- `USE_QNNPACK=0` will disable QNNPACK build (quantized 8-bit operators).
2020-03-03 20:40:57 +00:00
- `USE_XNNPACK=0` will disable compiling with XNNPACK.
faster build instructions CONTRIBUTING.md (#109900)
Discovered this as I was building pytorch on a fresh g5.4x instance on aws, building flash attnetion was bricking my machine
```
Building wheel torch-2.2.0a0+gitd0c8e82
-- Building version 2.2.0a0+gitd0c8e82
cmake --build . --target install --config Release
[1/748] Building CUDA object caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu.o
FAILED: caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu.o
/opt/conda/envs/torchbench/bin/ccache /usr/local/cuda/bin/nvcc -forward-unknown-to-host-compiler -DAT_PER_OPERATOR_HEADERS -DHAVE_MALLOC_USABLE_SIZE=1 -DHAVE_MMAP=1 -DHAVE_SHM_OPEN=1 -DHAVE_SHM_UNLINK=1 -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DTORCH_CUDA_BUILD_MAIN_LIB -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_CUDA -DUSE_DISTRIBUTED -DUSE_EXPERIMENTAL_CUDNN_V8_API -DUSE_EXTERNAL_MZCRC -DUSE_FLASH_ATTENTION -DUSE_MEM_EFF_ATTENTION -DUSE_NCCL -DUSE_RPC -DUSE_TENSORPIPE -D_FILE_OFFSET_BITS=64 -Dtorch_cuda_EXPORTS -I/home/ubuntu/pytorch/build/aten/src -I/home/ubuntu/pytorch/aten/src -I/home/ubuntu/pytorch/build -I/home/ubuntu/pytorch -I/home/ubuntu/pytorch/cmake/../third_party/benchmark/include -I/home/ubuntu/pytorch/third_party/onnx -I/home/ubuntu/pytorch/build/third_party/onnx -I/home/ubuntu/pytorch/third_party/foxi -I/home/ubuntu/pytorch/build/third_party/foxi -I/home/ubuntu/pytorch/aten/src/THC -I/home/ubuntu/pytorch/aten/src/ATen/cuda -I/home/ubuntu/pytorch/aten/src/ATen/../../../third_party/cutlass/include -I/home/ubuntu/pytorch/build/caffe2/aten/src -I/home/ubuntu/pytorch/aten/src/ATen/.. -I/home/ubuntu/pytorch/build/nccl/include -I/home/ubuntu/pytorch/c10/cuda/../.. -I/home/ubuntu/pytorch/c10/.. -I/home/ubuntu/pytorch/third_party/tensorpipe -I/home/ubuntu/pytorch/build/third_party/tensorpipe -I/home/ubuntu/pytorch/third_party/tensorpipe/third_party/libnop/include -I/home/ubuntu/pytorch/torch/csrc/api -I/home/ubuntu/pytorch/torch/csrc/api/include -isystem /home/ubuntu/pytorch/build/third_party/gloo -isystem /home/ubuntu/pytorch/cmake/../third_party/gloo -isystem /home/ubuntu/pytorch/cmake/../third_party/tensorpipe/third_party/libuv/include -isystem /home/ubuntu/pytorch/cmake/../third_party/googletest/googlemock/include -isystem /home/ubuntu/pytorch/cmake/../third_party/googletest/googletest/include -isystem /home/ubuntu/pytorch/third_party/protobuf/src -isystem /home/ubuntu/pytorch/third_party/gemmlowp -isystem /home/ubuntu/pytorch/third_party/neon2sse -isystem /home/ubuntu/pytorch/third_party/XNNPACK/include -isystem /home/ubuntu/pytorch/third_party/ittapi/include -isystem /home/ubuntu/pytorch/cmake/../third_party/eigen -isystem /usr/local/cuda/include -isystem /home/ubuntu/pytorch/third_party/ideep/mkl-dnn/include/oneapi/dnnl -isystem /home/ubuntu/pytorch/third_party/ideep/include -isystem /home/ubuntu/pytorch/cmake/../third_party/cudnn_frontend/include -D_GLIBCXX_USE_CXX11_ABI=1 -Xfatbin -compress-all -DONNX_NAMESPACE=onnx_torch -gencode arch=compute_86,code=sm_86 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=set_but_not_used,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -Wno-deprecated-gpu-targets --expt-extended-lambda -DCUB_WRAPPED_NAMESPACE=at_cuda_detail -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -O3 -DNDEBUG -std=c++17 -Xcompiler=-fPIC -DTORCH_USE_LIBUV -DCAFFE2_USE_GLOO -Xcompiler=-Wall,-Wextra,-Wno-unused-parameter,-Wno-unused-function,-Wno-unused-result,-Wno-missing-field-initializers,-Wno-unknown-pragmas,-Wno-type-limits,-Wno-array-bounds,-Wno-unknown-pragmas,-Wno-strict-overflow,-Wno-strict-aliasing,-Wno-missing-braces,-Wno-maybe-uninitialized -MD -MT caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu.o -MF caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu.o.d -x cu -c /home/ubuntu/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu -o caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu.o
Killed
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109900
Approved by: https://github.com/drisspg
2023-09-22 22:39:47 +00:00
- `USE_FLASH_ATTENTION=0` and `USE_MEM_EFF_ATTENTION=0` will disable compiling flash attention and memory efficient kernels respectively
2017-04-18 19:39:26 +00:00
For example:
2021-09-08 02:00:18 +00:00
2018-12-20 20:20:42 +00:00
```bash
2020-03-03 20:40:57 +00:00
DEBUG=1 USE_DISTRIBUTED=0 USE_MKLDNN=0 USE_CUDA=0 BUILD_TEST=0 USE_FBGEMM=0 USE_NNPACK=0 USE_QNNPACK=0 USE_XNNPACK=0 python setup.py develop
2017-04-18 19:39:26 +00:00
```
2019-07-26 15:10:05 +00:00
For subsequent builds (i.e., when `build/CMakeCache.txt` exists), the build
options passed for the first time will persist; please run `ccmake build/` , run
`cmake-gui build/` , or directly edit `build/CMakeCache.txt` to adapt build
options.
2017-06-13 21:23:39 +00:00
2017-11-30 18:26:56 +00:00
### Code completion and IDE support
2018-06-14 16:36:50 +00:00
When using `python setup.py develop` , PyTorch will generate
2017-11-30 18:26:56 +00:00
a `compile_commands.json` file that can be used by many editors
to provide command completion and error highlighting for PyTorch's
C++ code. You need to `pip install ninja` to generate accurate
information for the code in `torch/csrc` . More information at:
- https://sarcasm.github.io/notes/dev/compilation-database.html
2019-12-03 00:42:41 +00:00
### Make no-op build fast
2017-06-13 21:23:39 +00:00
2017-11-30 18:26:56 +00:00
#### Use Ninja
2019-03-20 14:33:51 +00:00
By default, cmake will use its Makefile generator to generate your build
system. You can get faster builds if you install the ninja build system
with `pip install ninja` . If PyTorch was already built, you will need
to run `python setup.py clean` once after installing ninja for builds to
succeed.
2017-11-30 18:26:56 +00:00
faster build instructions CONTRIBUTING.md (#109900)
Discovered this as I was building pytorch on a fresh g5.4x instance on aws, building flash attnetion was bricking my machine
```
Building wheel torch-2.2.0a0+gitd0c8e82
-- Building version 2.2.0a0+gitd0c8e82
cmake --build . --target install --config Release
[1/748] Building CUDA object caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu.o
FAILED: caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu.o
/opt/conda/envs/torchbench/bin/ccache /usr/local/cuda/bin/nvcc -forward-unknown-to-host-compiler -DAT_PER_OPERATOR_HEADERS -DHAVE_MALLOC_USABLE_SIZE=1 -DHAVE_MMAP=1 -DHAVE_SHM_OPEN=1 -DHAVE_SHM_UNLINK=1 -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DTORCH_CUDA_BUILD_MAIN_LIB -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_CUDA -DUSE_DISTRIBUTED -DUSE_EXPERIMENTAL_CUDNN_V8_API -DUSE_EXTERNAL_MZCRC -DUSE_FLASH_ATTENTION -DUSE_MEM_EFF_ATTENTION -DUSE_NCCL -DUSE_RPC -DUSE_TENSORPIPE -D_FILE_OFFSET_BITS=64 -Dtorch_cuda_EXPORTS -I/home/ubuntu/pytorch/build/aten/src -I/home/ubuntu/pytorch/aten/src -I/home/ubuntu/pytorch/build -I/home/ubuntu/pytorch -I/home/ubuntu/pytorch/cmake/../third_party/benchmark/include -I/home/ubuntu/pytorch/third_party/onnx -I/home/ubuntu/pytorch/build/third_party/onnx -I/home/ubuntu/pytorch/third_party/foxi -I/home/ubuntu/pytorch/build/third_party/foxi -I/home/ubuntu/pytorch/aten/src/THC -I/home/ubuntu/pytorch/aten/src/ATen/cuda -I/home/ubuntu/pytorch/aten/src/ATen/../../../third_party/cutlass/include -I/home/ubuntu/pytorch/build/caffe2/aten/src -I/home/ubuntu/pytorch/aten/src/ATen/.. -I/home/ubuntu/pytorch/build/nccl/include -I/home/ubuntu/pytorch/c10/cuda/../.. -I/home/ubuntu/pytorch/c10/.. -I/home/ubuntu/pytorch/third_party/tensorpipe -I/home/ubuntu/pytorch/build/third_party/tensorpipe -I/home/ubuntu/pytorch/third_party/tensorpipe/third_party/libnop/include -I/home/ubuntu/pytorch/torch/csrc/api -I/home/ubuntu/pytorch/torch/csrc/api/include -isystem /home/ubuntu/pytorch/build/third_party/gloo -isystem /home/ubuntu/pytorch/cmake/../third_party/gloo -isystem /home/ubuntu/pytorch/cmake/../third_party/tensorpipe/third_party/libuv/include -isystem /home/ubuntu/pytorch/cmake/../third_party/googletest/googlemock/include -isystem /home/ubuntu/pytorch/cmake/../third_party/googletest/googletest/include -isystem /home/ubuntu/pytorch/third_party/protobuf/src -isystem /home/ubuntu/pytorch/third_party/gemmlowp -isystem /home/ubuntu/pytorch/third_party/neon2sse -isystem /home/ubuntu/pytorch/third_party/XNNPACK/include -isystem /home/ubuntu/pytorch/third_party/ittapi/include -isystem /home/ubuntu/pytorch/cmake/../third_party/eigen -isystem /usr/local/cuda/include -isystem /home/ubuntu/pytorch/third_party/ideep/mkl-dnn/include/oneapi/dnnl -isystem /home/ubuntu/pytorch/third_party/ideep/include -isystem /home/ubuntu/pytorch/cmake/../third_party/cudnn_frontend/include -D_GLIBCXX_USE_CXX11_ABI=1 -Xfatbin -compress-all -DONNX_NAMESPACE=onnx_torch -gencode arch=compute_86,code=sm_86 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=set_but_not_used,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -Wno-deprecated-gpu-targets --expt-extended-lambda -DCUB_WRAPPED_NAMESPACE=at_cuda_detail -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -O3 -DNDEBUG -std=c++17 -Xcompiler=-fPIC -DTORCH_USE_LIBUV -DCAFFE2_USE_GLOO -Xcompiler=-Wall,-Wextra,-Wno-unused-parameter,-Wno-unused-function,-Wno-unused-result,-Wno-missing-field-initializers,-Wno-unknown-pragmas,-Wno-type-limits,-Wno-array-bounds,-Wno-unknown-pragmas,-Wno-strict-overflow,-Wno-strict-aliasing,-Wno-missing-braces,-Wno-maybe-uninitialized -MD -MT caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu.o -MF caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu.o.d -x cu -c /home/ubuntu/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu -o caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu.o
Killed
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109900
Approved by: https://github.com/drisspg
2023-09-22 22:39:47 +00:00
Note: Make sure to use a machine with a larger number of CPU cores, this will significantly reduce your build times.
2017-11-30 18:26:56 +00:00
#### Use CCache
2021-08-24 00:39:45 +00:00
Even when dependencies are tracked with file modification, there are many
situations where files get rebuilt when a previous compilation was exactly the
same. Using ccache in a situation like this is a real time-saver.
2017-06-13 21:23:39 +00:00
2021-08-24 00:39:45 +00:00
Before building pytorch, install ccache from your package manager of choice:
2021-09-08 02:00:18 +00:00
2018-12-20 20:20:42 +00:00
```bash
2022-05-23 15:43:24 +00:00
conda install ccache -c conda-forge
2021-08-24 00:39:45 +00:00
sudo apt install ccache
sudo yum install ccache
brew install ccache
2017-04-18 19:39:26 +00:00
```
2021-08-24 00:39:45 +00:00
You may also find the default cache size in ccache is too small to be useful.
The cache sizes can be increased from the command line:
2019-03-27 02:56:39 +00:00
```bash
# config: cache dir is ~/.ccache, conf file ~/.ccache/ccache.conf
# max size of cache
ccache -M 25Gi # -M 0 for unlimited
# unlimited number of files
ccache -F 0
```
2019-10-09 15:48:06 +00:00
2021-08-24 00:39:45 +00:00
To check this is working, do two clean builds of pytorch in a row. The second
build should be substantially and noticeably faster than the first build. If
this doesn't seem to be the case, check the `CMAKE_<LANG>_COMPILER_LAUNCHER`
rules in `build/CMakeCache.txt` , where `<LANG>` is `C` , `CXX` and `CUDA` .
Each of these 3 variables should contain ccache, e.g.
2021-09-08 02:00:18 +00:00
2021-08-24 00:39:45 +00:00
```
//CXX compiler launcher
CMAKE_CXX_COMPILER_LAUNCHER:STRING=/usr/bin/ccache
```
2019-10-09 15:48:06 +00:00
2021-08-24 00:39:45 +00:00
If not, you can define these variables on the command line before invoking `setup.py` .
2021-09-08 02:00:18 +00:00
2019-10-09 15:48:06 +00:00
```bash
2021-08-24 00:39:45 +00:00
export CMAKE_C_COMPILER_LAUNCHER=ccache
export CMAKE_CXX_COMPILER_LAUNCHER=ccache
export CMAKE_CUDA_COMPILER_LAUNCHER=ccache
python setup.py develop
2019-10-09 15:48:06 +00:00
```
2019-06-04 22:31:01 +00:00
#### Use a faster linker
2021-09-08 02:00:18 +00:00
2019-08-22 02:14:52 +00:00
If you are editing a single file and rebuilding in a tight loop, the time spent
linking will dominate. The system linker available in most Linux distributions
2019-06-04 22:31:01 +00:00
(GNU `ld` ) is quite slow. Use a faster linker, like [lld ](https://lld.llvm.org/ ).
2021-08-25 21:49:06 +00:00
People on Mac, follow [this guide ](https://stackoverflow.com/questions/42730345/how-to-install-llvm-for-mac ) instead.
2019-08-22 02:14:52 +00:00
The easiest way to use `lld` this is download the
2019-06-04 22:31:01 +00:00
[latest LLVM binaries ](http://releases.llvm.org/download.html#8.0.0 ) and run:
2021-09-08 02:00:18 +00:00
```bash
2019-06-04 22:31:01 +00:00
ln -s /path/to/downloaded/ld.lld /usr/local/bin/ld
```
2019-03-27 02:56:39 +00:00
2021-08-17 17:11:05 +00:00
#### Use pre-compiled headers
Sometimes there's no way of getting around rebuilding lots of files, for example
editing `native_functions.yaml` usually means 1000+ files being rebuilt. If
you're using CMake newer than 3.16, you can enable pre-compiled headers by
setting `USE_PRECOMPILED_HEADERS=1` either on first setup, or in the
`CMakeCache.txt` file.
```sh
USE_PRECOMPILED_HEADERS=1 python setup.py develop
```
This adds a build step where the compiler takes `<ATen/ATen.h>` and essentially
2024-05-13 18:14:09 +00:00
dumps its internal AST to a file so the compiler can avoid repeating itself for
2021-08-17 17:11:05 +00:00
every `.cpp` file.
One caveat is that when enabled, this header gets included in every file by default.
Which may change what code is legal, for example:
- internal functions can never alias existing names in `<ATen/ATen.h>`
- names in `<ATen/ATen.h>` will work even if you don't explicitly include it.
2021-10-11 16:04:07 +00:00
#### Workaround for header dependency bug in nvcc
If re-building without modifying any files results in several CUDA files being
re-compiled, you may be running into an `nvcc` bug where header dependencies are
not converted to absolute paths before reporting it to the build system. This
makes `ninja` think one of the header files has been deleted, so it runs the
build again.
A compiler-wrapper to fix this is provided in `tools/nvcc_fix_deps.py` . You can use
this as a compiler launcher, similar to `ccache`
```bash
export CMAKE_CUDA_COMPILER_LAUNCHER="python;`pwd`/tools/nvcc_fix_deps.py;ccache"
python setup.py develop
```
[DevX] Add tool and doc on partial debug builds (#116521)
Turned command sequence mentioned in https://dev-discuss.pytorch.org/t/how-to-get-a-fast-debug-build/1597 and in various discussions into a tool that I use almost daily to debug crashes or correctness issues in the codebase
Essentially it allows one to turn this:
```
Process 87729 stopped
* thread #1, queue = 'com.apple.main-thread', stop reason = breakpoint 1.1
frame #0: 0x00000001023d55a8 libtorch_python.dylib`at::indexing::impl::applySelect(at::Tensor const&, long long, c10::SymInt, long long, c10::Device const&, std::__1::optional<c10::ArrayRef<c10::SymInt>> const&)
libtorch_python.dylib`at::indexing::impl::applySelect:
-> 0x1023d55a8 <+0>: sub sp, sp, #0xd0
0x1023d55ac <+4>: stp x24, x23, [sp, #0x90]
0x1023d55b0 <+8>: stp x22, x21, [sp, #0xa0]
0x1023d55b4 <+12>: stp x20, x19, [sp, #0xb0]
```
into this
```
Process 87741 stopped
* thread #1, queue = 'com.apple.main-thread', stop reason = breakpoint 1.1
frame #0: 0x00000001024e2628 libtorch_python.dylib`at::indexing::impl::applySelect(self=0x00000001004ee8a8, dim=0, index=(data_ = 3), real_dim=0, (null)=0x000000016fdfe535, self_sizes= Has Value=true ) at TensorIndexing.h:239:7
236 const at::Device& /*self_device*/,
237 const c10::optional<SymIntArrayRef>& self_sizes) {
238 // See NOTE [nested tensor size for indexing]
-> 239 if (self_sizes.has_value()) {
240 auto maybe_index = index.maybe_as_int();
241 if (maybe_index.has_value()) {
242 TORCH_CHECK_INDEX(
```
while retaining good performance for the rest of the codebase
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116521
Approved by: https://github.com/atalman
2023-12-29 05:15:35 +00:00
### Rebuild few files with debug information
While debugging a problem one often had to maintain a debug build in a separate folder.
But often only a few files needs to be rebuild with debug info to get a symbolicated backtrace or enable source debugging
One can easily solve this with the help of `tools/build_with_debinfo.py`
For example, suppose one wants to debug what is going on while tensor index is selected, which can be achieved by setting a breakpoint at `applySelect` function:
```
% lldb -o "b applySelect" -o "process launch" -- python3 -c "import torch;print(torch.rand(5)[3])"
(lldb) target create "python"
Current executable set to '/usr/bin/python3' (arm64).
(lldb) settings set -- target.run-args "-c" "import torch;print(torch.rand(5)[3])"
(lldb) b applySelect
Breakpoint 1: no locations (pending).
WARNING: Unable to resolve breakpoint to any actual locations.
(lldb) process launch
2 locations added to breakpoint 1
Process 87729 stopped
* thread #1 , queue = 'com.apple.main-thread', stop reason = breakpoint 1.1
frame #0: 0x00000001023d55a8 libtorch_python.dylib`at::indexing::impl::applySelect(at::Tensor const& , long long, c10::SymInt, long long, c10::Device const& , std::__1::optional< c10::ArrayRef < c10::SymInt > > const& )
libtorch_python.dylib`at::indexing::impl::applySelect:
-> 0x1023d55a8 < +0>: sub sp, sp, #0xd0
0x1023d55ac < +4>: stp x24, x23, [sp, #0x90 ]
0x1023d55b0 < +8>: stp x22, x21, [sp, #0xa0 ]
0x1023d55b4 < +12>: stp x20, x19, [sp, #0xb0 ]
Target 0: (python) stopped.
Process 87729 launched: '/usr/bin/python' (arm64)
```
Which is not very informative, but can be easily remedied by rebuilding `python_variable_indexing.cpp` with debug information
```
% ./tools/build_with_debinfo.py torch/csrc/autograd/python_variable_indexing.cpp
[1 / 2] Building caffe2/torch/CMakeFiles/torch_python.dir/csrc/autograd/python_variable_indexing.cpp.o
[2 / 2] Building lib/libtorch_python.dylib
```
And afterwards:
```
% lldb -o "b applySelect" -o "process launch" -- python3 -c "import torch;print(torch.rand(5)[3])"
(lldb) target create "python"
Current executable set to '/usr/bin/python3' (arm64).
(lldb) settings set -- target.run-args "-c" "import torch;print(torch.rand(5)[3])"
(lldb) b applySelect
Breakpoint 1: no locations (pending).
WARNING: Unable to resolve breakpoint to any actual locations.
(lldb) process launch
2 locations added to breakpoint 1
Process 87741 stopped
* thread #1 , queue = 'com.apple.main-thread', stop reason = breakpoint 1.1
frame #0: 0x00000001024e2628 libtorch_python.dylib`at::indexing::impl::applySelect(self=0x00000001004ee8a8, dim=0, index=(data_ = 3), real_dim=0, (null)=0x000000016fdfe535, self_sizes= Has Value=true ) at TensorIndexing.h:239:7
236 const at::Device& /*self_device*/,
2024-10-23 20:42:24 +00:00
237 const std::optional< SymIntArrayRef > & self_sizes) {
[DevX] Add tool and doc on partial debug builds (#116521)
Turned command sequence mentioned in https://dev-discuss.pytorch.org/t/how-to-get-a-fast-debug-build/1597 and in various discussions into a tool that I use almost daily to debug crashes or correctness issues in the codebase
Essentially it allows one to turn this:
```
Process 87729 stopped
* thread #1, queue = 'com.apple.main-thread', stop reason = breakpoint 1.1
frame #0: 0x00000001023d55a8 libtorch_python.dylib`at::indexing::impl::applySelect(at::Tensor const&, long long, c10::SymInt, long long, c10::Device const&, std::__1::optional<c10::ArrayRef<c10::SymInt>> const&)
libtorch_python.dylib`at::indexing::impl::applySelect:
-> 0x1023d55a8 <+0>: sub sp, sp, #0xd0
0x1023d55ac <+4>: stp x24, x23, [sp, #0x90]
0x1023d55b0 <+8>: stp x22, x21, [sp, #0xa0]
0x1023d55b4 <+12>: stp x20, x19, [sp, #0xb0]
```
into this
```
Process 87741 stopped
* thread #1, queue = 'com.apple.main-thread', stop reason = breakpoint 1.1
frame #0: 0x00000001024e2628 libtorch_python.dylib`at::indexing::impl::applySelect(self=0x00000001004ee8a8, dim=0, index=(data_ = 3), real_dim=0, (null)=0x000000016fdfe535, self_sizes= Has Value=true ) at TensorIndexing.h:239:7
236 const at::Device& /*self_device*/,
237 const c10::optional<SymIntArrayRef>& self_sizes) {
238 // See NOTE [nested tensor size for indexing]
-> 239 if (self_sizes.has_value()) {
240 auto maybe_index = index.maybe_as_int();
241 if (maybe_index.has_value()) {
242 TORCH_CHECK_INDEX(
```
while retaining good performance for the rest of the codebase
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116521
Approved by: https://github.com/atalman
2023-12-29 05:15:35 +00:00
238 // See NOTE [nested tensor size for indexing]
-> 239 if (self_sizes.has_value()) {
240 auto maybe_index = index.maybe_as_int();
241 if (maybe_index.has_value()) {
242 TORCH_CHECK_INDEX(
Target 0: (python) stopped.
Process 87741 launched: '/usr/bin/python3' (arm64)
```
Which is much more useful, isn't it?
2020-01-08 23:35:47 +00:00
### C++ frontend development tips
We have very extensive tests in the [test/cpp/api ](test/cpp/api ) folder. The
tests are a great way to see how certain components are intended to be used.
When compiling PyTorch from source, the test runner binary will be written to
`build/bin/test_api` . The tests use the [GoogleTest ](https://github.com/google/googletest/blob/master/googletest )
framework, which you can read up about to learn how to configure the test runner. When
submitting a new feature, we care very much that you write appropriate tests.
Please follow the lead of the other tests to see how to write a new test case.
gdb special command to print tensors (#54339)
Summary:
This is something which I wrote because it was useful during my debugging sessions, but I think it might be generally useful to other people as well so I took the liberty of proposing an official `pytorch-gdb` extension.
`pytorch-gdb` is a gdb script written in python. Currently, it contains only one command: `torch-tensor-repr`, which prints a human-readable repr of an `at::Tensor` object. Example:
```
Breakpoint 1, at::native::neg (self=...) at [...]/pytorch/aten/src/ATen/native/UnaryOps.cpp:520
520 Tensor neg(const Tensor& self) { return unary_op_impl(self, at::neg_out); }
(gdb) # the default repr of 'self' is not very useful
(gdb) p self
$1 = (const at::Tensor &) 0x7ffff72ed780: {impl_ = {target_ = 0x5555559df6e0}}
(gdb) torch-tensor-repr self
Python-level repr of self:
tensor([1., 2., 3., 4.], dtype=torch.float64)
```
The idea is that by having an official place where to put these things, `pytorch-gdb` will slowly grow other useful features and make the pytorch debugging experience nicer and faster.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54339
Reviewed By: bdhirsh
Differential Revision: D27253674
Pulled By: ezyang
fbshipit-source-id: dba219e126cc2fe66b2d26740f3a8e3b886e56f5
2021-03-23 19:27:51 +00:00
### GDB integration
If you are debugging pytorch inside GDB, you might be interested in
[pytorch-gdb ](tools/gdb/pytorch-gdb.py ). This script introduces some
pytorch-specific commands which you can use from the GDB prompt. In
particular, `torch-tensor-repr` prints a human-readable repr of an at::Tensor
object. Example of usage:
```
$ gdb python
2021-07-12 15:56:03 +00:00
GNU gdb (GDB) 9.2
gdb special command to print tensors (#54339)
Summary:
This is something which I wrote because it was useful during my debugging sessions, but I think it might be generally useful to other people as well so I took the liberty of proposing an official `pytorch-gdb` extension.
`pytorch-gdb` is a gdb script written in python. Currently, it contains only one command: `torch-tensor-repr`, which prints a human-readable repr of an `at::Tensor` object. Example:
```
Breakpoint 1, at::native::neg (self=...) at [...]/pytorch/aten/src/ATen/native/UnaryOps.cpp:520
520 Tensor neg(const Tensor& self) { return unary_op_impl(self, at::neg_out); }
(gdb) # the default repr of 'self' is not very useful
(gdb) p self
$1 = (const at::Tensor &) 0x7ffff72ed780: {impl_ = {target_ = 0x5555559df6e0}}
(gdb) torch-tensor-repr self
Python-level repr of self:
tensor([1., 2., 3., 4.], dtype=torch.float64)
```
The idea is that by having an official place where to put these things, `pytorch-gdb` will slowly grow other useful features and make the pytorch debugging experience nicer and faster.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54339
Reviewed By: bdhirsh
Differential Revision: D27253674
Pulled By: ezyang
fbshipit-source-id: dba219e126cc2fe66b2d26740f3a8e3b886e56f5
2021-03-23 19:27:51 +00:00
[...]
(gdb) # insert a breakpoint when we call .neg()
2021-07-12 15:56:03 +00:00
(gdb) break at::Tensor::neg
Function "at::Tensor::neg" not defined.
gdb special command to print tensors (#54339)
Summary:
This is something which I wrote because it was useful during my debugging sessions, but I think it might be generally useful to other people as well so I took the liberty of proposing an official `pytorch-gdb` extension.
`pytorch-gdb` is a gdb script written in python. Currently, it contains only one command: `torch-tensor-repr`, which prints a human-readable repr of an `at::Tensor` object. Example:
```
Breakpoint 1, at::native::neg (self=...) at [...]/pytorch/aten/src/ATen/native/UnaryOps.cpp:520
520 Tensor neg(const Tensor& self) { return unary_op_impl(self, at::neg_out); }
(gdb) # the default repr of 'self' is not very useful
(gdb) p self
$1 = (const at::Tensor &) 0x7ffff72ed780: {impl_ = {target_ = 0x5555559df6e0}}
(gdb) torch-tensor-repr self
Python-level repr of self:
tensor([1., 2., 3., 4.], dtype=torch.float64)
```
The idea is that by having an official place where to put these things, `pytorch-gdb` will slowly grow other useful features and make the pytorch debugging experience nicer and faster.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54339
Reviewed By: bdhirsh
Differential Revision: D27253674
Pulled By: ezyang
fbshipit-source-id: dba219e126cc2fe66b2d26740f3a8e3b886e56f5
2021-03-23 19:27:51 +00:00
Make breakpoint pending on future shared library load? (y or [n]) y
2021-07-12 15:56:03 +00:00
Breakpoint 1 (at::Tensor::neg) pending.
gdb special command to print tensors (#54339)
Summary:
This is something which I wrote because it was useful during my debugging sessions, but I think it might be generally useful to other people as well so I took the liberty of proposing an official `pytorch-gdb` extension.
`pytorch-gdb` is a gdb script written in python. Currently, it contains only one command: `torch-tensor-repr`, which prints a human-readable repr of an `at::Tensor` object. Example:
```
Breakpoint 1, at::native::neg (self=...) at [...]/pytorch/aten/src/ATen/native/UnaryOps.cpp:520
520 Tensor neg(const Tensor& self) { return unary_op_impl(self, at::neg_out); }
(gdb) # the default repr of 'self' is not very useful
(gdb) p self
$1 = (const at::Tensor &) 0x7ffff72ed780: {impl_ = {target_ = 0x5555559df6e0}}
(gdb) torch-tensor-repr self
Python-level repr of self:
tensor([1., 2., 3., 4.], dtype=torch.float64)
```
The idea is that by having an official place where to put these things, `pytorch-gdb` will slowly grow other useful features and make the pytorch debugging experience nicer and faster.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54339
Reviewed By: bdhirsh
Differential Revision: D27253674
Pulled By: ezyang
fbshipit-source-id: dba219e126cc2fe66b2d26740f3a8e3b886e56f5
2021-03-23 19:27:51 +00:00
(gdb) run
[...]
>>> import torch
>>> t = torch.tensor([1, 2, 3, 4], dtype=torch.float64)
>>> t
tensor([1., 2., 3., 4.], dtype=torch.float64)
>>> t.neg()
2021-07-12 15:56:03 +00:00
Thread 1 "python" hit Breakpoint 1, at::Tensor::neg (this=0x7ffb118a9c88) at aten/src/ATen/core/TensorBody.h:3295
3295 inline at::Tensor Tensor::neg() const {
(gdb) # the default repr of 'this' is not very useful
(gdb) p this
$1 = (const at::Tensor * const) 0x7ffb118a9c88
(gdb) p *this
$2 = {impl_ = {target_ = 0x55629b5cd330}}
(gdb) torch-tensor-repr *this
Python-level repr of *this:
gdb special command to print tensors (#54339)
Summary:
This is something which I wrote because it was useful during my debugging sessions, but I think it might be generally useful to other people as well so I took the liberty of proposing an official `pytorch-gdb` extension.
`pytorch-gdb` is a gdb script written in python. Currently, it contains only one command: `torch-tensor-repr`, which prints a human-readable repr of an `at::Tensor` object. Example:
```
Breakpoint 1, at::native::neg (self=...) at [...]/pytorch/aten/src/ATen/native/UnaryOps.cpp:520
520 Tensor neg(const Tensor& self) { return unary_op_impl(self, at::neg_out); }
(gdb) # the default repr of 'self' is not very useful
(gdb) p self
$1 = (const at::Tensor &) 0x7ffff72ed780: {impl_ = {target_ = 0x5555559df6e0}}
(gdb) torch-tensor-repr self
Python-level repr of self:
tensor([1., 2., 3., 4.], dtype=torch.float64)
```
The idea is that by having an official place where to put these things, `pytorch-gdb` will slowly grow other useful features and make the pytorch debugging experience nicer and faster.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54339
Reviewed By: bdhirsh
Differential Revision: D27253674
Pulled By: ezyang
fbshipit-source-id: dba219e126cc2fe66b2d26740f3a8e3b886e56f5
2021-03-23 19:27:51 +00:00
tensor([1., 2., 3., 4.], dtype=torch.float64)
```
GDB tries to automatically load `pytorch-gdb` thanks to the
2021-05-21 18:43:47 +00:00
[.gdbinit ](.gdbinit ) at the root of the pytorch repo. However, auto-loadings is disabled by default, because of security reasons:
gdb special command to print tensors (#54339)
Summary:
This is something which I wrote because it was useful during my debugging sessions, but I think it might be generally useful to other people as well so I took the liberty of proposing an official `pytorch-gdb` extension.
`pytorch-gdb` is a gdb script written in python. Currently, it contains only one command: `torch-tensor-repr`, which prints a human-readable repr of an `at::Tensor` object. Example:
```
Breakpoint 1, at::native::neg (self=...) at [...]/pytorch/aten/src/ATen/native/UnaryOps.cpp:520
520 Tensor neg(const Tensor& self) { return unary_op_impl(self, at::neg_out); }
(gdb) # the default repr of 'self' is not very useful
(gdb) p self
$1 = (const at::Tensor &) 0x7ffff72ed780: {impl_ = {target_ = 0x5555559df6e0}}
(gdb) torch-tensor-repr self
Python-level repr of self:
tensor([1., 2., 3., 4.], dtype=torch.float64)
```
The idea is that by having an official place where to put these things, `pytorch-gdb` will slowly grow other useful features and make the pytorch debugging experience nicer and faster.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54339
Reviewed By: bdhirsh
Differential Revision: D27253674
Pulled By: ezyang
fbshipit-source-id: dba219e126cc2fe66b2d26740f3a8e3b886e56f5
2021-03-23 19:27:51 +00:00
2021-09-08 02:00:18 +00:00
```bash
gdb special command to print tensors (#54339)
Summary:
This is something which I wrote because it was useful during my debugging sessions, but I think it might be generally useful to other people as well so I took the liberty of proposing an official `pytorch-gdb` extension.
`pytorch-gdb` is a gdb script written in python. Currently, it contains only one command: `torch-tensor-repr`, which prints a human-readable repr of an `at::Tensor` object. Example:
```
Breakpoint 1, at::native::neg (self=...) at [...]/pytorch/aten/src/ATen/native/UnaryOps.cpp:520
520 Tensor neg(const Tensor& self) { return unary_op_impl(self, at::neg_out); }
(gdb) # the default repr of 'self' is not very useful
(gdb) p self
$1 = (const at::Tensor &) 0x7ffff72ed780: {impl_ = {target_ = 0x5555559df6e0}}
(gdb) torch-tensor-repr self
Python-level repr of self:
tensor([1., 2., 3., 4.], dtype=torch.float64)
```
The idea is that by having an official place where to put these things, `pytorch-gdb` will slowly grow other useful features and make the pytorch debugging experience nicer and faster.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54339
Reviewed By: bdhirsh
Differential Revision: D27253674
Pulled By: ezyang
fbshipit-source-id: dba219e126cc2fe66b2d26740f3a8e3b886e56f5
2021-03-23 19:27:51 +00:00
$ gdb
warning: File "/path/to/pytorch/.gdbinit" auto-loading has been declined by your `auto-load safe-path' set to "$debugdir:$datadir/auto-load".
To enable execution of this file add
add-auto-load-safe-path /path/to/pytorch/.gdbinit
line to your configuration file "/home/YOUR-USERNAME/.gdbinit".
To completely disable this security protection add
set auto-load safe-path /
line to your configuration file "/home/YOUR-USERNAME/.gdbinit".
For more information about this security protection see the
"Auto-loading safe path" section in the GDB manual. E.g., run from the shell:
info "(gdb)Auto-loading safe path"
(gdb)
```
As gdb itself suggests, the best way to enable auto-loading of `pytorch-gdb`
is to add the following line to your `~/.gdbinit` (i.e., the `.gdbinit` file
which is in your home directory, **not** `/path/to/pytorch/.gdbinit` ):
2021-09-08 02:00:18 +00:00
```bash
gdb special command to print tensors (#54339)
Summary:
This is something which I wrote because it was useful during my debugging sessions, but I think it might be generally useful to other people as well so I took the liberty of proposing an official `pytorch-gdb` extension.
`pytorch-gdb` is a gdb script written in python. Currently, it contains only one command: `torch-tensor-repr`, which prints a human-readable repr of an `at::Tensor` object. Example:
```
Breakpoint 1, at::native::neg (self=...) at [...]/pytorch/aten/src/ATen/native/UnaryOps.cpp:520
520 Tensor neg(const Tensor& self) { return unary_op_impl(self, at::neg_out); }
(gdb) # the default repr of 'self' is not very useful
(gdb) p self
$1 = (const at::Tensor &) 0x7ffff72ed780: {impl_ = {target_ = 0x5555559df6e0}}
(gdb) torch-tensor-repr self
Python-level repr of self:
tensor([1., 2., 3., 4.], dtype=torch.float64)
```
The idea is that by having an official place where to put these things, `pytorch-gdb` will slowly grow other useful features and make the pytorch debugging experience nicer and faster.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54339
Reviewed By: bdhirsh
Differential Revision: D27253674
Pulled By: ezyang
fbshipit-source-id: dba219e126cc2fe66b2d26740f3a8e3b886e56f5
2021-03-23 19:27:51 +00:00
add-auto-load-safe-path /path/to/pytorch/.gdbinit
```
2021-09-23 05:51:44 +00:00
### C++ stacktraces
Set `TORCH_SHOW_CPP_STACKTRACES=1` to get the C++ stacktrace when an error occurs in Python.
2019-12-03 00:42:41 +00:00
## CUDA development tips
2017-11-02 19:35:18 +00:00
If you are working on the CUDA code, here are some useful CUDA debugging tips:
2018-06-14 16:36:50 +00:00
1. `CUDA_DEVICE_DEBUG=1` will enable CUDA device function debug symbols (`-g -G`).
This will be particularly helpful in debugging device code. However, it will
slow down the build process for about 50% (compared to only `DEBUG=1` ), so use wisely.
2017-12-01 18:22:46 +00:00
2. `cuda-gdb` and `cuda-memcheck` are your best CUDA debugging friends. Unlike`gdb`,
2017-11-02 19:35:18 +00:00
`cuda-gdb` can display actual values in a CUDA tensor (rather than all zeros).
2019-12-17 22:05:52 +00:00
3. CUDA supports a lot of C++11/14 features such as, `std::numeric_limits` , `std::nextafter` ,
2019-06-05 02:11:56 +00:00
`std::tuple` etc. in device code. Many of such features are possible because of the
[--expt-relaxed-constexpr ](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#constexpr-functions )
nvcc flag. There is a known [issue ](https://github.com/ROCm-Developer-Tools/HIP/issues/374 )
that ROCm errors out on device code, which uses such stl functions.
4. A good performance metric for a CUDA kernel is the
[Effective Memory Bandwidth ](https://devblogs.nvidia.com/how-implement-performance-metrics-cuda-cc/ ).
It is useful for you to measure this metric whenever you are writing/optimizing a CUDA
kernel. Following script shows how we can measure the effective bandwidth of CUDA `uniform_`
kernel.
```python
import torch
2021-08-13 16:49:15 +00:00
from torch.utils.benchmark import Timer
2019-06-05 02:11:56 +00:00
size = 128*512
nrep = 100
nbytes_read_write = 4 # this is number of bytes read + written by a kernel. Change this to fit your kernel.
for i in range(10):
2021-04-11 22:43:54 +00:00
a=torch.empty(size).cuda().uniform_()
2019-06-05 02:11:56 +00:00
torch.cuda.synchronize()
out = a.uniform_()
torch.cuda.synchronize()
2021-08-13 16:49:15 +00:00
t = Timer(stmt="a.uniform_()", globals=globals())
res = t.blocked_autorange()
timec = res.median
2019-06-05 02:11:56 +00:00
print("uniform, size, elements", size, "forward", timec, "bandwidth (GB/s)", size*(nbytes_read_write)*1e-9/timec)
size *=2
```
2017-11-02 19:35:18 +00:00
2021-08-13 16:49:15 +00:00
See more cuda development tips [here ](https://github.com/pytorch/pytorch/wiki/CUDA-basics )
2017-03-17 11:59:37 +00:00
2018-05-13 02:55:11 +00:00
## Windows development tips
2019-03-07 23:23:16 +00:00
For building from source on Windows, consult
2019-02-25 07:00:10 +00:00
[our documentation ](https://pytorch.org/docs/stable/notes/windows.html ) on it.
2018-05-13 02:55:11 +00:00
Occasionally, you will write a patch which works on Linux, but fails CI on Windows.
There are a few aspects in which MSVC (the Windows compiler toolchain we use) is stricter
than Linux, which are worth keeping in mind when fixing these problems.
1. Symbols are NOT exported by default on Windows; instead, you have to explicitly
mark a symbol as exported/imported in a header file with `__declspec(dllexport)` /
2018-12-20 20:20:42 +00:00
`__declspec(dllimport)` . We have codified this pattern into a set of macros
2020-12-18 18:53:11 +00:00
which follow the convention `*_API` , e.g., `TORCH_API` inside Caffe2, Aten and Torch.
2018-09-24 18:02:46 +00:00
(Every separate shared library needs a unique macro name, because symbol visibility
is on a per shared library basis. See c10/macros/Macros.h for more details.)
2018-06-14 16:36:50 +00:00
2018-05-13 02:55:11 +00:00
The upshot is if you see an "unresolved external" error in your Windows build, this
2018-12-20 20:20:42 +00:00
is probably because you forgot to mark a function with `*_API` . However, there is
2018-05-13 02:55:11 +00:00
one important counterexample to this principle: if you want a *templated* function
to be instantiated at the call site, do NOT mark it with `*_API` (if you do mark it,
you'll have to explicitly instantiate all of the specializations used by the call
sites.)
2. If you link against a library, this does not make its dependencies transitively
visible. You must explicitly specify a link dependency against every library whose
2018-12-20 20:20:42 +00:00
symbols you use. (This is different from Linux where in most environments,
2018-05-13 02:55:11 +00:00
transitive dependencies can be used to fulfill unresolved symbols.)
3. If you have a Windows box (we have a few on EC2 which you can request access to) and
2023-01-30 18:28:32 +00:00
you want to run the build, the easiest way is to just run `.ci/pytorch/win-build.sh` .
If you need to rebuild, run `REBUILD=1 .ci/pytorch/win-build.sh` (this will avoid
2018-08-27 14:02:39 +00:00
blowing away your Conda environment.)
2018-05-13 02:55:11 +00:00
Even if you don't know anything about MSVC, you can use cmake to build simple programs on
Windows; this can be helpful if you want to learn more about some peculiar linking behavior
2018-12-20 20:20:42 +00:00
by reproducing it on a small example. Here's a simple example cmake file that defines
2018-05-13 02:55:11 +00:00
two dynamic libraries, one linking with the other:
2018-12-20 20:20:42 +00:00
```CMake
2018-05-13 02:55:11 +00:00
project(myproject CXX)
2019-12-03 22:29:00 +00:00
set(CMAKE_CXX_STANDARD 14)
2018-05-13 02:55:11 +00:00
add_library(foo SHARED foo.cpp)
add_library(bar SHARED bar.cpp)
# NB: don't forget to __declspec(dllexport) at least one symbol from foo,
# otherwise foo.lib will not be created.
target_link_libraries(bar PUBLIC foo)
```
You can build it with:
2018-12-20 20:20:42 +00:00
```bash
2018-05-13 02:55:11 +00:00
mkdir build
cd build
cmake ..
cmake --build .
```
2018-08-27 14:02:39 +00:00
### Known MSVC (and MSVC with NVCC) bugs
The PyTorch codebase sometimes likes to use exciting C++ features, and
these exciting features lead to exciting bugs in Windows compilers.
To add insult to injury, the error messages will often not tell you
which line of code actually induced the erroring template instantiation.
2018-12-20 20:20:42 +00:00
We've found the most effective way to debug these problems is to
2018-08-27 14:02:39 +00:00
carefully read over diffs, keeping in mind known bugs in MSVC/NVCC.
Here are a few well known pitfalls and workarounds:
* This is not actually a bug per se, but in general, code generated by MSVC
is more sensitive to memory errors; you may have written some code
that does a use-after-free or stack overflows; on Linux the code
2018-12-20 20:20:42 +00:00
might work, but on Windows your program will crash. ASAN may not
2018-08-27 14:02:39 +00:00
catch all of these problems: stay vigilant to the possibility that
your crash is due to a real memory problem.
* `constexpr` generally works less well on MSVC.
* The idiom `static_assert(f() == f())` to test if `f` is constexpr
does not work; you'll get "error C2131: expression did not evaluate
2018-12-20 20:20:42 +00:00
to a constant". Don't use these asserts on Windows.
2018-11-22 07:04:42 +00:00
(Example: `c10/util/intrusive_ptr.h` )
2018-08-27 14:02:39 +00:00
* (NVCC) Code you access inside a `static_assert` will eagerly be
evaluated as if it were device code, and so you might get an error
that the code is "not accessible".
2018-12-20 20:20:42 +00:00
```cpp
2018-08-27 14:02:39 +00:00
class A {
static A singleton_;
static constexpr inline A* singleton() {
return &singleton_;
}
};
2018-12-13 02:11:03 +00:00
static_assert(std::is_same(A*, decltype(A::singleton()))::value, "hmm");
2018-08-27 14:02:39 +00:00
```
2018-12-20 20:20:42 +00:00
* The compiler will run out of heap space if you attempt to compile files that
are too large. Splitting such files into separate files helps.
2018-08-27 14:02:39 +00:00
(Example: `THTensorMath` , `THTensorMoreMath` , `THTensorEvenMoreMath` .)
2019-01-11 18:45:40 +00:00
* MSVC's preprocessor (but not the standard compiler) has a bug
where it incorrectly tokenizes raw string literals, ending when it sees a `"` .
This causes preprocessor tokens inside the literal like an`#endif` to be incorrectly
treated as preprocessor directives. See https://godbolt.org/z/eVTIJq as an example.
2019-08-29 16:41:09 +00:00
* Either MSVC or the Windows headers have a PURE macro defined and will replace
any occurrences of the PURE token in code with an empty string. This is why
we have AliasAnalysisKind::PURE_FUNCTION and not AliasAnalysisKind::PURE.
The same is likely true for other identifiers that we just didn't try to use yet.
2021-05-07 00:34:48 +00:00
### Building on legacy code and CUDA
CUDA, MSVC, and PyTorch versions are interdependent; please install matching versions from this table:
| CUDA version | Newest supported VS version | PyTorch version |
| ------------ | ------------------------------------------------------- | --------------- |
| 10.1 | Visual Studio 2019 (16.X) (`_MSC_VER` < 1930 ) | 1 . 3 . 0 ~ 1 . 7 . 0 |
| 10.2 | Visual Studio 2019 (16.X) (`_MSC_VER` < 1930 ) | 1 . 5 . 0 ~ 1 . 7 . 0 |
| 11.0 | Visual Studio 2019 (16.X) (`_MSC_VER` < 1930 ) | 1 . 7 . 0 |
Note: There's a [compilation issue ](https://github.com/oneapi-src/oneDNN/issues/812 ) in several Visual Studio 2019 versions since 16.7.1, so please make sure your Visual Studio 2019 version is not in 16.7.1 ~ 16.7.5
2020-02-03 19:36:34 +00:00
## Pre-commit tidy/linting hook
2018-12-12 06:15:20 +00:00
2022-04-07 19:01:30 +00:00
We use clang-tidy to perform additional
2019-03-07 23:23:16 +00:00
formatting and semantic checking of code. We provide a pre-commit git hook for
performing these checks, before a commit is created:
2018-12-12 06:15:20 +00:00
2018-12-20 20:20:42 +00:00
```bash
2021-08-19 23:46:31 +00:00
ln -s ../../tools/git-pre-commit .git/hooks/pre-commit
2018-12-12 06:15:20 +00:00
```
2022-04-07 19:01:30 +00:00
If you have already committed files and
2020-12-23 17:15:44 +00:00
CI reports `flake8` errors, you can run the check locally in your PR branch with:
```bash
2023-04-17 01:48:14 +00:00
flake8 $(git diff --name-only $(git merge-base --fork-point main))
2020-12-23 17:15:44 +00:00
```
2022-04-07 19:01:30 +00:00
You'll need to install an appropriately configured flake8; see
[Lint as you type ](https://github.com/pytorch/pytorch/wiki/Lint-as-you-type )
for documentation on how to do this.
Fix the code so that no errors are reported when you re-run the above check again,
2020-12-23 17:15:44 +00:00
and then commit the fix.
2020-02-03 19:36:34 +00:00
## Building PyTorch with ASAN
2019-08-22 02:14:52 +00:00
[ASAN ](https://github.com/google/sanitizers/wiki/AddressSanitizer ) is very
useful for debugging memory errors in C++. We run it in CI, but here's how to
get the same thing to run on your local machine.
First, install LLVM 8. The easiest way is to get [prebuilt
binaries](http://releases.llvm.org/download.html#8.0.0) and extract them to
folder (later called `$LLVM_ROOT` ).
Then set up the appropriate scripts. You can put this in your `.bashrc` :
2021-09-08 02:00:18 +00:00
```bash
2019-08-22 02:14:52 +00:00
LLVM_ROOT=< wherever your llvm install is >
PYTORCH_ROOT=< wherever your pytorch checkout is >
LIBASAN_RT="$LLVM_ROOT/lib/clang/8.0.0/lib/linux/libclang_rt.asan-x86_64.so"
build_with_asan()
{
LD_PRELOAD=${LIBASAN_RT} \
CC="$LLVM_ROOT/bin/clang" \
CXX="$LLVM_ROOT/bin/clang++" \
LDSHARED="clang --shared" \
LDFLAGS="-stdlib=libstdc++" \
CFLAGS="-fsanitize=address -fno-sanitize-recover=all -shared-libasan -pthread" \
CXX_FLAGS="-pthread" \
2024-05-13 18:37:54 +00:00
USE_CUDA=0 USE_OPENMP=0 USE_DISTRIBUTED=0 DEBUG=1 \
2019-08-22 02:14:52 +00:00
python setup.py develop
}
run_with_asan()
{
LD_PRELOAD=${LIBASAN_RT} $@
}
# you can look at build-asan.sh to find the latest options the CI uses
export ASAN_OPTIONS=detect_leaks=0:symbolize=1:strict_init_order=true
export UBSAN_OPTIONS=print_stacktrace=1:suppressions=$PYTORCH_ROOT/ubsan.supp
export ASAN_SYMBOLIZER_PATH=$LLVM_ROOT/bin/llvm-symbolizer
```
Then you can use the scripts like:
```
suo-devfair ~/pytorch ❯ build_with_asan
suo-devfair ~/pytorch ❯ run_with_asan python test/test_jit.py
```
2020-02-03 19:36:34 +00:00
### Getting `ccache` to work
2019-08-22 02:14:52 +00:00
The scripts above specify the `clang` and `clang++` binaries directly, which
bypasses `ccache` . Here's how to get `ccache` to work:
1. Make sure the ccache symlinks for `clang` and `clang++` are set up (see
CONTRIBUTING.md)
2. Make sure `$LLVM_ROOT/bin` is available on your `$PATH` .
3. Change the `CC` and `CXX` variables in `build_with_asan()` to point
directly to `clang` and `clang++` .
2020-02-03 19:36:34 +00:00
### Why this stuff with `LD_PRELOAD` and `LIBASAN_RT`?
2019-08-22 02:14:52 +00:00
The “standard” workflow for ASAN assumes you have a standalone binary:
1. Recompile your binary with `-fsanitize=address` .
2. Run the binary, and ASAN will report whatever errors it find.
Unfortunately, PyTorch is a distributed as a shared library that is loaded by
a third-party executable (Python). It’ s too much of a hassle to recompile all
of Python every time we want to use ASAN. Luckily, the ASAN folks have a
workaround for cases like this:
1. Recompile your library with `-fsanitize=address -shared-libasan` . The
extra `-shared-libasan` tells the compiler to ask for the shared ASAN
runtime library.
2. Use `LD_PRELOAD` to tell the dynamic linker to load the ASAN runtime
library before anything else.
More information can be found
[here ](https://github.com/google/sanitizers/wiki/AddressSanitizerAsDso ).
2020-02-03 19:36:34 +00:00
### Why LD_PRELOAD in the build function?
2019-08-22 02:14:52 +00:00
We need `LD_PRELOAD` because there is a cmake check that ensures that a
simple program builds and runs. If we are building with ASAN as a shared
library, we need to `LD_PRELOAD` the runtime library, otherwise there will
dynamic linker errors and the check will fail.
We don’ t actually need either of these if we fix the cmake checks.
2020-02-03 19:36:34 +00:00
### Why no leak detection?
2019-08-22 02:14:52 +00:00
Python leaks a lot of memory. Possibly we could configure a suppression file,
but we haven’ t gotten around to it.
2018-03-31 18:33:01 +00:00
## Caffe2 notes
2018-12-20 20:20:42 +00:00
In 2018, we merged Caffe2 into the PyTorch source repository. While the
2018-03-31 18:33:01 +00:00
steady state aspiration is that Caffe2 and PyTorch share code freely,
in the meantime there will be some separation.
There are a few "unusual" directories which, for historical reasons,
2018-12-20 20:20:42 +00:00
are Caffe2/PyTorch specific. Here they are:
2018-03-31 18:33:01 +00:00
- `CMakeLists.txt` , `Makefile` , `binaries` , `cmake` , `conda` , `modules` ,
2018-12-20 20:20:42 +00:00
`scripts` are Caffe2-specific. Don't put PyTorch code in them without
2018-03-31 18:33:01 +00:00
extra coordination.
- `mypy*` , `requirements.txt` , `setup.py` , `test` , `tools` are
2018-12-20 20:20:42 +00:00
PyTorch-specific. Don't put Caffe2 code in them without extra
2018-03-31 18:33:01 +00:00
coordination.
2020-04-14 00:37:30 +00:00
## CI failure tips
Once you submit a PR or push a new commit to a branch that is in
an active PR, CI jobs will be run automatically. Some of these may
fail and you will need to find out why, by looking at the logs.
2022-04-01 21:12:33 +00:00
Fairly often, a CI failure might be unrelated to your changes. You can
2023-05-18 00:11:43 +00:00
confirm by going to our [HUD ](https://hud.pytorch.org ) and seeing if the CI job
2022-04-01 21:12:33 +00:00
is failing upstream already. In this case, you
2021-03-30 18:46:10 +00:00
can usually ignore the failure. See [the following
subsection](#which-commit-is-used-in-ci) for more details.
2020-04-14 00:37:30 +00:00
Some failures might be related to specific hardware or environment
2022-04-01 21:12:33 +00:00
configurations. In this case, if you're a Meta employee, you can ssh into
the job's session to perform manual debugging following the instructions in
our [CI wiki ](https://github.com/pytorch/pytorch/wiki/Debugging-using-with-ssh-for-Github-Actions ).
2021-03-30 18:46:10 +00:00
### Which commit is used in CI?
2023-04-17 01:48:14 +00:00
For CI run on `main` , this repository is checked out for a given `main`
2022-11-07 17:38:42 +00:00
commit, and CI is run on that commit (there isn't really any other choice).
For PRs, however, it's a bit more complicated. Consider this commit graph, where
2023-04-17 01:48:14 +00:00
`main` is at commit `A` , and the branch for PR #42 (just a placeholder) is at
2021-03-30 18:46:10 +00:00
commit `B` :
```
o---o---B (refs/pull/42/head)
/ \
/ C (refs/pull/42/merge)
/ /
2023-04-17 01:48:14 +00:00
---o---o---o---A (merge-destination) - usually main
2021-03-30 18:46:10 +00:00
```
There are two possible choices for which commit to use:
1. Checkout commit `B` , the head of the PR (manually committed by the PR
author).
2. Checkout commit `C` , the hypothetical result of what would happen if the PR
2024-05-13 18:14:09 +00:00
were merged into its destination (usually `main` ).
2022-11-07 17:38:42 +00:00
For all practical purposes, most people can think of the commit being used as
commit `B` (choice **1** ).
However, if workflow files (which govern CI behavior) were modified (either by your PR or since dev branch were created ) there's
a nuance to know about:
The workflow files themselves get taken from checkpoint `C` , the merger of your
2023-04-17 01:48:14 +00:00
PR and the `main` branch. But only the workflow files get taken from that merged
2022-11-07 17:38:42 +00:00
checkpoint. Everything else (tests, code, etc) all get taken directly from your
PR's commit (commit `B` ). Please note, this scenario would never affect PRs authored by `ghstack` as they would not automatically ingest the updates from default branch.
2022-05-20 19:44:38 +00:00
## Dev Infra Office Hours
[Dev Infra Office Hours ](https://github.com/pytorch/pytorch/wiki/Dev-Infra-Office-Hours ) are hosted every Friday to answer any questions regarding developer experience, Green HUD, and CI.