From 9c3009a391c909b8993db6b22ca8cc19b0073e2f Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 6 Feb 2025 16:53:28 +0000 Subject: [PATCH] Fix StopStringCriteria to handle tokens above len(tokenizer) (#35797) * Fix StopStringCriteria to handle tokens above len(tokenizer) This fixes #35244 by clipping token IDs to be within the tokenizer's vocabulary size before performing the embedding lookup. This prevents index errors when model.config.vocab_size > len(tokenizer). The fix: 1. Adds a clamp operation to ensure token IDs are within bounds 2. Adds a test case to verify the behavior * Use self.stop_strings instead of stop_strings * Handle clipping correctly * make fixup * Update test to the new embedding vecs * Use much bigger values in the mismatch test * Typo fix * Slight simplification --------- Co-authored-by: openhands --- .../generation/stopping_criteria.py | 19 ++++++++++++------- tests/generation/test_stopping_criteria.py | 18 +++++++++++++++--- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index b950a69f8..4627aeb97 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -245,26 +245,26 @@ class StopStringCriteria(StoppingCriteria): vocab = tokenizer.get_vocab() token_list, token_indices = tuple(vocab.keys()), tuple(vocab.values()) self.embedding_vec, self.max_valid_positions, self.max_valid_end_lens = self.clean_and_embed_tokens_with_cache( - token_list, token_indices, self.stop_strings, tokenizer + token_list, token_indices, tokenizer ) self.maximum_token_len = max([len(stop_string) for stop_string in self.stop_strings]) self.num_stop_strings = len(self.stop_strings) self.target_lens = torch.tensor([len(stop_string) for stop_string in stop_strings], dtype=torch.int32) - def clean_and_embed_tokens_with_cache(self, token_list, token_indices, stop_strings, tokenizer): + def clean_and_embed_tokens_with_cache(self, token_list, token_indices, tokenizer): # We don't use the tokenizer in the cache key, because I don't trust it to have well-behaved equality - if (token_list, token_indices, stop_strings) in STOP_STRING_EMBEDDING_CACHE: + if (token_list, token_indices, self.stop_strings) in STOP_STRING_EMBEDDING_CACHE: embedding_vec, max_valid_positions, max_valid_end_lens = STOP_STRING_EMBEDDING_CACHE[ (token_list, token_indices, self.stop_strings) ] - STOP_STRING_EMBEDDING_CACHE.move_to_end((token_list, token_indices, stop_strings)) + STOP_STRING_EMBEDDING_CACHE.move_to_end((token_list, token_indices, self.stop_strings)) else: clean_token_list, clean_token_indices = self.clean_tokenizer_vocab(tokenizer) embedding_vec, max_valid_positions, max_valid_end_lens = self._stop_string_create_embedding_vec( - clean_token_list, clean_token_indices, stop_strings + clean_token_list, clean_token_indices, self.stop_strings ) - STOP_STRING_EMBEDDING_CACHE[(token_list, token_indices, stop_strings)] = ( + STOP_STRING_EMBEDDING_CACHE[(token_list, token_indices, self.stop_strings)] = ( embedding_vec, max_valid_positions, max_valid_end_lens, @@ -357,7 +357,9 @@ class StopStringCriteria(StoppingCriteria): ) max_valid_end_lens = max(valid_end_lens) vec_size = len(stop_strings) * (max_valid_positions + max_valid_end_lens) + 1 - gather_vec = np.full((len(token_list), vec_size), dtype=np.int32, fill_value=-1) + # We use +2 instead of +1 so we can have a dummy entry at the end. We will clamp all token values + # over the max to this, ensuring they do not contribute to stop string matching. + gather_vec = np.full((max(token_indices) + 2, vec_size), dtype=np.int32, fill_value=-1) for i, stop_string in enumerate(stop_strings): positions = token_valid_positions[stop_string] @@ -395,6 +397,9 @@ class StopStringCriteria(StoppingCriteria): # Flip input_ids because we're only matching strings at the end of the generated sequence flipped_ids = torch.flip(input_ids, (1,)) + # Clip out-of-vocab values to the dummy value at the end of the embedding vector + flipped_ids = torch.clamp(flipped_ids, max=self.embedding_vec.size(0) - 1) + # Size of the vector of positions a single token can match max_valid_positions = self.max_valid_positions diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index e8594dcdb..ace7d496d 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -176,6 +176,18 @@ class StoppingCriteriaTestCase(unittest.TestCase): for i in range(len(false_strings)): self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores)) + def test_stop_string_criteria_vocab_size_mismatch(self): + """Test that StopStringCriteria handles tokens above len(tokenizer) correctly.""" + tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + + # Create input_ids with tokens above len(tokenizer) + input_ids = torch.tensor([[len(tokenizer) + 1024, 1, 2]], device=torch_device) + scores = None + criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["test"]) + + # This should not raise an error and should return False since no stop string is matched + self.assertFalse(criteria(input_ids, scores)) + def test_stop_string_matching_positions(self): stop_string = "stop" token_list = ["last", "top", "topper", "s", "p"] @@ -200,14 +212,14 @@ class StoppingCriteriaTestCase(unittest.TestCase): # Positions inside the stop string where the token matches (excluding end overlaps) valid_positions = embedding_vec[:, 0].tolist() - self.assertEqual(valid_positions, [2, -1, -1, 3, -1]) + self.assertEqual(valid_positions, [2, -1, -1, 3, -1, -1]) # Overlap lengths between end of stop string and start of token end_overlaps = embedding_vec[:, 1].tolist() - self.assertEqual(end_overlaps, [-1, 3, 3, -1, 1]) + self.assertEqual(end_overlaps, [-1, 3, 3, -1, 1, -1]) # Length of each token - token_lengths = embedding_vec[:, 2].tolist() + token_lengths = embedding_vec[:-1, 2].tolist() self.assertEqual(token_lengths, [len(token) for token in token_list]) def test_single_letter_stop_string(self):