diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 524e15580..0911b2491 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -27,8 +27,6 @@ from contextlib import contextmanager from os.path import abspath, exists from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union -from packaging import version - from ..dynamic_module_utils import custom_object_save from ..feature_extraction_utils import PreTrainedFeatureExtractor from ..image_processing_utils import BaseImageProcessor @@ -1015,12 +1013,7 @@ class Pipeline(_ScikitCompat): raise NotImplementedError("postprocess not implemented") def get_inference_context(self): - inference_context = ( - torch.inference_mode - if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.9.0") - else torch.no_grad - ) - return inference_context + return torch.no_grad def forward(self, model_inputs, **forward_params): with self.device_placement():