diff --git a/docs/source/en/tasks/semantic_segmentation.md b/docs/source/en/tasks/semantic_segmentation.md index 267d0083b..6bb25e3e2 100644 --- a/docs/source/en/tasks/semantic_segmentation.md +++ b/docs/source/en/tasks/semantic_segmentation.md @@ -221,6 +221,10 @@ logits first, and then reshaped to match the size of the labels before you can c ```py +>>> import numpy as np +>>> import torch +>>> from torch import nn + >>> def compute_metrics(eval_pred): ... with torch.no_grad(): ... logits, labels = eval_pred