This commit is contained in:
ydshieh 2024-11-23 22:42:07 +01:00
parent 4c86fb37ff
commit 9c4b4efbbd

View file

@ -135,7 +135,7 @@ def foo():
while True: while True:
item = q.get() item = q.get()
o, model, model_inputs = item o, model, model_inputs, put_output = item
if o['model_forward'] is None: if o['model_forward'] is None:
#if isinstance(model_kwargs.get("past_key_values"), StaticCache): #if isinstance(model_kwargs.get("past_key_values"), StaticCache):
@ -147,7 +147,10 @@ def foo():
o['model_forward'] = model_forward_3 o['model_forward'] = model_forward_3
else: else:
outputs = o['model_forward'](my_model, return_dict=True, **model_inputs) outputs = o['model_forward'](my_model, return_dict=True, **model_inputs)
o['outputs'] = outputs
# only put if necessary!
if put_output:
p.put(o)
@dataclass @dataclass
class GenerateDecoderOnlyOutput(ModelOutput): class GenerateDecoderOnlyOutput(ModelOutput):
@ -3289,12 +3292,15 @@ class GenerationMixin:
i += 1 i += 1
else: else:
if not already_compied: if not already_compied:
q.put((o, self, model_inputs)) q.put((o, self, model_inputs, False))
# use self # use self
outputs = self(**model_inputs, return_dict=True) outputs = self(**model_inputs, return_dict=True)
else: else:
# directly call (??) # directly call (??)
outputs = o['model_forward'](self, return_dict=True, **model_inputs) # outputs = o['model_forward'](self, return_dict=True, **model_inputs)
q.put((o, self, model_inputs, True))
item = p.get()
outputs = item['outputs']
# if i == 1 and o['model_forward'] is None: # if i == 1 and o['model_forward'] is None:
# # don't join # # don't join