From 258f47fc0b3f58aebc9ab2b6e7657a8a5ba87ced Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Wed, 7 Aug 2024 15:53:04 +0000 Subject: [PATCH] Add `padding_side` to `pad_sequence` with `"left"` and `"right"` options (`"right"` as default) (#131884) Fixes #10536 Reattempt of #61467. Thank you so much to @mskoh52 for your excellent work! As I was trying to create a more efficient LLM data collator, I realized that `pad_sequence` only supports right padding, even though left padding is a very common format for LLMs, like Llama and Mistral. The proposed alternative implementation was to use multiple flips, which tends to be 1.5x-2x slower. Instead we can add a [`padding_side` parameter as there is for for Hugging Face tokenizers](https://github.com/huggingface/transformers/blob/9d6c0641c4a3c2c5ecf4d49d7609edd5b745d9bc/src/transformers/tokenization_utils_base.py#L1565), which requires only a very small change in the C++ code. Here are the benchmarks of the new implementation! `float32`: ![eaaa95ef-9384-45d2-be56-6898bc1d3514](https://github.com/user-attachments/assets/3b0eb309-e5a0-4a4d-97bb-4e3298783dbb) `bool`: ![892f32da-8d9a-492b-9507-18d3f0a41e8e](https://github.com/user-attachments/assets/6824ea15-7d4e-4b89-95f0-8546635f0c2e) Code: ```python from __future__ import annotations import random import time from typing import Literal import numpy as np import torch def pad_sequence_with_flips( sequences: list[torch.Tensor], batch_first: bool = False, padding_value: int | float | bool = 0.0, padding_side: Literal["left", "right"] | str = "left", ) -> torch.Tensor: if padding_side == 'right': padded_sequence = torch._C._nn.pad_sequence([t.flatten() for t in sequences], batch_first=batch_first, padding_value=padding_value) elif padding_side=='left': padded_sequence = torch._C._nn.pad_sequence([t.flatten().flip(0) for t in sequences], batch_first=batch_first, padding_value=padding_value) # pyright: ignore[reportArgumentType] padded_sequence = padded_sequence.flip(int(batch_first)) else: raise ValueError(f"padding_side should be either 'right' or 'left', but got {padding_side}") return padded_sequence sequence_lengths: list[int] = [] flip_left_pad_times: list[float] = [] flip_left_pad_times_std: list[float] = [] left_pad_times: list[float] = [] left_pad_times_std: list[float] = [] RUNS_PER_LOOP: int = 100 for i in range(1, 7): sequence_length = i * int(1e6) // 6 sequence_lengths.append(sequence_length) sequences = [torch.randint(0, 2, (random.randint(1, sequence_length),), dtype=torch.bool) for _ in range(64)] inner_left_pad_times: list[float] = [] inner_right_pad_times: list[float] = [] inner_flip_left_pad_times: list[float] = [] inner_flip_right_pad_times: list[float] = [] for _ in range(RUNS_PER_LOOP): start = time.perf_counter() torch._C._nn.pad_sequence(sequences, batch_first=True, padding_value=False, padding_side="left") end = time.perf_counter() inner_left_pad_times.append(end - start) start = time.perf_counter() pad_sequence_with_flips(sequences, batch_first=True, padding_value=False, padding_side="left") end = time.perf_counter() inner_flip_left_pad_times.append(end - start) left_pad_times.append(sum(inner_left_pad_times) / len(inner_left_pad_times)) left_pad_times_std.append(np.std(inner_left_pad_times)) flip_left_pad_times.append(sum(inner_flip_left_pad_times) / len(inner_flip_left_pad_times)) flip_left_pad_times_std.append(np.std(inner_flip_left_pad_times)) print(f"Sequence Length: {sequence_length}, Left Pad Time: {left_pad_times[-1]}, Left with Flips Pad Time: {flip_left_pad_times[-1]}") import matplotlib.pyplot as plt plt.plot(sequence_lengths, left_pad_times, label="new pad_sequence left") plt.scatter(sequence_lengths, left_pad_times) plt.errorbar(sequence_lengths, left_pad_times, yerr=left_pad_times_std, linestyle='None', marker='^') plt.plot(sequence_lengths, flip_left_pad_times, label="old pad_sequence left (2 flips)") plt.scatter(sequence_lengths, flip_left_pad_times) plt.errorbar(sequence_lengths, flip_left_pad_times, yerr=flip_left_pad_times_std, linestyle='None', marker='^') plt.xlabel("Sequence Length") plt.ylabel("Time (s)") plt.legend(loc="upper right") # Sequence Length: 166666, Left Pad Time: 0.06147645162009212, Left with Flips Pad Time: 0.09842291727001794 # Sequence Length: 333333, Left Pad Time: 0.08933195920990329, Left with Flips Pad Time: 0.15597836187991562 # Sequence Length: 500000, Left Pad Time: 0.08863158334006585, Left with Flips Pad Time: 0.15224887342999863 # Sequence Length: 666666, Left Pad Time: 0.10524682551997103, Left with Flips Pad Time: 0.18177212480995877 # Sequence Length: 833333, Left Pad Time: 0.11801802741003485, Left with Flips Pad Time: 0.20821274195001024 # Sequence Length: 1000000, Left Pad Time: 0.131894061660023, Left with Flips Pad Time: 0.23223503091008751 ``` Co-authored-by: mskoh52 Pull Request resolved: https://github.com/pytorch/pytorch/pull/131884 Approved by: https://github.com/ezyang --- aten/src/ATen/native/PackedSequence.cpp | 9 +++-- aten/src/ATen/native/native_functions.yaml | 2 +- test/cpp/api/nn_utils.cpp | 26 ++++++++++++ test/nn/test_packed_sequence.py | 44 ++++++++++++++++++--- test/test_jit.py | 10 +++-- torch/_C/_nn.pyi.in | 3 +- torch/csrc/api/include/torch/nn/utils/rnn.h | 7 +++- torch/nn/utils/rnn.py | 7 +++- 8 files changed, 92 insertions(+), 16 deletions(-) diff --git a/aten/src/ATen/native/PackedSequence.cpp b/aten/src/ATen/native/PackedSequence.cpp index f45d5967bb2..85e24d2275a 100644 --- a/aten/src/ATen/native/PackedSequence.cpp +++ b/aten/src/ATen/native/PackedSequence.cpp @@ -202,9 +202,11 @@ std::tuple _pad_packed_sequence(const Tensor& data, const Tensor return std::make_tuple(output, lengths_t); } -Tensor pad_sequence(TensorList sequences, bool batch_first, double padding_value) { +Tensor pad_sequence(TensorList sequences, bool batch_first, double padding_value, const c10::string_view padding_side) { const int64_t sequences_size = sequences.size(); TORCH_CHECK(sequences_size > 0, "received an empty list of sequences"); + TORCH_CHECK(padding_side == "left" || padding_side == "right", + "Expected padding_side to be one of left or right, but got ", padding_side, "."); IntArrayRef max_size = sequences[0].sizes(); IntArrayRef trailing_dims = max_size.slice(1); int64_t max_len = std::max_element( @@ -227,11 +229,12 @@ Tensor pad_sequence(TensorList sequences, bool batch_first, double padding_value for (const auto i : c10::irange(sequences_size)) { const Tensor& currseq = sequences[i]; const int64_t length_i = currseq.size(0); + const int64_t start = padding_side == "left" ? max_len - length_i : 0; // use index notation to prevent duplicate references to the tensor if (batch_first) { - out.select(0, i).narrow(0, 0, length_i).copy_(currseq); + out.select(0, i).narrow(0, start, length_i).copy_(currseq); } else { - out.narrow(0, 0, length_i).select(1, i).copy_(currseq); + out.narrow(0, start, length_i).select(1, i).copy_(currseq); } } return out; diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 30ef4ef80ba..afe3b3814ea 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -14414,7 +14414,7 @@ CPU, CUDA: _segment_reduce_backward_kernel autogen: _segment_reduce_backward.out -- func: pad_sequence(Tensor[] sequences, bool batch_first=False, float padding_value=0.0) -> Tensor +- func: pad_sequence(Tensor[] sequences, bool batch_first=False, float padding_value=0.0, str padding_side="right") -> Tensor python_module: nn variants: function diff --git a/test/cpp/api/nn_utils.cpp b/test/cpp/api/nn_utils.cpp index 76aab44ac29..e45786b4d33 100644 --- a/test/cpp/api/nn_utils.cpp +++ b/test/cpp/api/nn_utils.cpp @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -830,6 +831,15 @@ TEST_F(NNUtilsTest, PadSequence) { padded = rnn_utils::pad_sequence({b, a, c}); ASSERT_TRUE(padded.allclose(expected.transpose(0, 1))); + // padding_side = "left", batch_first = true + expected = torch::tensor({{0, 4, 5}, {1, 2, 3}, {0, 0, 6}}); + padded = rnn_utils::pad_sequence({b, a, c}, true, 0, "left"); + ASSERT_TRUE(padded.allclose(expected)); + + // padding_side = "left", batch_first = false + padded = rnn_utils::pad_sequence({b, a, c}, false, 0, "left"); + ASSERT_TRUE(padded.allclose(expected.transpose(0, 1))); + // pad with non-zero value expected = torch::tensor({{4, 5, 1}, {1, 2, 3}, {6, 1, 1}}); padded = rnn_utils::pad_sequence({b, a, c}, true, 1); @@ -870,5 +880,21 @@ TEST_F(NNUtilsTest, PadSequence) { // batch first = false padded = rnn_utils::pad_sequence(sequences); ASSERT_TRUE(padded.allclose(expected.transpose(0, 1))); + + // reset expected_tensors for padding_side + expected_tensors.clear(); + for (const torch::Tensor& seq : sequences) { + // NOLINTNEXTLINE(performance-inefficient-vector-operation) + expected_tensors.emplace_back( + torch::flip(pad(torch::flip(seq, {0}), maxlen * maxlen), {0})); + } + expected = torch::stack(expected_tensors); + // padding_side = "left", batch_first = true + padded = rnn_utils::pad_sequence(sequences, true, 0, "left"); + ASSERT_TRUE(padded.allclose(expected)); + + // padding_side = "left", batch_first = false + padded = rnn_utils::pad_sequence(sequences, false, 0, "left"); + ASSERT_TRUE(padded.allclose(expected.transpose(0, 1))); } } diff --git a/test/nn/test_packed_sequence.py b/test/nn/test_packed_sequence.py index a9496edd366..8bb6ff64e4f 100644 --- a/test/nn/test_packed_sequence.py +++ b/test/nn/test_packed_sequence.py @@ -2,6 +2,7 @@ import itertools import random +from typing import List import torch import torch.nn.utils.rnn as rnn_utils @@ -188,6 +189,23 @@ class PackedSequenceTest(TestCase): padded = rnn_utils.pad_sequence([b, a, c]) self.assertEqual(padded, expected.transpose(0, 1)) + # padding_side = "left", batch_first=True + expected = torch.tensor([[0, 4, 5], [1, 2, 3], [0, 0, 6]]) + padded = rnn_utils.pad_sequence( + [b, a, c], + batch_first=True, + padding_side="left", + ) + self.assertEqual(padded, expected) + + # padding_side = "left", batch_first=False + padded = rnn_utils.pad_sequence( + [b, a, c], + batch_first=False, + padding_side="left", + ) + self.assertEqual(padded, expected.transpose(0, 1)) + # pad with non-zero value expected = torch.tensor([[4, 5, 1], [1, 2, 3], [6, 1, 1]]) padded = rnn_utils.pad_sequence([b, a, c], True, 1) @@ -201,17 +219,14 @@ class PackedSequenceTest(TestCase): # more dimensions maxlen = 9 for num_dim in (0, 1, 2, 3): - sequences = [] + sequences: List[torch.Tensor] = [] trailing_dims = [4] * num_dim for i in range(1, maxlen + 1): seq_len = i * i sequences.append(torch.rand(seq_len, 5, *trailing_dims)) random.shuffle(sequences) - expected = [] - for seq in sequences: - expected.append(pad(seq, maxlen * maxlen)) # batch first = true - expected = torch.stack(expected) + expected = torch.stack([pad(seq, maxlen * maxlen) for seq in sequences]) padded = rnn_utils.pad_sequence(sequences, True) self.assertEqual(padded, expected) @@ -219,6 +234,25 @@ class PackedSequenceTest(TestCase): padded = rnn_utils.pad_sequence(sequences) self.assertEqual(padded, expected.transpose(0, 1)) + # padding_side = "left", batch_first=True + expected = torch.stack( + [pad(seq.flip(0), maxlen * maxlen).flip(0) for seq in sequences] + ) + padded = rnn_utils.pad_sequence( + sequences, + batch_first=True, + padding_side="left", + ) + self.assertEqual(padded, expected) + + # padding_side = "left", batch_first=False + padded = rnn_utils.pad_sequence( + sequences, + batch_first=False, + padding_side="left", + ) + self.assertEqual(padded, expected.transpose(0, 1)) + def test_unpad_sequence(self): # single dimensional a = torch.tensor([1, 2, 3]) diff --git a/test/test_jit.py b/test/test_jit.py index adc3c71b4b9..5b3a40973ac 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -9706,9 +9706,9 @@ dedent """ def test_script_pad_sequence_pack_sequence(self): from torch.nn.utils.rnn import pad_sequence, pack_sequence, pad_packed_sequence - def pad_sequence_func(tensor_list, batch_first=False, padding_value=0.0): - # type: (List[Tensor], bool, float) -> Tensor - return pad_sequence(tensor_list, batch_first, padding_value) + def pad_sequence_func(tensor_list, batch_first=False, padding_value=0.0, padding_side="right"): + # type: (List[Tensor], bool, float, str) -> Tensor + return pad_sequence(tensor_list, batch_first, padding_value, padding_side) def pack_sequence_func(tensor_list, enforce_sorted=True): # type: (List[Tensor], bool) -> Tensor @@ -9727,6 +9727,10 @@ dedent """ ([ones3, ones4, ones5], True)) self.checkScript(pad_sequence_func, ([ones3, ones4, ones5], True, 2.5)) + self.checkScript(pad_sequence_func, + ([ones3, ones4, ones5], True, 2.5, "left")) + self.checkScript(pad_sequence_func, + ([ones3, ones4, ones5], False, 2.5, "left")) self.checkScript(pack_sequence_func, ([tensor1, tensor2, tensor3],)) self.checkScript(pack_sequence_func, diff --git a/torch/_C/_nn.pyi.in b/torch/_C/_nn.pyi.in index 8aa56953bec..336190443a8 100644 --- a/torch/_C/_nn.pyi.in +++ b/torch/_C/_nn.pyi.in @@ -1,7 +1,7 @@ # ${generated_comment} # mypy: disable-error-code="type-arg" -from typing import List, Optional, overload, Sequence, Tuple, Union +from typing import List, Literal, Optional, overload, Sequence, Tuple, Union from torch import memory_format, Tensor from torch.types import _bool, _device, _dtype, _int, _size @@ -64,6 +64,7 @@ def pad_sequence( sequences: Union[List[Tensor], Tuple[Tensor, ...]], batch_first: bool = False, padding_value: float = 0.0, + padding_side: Union[Literal["left", "right"], str] = "right", ) -> Tensor: ... def flatten_dense_tensors(tensors: List[Tensor]) -> Tensor: ... def unflatten_dense_tensors(flat: Tensor, tensors: List[Tensor]) -> List[Tensor]: ... diff --git a/torch/csrc/api/include/torch/nn/utils/rnn.h b/torch/csrc/api/include/torch/nn/utils/rnn.h index ba8b0db4271..6f2a68984c8 100644 --- a/torch/csrc/api/include/torch/nn/utils/rnn.h +++ b/torch/csrc/api/include/torch/nn/utils/rnn.h @@ -300,6 +300,8 @@ inline std::tuple pad_packed_sequence( /// or in /// ``T x B x *`` otherwise /// padding_value (double, optional): value for padded elements. Default: 0. +/// padding_side (str, optional): the side to pad the sequences on. Default: +/// "right". /// /// Returns: /// Tensor of size ``T x B x *`` if `batch_first` is ``false``. @@ -307,8 +309,9 @@ inline std::tuple pad_packed_sequence( inline Tensor pad_sequence( ArrayRef sequences, bool batch_first = false, - double padding_value = 0) { - return at::pad_sequence(sequences, batch_first, padding_value); + double padding_value = 0, + c10::string_view padding_side = "right") { + return at::pad_sequence(sequences, batch_first, padding_value, padding_side); } /// Packs a list of variable length Tensors diff --git a/torch/nn/utils/rnn.py b/torch/nn/utils/rnn.py index 2304c5c64b5..09a38e1119d 100644 --- a/torch/nn/utils/rnn.py +++ b/torch/nn/utils/rnn.py @@ -419,6 +419,7 @@ def pad_sequence( sequences: Union[Tensor, List[Tensor]], batch_first: bool = False, padding_value: float = 0.0, + padding_side: str = "right", ) -> Tensor: r"""Pad a list of variable length Tensors with :attr:`padding_value`. @@ -448,6 +449,8 @@ def pad_sequence( batch_first (bool, optional): if ``True``, the output will be in ``B x T x *`` format, ``T x B x *`` otherwise. padding_value (float, optional): value for padded elements. Default: 0. + padding_side (str, optional): the side to pad the sequences on. + Default: "right". Returns: Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``. @@ -472,7 +475,9 @@ def pad_sequence( # assuming trailing dimensions and type of all the Tensors # in sequences are same and fetching those from sequences[0] - return torch._C._nn.pad_sequence(sequences, batch_first, padding_value) # type: ignore[arg-type] + return torch._C._nn.pad_sequence( + sequences, batch_first, padding_value, padding_side # type: ignore[arg-type] + ) def unpad_sequence(