Fix GPU OOM for mistral.py::Mask4DTestHard (#31212)

* build

* build

* build

* build

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2024-06-03 19:25:15 +02:00 committed by GitHub
parent df5abae894
commit 8a1a23ae4d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -734,15 +734,24 @@ class MistralIntegrationTest(unittest.TestCase):
@slow
@require_torch_gpu
class Mask4DTestHard(unittest.TestCase):
model_name = "mistralai/Mistral-7B-v0.1"
_model = None
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
@property
def model(self):
if self.__class__._model is None:
self.__class__._model = MistralForCausalLM.from_pretrained(
self.model_name, torch_dtype=self.model_dtype
).to(torch_device)
return self.__class__._model
def setUp(self):
model_name = "mistralai/Mistral-7B-v0.1"
self.model_dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
self.model = MistralForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
self.model_dtype = torch.float16
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False)
def get_test_data(self):
template = "my favorite {}"