diff --git a/src/transformers/data/processors/squad.py b/src/transformers/data/processors/squad.py index 0f8bd2480..4677af124 100644 --- a/src/transformers/data/processors/squad.py +++ b/src/transformers/data/processors/squad.py @@ -249,7 +249,7 @@ def squad_convert_example_to_features( else: p_mask[-len(span["tokens"]) : -(len(truncated_query) + sequence_added_tokens)] = 0 - pad_token_indices = np.where(span["input_ids"] == tokenizer.pad_token_id) + pad_token_indices = np.where(np.atleast_1d(span["input_ids"] == tokenizer.pad_token_id)) special_token_indices = np.asarray( tokenizer.get_special_tokens_mask(span["input_ids"], already_has_special_tokens=True) ).nonzero()