Fix preprocess_function in run_summarization_flax.py (#14769)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2021-12-15 11:36:28 +01:00 committed by GitHub
parent 7e61d56a45
commit a94105f95f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -533,7 +533,7 @@ def main():
model_inputs["labels"] = labels["input_ids"]
decoder_input_ids = shift_tokens_right_fn(
jnp.array(labels["input_ids"]), config.pad_token_id, config.decoder_start_token_id
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
)
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)