Fixes#10536
Reattempt of #61467. Thank you so much to @mskoh52 for your excellent work!
As I was trying to create a more efficient LLM data collator, I realized that `pad_sequence` only supports right padding, even though left padding is a very common format for LLMs, like Llama and Mistral.
The proposed alternative implementation was to use multiple flips, which tends to be 1.5x-2x slower. Instead we can add a [`padding_side` parameter as there is for for Hugging Face tokenizers](9d6c0641c4/src/transformers/tokenization_utils_base.py (L1565)), which requires only a very small change in the C++ code.
Here are the benchmarks of the new implementation!
`float32`:

`bool`:

Code:
```python
from __future__ import annotations
import random
import time
from typing import Literal
import numpy as np
import torch
def pad_sequence_with_flips(
sequences: list[torch.Tensor],
batch_first: bool = False,
padding_value: int | float | bool = 0.0,
padding_side: Literal["left", "right"] | str = "left",
) -> torch.Tensor:
if padding_side == 'right':
padded_sequence = torch._C._nn.pad_sequence([t.flatten() for t in sequences], batch_first=batch_first, padding_value=padding_value)
elif padding_side=='left':
padded_sequence = torch._C._nn.pad_sequence([t.flatten().flip(0) for t in sequences], batch_first=batch_first, padding_value=padding_value) # pyright: ignore[reportArgumentType]
padded_sequence = padded_sequence.flip(int(batch_first))
else:
raise ValueError(f"padding_side should be either 'right' or 'left', but got {padding_side}")
return padded_sequence
sequence_lengths: list[int] = []
flip_left_pad_times: list[float] = []
flip_left_pad_times_std: list[float] = []
left_pad_times: list[float] = []
left_pad_times_std: list[float] = []
RUNS_PER_LOOP: int = 100
for i in range(1, 7):
sequence_length = i * int(1e6) // 6
sequence_lengths.append(sequence_length)
sequences = [torch.randint(0, 2, (random.randint(1, sequence_length),), dtype=torch.bool) for _ in range(64)]
inner_left_pad_times: list[float] = []
inner_right_pad_times: list[float] = []
inner_flip_left_pad_times: list[float] = []
inner_flip_right_pad_times: list[float] = []
for _ in range(RUNS_PER_LOOP):
start = time.perf_counter()
torch._C._nn.pad_sequence(sequences, batch_first=True, padding_value=False, padding_side="left")
end = time.perf_counter()
inner_left_pad_times.append(end - start)
start = time.perf_counter()
pad_sequence_with_flips(sequences, batch_first=True, padding_value=False, padding_side="left")
end = time.perf_counter()
inner_flip_left_pad_times.append(end - start)
left_pad_times.append(sum(inner_left_pad_times) / len(inner_left_pad_times))
left_pad_times_std.append(np.std(inner_left_pad_times))
flip_left_pad_times.append(sum(inner_flip_left_pad_times) / len(inner_flip_left_pad_times))
flip_left_pad_times_std.append(np.std(inner_flip_left_pad_times))
print(f"Sequence Length: {sequence_length}, Left Pad Time: {left_pad_times[-1]}, Left with Flips Pad Time: {flip_left_pad_times[-1]}")
import matplotlib.pyplot as plt
plt.plot(sequence_lengths, left_pad_times, label="new pad_sequence left")
plt.scatter(sequence_lengths, left_pad_times)
plt.errorbar(sequence_lengths, left_pad_times, yerr=left_pad_times_std, linestyle='None', marker='^')
plt.plot(sequence_lengths, flip_left_pad_times, label="old pad_sequence left (2 flips)")
plt.scatter(sequence_lengths, flip_left_pad_times)
plt.errorbar(sequence_lengths, flip_left_pad_times, yerr=flip_left_pad_times_std, linestyle='None', marker='^')
plt.xlabel("Sequence Length")
plt.ylabel("Time (s)")
plt.legend(loc="upper right")
# Sequence Length: 166666, Left Pad Time: 0.06147645162009212, Left with Flips Pad Time: 0.09842291727001794
# Sequence Length: 333333, Left Pad Time: 0.08933195920990329, Left with Flips Pad Time: 0.15597836187991562
# Sequence Length: 500000, Left Pad Time: 0.08863158334006585, Left with Flips Pad Time: 0.15224887342999863
# Sequence Length: 666666, Left Pad Time: 0.10524682551997103, Left with Flips Pad Time: 0.18177212480995877
# Sequence Length: 833333, Left Pad Time: 0.11801802741003485, Left with Flips Pad Time: 0.20821274195001024
# Sequence Length: 1000000, Left Pad Time: 0.131894061660023, Left with Flips Pad Time: 0.23223503091008751
```
Co-authored-by: mskoh52 <mskoh52@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131884
Approved by: https://github.com/ezyang
This PR does 3 things:
1. Adds a copy-free strided->jagged layout conversion for NT
2. Adds a copy-free jagged->strided layout conversion for NT
3. Modifies and expands the .to() API to support the layout argument for the specific case of NT layout conversion.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115749
Approved by: https://github.com/jbschlosser
Fixes#132290
This PR attempts a more invasive / complete solution than the one from #132338, which removes immediate tensor fields from the `tensor_dict` copy stored in node meta. The approach taken here is to store only those fields of the `tensor_dict` which are absolutely utilized somewhere else.
So far, this appears to be limited to:
* `_dynamo_static_input_type`
* `tag` (at least in the tests). Discussion at #94080 appears to indicate this is depended on for export
(CI may point out more)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132805
Approved by: https://github.com/mlazos
The goal of this PR is to avoid stack overflow when we create extremely long chains of thunks, and then evaluate them (e.g., as occurs if you sum(long list of symint)). The basic idea behind this PR is to only thunkify proxies if they're being created in places where they may or may not be used--crucially, symint operations that occur in user code we are tracing are eagerly placed into the graph, even if they may eventually be dead.
I annotated the PR with explanation of changes.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132421
Approved by: https://github.com/Skylion007, https://github.com/zou3519
ghstack dependencies: #132674, #132675
Previously, when we slice out a submesh from a mesh, we assign the mesh as the parent mesh of the submesh. In this case, when we have a 3D mesh topology, the parent mesh of a 1D mesh sliced out from the 3D mesh is different from the parent mesh of the same 1D mesh sliced out from the 2D submesh of the 3D mesh. For example:
```
mesh_3d = init_device_mesh("cuda", (2,2,2), ("dim0", "dim1", "dim2"))
mesh_dim0 = mesh_3d["dim0"]
mesh_2d = mesh_2d["dim0", "dim1"]
mesh_dim0_2 = mesh_2d["dim0_2"]
# This would evaluate to be True
print(_mesh_resources.get_parent_mesh(mesh_dim0) != _mesh_resources.get_parent_mesh(mesh_dim0))
```
We can always reconstruct the mesh needed from the mesh dim names, as long as two dims come from the same root. For simplicity, we do not see the necessity of building a tree structure to represent child-parent relationship. Therefore, we are replacing the parent mesh concept with a root mesh concept in `_MeshEnv` so we would have:
```
mesh_3d = init_device_mesh("cuda", (2,2,2), ("dim0", "dim1", "dim2"))
mesh_dim0 = mesh_3d["dim0"]
mesh_2d = mesh_2d["dim0", "dim1"]
mesh_dim0_2 = mesh_2d["dim0_2"]
# This would evaluate to be True
print(_mesh_resources.get_root_mesh(mesh_dim0) == _mesh_resources.get_root_mesh(mesh_dim0))
```
With this change, we will have two types of meshes in an environment.
1. `device_mesh != _mesh_resources.get_root_mesh(device_mesh)` means that the device_mesh is created by slicing.
2. `device_mesh == _mesh_resources.get_root_mesh(device_mesh)` means that the device_mesh is a root mesh not created through slicing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132339
Approved by: https://github.com/wanchaol
ghstack dependencies: #132310, #132311
Summary: We observe the stack mpde can be transformed to cat node to elimiate split nodes, which could further enable the unbind cat optimization, thus we add a more advanced pattern to do the graph transformation
Test Plan:
# unit test
```
CUDA_VISIBLE_DEVICES=3 OC_CAUSE=1 buck2 test //caffe2/test/inductor:split_cat_fx_passes
```
Buck UI: https://www.internalfb.com/buck2/de6c1cda-3d74-4a30-8980-7b209b6fe5dc
Test UI: https://www.internalfb.com/intern/testinfra/testrun/12103424042268125
Network: Up: 485KiB Down: 728KiB (reSessionID-2f2c01c3-79bb-4e37-b5be-fb77ec09b264)
Jobs completed: 29. Time elapsed: 5:19.8s.
Cache hits: 0%. Commands: 4 (cached: 0, remote: 0, local: 4)
Tests finished: Pass 9. Fail 0. Fatal 0. Skip 1. Build failure 0
# benchmark
```
CUDA_VISIBLE_DEVICES=3 OC_CAUSE=1 buck2 run mode/opt //scripts/jackiexu0313/pt2:local_model_with_pt2 -- --test_mode batch-split --model_type "ig_ctr" --flow_id 584880697
```
P1503698962
before and after graph transformation
https://www.internalfb.com/intern/diffing/?paste_number=1504050718
Differential Revision: D60411560
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132542
Approved by: https://github.com/jackiexu1992
Summary:
- We add Inductor logs for what tensors we tried to reinplace, what
tensors we were unable to reinplace, and of those tensors, which of
those might be bugs (the "missed reinplacing opportunities"). You can
tell this by reading the Inductor output graph but the logs make it
easier to figure out.
- Add a dynamo_compile counter for missed reinplacing opportunities. The
goal is to see how widespread existing problems (if any) are. We've had
trouble getting all of the edge cases for the reinplacing pass; the
counter will help us hunt down issues.
Test Plan:
- tested locally
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132758
Approved by: https://github.com/eellison
Summary:
- make default DCE pass check schema,
- need to rebase onto https://github.com/pytorch/pytorch/pull/131651 after it's in phabricator (for now the change is manually added).
- mark Proxy dump as NotImplemented for better error msg
- Remove Proxy from tensors when dumping models, as Proxy cannot be dumped.
More details in https://docs.google.com/document/d/1G5vmTXjzxoyVGRI2kpA1gQukK_Glyg2NrE0Oh6Nlg9A/edit?usp=sharing.
Test Plan:
CI
```
- buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r qat_conv2d
- test_export.py
- buck2 run 'fbcode//mode/dev-nosan' fbcode//modai/test:test_modai -- -r test_qat_stinson_htp_export
- buck2 run 'fbcode//mode/dev-nosan' fbcode//vizard_projects/ml_depth/tests:test_model -- -r test_qat_model_et
- buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:fx -- -r dce
- buck2 run 'fbcode//mode/dev-nosan' fbcode//bolt/nn/executorch/backends/tests:qnn_test -- -r test_qat_bias=False,use_3d_input=False
- buck2 run 'fbcode//mode/dev-nosan' fbcode//bolt/nn/executorch/backends/tests:qnn_test -- -r test_qat_bias=True,use_3d_input=False
- buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_fold_bn_erases_bn_node
```
Reviewed By: angelayi
Differential Revision: D60319175
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132764
Approved by: https://github.com/angelayi
More context in [#132471](https://github.com/pytorch/pytorch/issues/132471) and https://github.com/pytorch/pytorch/issues/132366.
TLDR:
When cuda is available and users move tensors to cuda, we cannot really reuse the default pg if default pg is gloo, as lots of collectives are not supported on gloo for cuda tensors. For example, `dtensor.full_tensor()` would result in a mysterious SIGTERM when all_gather a cuda tensor using gloo. Without the change in this PR, users would have to know the context and explicitly move the cuda tensor to cpu before invoking most collectives, which I think is not so ideal UX.
Therefore, given most collectives are not supported on gloo for cuda tensors, we should init a new pg if the default pg is gloo when torch.cuda.is_available() and device_type is cuda.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132709
Approved by: https://github.com/awgu, https://github.com/wanchaol
This PR makes sure all current tests in the sparsity export test suite pass. Note that there will probably be anecdotal cases that need fixing after this, but the general idea of preserving sparsity metadata has been completed.
Fixes: https://github.com/pytorch/pytorch/issues/117188
```
$ PYTORCH_TEST_WITH_DYNAMO=0 python test/export/test_sparse.py ........................................................................................................................................................
----------------------------------------------------------------------
Ran 152 tests
OK
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132690
Approved by: https://github.com/ezyang
Bumps [rexml](https://github.com/ruby/rexml) from 3.2.8 to 3.3.3.
<details>
<summary>Release notes</summary>
<p><em>Sourced from <a href="https://github.com/ruby/rexml/releases">rexml's releases</a>.</em></p>
<blockquote>
<h2>REXML 3.3.3 - 2024-08-01</h2>
<h3>Improvements</h3>
<ul>
<li>
<p>Added support for detecting invalid XML that has unsupported
content before root element</p>
<ul>
<li><a href="https://redirect.github.com/ruby/rexml/issues/184">GH-184</a></li>
<li>Patch by NAITOH Jun.</li>
</ul>
</li>
<li>
<p>Added support for <code>REXML::Security.entity_expansion_limit=</code> and
<code>REXML::Security.entity_expansion_text_limit=</code> in SAX2 and pull
parsers</p>
<ul>
<li><a href="https://redirect.github.com/ruby/rexml/issues/187">GH-187</a></li>
<li>Patch by NAITOH Jun.</li>
</ul>
</li>
<li>
<p>Added more tests for invalid XMLs.</p>
<ul>
<li><a href="https://redirect.github.com/ruby/rexml/issues/183">GH-183</a></li>
<li>Patch by Watson.</li>
</ul>
</li>
<li>
<p>Added more performance tests.</p>
<ul>
<li>Patch by Watson.</li>
</ul>
</li>
<li>
<p>Improved parse performance.</p>
<ul>
<li><a href="https://redirect.github.com/ruby/rexml/issues/186">GH-186</a></li>
<li>Patch by tomoya ishida.</li>
</ul>
</li>
</ul>
<h3>Thanks</h3>
<ul>
<li>
<p>NAITOH Jun</p>
</li>
<li>
<p>Watson</p>
</li>
<li>
<p>tomoya ishida</p>
</li>
</ul>
<h2>REXML 3.3.2 - 2024-07-16</h2>
<h3>Improvements</h3>
<ul>
<li>
<p>Improved parse performance.</p>
<ul>
<li><a href="https://redirect.github.com/ruby/rexml/issues/160">GH-160</a></li>
<li>Patch by NAITOH Jun.</li>
</ul>
</li>
<li>
<p>Improved parse performance.</p>
<ul>
<li><a href="https://redirect.github.com/ruby/rexml/issues/169">GH-169</a></li>
<li><a href="https://redirect.github.com/ruby/rexml/issues/170">GH-170</a></li>
<li><a href="https://redirect.github.com/ruby/rexml/issues/171">GH-171</a></li>
<li><a href="https://redirect.github.com/ruby/rexml/issues/172">GH-172</a></li>
<li><a href="https://redirect.github.com/ruby/rexml/issues/173">GH-173</a></li>
<li><a href="https://redirect.github.com/ruby/rexml/issues/174">GH-174</a></li>
<li><a href="https://redirect.github.com/ruby/rexml/issues/175">GH-175</a></li>
<li><a href="https://redirect.github.com/ruby/rexml/issues/176">GH-176</a></li>
</ul>
</li>
</ul>
<!-- raw HTML omitted -->
</blockquote>
<p>... (truncated)</p>
</details>
<details>
<summary>Changelog</summary>
<p><em>Sourced from <a href="https://github.com/ruby/rexml/blob/master/NEWS.md">rexml's changelog</a>.</em></p>
<blockquote>
<h2>3.3.3 - 2024-08-01 {#version-3-3-3}</h2>
<h3>Improvements</h3>
<ul>
<li>
<p>Added support for detecting invalid XML that has unsupported
content before root element</p>
<ul>
<li><a href="https://redirect.github.com/ruby/rexml/issues/184">GH-184</a></li>
<li>Patch by NAITOH Jun.</li>
</ul>
</li>
<li>
<p>Added support for <code>REXML::Security.entity_expansion_limit=</code> and
<code>REXML::Security.entity_expansion_text_limit=</code> in SAX2 and pull
parsers</p>
<ul>
<li><a href="https://redirect.github.com/ruby/rexml/issues/187">GH-187</a></li>
<li>Patch by NAITOH Jun.</li>
</ul>
</li>
<li>
<p>Added more tests for invalid XMLs.</p>
<ul>
<li><a href="https://redirect.github.com/ruby/rexml/issues/183">GH-183</a></li>
<li>Patch by Watson.</li>
</ul>
</li>
<li>
<p>Added more performance tests.</p>
<ul>
<li>Patch by Watson.</li>
</ul>
</li>
<li>
<p>Improved parse performance.</p>
<ul>
<li><a href="https://redirect.github.com/ruby/rexml/issues/186">GH-186</a></li>
<li>Patch by tomoya ishida.</li>
</ul>
</li>
</ul>
<h3>Thanks</h3>
<ul>
<li>
<p>NAITOH Jun</p>
</li>
<li>
<p>Watson</p>
</li>
<li>
<p>tomoya ishida</p>
</li>
</ul>
<h2>3.3.2 - 2024-07-16 {#version-3-3-2}</h2>
<h3>Improvements</h3>
<ul>
<li>
<p>Improved parse performance.</p>
<ul>
<li><a href="https://redirect.github.com/ruby/rexml/issues/160">GH-160</a></li>
<li>Patch by NAITOH Jun.</li>
</ul>
</li>
<li>
<p>Improved parse performance.</p>
<ul>
<li><a href="https://redirect.github.com/ruby/rexml/issues/169">GH-169</a></li>
<li><a href="https://redirect.github.com/ruby/rexml/issues/170">GH-170</a></li>
<li><a href="https://redirect.github.com/ruby/rexml/issues/171">GH-171</a></li>
<li><a href="https://redirect.github.com/ruby/rexml/issues/172">GH-172</a></li>
<li><a href="https://redirect.github.com/ruby/rexml/issues/173">GH-173</a></li>
<li><a href="https://redirect.github.com/ruby/rexml/issues/174">GH-174</a></li>
<li><a href="https://redirect.github.com/ruby/rexml/issues/175">GH-175</a></li>
</ul>
</li>
</ul>
<!-- raw HTML omitted -->
</blockquote>
<p>... (truncated)</p>
</details>
<details>
<summary>Commits</summary>
<ul>
<li><a href="e4a067e112"><code>e4a067e</code></a> Add 3.3.3 entry</li>
<li><a href="17ff3e7874"><code>17ff3e7</code></a> test: add a performance test for attribute list declaration</li>
<li><a href="be86b3de0a"><code>be86b3d</code></a> test: fix wrong test name</li>
<li><a href="b93d790b36"><code>b93d790</code></a> test: use double quote for string literal</li>
<li><a href="0fbe7d5a0e"><code>0fbe7d5</code></a> test: don't use abbreviated name</li>
<li><a href="1599e8785f"><code>1599e87</code></a> test: add a performance test for PI with many tabs</li>
<li><a href="e2546e6eca"><code>e2546e6</code></a> parse pi: improve invalid case detection</li>
<li><a href="73661ef281"><code>73661ef</code></a> test: fix a typo</li>
<li><a href="850488abf2"><code>850488a</code></a> test: use double quote for string literal</li>
<li><a href="46c6397d5c"><code>46c6397</code></a> test: add performance tests for entity declaration</li>
<li>Additional commits viewable in <a href="https://github.com/ruby/rexml/compare/v3.2.8...v3.3.3">compare view</a></li>
</ul>
</details>
<br />
[](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)
Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`.
[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)
---
<details>
<summary>Dependabot commands and options</summary>
<br />
You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
- `@dependabot show <dependency name> ignore conditions` will show all of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/pytorch/pytorch/network/alerts).
</details>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132469
Approved by: https://github.com/ezyang
Summary:
## Why
utils.checkpoint doesn't support meta device:
```
File "/Users/lyu1/torchdev/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 490, in checkpoint
next(gen)
File "/Users/lyu1/torchdev/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 1359, in _checkpoint_without_reentrant_generator
device_module = _get_device_module(device)
File "/Users/lyu1/torchdev/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 98, in _get_device_module
device_module = getattr(torch, device)
File "/Users/lyu1/torchdev/lib/python3.9/site-packages/torch/__init__.py", line 1938, in __getattr__
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
AttributeError: module 'torch' has no attribute 'meta'
```
This blocks us from running model with checkpoint enabled in meta mode.
## What
This diff handles the case of meta device in checkpoint.py.
(in checkpoint.py, device module is manily used when preserve_rng_state=true, which doesn't apply to meta case. So a more elgant fix might be set preserve_rng_state=false when detecting args are on meta device. But I didn't find where to do this check in the minimum way. Let me know if you have ideas.)
Test Plan: Tested with toy model which has checkpoint on its module: P1513716944
Differential Revision: D60749427
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132684
Approved by: https://github.com/kit1980