From bcb841f0073fcd7a4fb88ea8064313c17dcab04a Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Sat, 25 Jan 2025 02:13:07 +0800 Subject: [PATCH] add xpu device check in device_placement (#35865) add xpu device --- src/transformers/pipelines/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index f7480da35..971c0211b 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -1148,6 +1148,9 @@ class Pipeline(_ScikitCompat, PushToHubMixin): elif self.device.type == "musa": with torch.musa.device(self.device): yield + elif self.device.type == "xpu": + with torch.xpu.device(self.device): + yield else: yield