mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Liqun/a transformer example (#3845)
Add transformer glue test example to show how to use ORTTrainer to fine-tune a transformer model Co-authored-by: liqun <liqun@OrtTrainingDev4.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
parent
a983509ed3
commit
6665d5e2bc
7 changed files with 582 additions and 3 deletions
196
orttraining/orttraining/test/python/orttraining_run_glue.py
Normal file
196
orttraining/orttraining/test/python/orttraining_run_glue.py
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
# adapted from run_glue.py of huggingface transformers
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
EvalPrediction,
|
||||
GlueDataset,
|
||||
GlueDataTrainingArguments,
|
||||
TrainingArguments,
|
||||
glue_compute_metrics,
|
||||
glue_output_modes,
|
||||
glue_tasks_num_labels,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.capi.ort_trainer import ORTTrainer, LossScaler, ModelDescription, IODescription
|
||||
|
||||
from orttraining_transformer_trainer import ORTTransformerTrainer
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "model identifier from huggingface.co/models"}
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
||||
)
|
||||
|
||||
class ORTGlueTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# configurations not to be changed accoss tests
|
||||
self.max_seq_length = 128
|
||||
self.train_batch_size = 8
|
||||
self.learning_rate = 2e-5
|
||||
self.num_train_epochs = 3.0
|
||||
self.local_rank = -1
|
||||
self.overwrite_output_dir = True
|
||||
self.gradient_accumulation_steps = 1
|
||||
self.data_dir = "/bert_data/hf_data/glue_data/"
|
||||
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "glue_test_output/")
|
||||
self.cache_dir = '/tmp/glue/'
|
||||
self.logging_steps = 10
|
||||
|
||||
def test_bert_with_mrpc(self):
|
||||
results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=False)
|
||||
self.assertTrue(results['acc'] > 0.84)
|
||||
self.assertTrue(results['f1'] > 0.88)
|
||||
self.assertTrue(results['acc_and_f1'] > 0.86)
|
||||
self.assertTrue(results['loss'] < 0.47)
|
||||
|
||||
def test_bert_fp16_with_mrpc(self):
|
||||
results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=True)
|
||||
self.assertTrue(results['acc'] > 0.85)
|
||||
self.assertTrue(results['f1'] > 0.89)
|
||||
self.assertTrue(results['acc_and_f1'] > 0.87)
|
||||
self.assertTrue(results['loss'] < 0.46)
|
||||
|
||||
def run_glue(self, model_name, task_name, fp16):
|
||||
model_args = ModelArguments(model_name_or_path=model_name, cache_dir=self.cache_dir)
|
||||
data_args = GlueDataTrainingArguments(task_name=task_name, data_dir=self.data_dir + "/" + task_name,
|
||||
max_seq_length=self.max_seq_length)
|
||||
|
||||
training_args = TrainingArguments(output_dir=self.output_dir + "/" + task_name, do_train=True, do_eval=True,
|
||||
per_gpu_train_batch_size=self.train_batch_size,
|
||||
learning_rate=self.learning_rate, num_train_epochs=self.num_train_epochs,local_rank=self.local_rank,
|
||||
overwrite_output_dir=self.overwrite_output_dir, gradient_accumulation_steps=self.gradient_accumulation_steps,
|
||||
fp16=fp16, logging_steps=self.logging_steps)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
training_args.local_rank,
|
||||
training_args.device,
|
||||
training_args.n_gpu,
|
||||
bool(training_args.local_rank != -1),
|
||||
training_args.fp16,
|
||||
)
|
||||
logger.info("Training/evaluation parameters %s", training_args)
|
||||
|
||||
set_seed(training_args.seed)
|
||||
onnxruntime.set_seed(training_args.seed)
|
||||
|
||||
try:
|
||||
num_labels = glue_tasks_num_labels[data_args.task_name]
|
||||
output_mode = glue_output_modes[data_args.task_name]
|
||||
except KeyError:
|
||||
raise ValueError("Task not found: %s" % (data_args.task_name))
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=data_args.task_name,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
|
||||
train_dataset = (
|
||||
GlueDataset(data_args, tokenizer=tokenizer)
|
||||
if training_args.do_train
|
||||
else None
|
||||
)
|
||||
|
||||
print(data_args)
|
||||
print(training_args.local_rank)
|
||||
eval_dataset = (
|
||||
GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
|
||||
if training_args.do_eval
|
||||
else None
|
||||
)
|
||||
|
||||
def compute_metrics(p: EvalPrediction) -> Dict:
|
||||
if output_mode == "classification":
|
||||
preds = np.argmax(p.predictions, axis=1)
|
||||
elif output_mode == "regression":
|
||||
preds = np.squeeze(p.predictions)
|
||||
return glue_compute_metrics(data_args.task_name, preds, p.label_ids)
|
||||
|
||||
model_desc = ModelDescription([
|
||||
IODescription('input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=model.config.vocab_size),
|
||||
IODescription('attention_mask', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2),
|
||||
IODescription('token_type_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2),
|
||||
IODescription('labels', ['batch',], torch.int64, num_classes=2)], [
|
||||
IODescription('loss', [], torch.float32),
|
||||
IODescription('logits', ['batch', 2], torch.float32)])
|
||||
|
||||
# Initialize the ORTTrainer within ORTTransformerTrainer
|
||||
trainer = ORTTransformerTrainer(
|
||||
model=model,
|
||||
model_desc=model_desc,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval and training_args.local_rank in [-1, 0]:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
result = trainer.evaluate()
|
||||
|
||||
logger.info("***** Eval results {} *****".format(data_args.task_name))
|
||||
for key, value in result.items():
|
||||
logger.info(" %s = %s", key, value)
|
||||
|
||||
results.update(result)
|
||||
|
||||
return results
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -6,12 +6,14 @@ import unittest
|
|||
import shutil
|
||||
import pytest
|
||||
import os
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import (BertConfig, BertForPreTraining, BertModel)
|
||||
|
||||
from orttraining_test_data_loader import ids_tensor, BatchArgsOption
|
||||
from orttraining_test_utils import run_test, get_lr
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription, ModelDescription, LossScaler
|
||||
|
||||
import torch
|
||||
|
|
@ -141,6 +143,13 @@ class BertModelTest(unittest.TestCase):
|
|||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def create_and_check_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
seed = 42
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
onnxruntime.set_seed(seed)
|
||||
|
||||
model = BertForPreTraining(config=config)
|
||||
model.eval()
|
||||
loss, prediction_scores, seq_relationship_score = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids,
|
||||
|
|
|
|||
|
|
@ -102,6 +102,8 @@ def run_test(model, model_desc, device, args, gradient_accumulation_steps, fp16,
|
|||
kwargs[model.loss_scale_input_name] = loss_scale
|
||||
outputs = model(*args, **kwargs)
|
||||
|
||||
print(outputs[0])
|
||||
|
||||
# eval
|
||||
model.eval()
|
||||
if batch_args_option == BatchArgsOption.List:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,357 @@
|
|||
# adapted from Trainer.py of huggingface transformers
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data.dataset import Dataset
|
||||
# from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from transformers.data.data_collator import DataCollator, DefaultDataCollator
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.training_args import TrainingArguments
|
||||
|
||||
import onnxruntime
|
||||
from orttraining_test_bert_postprocess import postprocess_model
|
||||
from onnxruntime.capi.ort_trainer import ORTTrainer, LossScaler, ModelDescription, IODescription
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
_has_tensorboard = True
|
||||
except ImportError:
|
||||
try:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
_has_tensorboard = True
|
||||
except ImportError:
|
||||
_has_tensorboard = False
|
||||
|
||||
|
||||
def is_tensorboard_available():
|
||||
return _has_tensorboard
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
onnxruntime.set_seed(seed)
|
||||
|
||||
class EvalPrediction(NamedTuple):
|
||||
predictions: np.ndarray
|
||||
label_ids: np.ndarray
|
||||
|
||||
|
||||
class PredictionOutput(NamedTuple):
|
||||
predictions: np.ndarray
|
||||
label_ids: Optional[np.ndarray]
|
||||
metrics: Optional[Dict[str, float]]
|
||||
|
||||
|
||||
class TrainOutput(NamedTuple):
|
||||
global_step: int
|
||||
training_loss: float
|
||||
|
||||
def get_linear_schedule_with_warmup(num_warmup_steps, num_training_steps, base_lr):
|
||||
|
||||
def lr_lambda_linear(current_step):
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(1, num_warmup_steps))
|
||||
return max(
|
||||
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
|
||||
)
|
||||
|
||||
def lambda_lr_get_lr(current_global_step):
|
||||
# LambdaLR increment self.last_epoch at evert sept()
|
||||
return base_lr * lr_lambda_linear(current_global_step)
|
||||
|
||||
return lambda_lr_get_lr
|
||||
|
||||
class ORTTransformerTrainer:
|
||||
"""
|
||||
"""
|
||||
|
||||
model: PreTrainedModel
|
||||
args: TrainingArguments
|
||||
train_dataset: Dataset
|
||||
eval_dataset: Dataset
|
||||
compute_metrics: Callable[[EvalPrediction], Dict]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
model_desc: ModelDescription,
|
||||
args: TrainingArguments,
|
||||
train_dataset: Dataset,
|
||||
eval_dataset: Dataset,
|
||||
compute_metrics: Callable[[EvalPrediction], Dict],
|
||||
):
|
||||
"""
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
self.model_desc = model_desc
|
||||
self.args = args
|
||||
self.data_collator = DefaultDataCollator()
|
||||
self.train_dataset = train_dataset
|
||||
self.eval_dataset = eval_dataset
|
||||
self.compute_metrics = compute_metrics
|
||||
set_seed(self.args.seed)
|
||||
# Create output directory if needed
|
||||
if self.args.local_rank in [-1, 0]:
|
||||
os.makedirs(self.args.output_dir, exist_ok=True)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
if self.train_dataset is None:
|
||||
raise ValueError("Trainer: training requires a train_dataset.")
|
||||
train_sampler = (
|
||||
SequentialSampler(self.train_dataset) if self.args.local_rank == -1 else DistributedSampler(self.train_dataset)
|
||||
)
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.args.train_batch_size,
|
||||
sampler=train_sampler,
|
||||
collate_fn=self.data_collator.collate_batch,
|
||||
)
|
||||
|
||||
def get_eval_dataloader(self) -> DataLoader:
|
||||
return DataLoader(
|
||||
self.eval_dataset,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=self.data_collator.collate_batch,
|
||||
)
|
||||
|
||||
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
|
||||
# We use the same batch_size as for eval.
|
||||
return DataLoader(
|
||||
test_dataset,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=self.data_collator.collate_batch,
|
||||
)
|
||||
|
||||
|
||||
def train(self):
|
||||
"""
|
||||
Main training entry point.
|
||||
"""
|
||||
train_dataloader = self.get_train_dataloader()
|
||||
|
||||
if self.args.max_steps > 0:
|
||||
t_total = self.args.max_steps
|
||||
num_train_epochs = (
|
||||
self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
|
||||
)
|
||||
else:
|
||||
t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
|
||||
num_train_epochs = self.args.num_train_epochs
|
||||
|
||||
get_lr_this_step = get_linear_schedule_with_warmup(self.args.warmup_steps, t_total, self.args.learning_rate)
|
||||
loss_scaler = LossScaler('loss_scale_input_name', True, up_scale_window=2000)
|
||||
|
||||
def map_optimizer_attributes(name):
|
||||
# no_decay_keys = ["bias", "LayerNorm.weight"]
|
||||
no_decay = "bias" in name or "LayerNorm.weight" in name
|
||||
if no_decay:
|
||||
return {"weight_decay": 0.0, "weight_decay_mode" : 1}
|
||||
else:
|
||||
return {"weight_decay": self.args.weight_decay, "weight_decay_mode" : 1}
|
||||
|
||||
self.model = ORTTrainer(self.model, None,
|
||||
self.model_desc,
|
||||
"AdamOptimizer",
|
||||
map_optimizer_attributes=map_optimizer_attributes,
|
||||
learning_rate_description=IODescription('Learning_Rate', [1,], torch.float32),
|
||||
device=self.args.device,
|
||||
postprocess_model=postprocess_model,
|
||||
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
|
||||
world_rank=0, world_size=1, # only support single GPU cases
|
||||
use_mixed_precision=self.args.fp16,
|
||||
allreduce_post_accumulation=True,
|
||||
get_lr_this_step=get_lr_this_step,
|
||||
loss_scaler=loss_scaler,
|
||||
enable_grad_norm_clip=False,
|
||||
_opset_version=12)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataloader.dataset))
|
||||
logger.info(" Num Epochs = %d", num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", self.args.per_gpu_train_batch_size)
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
self.args.train_batch_size
|
||||
* self.args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
global_step = 0
|
||||
epochs_trained = 0
|
||||
steps_trained_in_current_epoch = 0
|
||||
|
||||
tr_loss = 0.0
|
||||
logging_loss = 0.0
|
||||
train_iterator = trange(
|
||||
epochs_trained, int(num_train_epochs), desc="Epoch", disable=self.args.local_rank not in [-1, 0],
|
||||
)
|
||||
|
||||
for epoch in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=self.args.local_rank not in [-1, 0])
|
||||
for step, inputs in enumerate(epoch_iterator):
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
continue
|
||||
|
||||
if step == 0:
|
||||
self.model.eval()
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = v.to(self.args.device)
|
||||
outputs = self.model(**inputs)
|
||||
|
||||
tr_loss += self._training_step(self.model, inputs)
|
||||
|
||||
|
||||
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
||||
len(epoch_iterator) <= self.args.gradient_accumulation_steps
|
||||
and (step + 1) == len(epoch_iterator)
|
||||
):
|
||||
global_step += 1
|
||||
|
||||
if self.args.local_rank in [-1, 0]:
|
||||
if (self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0) or (
|
||||
global_step == 1 and self.args.logging_first_step
|
||||
):
|
||||
logs = {}
|
||||
if self.args.evaluate_during_training:
|
||||
results = self.evaluate()
|
||||
for key, value in results.items():
|
||||
eval_key = "eval_{}".format(key)
|
||||
logs[eval_key] = value
|
||||
|
||||
loss_scalar = (tr_loss - logging_loss) / self.args.logging_steps
|
||||
learning_rate_scalar = get_lr_this_step(global_step)
|
||||
logs["learning_rate"] = learning_rate_scalar
|
||||
logs["loss"] = loss_scalar
|
||||
logging_loss = tr_loss
|
||||
|
||||
epoch_iterator.write(json.dumps({**logs, **{"step": global_step}}))
|
||||
|
||||
if self.args.max_steps > 0 and global_step > self.args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
if self.args.max_steps > 0 and global_step > self.args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
logger.info("\n\nTraining completed. \n\n")
|
||||
return TrainOutput(global_step, tr_loss / global_step)
|
||||
|
||||
def _training_step(
|
||||
self, model: ORTTrainer, inputs: Dict[str, torch.Tensor]
|
||||
) -> float:
|
||||
model.train()
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = v.to(self.args.device)
|
||||
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
|
||||
return loss.item()
|
||||
|
||||
def save_model(self, output_dir: Optional[str] = None):
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
self.model.save_as_onnx(os.path.join(output_dir, "transformer.onnx"))
|
||||
|
||||
def evaluate(self) -> Dict[str, float]:
|
||||
"""
|
||||
Run evaluation and return metrics.
|
||||
|
||||
Returns:
|
||||
A dict containing:
|
||||
- the eval loss
|
||||
- the potential metrics computed from the predictions
|
||||
"""
|
||||
eval_dataloader = self.get_eval_dataloader()
|
||||
|
||||
output = self._prediction_loop(eval_dataloader, description="Evaluation")
|
||||
return output.metrics
|
||||
|
||||
def predict(self, test_dataset: Dataset) -> PredictionOutput:
|
||||
"""
|
||||
Run prediction and return predictions and potential metrics.
|
||||
|
||||
Depending on the dataset and your use case, your test dataset may contain labels.
|
||||
In that case, this method will also return metrics, like in evaluate().
|
||||
"""
|
||||
test_dataloader = self.get_test_dataloader(test_dataset)
|
||||
return self._prediction_loop(test_dataloader, description="Prediction")
|
||||
|
||||
def _prediction_loop(
|
||||
self, dataloader: DataLoader, description: str
|
||||
) -> PredictionOutput:
|
||||
"""
|
||||
Prediction/evaluation loop, shared by `evaluate()` and `predict()`.
|
||||
|
||||
Works both with or without labels.
|
||||
"""
|
||||
|
||||
logger.info("***** Running %s *****", description)
|
||||
logger.info(" Num examples = %d", len(dataloader.dataset))
|
||||
logger.info(" Batch size = %d", dataloader.batch_size)
|
||||
eval_losses: List[float] = []
|
||||
preds: np.ndarray = None
|
||||
label_ids: np.ndarray = None
|
||||
self.model.eval()
|
||||
|
||||
for inputs in tqdm(dataloader, desc=description):
|
||||
has_labels = any(inputs.get(k) is not None for k in ["labels", "masked_lm_labels"])
|
||||
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = v.to(self.args.device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
if has_labels:
|
||||
step_eval_loss, logits = outputs[:2]
|
||||
eval_losses += [step_eval_loss.mean().item()]
|
||||
else:
|
||||
logits = outputs[0]
|
||||
|
||||
if preds is None:
|
||||
preds = logits.detach().cpu().numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||
if inputs.get("labels") is not None:
|
||||
if label_ids is None:
|
||||
label_ids = inputs["labels"].detach().cpu().numpy()
|
||||
else:
|
||||
label_ids = np.append(label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
|
||||
|
||||
if self.compute_metrics is not None and preds is not None and label_ids is not None:
|
||||
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
|
||||
else:
|
||||
metrics = {}
|
||||
if len(eval_losses) > 0:
|
||||
metrics["loss"] = np.mean(eval_losses)
|
||||
|
||||
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
|
||||
|
|
@ -1051,6 +1051,13 @@ def adb_shell(*args, **kwargs):
|
|||
def run_training_python_frontend_e2e_tests(args, cwd):
|
||||
# frontend tests are to be added here:
|
||||
log.info("Running python frontend e2e tests.")
|
||||
|
||||
# with orttraining_run_glue.py.
|
||||
# 1. we like to force to use single GPU (with CUDA_VISIBLE_DEVICES) for fine-tune tests.
|
||||
# 2. need to run test separately (not to mix between fp16 and full precision runs. this need to be investigated).
|
||||
run_subprocess([sys.executable, 'orttraining_run_glue.py', 'ORTGlueTest.test_bert_with_mrpc', '-v'], cwd=cwd, env={'CUDA_VISIBLE_DEVICES': '0'})
|
||||
run_subprocess([sys.executable, 'orttraining_run_glue.py', 'ORTGlueTest.test_bert_fp16_with_mrpc', '-v'], cwd=cwd, env={'CUDA_VISIBLE_DEVICES': '0'})
|
||||
|
||||
run_subprocess([sys.executable, 'orttraining_test_transformers.py'], cwd=cwd)
|
||||
|
||||
run_subprocess([sys.executable, 'onnxruntime_test_ort_trainer.py'], cwd=cwd)
|
||||
|
|
|
|||
|
|
@ -116,8 +116,11 @@ elif [ $DEVICE_TYPE = "gpu" ]; then
|
|||
if [[ $BUILD_EXTR_PAR = *--enable_training* ]]; then
|
||||
${PYTHON_EXE} -m pip install --upgrade --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html
|
||||
fi
|
||||
if [[ $BUILD_EXTR_PAR = *--enable_training_python_frontend_e2e_tests* ]]; then
|
||||
${PYTHON_EXE} -m pip install transformers
|
||||
if [[ $BUILD_EXTR_PAR = *--enable_training_python_frontend_e2e_tests* ]]; then
|
||||
${PYTHON_EXE} -m pip install transformers==v2.10.0
|
||||
|
||||
# transformers requires sklearn
|
||||
${PYTHON_EXE} -m pip install sklearn
|
||||
fi
|
||||
fi
|
||||
|
||||
|
|
|
|||
|
|
@ -119,6 +119,11 @@ if [ $BUILD_DEVICE = "openvino" ] && [[ $BUILD_EXTR_PAR == *"--use_openvino GPU_
|
|||
DOCKER_RUN_PARAMETER="$DOCKER_RUN_PARAMETER --device /dev/dri:/dev/dri"
|
||||
fi
|
||||
|
||||
if [[ $BUILD_EXTR_PAR = *--enable_training_python_frontend_e2e_tests* ]]; then
|
||||
DOCKER_RUN_PARAMETER="$DOCKER_RUN_PARAMETER --volume /bert_data/hf_data:/bert_data/hf_data"
|
||||
# DOCKER_RUN_PARAMETER="$DOCKER_RUN_PARAMETER -u0"
|
||||
fi
|
||||
|
||||
$DOCKER_CMD rm -f "onnxruntime-$BUILD_DEVICE" || true
|
||||
$DOCKER_CMD run $RUNTIME -h $HOSTNAME $DOCKER_RUN_PARAMETER \
|
||||
-e NIGHTLY_BUILD \
|
||||
|
|
|
|||
Loading…
Reference in a new issue