mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Zero shot distillation script cuda patch (#10284)
This commit is contained in:
parent
f1299f5038
commit
cbadb5243c
1 changed files with 1 additions and 1 deletions
|
|
@ -174,7 +174,7 @@ def get_teacher_predictions(
|
|||
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
||||
model_config = model.config
|
||||
if not no_cuda and torch.cuda.is_available():
|
||||
model = nn.DataParallel(model)
|
||||
model = nn.DataParallel(model.cuda())
|
||||
batch_size *= len(model.device_ids)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=use_fast_tokenizer)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue