mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Fix StopStringCriteria to handle tokens above len(tokenizer) (#35797)
* Fix StopStringCriteria to handle tokens above len(tokenizer) This fixes #35244 by clipping token IDs to be within the tokenizer's vocabulary size before performing the embedding lookup. This prevents index errors when model.config.vocab_size > len(tokenizer). The fix: 1. Adds a clamp operation to ensure token IDs are within bounds 2. Adds a test case to verify the behavior * Use self.stop_strings instead of stop_strings * Handle clipping correctly * make fixup * Update test to the new embedding vecs * Use much bigger values in the mismatch test * Typo fix * Slight simplification --------- Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
e173ffd3ba
commit
9c3009a391
2 changed files with 27 additions and 10 deletions
|
|
@ -245,26 +245,26 @@ class StopStringCriteria(StoppingCriteria):
|
|||
vocab = tokenizer.get_vocab()
|
||||
token_list, token_indices = tuple(vocab.keys()), tuple(vocab.values())
|
||||
self.embedding_vec, self.max_valid_positions, self.max_valid_end_lens = self.clean_and_embed_tokens_with_cache(
|
||||
token_list, token_indices, self.stop_strings, tokenizer
|
||||
token_list, token_indices, tokenizer
|
||||
)
|
||||
|
||||
self.maximum_token_len = max([len(stop_string) for stop_string in self.stop_strings])
|
||||
self.num_stop_strings = len(self.stop_strings)
|
||||
self.target_lens = torch.tensor([len(stop_string) for stop_string in stop_strings], dtype=torch.int32)
|
||||
|
||||
def clean_and_embed_tokens_with_cache(self, token_list, token_indices, stop_strings, tokenizer):
|
||||
def clean_and_embed_tokens_with_cache(self, token_list, token_indices, tokenizer):
|
||||
# We don't use the tokenizer in the cache key, because I don't trust it to have well-behaved equality
|
||||
if (token_list, token_indices, stop_strings) in STOP_STRING_EMBEDDING_CACHE:
|
||||
if (token_list, token_indices, self.stop_strings) in STOP_STRING_EMBEDDING_CACHE:
|
||||
embedding_vec, max_valid_positions, max_valid_end_lens = STOP_STRING_EMBEDDING_CACHE[
|
||||
(token_list, token_indices, self.stop_strings)
|
||||
]
|
||||
STOP_STRING_EMBEDDING_CACHE.move_to_end((token_list, token_indices, stop_strings))
|
||||
STOP_STRING_EMBEDDING_CACHE.move_to_end((token_list, token_indices, self.stop_strings))
|
||||
else:
|
||||
clean_token_list, clean_token_indices = self.clean_tokenizer_vocab(tokenizer)
|
||||
embedding_vec, max_valid_positions, max_valid_end_lens = self._stop_string_create_embedding_vec(
|
||||
clean_token_list, clean_token_indices, stop_strings
|
||||
clean_token_list, clean_token_indices, self.stop_strings
|
||||
)
|
||||
STOP_STRING_EMBEDDING_CACHE[(token_list, token_indices, stop_strings)] = (
|
||||
STOP_STRING_EMBEDDING_CACHE[(token_list, token_indices, self.stop_strings)] = (
|
||||
embedding_vec,
|
||||
max_valid_positions,
|
||||
max_valid_end_lens,
|
||||
|
|
@ -357,7 +357,9 @@ class StopStringCriteria(StoppingCriteria):
|
|||
)
|
||||
max_valid_end_lens = max(valid_end_lens)
|
||||
vec_size = len(stop_strings) * (max_valid_positions + max_valid_end_lens) + 1
|
||||
gather_vec = np.full((len(token_list), vec_size), dtype=np.int32, fill_value=-1)
|
||||
# We use +2 instead of +1 so we can have a dummy entry at the end. We will clamp all token values
|
||||
# over the max to this, ensuring they do not contribute to stop string matching.
|
||||
gather_vec = np.full((max(token_indices) + 2, vec_size), dtype=np.int32, fill_value=-1)
|
||||
|
||||
for i, stop_string in enumerate(stop_strings):
|
||||
positions = token_valid_positions[stop_string]
|
||||
|
|
@ -395,6 +397,9 @@ class StopStringCriteria(StoppingCriteria):
|
|||
# Flip input_ids because we're only matching strings at the end of the generated sequence
|
||||
flipped_ids = torch.flip(input_ids, (1,))
|
||||
|
||||
# Clip out-of-vocab values to the dummy value at the end of the embedding vector
|
||||
flipped_ids = torch.clamp(flipped_ids, max=self.embedding_vec.size(0) - 1)
|
||||
|
||||
# Size of the vector of positions a single token can match
|
||||
max_valid_positions = self.max_valid_positions
|
||||
|
||||
|
|
|
|||
|
|
@ -176,6 +176,18 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
|||
for i in range(len(false_strings)):
|
||||
self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores))
|
||||
|
||||
def test_stop_string_criteria_vocab_size_mismatch(self):
|
||||
"""Test that StopStringCriteria handles tokens above len(tokenizer) correctly."""
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||
|
||||
# Create input_ids with tokens above len(tokenizer)
|
||||
input_ids = torch.tensor([[len(tokenizer) + 1024, 1, 2]], device=torch_device)
|
||||
scores = None
|
||||
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["test"])
|
||||
|
||||
# This should not raise an error and should return False since no stop string is matched
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
|
||||
def test_stop_string_matching_positions(self):
|
||||
stop_string = "stop"
|
||||
token_list = ["last", "top", "topper", "s", "p"]
|
||||
|
|
@ -200,14 +212,14 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
|||
|
||||
# Positions inside the stop string where the token matches (excluding end overlaps)
|
||||
valid_positions = embedding_vec[:, 0].tolist()
|
||||
self.assertEqual(valid_positions, [2, -1, -1, 3, -1])
|
||||
self.assertEqual(valid_positions, [2, -1, -1, 3, -1, -1])
|
||||
|
||||
# Overlap lengths between end of stop string and start of token
|
||||
end_overlaps = embedding_vec[:, 1].tolist()
|
||||
self.assertEqual(end_overlaps, [-1, 3, 3, -1, 1])
|
||||
self.assertEqual(end_overlaps, [-1, 3, 3, -1, 1, -1])
|
||||
|
||||
# Length of each token
|
||||
token_lengths = embedding_vec[:, 2].tolist()
|
||||
token_lengths = embedding_vec[:-1, 2].tolist()
|
||||
self.assertEqual(token_lengths, [len(token) for token in token_list])
|
||||
|
||||
def test_single_letter_stop_string(self):
|
||||
|
|
|
|||
Loading…
Reference in a new issue