diff --git a/examples/research_projects/zero-shot-distillation/distill_classifier.py b/examples/research_projects/zero-shot-distillation/distill_classifier.py index f16038761..5012630a5 100644 --- a/examples/research_projects/zero-shot-distillation/distill_classifier.py +++ b/examples/research_projects/zero-shot-distillation/distill_classifier.py @@ -174,7 +174,7 @@ def get_teacher_predictions( model = AutoModelForSequenceClassification.from_pretrained(model_path) model_config = model.config if not no_cuda and torch.cuda.is_available(): - model = nn.DataParallel(model) + model = nn.DataParallel(model.cuda()) batch_size *= len(model.device_ids) tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=use_fast_tokenizer)