diff --git a/examples/research_projects/token-healing/README.md b/examples/research_projects/token-healing/README.md new file mode 100644 index 000000000..f3594f32d --- /dev/null +++ b/examples/research_projects/token-healing/README.md @@ -0,0 +1,40 @@ + + + + +## What is token healing? + +Token healing rectifies the token boundary bias in greedy tokenization. It does this by trimming and regrowing the prompt to better align with the model's tokenizer, thus enhancing generation quality. The improvement is clearest with completion models. + +Example: given a completion prompt with a partial url ending with `:`, the model might have seen the expected completion `://` as a _single_ token in training. However, the prompt's tail token `:` tells it that the next token is not `//`, and so it looks for wrong completions. Such errors compound in auto-regressive language models. + +Debiasing token boundaries also addresses output sensitivity to prompts ending with whitespace. + +A more thorough explanation can be found on [The Art of Prompt Design: Prompt Boundaries and Token Healing | by Scott Lundberg](https://towardsdatascience.com/the-art-of-prompt-design-prompt-boundaries-and-token-healing-3b2448b0be38). + +## Usage + +```py +prompt = 'The link is (back to top)

\ No newline at end of file diff --git a/examples/research_projects/token-healing/run_token_healing.py b/examples/research_projects/token-healing/run_token_healing.py new file mode 100644 index 000000000..2dd9148c1 --- /dev/null +++ b/examples/research_projects/token-healing/run_token_healing.py @@ -0,0 +1,62 @@ +import argparse + +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + + +def generate(inputs, model, tokenizer, token_healing): + input_ids = tokenizer(inputs, return_tensors="pt", padding=True, device_map="auto").input_ids + generation_config = GenerationConfig( + max_new_tokens=8, + token_healing=token_healing, + pad_token_id=model.config.pad_token_id, + repetition_penalty=1.1, + ) + output = model.generate(inputs=input_ids, generation_config=generation_config) + return tokenizer.batch_decode(output, skip_special_tokens=True) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--prompt", type=str) + parser.add_argument("--model_name_or_path", type=str, default="TheBloke/deepseek-llm-7B-base-GPTQ") + args = parser.parse_args() + + prompts = ( + [args.prompt] + if args.prompt + else [ + 'An example ["like this"] and another example [', + 'The link is https + "I read a book about ", # test trailing whitespace + "I read a book about", # test nothing to heal + ] + ) + + model_name_or_path = args.model_name_or_path + completion_model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + device_map="auto", + use_cache=True, + ) + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + + raw_output = generate(prompts, completion_model, tokenizer, token_healing=False) + healed_output = generate(prompts, completion_model, tokenizer, token_healing=True) + + for p, a, b in zip(prompts, raw_output, healed_output): + print(f"\nPrompt: {p}\nWithout healing:\n{a}\nWith healing:\n{b}") + + # You can also use token healing in isolation + # This can be useful if you have other work to do before the generation + # Or if you want to delegate generation to another process + input_ids = tokenizer(prompts, return_tensors="pt", padding=True).input_ids.cuda() + healed_ids = completion_model.heal_tokens(input_ids) + healed_prompts = tokenizer.batch_decode(healed_ids, skip_special_tokens=True) + print("\nhealed prompts:") + for p in healed_prompts: + print(p) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 0d1eba0bd..7dbeb6cce 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -222,6 +222,9 @@ class GenerationConfig(PushToHubMixin): Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the sequence being selected, while negative biases do the opposite. Check [`~generation.SequenceBiasLogitsProcessor`] for further documentation and examples. + token_healing (`bool`, *optional*, defaults to `False`): + Heal tail tokens of prompts by replacing them with their appropriate extensions. + This enhances the quality of completions for prompts affected by greedy tokenization bias. guidance_scale (`float`, *optional*): The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages the model to generate samples that are more closely linked to the input @@ -360,6 +363,7 @@ class GenerationConfig(PushToHubMixin): self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None) self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None) self.sequence_bias = kwargs.pop("sequence_bias", None) + self.token_healing = kwargs.pop("token_healing", False) self.guidance_scale = kwargs.pop("guidance_scale", None) self.low_memory = kwargs.pop("low_memory", None) watermarking_config = kwargs.pop("watermarking_config", None) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 84c9dd995..074f51b35 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -42,6 +42,7 @@ from ..models.auto import ( MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_VISION_2_SEQ_MAPPING, ) +from ..tokenization_utils import ExtensionsTrie from ..utils import ( ModelOutput, is_accelerate_available, @@ -1591,6 +1592,8 @@ class GenerationMixin: else: synced_gpus = False + tokenizer = kwargs.pop("tokenizer", None) + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() @@ -1653,6 +1656,9 @@ class GenerationMixin: else: input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") + if generation_config.token_healing: + input_ids = self.heal_tokens(input_ids, tokenizer) + if streamer is not None: streamer.put(input_ids.cpu()) @@ -1989,6 +1995,75 @@ class GenerationMixin: return False return True + def heal_tokens( + self, input_ids: torch.LongTensor, tokenizer: Optional["PreTrainedTokenizerBase"] = None + ) -> torch.LongTensor: + r""" + Generates sequences of token ids for models with a language modeling head. + Parameters: + input_ids (`torch.LongTensor`): The sequence used as a prompt for the generation. + tokenizer (`PreTrainedTokenizerBase`, *optional*): The tokenizer used to decode the input ids. + Return: + `torch.LongTensor` where each sequence has its tail token replaced with its appropriate extension. + """ + if tokenizer is None: + raise ValueError( + " When generating with token healing, you must pass the model's tokenizer to the `tokenizer` " + "argument of `generate`." + ) + + bos_token_id, pad_token_id = tokenizer.bos_token_id, tokenizer.pad_token_id + vocab_trie = ExtensionsTrie(tokenizer.get_vocab()) + generation_config = GenerationConfig(max_new_tokens=1, pad_token_id=pad_token_id) + + # assumption: leading/trailing whitespace is not meaningful, so the prompts are + # stripped before re-tokenizing to desensitize generation to whitespace artefacts + prompts = [p.strip() for p in tokenizer.batch_decode(input_ids, skip_special_tokens=True)] + input_ids = tokenizer( + prompts, + return_tensors="pt", + padding=True, + ).input_ids.to(input_ids.device) + + # replace bos with pad to not condition healing on it + input_ids = torch.where(input_ids == bos_token_id, pad_token_id, input_ids) + + tail_ids = input_ids[:, -1].tolist() + space_tok = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(" "))[0] + # tail tokens are used for a prefix search, thus, whitespaces are replaced with + # their tokenization (e.g. 'Ġ') to enable search for tokens prefixed with a whitespace + tail_toks = (tokenizer.decode(t).replace(" ", space_tok) for t in tail_ids) + + for batch_idx, (tail_id, tail_tok) in enumerate(zip(tail_ids, tail_toks)): + batch_ids = input_ids[batch_idx] + if torch.all(batch_ids == pad_token_id).item(): + continue # skip empty sequences (all pad ids) + + # apply bias for alternatives (extensions) to the tail token + seq_bias = {(alt_tok,): 10.0 for alt_tok in vocab_trie.values(prefix=tail_tok)} + if len(seq_bias) == 1: + continue # skip if there are no token alternatives to heal with + + # slightly favor original token to limit aggressive healing e.g. 'http' -> 'https' + seq_bias[(tail_id,)] += 1.0 + generation_config.update(sequence_bias=seq_bias) + + trimmed_ids = batch_ids[:-1] + # if the prompt is a single (non-pad) token, regenerate from bos + if len(batch_ids[batch_ids != pad_token_id]) == 1: + trimmed_ids[-1] = bos_token_id + + input_ids[batch_idx] = self.generate(trimmed_ids.unsqueeze(0), generation_config=generation_config) + + return input_ids + + def contrastive_search(self, *args, **kwargs): + logger.warning_once( + "Calling `contrastive_search` directly is deprecated and will be removed in v4.41. Use `generate` or a " + "custom generation loop instead.", + ) + return self._contrastive_search(*args, **kwargs) + @torch.no_grad() def _contrastive_search( self, diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index b7c023d95..423152626 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -56,14 +56,26 @@ class Trie: Loose reference https://en.wikipedia.org/wiki/Trie """ - def __init__(self): + def __init__(self, *args): self.data = {} self._tokens = set() + self._termination_char = "" + self.update(*args) + + def update(self, *args): + """ + Updates the Trie with new tokens provided as arguments. + + Args: + *args: Variable number of words to be added to the Trie. + """ + for token in tuple(*args): + self.add(token) def add(self, word: str): """ Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation. - The special key `""` is used to represent termination. + The special key `""` in `self._termination_char` is used to represent termination. This function is idempotent, adding twice the same word will leave the trie unchanged @@ -87,9 +99,9 @@ class Trie: self._tokens.add(word) ref = self.data for char in word: - ref[char] = char in ref and ref[char] or {} + ref[char] = ref.setdefault(char, {}) ref = ref[char] - ref[""] = 1 + ref[self._termination_char] = 1 def split(self, text: str) -> List[str]: """ @@ -269,6 +281,62 @@ class Trie: return tokens +class ExtensionsTrie(Trie): + def __init__(self, *args): + super().__init__(*args) + + def extensions(self, prefix: str): + """ + Generates all extensions of a given prefix token in the Trie. + + Example: + + ```python + >>> trie = Trie() + >>> trie.add("apple") + >>> trie.add("app") + >>> trie.add("application") + >>> trie.extensions("app") + ['app', 'apple', 'application'] + ``` + """ + prefix_node = self._get_node(prefix) + ret = self._collect_tokens(prefix_node) + return [prefix + token for token in ret] + + def _get_node(self, token: str) -> dict: + """ + Retrieves the node corresponding to the given token in the Trie. + + Args: + token (str): The token for which the corresponding node needs to be retrieved. + + Returns: + dict: The node in the Trie corresponding to the given token. + """ + node = self.data + for char in token: + node = node[char] + return node + + def _collect_tokens(self, node: dict) -> list: + """ + Generates all tokens in the Trie starting from a given node. + + Args: + node (dict): The node in the Trie from which tokens need to be generated. + + Returns: + list: List of tokens generated from the given node. + """ + tokens = [self._termination_char] if self._termination_char in node else [] + for token, subtrie_head in node.items(): + if token != self._termination_char: + subtokens = self._collect_tokens(subtrie_head) + tokens.extend([token + subtoken for subtoken in subtokens]) + return tokens + + def _is_whitespace(char): """Checks whether `char` is a whitespace character.""" # \t, \n, and \r are technically control characters but we treat them diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 57b6c6d18..1981f5a63 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -27,6 +27,7 @@ from transformers import is_torch_available, pipeline, set_seed from transformers.testing_utils import ( is_flaky, require_accelerate, + require_auto_gptq, require_quanto, require_torch, require_torch_multi_accelerator, @@ -3066,6 +3067,43 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi self.assertTrue(y_prob > 0.001 and n_prob > 0.001) self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0) + +@require_torch +class TokenHealingTestCase(unittest.TestCase): + @parameterized.expand( + [ + ( + "square_bracket", + 'An example ["like this"] and another example [', + 'An example ["like this"] and another example ["', + ), + ("url", 'The link is