diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index f7480da35..971c0211b 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -1148,6 +1148,9 @@ class Pipeline(_ScikitCompat, PushToHubMixin): elif self.device.type == "musa": with torch.musa.device(self.device): yield + elif self.device.type == "xpu": + with torch.xpu.device(self.device): + yield else: yield