Add include_loss_for_metrics (#33088)

* Add include_loss_for_metrics

* Fix styling

* Initialize inputs and losses to avoid AttributeError

* Ruff styling

* Refactor compute_metrics and update EvalPrediction

* Change Naming

* Added include_for_metrics to group both args

* Fix style

* Change warnings to logger

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Manal ML 2024-10-01 15:51:41 +01:00 committed by GitHub
parent 5f9f58fc59
commit 0c4c2d7e07
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 65 additions and 51 deletions

View file

@ -4047,6 +4047,7 @@ class Trainer:
all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
metrics = None
eval_set_kwargs = {}
# Will be useful when we have an iterable dataset so don't know its length.
observed_num_examples = 0
@ -4064,7 +4065,9 @@ class Trainer:
# Prediction step
losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
main_input_name = getattr(self.model, "main_input_name", "input_ids")
inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None
inputs_decode = (
self._prepare_input(inputs[main_input_name]) if "inputs" in args.include_for_metrics else None
)
if is_torch_xla_available():
xm.mark_step()
@ -4098,16 +4101,13 @@ class Trainer:
if self.args.batch_eval_metrics:
if self.compute_metrics is not None and logits is not None and labels is not None:
is_last_step = self.accelerator.gradient_state.end_of_dataloader
if args.include_inputs_for_metrics:
metrics = self.compute_metrics(
EvalPrediction(predictions=logits, label_ids=labels, inputs=inputs),
compute_result=is_last_step,
)
else:
metrics = self.compute_metrics(
EvalPrediction(predictions=logits, label_ids=labels),
compute_result=is_last_step,
)
batch_kwargs = {}
batch_kwargs["losses"] = losses if "loss" in args.include_for_metrics else None
batch_kwargs["inputs"] = inputs if "inputs" in args.include_for_metrics else None
metrics = self.compute_metrics(
EvalPrediction(predictions=logits, label_ids=labels, **batch_kwargs),
compute_result=is_last_step,
)
del losses, logits, labels, inputs
torch.cuda.empty_cache()
@ -4156,12 +4156,11 @@ class Trainer:
and all_labels is not None
and not self.args.batch_eval_metrics
):
if args.include_inputs_for_metrics:
metrics = self.compute_metrics(
EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
)
else:
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
eval_set_kwargs["losses"] = all_losses if "loss" in args.include_for_metrics else None
eval_set_kwargs["inputs"] = all_inputs if "inputs" in args.include_for_metrics else None
metrics = self.compute_metrics(
EvalPrediction(predictions=all_preds, label_ids=all_labels, **eval_set_kwargs)
)
elif metrics is None:
metrics = {}
@ -4634,6 +4633,7 @@ class Trainer:
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None
metrics: Optional[dict] = None
eval_set_kwargs: dict = {}
world_size = max(1, args.world_size)
@ -4660,7 +4660,9 @@ class Trainer:
for step, inputs in enumerate(dataloader):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
main_input_name = getattr(self.model, "main_input_name", "input_ids")
inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None
inputs_decode = (
self._prepare_input(inputs[main_input_name]) if "inputs" in args.include_for_metrics else None
)
if loss is not None:
losses = loss.repeat(batch_size)
@ -4680,16 +4682,13 @@ class Trainer:
if self.args.batch_eval_metrics:
if self.compute_metrics is not None and preds_host is not None and labels_host is not None:
is_last_step = self.accelerator.gradient_state.end_of_dataloader
if args.include_inputs_for_metrics:
metrics = self.compute_metrics(
EvalPrediction(predictions=preds_host, label_ids=labels_host, inputs=inputs_host),
compute_result=is_last_step,
)
else:
metrics = self.compute_metrics(
EvalPrediction(predictions=preds_host, label_ids=labels_host),
compute_result=is_last_step,
)
batch_kwargs = {}
batch_kwargs["losses"] = losses_host if "loss" in args.include_for_metrics else None
batch_kwargs["inputs"] = inputs_host if "inputs" in args.include_for_metrics else None
metrics = self.compute_metrics(
EvalPrediction(predictions=preds_host, label_ids=labels_host, **batch_kwargs),
compute_result=is_last_step,
)
if self.args.batch_eval_metrics or (
args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0
@ -4728,12 +4727,9 @@ class Trainer:
and label_ids is not None
and not self.args.batch_eval_metrics
):
if args.include_inputs_for_metrics:
metrics = self.compute_metrics(
EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids)
)
else:
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
eval_set_kwargs["losses"] = eval_loss if "loss" in args.include_for_metrics else None
eval_set_kwargs["inputs"] = inputs_ids if "inputs" in args.include_for_metrics else None
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids, **eval_set_kwargs))
elif metrics is None:
metrics = {}

View file

@ -156,7 +156,8 @@ class EvalPrediction:
Parameters:
predictions (`np.ndarray`): Predictions of the model.
label_ids (`np.ndarray`): Targets to be matched.
inputs (`np.ndarray`, *optional*):
inputs (`np.ndarray`, *optional*): Input data passed to the model.
losses (`np.ndarray`, *optional*): Loss values computed during evaluation.
"""
def __init__(
@ -164,28 +165,25 @@ class EvalPrediction:
predictions: Union[np.ndarray, Tuple[np.ndarray]],
label_ids: Union[np.ndarray, Tuple[np.ndarray]],
inputs: Optional[Union[np.ndarray, Tuple[np.ndarray]]] = None,
losses: Optional[Union[np.ndarray, Tuple[np.ndarray]]] = None,
):
self.predictions = predictions
self.label_ids = label_ids
self.inputs = inputs
self.losses = losses
self.elements = (self.predictions, self.label_ids)
if self.inputs is not None:
self.elements += (self.inputs,)
if self.losses is not None:
self.elements += (self.losses,)
def __iter__(self):
if self.inputs is not None:
return iter((self.predictions, self.label_ids, self.inputs))
else:
return iter((self.predictions, self.label_ids))
return iter(self.elements)
def __getitem__(self, idx):
if idx < 0 or idx > 2:
if idx < 0 or idx >= len(self.elements):
raise IndexError("tuple index out of range")
if idx == 2 and self.inputs is None:
raise IndexError("tuple index out of range")
if idx == 0:
return self.predictions
elif idx == 1:
return self.label_ids
elif idx == 2:
return self.inputs
return self.elements[idx]
class EvalLoopOutput(NamedTuple):

View file

@ -707,8 +707,12 @@ class TrainingArguments:
gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`):
Key word arguments to be passed to the `gradient_checkpointing_enable` method.
include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics
that need inputs, predictions and references for scoring calculation in Metric class.
This argument is deprecated. Use `include_for_metrics` instead, e.g, `include_for_metrics = ["inputs"]`.
include_for_metrics (`List[str]`, *optional*, defaults to `[]`):
Include additional data in the `compute_metrics` function if needed for metrics computation.
Possible options to add to `include_for_metrics` list:
- `"inputs"`: Input data passed to the model, intended for calculating input dependent metrics.
- `"loss"`: Loss values computed during evaluation, intended for calculating loss dependent metrics.
eval_do_concat_batches (`bool`, *optional*, defaults to `True`):
Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`,
will instead store them as lists, with each batch kept separate.
@ -1362,7 +1366,17 @@ class TrainingArguments:
},
)
include_inputs_for_metrics: bool = field(
default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."}
default=False,
metadata={
"help": "This argument is deprecated and will be removed in version 5 of 🤗 Transformers. Use `include_for_metrics` instead."
},
)
include_for_metrics: List[str] = field(
default_factory=list,
metadata={
"help": "List of strings to specify additional data to include in the `compute_metrics` function."
"Options: 'inputs', 'loss'."
},
)
eval_do_concat_batches: bool = field(
default=True,
@ -2064,6 +2078,12 @@ class TrainingArguments:
"This is not supported and we recommend you to update your version."
)
if self.include_inputs_for_metrics:
logger.warning(
"Using `include_inputs_for_metrics` is deprecated and will be removed in version 5 of 🤗 Transformers. Please use `include_for_metrics` list argument instead."
)
self.include_for_metrics.append("inputs")
def __str__(self):
self_as_dict = asdict(self)