mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Add include_loss_for_metrics (#33088)
* Add include_loss_for_metrics * Fix styling * Initialize inputs and losses to avoid AttributeError * Ruff styling * Refactor compute_metrics and update EvalPrediction * Change Naming * Added include_for_metrics to group both args * Fix style * Change warnings to logger Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
5f9f58fc59
commit
0c4c2d7e07
3 changed files with 65 additions and 51 deletions
|
|
@ -4047,6 +4047,7 @@ class Trainer:
|
|||
all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
|
||||
|
||||
metrics = None
|
||||
eval_set_kwargs = {}
|
||||
|
||||
# Will be useful when we have an iterable dataset so don't know its length.
|
||||
observed_num_examples = 0
|
||||
|
|
@ -4064,7 +4065,9 @@ class Trainer:
|
|||
# Prediction step
|
||||
losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
|
||||
main_input_name = getattr(self.model, "main_input_name", "input_ids")
|
||||
inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None
|
||||
inputs_decode = (
|
||||
self._prepare_input(inputs[main_input_name]) if "inputs" in args.include_for_metrics else None
|
||||
)
|
||||
|
||||
if is_torch_xla_available():
|
||||
xm.mark_step()
|
||||
|
|
@ -4098,16 +4101,13 @@ class Trainer:
|
|||
if self.args.batch_eval_metrics:
|
||||
if self.compute_metrics is not None and logits is not None and labels is not None:
|
||||
is_last_step = self.accelerator.gradient_state.end_of_dataloader
|
||||
if args.include_inputs_for_metrics:
|
||||
metrics = self.compute_metrics(
|
||||
EvalPrediction(predictions=logits, label_ids=labels, inputs=inputs),
|
||||
compute_result=is_last_step,
|
||||
)
|
||||
else:
|
||||
metrics = self.compute_metrics(
|
||||
EvalPrediction(predictions=logits, label_ids=labels),
|
||||
compute_result=is_last_step,
|
||||
)
|
||||
batch_kwargs = {}
|
||||
batch_kwargs["losses"] = losses if "loss" in args.include_for_metrics else None
|
||||
batch_kwargs["inputs"] = inputs if "inputs" in args.include_for_metrics else None
|
||||
metrics = self.compute_metrics(
|
||||
EvalPrediction(predictions=logits, label_ids=labels, **batch_kwargs),
|
||||
compute_result=is_last_step,
|
||||
)
|
||||
|
||||
del losses, logits, labels, inputs
|
||||
torch.cuda.empty_cache()
|
||||
|
|
@ -4156,12 +4156,11 @@ class Trainer:
|
|||
and all_labels is not None
|
||||
and not self.args.batch_eval_metrics
|
||||
):
|
||||
if args.include_inputs_for_metrics:
|
||||
metrics = self.compute_metrics(
|
||||
EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
|
||||
)
|
||||
else:
|
||||
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
|
||||
eval_set_kwargs["losses"] = all_losses if "loss" in args.include_for_metrics else None
|
||||
eval_set_kwargs["inputs"] = all_inputs if "inputs" in args.include_for_metrics else None
|
||||
metrics = self.compute_metrics(
|
||||
EvalPrediction(predictions=all_preds, label_ids=all_labels, **eval_set_kwargs)
|
||||
)
|
||||
elif metrics is None:
|
||||
metrics = {}
|
||||
|
||||
|
|
@ -4634,6 +4633,7 @@ class Trainer:
|
|||
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
|
||||
inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None
|
||||
metrics: Optional[dict] = None
|
||||
eval_set_kwargs: dict = {}
|
||||
|
||||
world_size = max(1, args.world_size)
|
||||
|
||||
|
|
@ -4660,7 +4660,9 @@ class Trainer:
|
|||
for step, inputs in enumerate(dataloader):
|
||||
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
|
||||
main_input_name = getattr(self.model, "main_input_name", "input_ids")
|
||||
inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None
|
||||
inputs_decode = (
|
||||
self._prepare_input(inputs[main_input_name]) if "inputs" in args.include_for_metrics else None
|
||||
)
|
||||
|
||||
if loss is not None:
|
||||
losses = loss.repeat(batch_size)
|
||||
|
|
@ -4680,16 +4682,13 @@ class Trainer:
|
|||
if self.args.batch_eval_metrics:
|
||||
if self.compute_metrics is not None and preds_host is not None and labels_host is not None:
|
||||
is_last_step = self.accelerator.gradient_state.end_of_dataloader
|
||||
if args.include_inputs_for_metrics:
|
||||
metrics = self.compute_metrics(
|
||||
EvalPrediction(predictions=preds_host, label_ids=labels_host, inputs=inputs_host),
|
||||
compute_result=is_last_step,
|
||||
)
|
||||
else:
|
||||
metrics = self.compute_metrics(
|
||||
EvalPrediction(predictions=preds_host, label_ids=labels_host),
|
||||
compute_result=is_last_step,
|
||||
)
|
||||
batch_kwargs = {}
|
||||
batch_kwargs["losses"] = losses_host if "loss" in args.include_for_metrics else None
|
||||
batch_kwargs["inputs"] = inputs_host if "inputs" in args.include_for_metrics else None
|
||||
metrics = self.compute_metrics(
|
||||
EvalPrediction(predictions=preds_host, label_ids=labels_host, **batch_kwargs),
|
||||
compute_result=is_last_step,
|
||||
)
|
||||
|
||||
if self.args.batch_eval_metrics or (
|
||||
args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0
|
||||
|
|
@ -4728,12 +4727,9 @@ class Trainer:
|
|||
and label_ids is not None
|
||||
and not self.args.batch_eval_metrics
|
||||
):
|
||||
if args.include_inputs_for_metrics:
|
||||
metrics = self.compute_metrics(
|
||||
EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids)
|
||||
)
|
||||
else:
|
||||
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
|
||||
eval_set_kwargs["losses"] = eval_loss if "loss" in args.include_for_metrics else None
|
||||
eval_set_kwargs["inputs"] = inputs_ids if "inputs" in args.include_for_metrics else None
|
||||
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids, **eval_set_kwargs))
|
||||
elif metrics is None:
|
||||
metrics = {}
|
||||
|
||||
|
|
|
|||
|
|
@ -156,7 +156,8 @@ class EvalPrediction:
|
|||
Parameters:
|
||||
predictions (`np.ndarray`): Predictions of the model.
|
||||
label_ids (`np.ndarray`): Targets to be matched.
|
||||
inputs (`np.ndarray`, *optional*):
|
||||
inputs (`np.ndarray`, *optional*): Input data passed to the model.
|
||||
losses (`np.ndarray`, *optional*): Loss values computed during evaluation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -164,28 +165,25 @@ class EvalPrediction:
|
|||
predictions: Union[np.ndarray, Tuple[np.ndarray]],
|
||||
label_ids: Union[np.ndarray, Tuple[np.ndarray]],
|
||||
inputs: Optional[Union[np.ndarray, Tuple[np.ndarray]]] = None,
|
||||
losses: Optional[Union[np.ndarray, Tuple[np.ndarray]]] = None,
|
||||
):
|
||||
self.predictions = predictions
|
||||
self.label_ids = label_ids
|
||||
self.inputs = inputs
|
||||
self.losses = losses
|
||||
self.elements = (self.predictions, self.label_ids)
|
||||
if self.inputs is not None:
|
||||
self.elements += (self.inputs,)
|
||||
if self.losses is not None:
|
||||
self.elements += (self.losses,)
|
||||
|
||||
def __iter__(self):
|
||||
if self.inputs is not None:
|
||||
return iter((self.predictions, self.label_ids, self.inputs))
|
||||
else:
|
||||
return iter((self.predictions, self.label_ids))
|
||||
return iter(self.elements)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if idx < 0 or idx > 2:
|
||||
if idx < 0 or idx >= len(self.elements):
|
||||
raise IndexError("tuple index out of range")
|
||||
if idx == 2 and self.inputs is None:
|
||||
raise IndexError("tuple index out of range")
|
||||
if idx == 0:
|
||||
return self.predictions
|
||||
elif idx == 1:
|
||||
return self.label_ids
|
||||
elif idx == 2:
|
||||
return self.inputs
|
||||
return self.elements[idx]
|
||||
|
||||
|
||||
class EvalLoopOutput(NamedTuple):
|
||||
|
|
|
|||
|
|
@ -707,8 +707,12 @@ class TrainingArguments:
|
|||
gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`):
|
||||
Key word arguments to be passed to the `gradient_checkpointing_enable` method.
|
||||
include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics
|
||||
that need inputs, predictions and references for scoring calculation in Metric class.
|
||||
This argument is deprecated. Use `include_for_metrics` instead, e.g, `include_for_metrics = ["inputs"]`.
|
||||
include_for_metrics (`List[str]`, *optional*, defaults to `[]`):
|
||||
Include additional data in the `compute_metrics` function if needed for metrics computation.
|
||||
Possible options to add to `include_for_metrics` list:
|
||||
- `"inputs"`: Input data passed to the model, intended for calculating input dependent metrics.
|
||||
- `"loss"`: Loss values computed during evaluation, intended for calculating loss dependent metrics.
|
||||
eval_do_concat_batches (`bool`, *optional*, defaults to `True`):
|
||||
Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`,
|
||||
will instead store them as lists, with each batch kept separate.
|
||||
|
|
@ -1362,7 +1366,17 @@ class TrainingArguments:
|
|||
},
|
||||
)
|
||||
include_inputs_for_metrics: bool = field(
|
||||
default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."}
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "This argument is deprecated and will be removed in version 5 of 🤗 Transformers. Use `include_for_metrics` instead."
|
||||
},
|
||||
)
|
||||
include_for_metrics: List[str] = field(
|
||||
default_factory=list,
|
||||
metadata={
|
||||
"help": "List of strings to specify additional data to include in the `compute_metrics` function."
|
||||
"Options: 'inputs', 'loss'."
|
||||
},
|
||||
)
|
||||
eval_do_concat_batches: bool = field(
|
||||
default=True,
|
||||
|
|
@ -2064,6 +2078,12 @@ class TrainingArguments:
|
|||
"This is not supported and we recommend you to update your version."
|
||||
)
|
||||
|
||||
if self.include_inputs_for_metrics:
|
||||
logger.warning(
|
||||
"Using `include_inputs_for_metrics` is deprecated and will be removed in version 5 of 🤗 Transformers. Please use `include_for_metrics` list argument instead."
|
||||
)
|
||||
self.include_for_metrics.append("inputs")
|
||||
|
||||
def __str__(self):
|
||||
self_as_dict = asdict(self)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue