mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Self-speculation (Layer-Skip Llama) (#34240)
* 😅 * early exit (#34244) * mvp * docs and tests * a few fixes * no shared cache * Apply suggestions from code review Co-authored-by: Mostafa Elhoushi <m.elhoushi@ieee.org> * docs * make fix-copies * cohere fix * [test all] * [test all] consistent model code copies * [test all] make fix-copies :D * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Mostafa Elhoushi <m.elhoushi@ieee.org> * Update src/transformers/generation/candidate_generator.py * Update src/transformers/generation/configuration_utils.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * [test all] don't use a stand-alone attribute; fix test --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Joao Gante <joao@huggingface.co> Co-authored-by: Mostafa Elhoushi <m.elhoushi@ieee.org> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
parent
5de58d5955
commit
54739a320e
15 changed files with 185 additions and 58 deletions
|
|
@ -416,16 +416,6 @@ Assisted decoding assumes the main and assistant models have the same tokenizer,
|
|||
Currently, only greedy search and sampling are supported with assisted decoding, and assisted decoding doesn't support batched inputs.
|
||||
To learn more about assisted decoding, check [this blog post](https://huggingface.co/blog/assisted-generation).
|
||||
|
||||
#### Universal Assisted Decoding
|
||||
|
||||
Universal Assisted Decoding (UAD) adds support for main and assistant models with different tokenizers.
|
||||
To use it, simply pass the tokenizers using the `tokenizer` and `assistant_tokenizer` arguments (see below).
|
||||
Internally, the main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are
|
||||
in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above.
|
||||
The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer.
|
||||
Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings,
|
||||
to ensure the new tokens include the correct prompt suffix.
|
||||
|
||||
To enable assisted decoding, set the `assistant_model` argument with a model.
|
||||
|
||||
```python
|
||||
|
|
@ -445,26 +435,6 @@ To enable assisted decoding, set the `assistant_model` argument with a model.
|
|||
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
|
||||
```
|
||||
|
||||
If the main and assistant models have different tokenizers, use Universal Assisted Decoding.
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
>>> prompt = "Alice and Bob"
|
||||
>>> checkpoint = "google/gemma-2-9b"
|
||||
>>> assistant_checkpoint = "double7/vicuna-68m"
|
||||
|
||||
>>> assistant_tokenizer = AutoTokenizer.from_pretrained(assistant_checkpoint)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
|
||||
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
|
||||
>>> outputs = model.generate(**inputs, assistant_model=assistant_model, tokenizer=tokenizer, assistant_tokenizer=assistant_tokenizer)
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
|
||||
```
|
||||
|
||||
When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness,
|
||||
just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency.
|
||||
|
||||
|
|
@ -486,9 +456,63 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t
|
|||
['Alice and Bob, a couple of friends of mine, who are both in the same office as']
|
||||
```
|
||||
|
||||
#### Universal Assisted Decoding
|
||||
|
||||
Universal Assisted Decoding (UAD) adds support for main and assistant models with different tokenizers.
|
||||
To use it, simply pass the tokenizers using the `tokenizer` and `assistant_tokenizer` arguments (see below).
|
||||
Internally, the main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are
|
||||
in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above.
|
||||
The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer.
|
||||
Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings,
|
||||
to ensure the new tokens include the correct prompt suffix.
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
>>> prompt = "Alice and Bob"
|
||||
>>> checkpoint = "google/gemma-2-9b"
|
||||
>>> assistant_checkpoint = "double7/vicuna-68m"
|
||||
|
||||
>>> assistant_tokenizer = AutoTokenizer.from_pretrained(assistant_checkpoint)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
|
||||
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
|
||||
>>> outputs = model.generate(**inputs, assistant_model=assistant_model, tokenizer=tokenizer, assistant_tokenizer=assistant_tokenizer)
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
|
||||
```
|
||||
|
||||
#### Prompt Lookup
|
||||
|
||||
Alternatively, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed
|
||||
to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259).
|
||||
|
||||
#### Self-Speculative Decoding
|
||||
|
||||
An LLM can be trained to also use its language modeling head with earlier hidden states as input, effectively
|
||||
skipping layers to yield a lower-quality output -- a technique called early exiting.
|
||||
We use the lower-quality early exit output as an assistant output, and apply self-speculation to fix the output using the remaining layers. The final generation of that self-speculative solution is the same (or has the same distribution) as the original model's generation.
|
||||
If the model you're using was trained to do early exit, you can pass
|
||||
`assistant_early_exit` (integer). In this case, the assistant model will be the same model but exiting early, hence the
|
||||
"self-speculative" name. Because the assistant model is a portion of the target model, caches and weights can be shared, which results in lower memory requirements. As in other assisted generation methods, the final generated result has the same quality as if no assistant had been used.
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
>>> prompt = "Alice and Bob"
|
||||
>>> checkpoint = "facebook/layerskip-llama3.2-1B"
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
|
||||
>>> outputs = model.generate(**inputs, assistant_early_exit=4, do_sample=False, max_new_tokens=20)
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
|
||||
```
|
||||
|
||||
### DoLa Decoding
|
||||
|
||||
**D**ecoding by C**o**ntrasting **La**yers (DoLa) is a contrastive decoding strategy to improve the factuality and reduce the
|
||||
|
|
|
|||
|
|
@ -433,19 +433,22 @@ class DynamicCache(Cache):
|
|||
self._seen_tokens += key_states.shape[-2]
|
||||
|
||||
# Update the cache
|
||||
if len(self.key_cache) <= layer_idx:
|
||||
# There may be skipped layers, fill them with empty lists
|
||||
for _ in range(len(self.key_cache), layer_idx):
|
||||
self.key_cache.append([])
|
||||
self.value_cache.append([])
|
||||
self.key_cache.append(key_states)
|
||||
self.value_cache.append(value_states)
|
||||
elif len(self.key_cache[layer_idx]) == 0: # fills previously skipped layers; checking for tensor causes errors
|
||||
self.key_cache[layer_idx] = key_states
|
||||
self.value_cache[layer_idx] = value_states
|
||||
else:
|
||||
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
||||
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
||||
if key_states is not None:
|
||||
if len(self.key_cache) <= layer_idx:
|
||||
# There may be skipped layers, fill them with empty lists
|
||||
for _ in range(len(self.key_cache), layer_idx):
|
||||
self.key_cache.append([])
|
||||
self.value_cache.append([])
|
||||
self.key_cache.append(key_states)
|
||||
self.value_cache.append(value_states)
|
||||
elif (
|
||||
len(self.key_cache[layer_idx]) == 0
|
||||
): # fills previously skipped layers; checking for tensor causes errors
|
||||
self.key_cache[layer_idx] = key_states
|
||||
self.value_cache[layer_idx] = value_states
|
||||
else:
|
||||
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
||||
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
||||
|
||||
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
||||
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ else:
|
|||
_import_structure["candidate_generator"] = [
|
||||
"AssistedCandidateGenerator",
|
||||
"CandidateGenerator",
|
||||
"EarlyExitCandidateGenerator",
|
||||
"PromptLookupCandidateGenerator",
|
||||
]
|
||||
_import_structure["logits_process"] = [
|
||||
|
|
@ -206,7 +207,12 @@ if TYPE_CHECKING:
|
|||
else:
|
||||
from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
|
||||
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||
from .candidate_generator import AssistedCandidateGenerator, CandidateGenerator, PromptLookupCandidateGenerator
|
||||
from .candidate_generator import (
|
||||
AssistedCandidateGenerator,
|
||||
CandidateGenerator,
|
||||
EarlyExitCandidateGenerator,
|
||||
PromptLookupCandidateGenerator,
|
||||
)
|
||||
from .logits_process import (
|
||||
AlternatingCodebooksLogitsProcessor,
|
||||
ClassifierFreeGuidanceLogitsProcessor,
|
||||
|
|
|
|||
|
|
@ -670,6 +670,62 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
|
|||
return
|
||||
|
||||
|
||||
class EarlyExitCandidateGenerator(AssistedCandidateGenerator):
|
||||
"""
|
||||
`CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates
|
||||
candidates through the use of **the model itself**, exiting early. Can only be used with models that support early
|
||||
exit, e.g., `facebook/layerskip-llama3.2-1B`.
|
||||
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
||||
assistant_model (`PreTrainedModel`):
|
||||
The original model. This model must support early exit (i.e. is trained to compute logits in earlier
|
||||
layers).
|
||||
generation_config (`~generation.GenerationConfig`, *optional*):
|
||||
The generation configuration to be used as base parametrization for the generation call.
|
||||
logits_processor (`LogitsProcessorList`):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
||||
used to modify the prediction scores of the language modeling head applied at each generation step.
|
||||
model_kwargs (`Dict`):
|
||||
The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant
|
||||
model as well.
|
||||
inputs_tensor (`torch.Tensor`, *optional*):
|
||||
The model input tensor. In encoder-decoder models, this is the encoder input.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
assistant_model: "PreTrainedModel",
|
||||
generation_config: "GenerationConfig",
|
||||
model_kwargs: Dict,
|
||||
inputs_tensor: Optional[torch.Tensor] = None,
|
||||
logits_processor: "LogitsProcessorList" = None,
|
||||
):
|
||||
super().__init__(
|
||||
input_ids=input_ids,
|
||||
assistant_model=assistant_model,
|
||||
generation_config=generation_config,
|
||||
model_kwargs=model_kwargs,
|
||||
inputs_tensor=inputs_tensor,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
# We have to move early exit out of the generation config, otherwise the assistant will also call `generate`
|
||||
# with early exit
|
||||
self.assistant_early_exit = self.generation_config.assistant_early_exit
|
||||
self.generation_config.assistant_early_exit = None
|
||||
|
||||
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
|
||||
# Temporarily sets the number of hidden layers to the early exit value
|
||||
base_model = getattr(self.assistant_model, self.assistant_model.base_model_prefix)
|
||||
original_num_hidden_layers = base_model.config.num_hidden_layers
|
||||
base_model.config.num_hidden_layers = self.assistant_early_exit
|
||||
candidate_ids, candidate_logits = super().get_candidates(input_ids)
|
||||
base_model.config.num_hidden_layers = original_num_hidden_layers
|
||||
return candidate_ids, candidate_logits
|
||||
|
||||
|
||||
def _crop_past_key_values(model, past_key_values, max_length):
|
||||
"""Crops the past key values up to a certain maximum length."""
|
||||
new_past = []
|
||||
|
|
|
|||
|
|
@ -353,10 +353,13 @@ class GenerationConfig(PushToHubMixin):
|
|||
than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_
|
||||
(defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead
|
||||
from Dynamic Speculation Lookahead Accelerates Speculative Decoding of Large Language Models <https://arxiv.org/abs/2405.04304>.
|
||||
prompt_lookup_num_tokens (`int`, *optional*, default to `None`):
|
||||
prompt_lookup_num_tokens (`int`, *optional*):
|
||||
The number of tokens to be output as candidate tokens.
|
||||
max_matching_ngram_size (`int`, *optional*, default to `None`):
|
||||
max_matching_ngram_size (`int`, *optional*):
|
||||
The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided.
|
||||
assistant_early_exit(`int`, *optional*):
|
||||
If set to a positive integer, early exit of the model will be used as an assistant. Can only be used with
|
||||
models that support early exit (i.e. models where logits from intermediate layers can be interpreted by the LM head).
|
||||
|
||||
> Wild card
|
||||
|
||||
|
|
@ -454,10 +457,9 @@ class GenerationConfig(PushToHubMixin):
|
|||
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 20)
|
||||
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "constant")
|
||||
self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", 0.4)
|
||||
|
||||
# Prompt lookup decoding
|
||||
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
|
||||
self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None)
|
||||
self.assistant_early_exit = kwargs.pop("assistant_early_exit", None)
|
||||
|
||||
# Wild card
|
||||
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
|
||||
|
|
@ -534,7 +536,11 @@ class GenerationConfig(PushToHubMixin):
|
|||
generation_mode = GenerationMode.BEAM_SEARCH
|
||||
|
||||
# Assisted generation may extend some generation modes
|
||||
if assistant_model is not None or self.prompt_lookup_num_tokens is not None:
|
||||
if (
|
||||
assistant_model is not None
|
||||
or self.prompt_lookup_num_tokens is not None
|
||||
or self.assistant_early_exit is not None
|
||||
):
|
||||
if generation_mode in ("greedy_search", "sample"):
|
||||
generation_mode = GenerationMode.ASSISTED_GENERATION
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ from .candidate_generator import (
|
|||
AssistedCandidateGenerator,
|
||||
AssistedCandidateGeneratorDifferentTokenizers,
|
||||
CandidateGenerator,
|
||||
EarlyExitCandidateGenerator,
|
||||
PromptLookupCandidateGenerator,
|
||||
_crop_past_key_values,
|
||||
_prepare_attention_mask,
|
||||
|
|
@ -822,7 +823,16 @@ class GenerationMixin:
|
|||
"""
|
||||
different_tokenizers = all(v is not None for v in (assistant_model, target_tokenizer, assistant_tokenizer))
|
||||
|
||||
if generation_config.prompt_lookup_num_tokens is not None:
|
||||
if generation_config.assistant_early_exit is not None:
|
||||
candidate_generator = EarlyExitCandidateGenerator(
|
||||
input_ids=input_ids,
|
||||
assistant_model=self,
|
||||
generation_config=generation_config,
|
||||
model_kwargs=model_kwargs,
|
||||
inputs_tensor=inputs_tensor,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
elif generation_config.prompt_lookup_num_tokens is not None:
|
||||
candidate_generator = PromptLookupCandidateGenerator(
|
||||
eos_token_id=generation_config._eos_token_tensor,
|
||||
num_output_tokens=generation_config.prompt_lookup_num_tokens,
|
||||
|
|
|
|||
|
|
@ -890,7 +890,7 @@ class CohereModel(CoherePreTrainedModel):
|
|||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
|
|
|
|||
|
|
@ -808,7 +808,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
|
|
|
|||
|
|
@ -886,7 +886,7 @@ class GemmaModel(LlamaModel):
|
|||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
|
|
|
|||
|
|
@ -823,7 +823,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
|||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
|
|
|
|||
|
|
@ -653,7 +653,7 @@ class Gemma2Model(GemmaModel, Gemma2PreTrainedModel):
|
|||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
|
|
|
|||
|
|
@ -789,7 +789,7 @@ class GlmModel(GlmPreTrainedModel):
|
|||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
|
|
|
|||
|
|
@ -893,7 +893,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
|
|
|
|||
|
|
@ -995,7 +995,7 @@ class OlmoeModel(OlmoePreTrainedModel):
|
|||
all_router_logits = () if output_router_logits else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
|
|
|
|||
|
|
@ -4108,6 +4108,28 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||
gen_out = compiled_generate(**model_inputs, generation_config=generation_config)
|
||||
self.assertTrue(gen_out.shape[1] > model_inputs["input_ids"].shape[1]) # some text was generated
|
||||
|
||||
def test_assisted_generation_early_exit(self):
|
||||
"""
|
||||
Tests that assisted generation with early exit works as expected. Under the hood, this has complex cache
|
||||
manipulation, which will cause the test to fail if something goes wrong there.
|
||||
"""
|
||||
expected_output = "Alice and Bob are playing a game of poker. Alice has a pair of 8s and Bob has a pair"
|
||||
|
||||
prompt = "Alice and Bob"
|
||||
checkpoint = "facebook/layerskip-llama3.2-1B"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(torch_device)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(torch_device)
|
||||
original_outputs = model.generate(**inputs, do_sample=False, max_new_tokens=20)
|
||||
original_decoded = tokenizer.batch_decode(original_outputs, skip_special_tokens=True)
|
||||
self.assertEqual(original_decoded, [expected_output])
|
||||
|
||||
outputs_assisted = model.generate(**inputs, assistant_early_exit=4, do_sample=False, max_new_tokens=20)
|
||||
decoded_assisted = tokenizer.batch_decode(outputs_assisted, skip_special_tokens=True)
|
||||
self.assertEqual(decoded_assisted, [expected_output])
|
||||
|
||||
|
||||
@require_torch
|
||||
class TokenHealingTestCase(unittest.TestCase):
|
||||
|
|
|
|||
Loading…
Reference in a new issue