diff --git a/docs/source/training.rst b/docs/source/training.rst index 7fe149847..5d0cbe982 100644 --- a/docs/source/training.rst +++ b/docs/source/training.rst @@ -282,7 +282,7 @@ your own ``compute_metrics`` function and pass it to the trainer. .. code-block:: python - from sklearn.metrics import precision_recall_fscore_support + from sklearn.metrics import accuracy_score, precision_recall_fscore_support def compute_metrics(pred): labels = pred.label_ids