mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Fix preprocess_function in run_summarization_flax.py (#14769)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
7e61d56a45
commit
a94105f95f
1 changed files with 1 additions and 1 deletions
|
|
@ -533,7 +533,7 @@ def main():
|
|||
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
decoder_input_ids = shift_tokens_right_fn(
|
||||
jnp.array(labels["input_ids"]), config.pad_token_id, config.decoder_start_token_id
|
||||
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
|
||||
)
|
||||
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue