Fix StopStringCriteria to handle tokens above len(tokenizer) (#35797)

* Fix StopStringCriteria to handle tokens above len(tokenizer)

This fixes #35244 by clipping token IDs to be within the tokenizer's vocabulary size before performing the embedding lookup. This prevents index errors when model.config.vocab_size > len(tokenizer).

The fix:
1. Adds a clamp operation to ensure token IDs are within bounds
2. Adds a test case to verify the behavior

* Use self.stop_strings instead of stop_strings

* Handle clipping correctly

* make fixup

* Update test to the new embedding vecs

* Use much bigger values in the mismatch test

* Typo fix

* Slight simplification

---------

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Matt 2025-02-06 16:53:28 +00:00 committed by GitHub
parent 28f73bc307
commit 4563ba2c6f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 27 additions and 10 deletions

View file

@ -245,26 +245,26 @@ class StopStringCriteria(StoppingCriteria):
vocab = tokenizer.get_vocab()
token_list, token_indices = tuple(vocab.keys()), tuple(vocab.values())
self.embedding_vec, self.max_valid_positions, self.max_valid_end_lens = self.clean_and_embed_tokens_with_cache(
token_list, token_indices, self.stop_strings, tokenizer
token_list, token_indices, tokenizer
)
self.maximum_token_len = max([len(stop_string) for stop_string in self.stop_strings])
self.num_stop_strings = len(self.stop_strings)
self.target_lens = torch.tensor([len(stop_string) for stop_string in stop_strings], dtype=torch.int32)
def clean_and_embed_tokens_with_cache(self, token_list, token_indices, stop_strings, tokenizer):
def clean_and_embed_tokens_with_cache(self, token_list, token_indices, tokenizer):
# We don't use the tokenizer in the cache key, because I don't trust it to have well-behaved equality
if (token_list, token_indices, stop_strings) in STOP_STRING_EMBEDDING_CACHE:
if (token_list, token_indices, self.stop_strings) in STOP_STRING_EMBEDDING_CACHE:
embedding_vec, max_valid_positions, max_valid_end_lens = STOP_STRING_EMBEDDING_CACHE[
(token_list, token_indices, self.stop_strings)
]
STOP_STRING_EMBEDDING_CACHE.move_to_end((token_list, token_indices, stop_strings))
STOP_STRING_EMBEDDING_CACHE.move_to_end((token_list, token_indices, self.stop_strings))
else:
clean_token_list, clean_token_indices = self.clean_tokenizer_vocab(tokenizer)
embedding_vec, max_valid_positions, max_valid_end_lens = self._stop_string_create_embedding_vec(
clean_token_list, clean_token_indices, stop_strings
clean_token_list, clean_token_indices, self.stop_strings
)
STOP_STRING_EMBEDDING_CACHE[(token_list, token_indices, stop_strings)] = (
STOP_STRING_EMBEDDING_CACHE[(token_list, token_indices, self.stop_strings)] = (
embedding_vec,
max_valid_positions,
max_valid_end_lens,
@ -357,7 +357,9 @@ class StopStringCriteria(StoppingCriteria):
)
max_valid_end_lens = max(valid_end_lens)
vec_size = len(stop_strings) * (max_valid_positions + max_valid_end_lens) + 1
gather_vec = np.full((len(token_list), vec_size), dtype=np.int32, fill_value=-1)
# We use +2 instead of +1 so we can have a dummy entry at the end. We will clamp all token values
# over the max to this, ensuring they do not contribute to stop string matching.
gather_vec = np.full((max(token_indices) + 2, vec_size), dtype=np.int32, fill_value=-1)
for i, stop_string in enumerate(stop_strings):
positions = token_valid_positions[stop_string]
@ -395,6 +397,9 @@ class StopStringCriteria(StoppingCriteria):
# Flip input_ids because we're only matching strings at the end of the generated sequence
flipped_ids = torch.flip(input_ids, (1,))
# Clip out-of-vocab values to the dummy value at the end of the embedding vector
flipped_ids = torch.clamp(flipped_ids, max=self.embedding_vec.size(0) - 1)
# Size of the vector of positions a single token can match
max_valid_positions = self.max_valid_positions

View file

@ -176,6 +176,18 @@ class StoppingCriteriaTestCase(unittest.TestCase):
for i in range(len(false_strings)):
self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores))
def test_stop_string_criteria_vocab_size_mismatch(self):
"""Test that StopStringCriteria handles tokens above len(tokenizer) correctly."""
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
# Create input_ids with tokens above len(tokenizer)
input_ids = torch.tensor([[len(tokenizer) + 1024, 1, 2]], device=torch_device)
scores = None
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["test"])
# This should not raise an error and should return False since no stop string is matched
self.assertFalse(criteria(input_ids, scores))
def test_stop_string_matching_positions(self):
stop_string = "stop"
token_list = ["last", "top", "topper", "s", "p"]
@ -200,14 +212,14 @@ class StoppingCriteriaTestCase(unittest.TestCase):
# Positions inside the stop string where the token matches (excluding end overlaps)
valid_positions = embedding_vec[:, 0].tolist()
self.assertEqual(valid_positions, [2, -1, -1, 3, -1])
self.assertEqual(valid_positions, [2, -1, -1, 3, -1, -1])
# Overlap lengths between end of stop string and start of token
end_overlaps = embedding_vec[:, 1].tolist()
self.assertEqual(end_overlaps, [-1, 3, 3, -1, 1])
self.assertEqual(end_overlaps, [-1, 3, 3, -1, 1, -1])
# Length of each token
token_lengths = embedding_vec[:, 2].tolist()
token_lengths = embedding_vec[:-1, 2].tolist()
self.assertEqual(token_lengths, [len(token) for token in token_list])
def test_single_letter_stop_string(self):