diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index fece24068..f3fee7d21 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -21,6 +21,7 @@ import os # See the License for the specific language governing permissions and # limitations under the License. import warnings +from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from numpy import isin @@ -638,6 +639,8 @@ def pipeline( " feature_extractor may not be compatible with the default model. Please provide a PreTrainedModel class" " or a path/identifier to a pretrained model when providing feature_extractor." ) + if isinstance(model, Path): + model = str(model) # Config is the primordial information item. # Instantiate config if needed diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 314b18366..626c7c4d3 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -356,6 +356,15 @@ class CommonPipelineTest(unittest.TestCase): self.assertEqual(pipe._batch_size, 2) self.assertEqual(pipe._num_workers, 1) + @require_torch + def test_pipeline_pathlike(self): + pipe = pipeline(model="hf-internal-testing/tiny-random-distilbert") + with tempfile.TemporaryDirectory() as d: + pipe.save_pretrained(d) + path = Path(d) + newpipe = pipeline(task="text-classification", model=path) + self.assertIsInstance(newpipe, TextClassificationPipeline) + @require_torch def test_pipeline_override(self): class MyPipeline(TextClassificationPipeline):