mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
try 1
This commit is contained in:
parent
4c86fb37ff
commit
9c4b4efbbd
1 changed files with 10 additions and 4 deletions
|
|
@ -135,7 +135,7 @@ def foo():
|
|||
|
||||
while True:
|
||||
item = q.get()
|
||||
o, model, model_inputs = item
|
||||
o, model, model_inputs, put_output = item
|
||||
|
||||
if o['model_forward'] is None:
|
||||
#if isinstance(model_kwargs.get("past_key_values"), StaticCache):
|
||||
|
|
@ -147,7 +147,10 @@ def foo():
|
|||
o['model_forward'] = model_forward_3
|
||||
else:
|
||||
outputs = o['model_forward'](my_model, return_dict=True, **model_inputs)
|
||||
|
||||
o['outputs'] = outputs
|
||||
# only put if necessary!
|
||||
if put_output:
|
||||
p.put(o)
|
||||
|
||||
@dataclass
|
||||
class GenerateDecoderOnlyOutput(ModelOutput):
|
||||
|
|
@ -3289,12 +3292,15 @@ class GenerationMixin:
|
|||
i += 1
|
||||
else:
|
||||
if not already_compied:
|
||||
q.put((o, self, model_inputs))
|
||||
q.put((o, self, model_inputs, False))
|
||||
# use self
|
||||
outputs = self(**model_inputs, return_dict=True)
|
||||
else:
|
||||
# directly call (??)
|
||||
outputs = o['model_forward'](self, return_dict=True, **model_inputs)
|
||||
# outputs = o['model_forward'](self, return_dict=True, **model_inputs)
|
||||
q.put((o, self, model_inputs, True))
|
||||
item = p.get()
|
||||
outputs = item['outputs']
|
||||
|
||||
# if i == 1 and o['model_forward'] is None:
|
||||
# # don't join
|
||||
|
|
|
|||
Loading…
Reference in a new issue