From e2bffcfafdbc69e6f1ec4ee34ac321b589b6925e Mon Sep 17 00:00:00 2001 From: Isaac Chung <48971969+isaac-chung@users.noreply.github.com> Date: Fri, 27 Oct 2023 13:07:33 +0300 Subject: [PATCH] Add early stopping for Bark generation via logits processor (#26675) * add early stopping logits processor * black formmated * indent * follow method signature * actual logic * check for None * address comments on docstrings and method signature * add unit test under `LogitsProcessorTest` wip * unit test passing * black formatted * condition per sample * add to BarkModelIntegrationTests * wip BarkSemanticModelTest * rename and add to kwargs handling * not add to BarkSemanticModelTest * correct logic and assert last outputs tokens different in test * doc-builder style * read from kwargs as well * assert len of with less than that of without * ruff * add back seed and test case * add original impl default suggestion * doc-builder * rename and use softmax * switch back to LogitsProcessor and update docs wording * camelCase and spelling and saving compute * assert strictly less than * assert less than * expand test_generate_semantic_early_stop instead --- src/transformers/generation/logits_process.py | 32 +++++++++ .../bark/generation_configuration_bark.py | 6 ++ src/transformers/models/bark/modeling_bark.py | 16 ++++- tests/generation/test_logits_process.py | 17 +++++ tests/models/bark/test_modeling_bark.py | 66 ++++++++++++++++--- 5 files changed, 125 insertions(+), 12 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 0ea1d2bdb..f9af4f7ff 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1749,3 +1749,35 @@ class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): unconditional_logits = torch.nn.functional.log_softmax(logits[:, -1], dim=-1) out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits return out + + +class BarkEosPrioritizerLogitsProcessor(LogitsProcessor): + r"""This processor ensures that the EOS token is selected if its probability is greater than the `min_eos_p`. + + Args: + eos_token_id (`Union[int, List[int]]`): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + min_eos_p (`float`, *optional*): + Minimum end of speech threshold. + """ + + def __init__(self, eos_token_id: Union[int, List[int]], min_eos_p: float): + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + self.eos_token_id = eos_token_id + if min_eos_p is not None and min_eos_p <= 0: + raise ValueError(f"`min_eos_p` has to be a positive float, but is {min_eos_p}") + self.min_eos_p = min_eos_p + + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if self.min_eos_p: + probs = torch.nn.functional.softmax(scores.float(), dim=-1) + # create scores full of -inf except for the eos_token_id + early_stop_scores = torch.ones_like(scores) * -float("inf") + early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id] + + do_early_stop = probs[:, self.eos_token_id] > self.min_eos_p + scores = torch.where(do_early_stop, early_stop_scores, scores) + + return scores diff --git a/src/transformers/models/bark/generation_configuration_bark.py b/src/transformers/models/bark/generation_configuration_bark.py index 92d836333..7d7d98449 100644 --- a/src/transformers/models/bark/generation_configuration_bark.py +++ b/src/transformers/models/bark/generation_configuration_bark.py @@ -44,6 +44,7 @@ class BarkSemanticGenerationConfig(GenerationConfig): semantic_vocab_size=10_000, max_input_semantic_length=256, semantic_rate_hz=49.9, + min_eos_p=None, **kwargs, ): """Class that holds a generation configuration for [`BarkSemanticModel`]. @@ -86,6 +87,10 @@ class BarkSemanticGenerationConfig(GenerationConfig): Max length of semantic input vector. semantic_rate_hz (`float`, *optional*, defaults to 49.9): Semantic rate in Hertz. + min_eos_p (`float`, *optional*): + Minimum threshold of the probability of the EOS token for it to be sampled. This is an early stopping + strategy to mitigate potential unwanted generations at the end of a prompt. The original implementation + suggests a default value of 0.2. """ super().__init__( temperature=temperature, @@ -107,6 +112,7 @@ class BarkSemanticGenerationConfig(GenerationConfig): self.semantic_vocab_size = semantic_vocab_size self.max_input_semantic_length = max_input_semantic_length self.semantic_rate_hz = semantic_rate_hz + self.min_eos_p = min_eos_p class BarkCoarseGenerationConfig(GenerationConfig): diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 2708b00d0..8e5cf0d84 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -21,7 +21,11 @@ import torch from torch import nn from torch.nn import functional as F -from ...generation.logits_process import AlternatingCodebooksLogitsProcessor, SuppressTokensLogitsProcessor +from ...generation.logits_process import ( + AlternatingCodebooksLogitsProcessor, + BarkEosPrioritizerLogitsProcessor, + SuppressTokensLogitsProcessor, +) from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput from ...modeling_utils import PreTrainedModel, get_parameter_device from ...utils import ( @@ -798,12 +802,17 @@ class BarkSemanticModel(BarkCausalModel): suppress_tokens_logits_processor = SuppressTokensLogitsProcessor(tokens_to_suppress) + min_eos_p = kwargs.get("min_eos_p", semantic_generation_config.min_eos_p) + early_stopping_logits_processor = BarkEosPrioritizerLogitsProcessor( + eos_token_id=semantic_generation_config.eos_token_id, min_eos_p=min_eos_p + ) + # pass input_ids in order to stay consistent with the transformers generate method even though it is not used # (except to get the input seq_len - that's why we keep the first 257 tokens) semantic_output = super().generate( torch.ones((batch_size, max_input_semantic_length + 1), dtype=torch.int).to(self.device), input_embeds=input_embeds, - logits_processor=[suppress_tokens_logits_processor], + logits_processor=[suppress_tokens_logits_processor, early_stopping_logits_processor], generation_config=semantic_generation_config, **kwargs, ) # size: 10048 @@ -1559,7 +1568,8 @@ class BarkModel(BarkPreTrainedModel): kwargs_semantic = { # if "attention_mask" is set, it should not be passed to CoarseModel and FineModel - "attention_mask": kwargs.pop("attention_mask", None) + "attention_mask": kwargs.pop("attention_mask", None), + "min_eos_p": kwargs.pop("min_eos_p", None), } kwargs_coarse = {} kwargs_fine = {} diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index 32bd02936..15f5cf1e4 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -53,6 +53,7 @@ if is_torch_available(): TypicalLogitsWarper, UnbatchedClassifierFreeGuidanceLogitsProcessor, ) + from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor @require_torch @@ -800,3 +801,19 @@ class LogitsProcessorTest(unittest.TestCase): self.assertAlmostEqual(out[0].item(), res[0].item()) self.assertAlmostEqual(out[1].item(), res[1].item()) self.assertAlmostEqual(out[2].item(), res[2].item()) + + def test_early_stop_processor(self): + input_ids = None + eos_token_id = 2 + min_eos_p = 0.1 ## some small float + + scores = self._get_uniform_logits(2, 4) + scores[0][eos_token_id] = -6 ## less than log(min_eos_p) + + esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p) + actual_scores = esp(input_ids, scores) + expected_scores_list = [ + scores[0].tolist(), + [float("-inf"), float("-inf"), scores[0][0], float("-inf")], + ] + self.assertListEqual(actual_scores.tolist(), expected_scores_list) diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index 3a5de3014..d80ee24a1 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -917,7 +917,51 @@ class BarkModelIntegrationTests(unittest.TestCase): temperature=1.0, semantic_generation_config=self.semantic_generation_config, ) + self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist(), expected_output_ids) + @slow + def test_generate_semantic_early_stop(self): + input_ids = self.inputs + min_eos_p = 0.01 + + # fmt: off + # check first ids + expected_output_ids = [7363, 321, 41, 1461, 6915, 952, 326, 41, 41, 927,] + # fmt: on + + # Should be able to read min_eos_p from kwargs + with torch.no_grad(): + torch.manual_seed(0) + output_ids_without_min_eos_p = self.model.semantic.generate( + **input_ids, + do_sample=False, + temperature=0.9, + semantic_generation_config=self.semantic_generation_config, + ) + torch.manual_seed(0) + output_ids_kwargs = self.model.semantic.generate( + **input_ids, + do_sample=False, + temperature=0.9, + semantic_generation_config=self.semantic_generation_config, + min_eos_p=min_eos_p, + ) + self.assertListEqual(output_ids_without_min_eos_p[0, : len(expected_output_ids)].tolist(), expected_output_ids) + self.assertLess(len(output_ids_kwargs[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist())) + + # Should be able to read min_eos_p from the semantic generation config + self.semantic_generation_config.min_eos_p = min_eos_p + with torch.no_grad(): + torch.manual_seed(0) + output_ids = self.model.semantic.generate( + **input_ids, + do_sample=False, + temperature=0.9, + semantic_generation_config=self.semantic_generation_config, + ) + + self.assertEqual(output_ids.shape, output_ids_kwargs.shape) + self.assertLess(len(output_ids[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist())) self.assertListEqual(output_ids[0, : len(expected_output_ids)].tolist(), expected_output_ids) @slow @@ -1022,26 +1066,30 @@ class BarkModelIntegrationTests(unittest.TestCase): input_ids = self.inputs with torch.no_grad(): + torch.manual_seed(0) self.model.generate( **input_ids, do_sample=False, temperature=1.0, coarse_do_sample=True, coarse_temperature=0.7 ) - self.model.generate( + output_ids_without_min_eos_p = self.model.generate( **input_ids, - do_sample=False, - temperature=1.0, + do_sample=True, + temperature=0.9, coarse_do_sample=True, coarse_temperature=0.7, fine_temperature=0.3, ) - self.model.generate( + + output_ids_with_min_eos_p = self.model.generate( **input_ids, do_sample=True, - temperature=0.6, - penalty_alpha=0.6, - semantic_temperature=0.9, - coarse_temperature=0.2, - fine_temperature=0.1, + temperature=0.9, + coarse_temperature=0.7, + fine_temperature=0.3, + min_eos_p=0.1, ) + self.assertLess( + len(output_ids_with_min_eos_p[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist()) + ) @require_torch_gpu @slow