From 9c4b4efbbd2ddb5461b6346c6be7019bd22ebded Mon Sep 17 00:00:00 2001 From: ydshieh Date: Sat, 23 Nov 2024 22:42:07 +0100 Subject: [PATCH] try 1 --- src/transformers/generation/utils.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 55f8c9196..ee7348b80 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -135,7 +135,7 @@ def foo(): while True: item = q.get() - o, model, model_inputs = item + o, model, model_inputs, put_output = item if o['model_forward'] is None: #if isinstance(model_kwargs.get("past_key_values"), StaticCache): @@ -147,7 +147,10 @@ def foo(): o['model_forward'] = model_forward_3 else: outputs = o['model_forward'](my_model, return_dict=True, **model_inputs) - + o['outputs'] = outputs + # only put if necessary! + if put_output: + p.put(o) @dataclass class GenerateDecoderOnlyOutput(ModelOutput): @@ -3289,12 +3292,15 @@ class GenerationMixin: i += 1 else: if not already_compied: - q.put((o, self, model_inputs)) + q.put((o, self, model_inputs, False)) # use self outputs = self(**model_inputs, return_dict=True) else: # directly call (??) - outputs = o['model_forward'](self, return_dict=True, **model_inputs) + # outputs = o['model_forward'](self, return_dict=True, **model_inputs) + q.put((o, self, model_inputs, True)) + item = p.get() + outputs = item['outputs'] # if i == 1 and o['model_forward'] is None: # # don't join