diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 57c436224..2349ce580 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -377,7 +377,7 @@ class EsmSelfAttention(nn.Module): if head_mask is not None: attention_probs = attention_probs * head_mask - context_layer = torch.matmul(attention_probs, value_layer) + context_layer = torch.matmul(attention_probs.to(value_layer.dtype), value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) diff --git a/src/transformers/quantizers/quantizer_bnb_4bit.py b/src/transformers/quantizers/quantizer_bnb_4bit.py index 6cea1b551..494bf1382 100644 --- a/src/transformers/quantizers/quantizer_bnb_4bit.py +++ b/src/transformers/quantizers/quantizer_bnb_4bit.py @@ -121,7 +121,7 @@ class Bnb4BitHfQuantizer(HfQuantizer): import bitsandbytes as bnb module, tensor_name = get_module_from_name(model, param_name) - if isinstance(module._parameters[tensor_name], bnb.nn.Params4bit): + if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit): # Add here check for loaded components' dtypes once serialization is implemented return True elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias": diff --git a/src/transformers/quantizers/quantizer_bnb_8bit.py b/src/transformers/quantizers/quantizer_bnb_8bit.py index 193da44d2..cc6942857 100644 --- a/src/transformers/quantizers/quantizer_bnb_8bit.py +++ b/src/transformers/quantizers/quantizer_bnb_8bit.py @@ -139,7 +139,7 @@ class Bnb8BitHfQuantizer(HfQuantizer): import bitsandbytes as bnb module, tensor_name = get_module_from_name(model, param_name) - if isinstance(module._parameters[tensor_name], bnb.nn.Int8Params): + if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Int8Params): if self.pre_quantized: if param_name.replace("weight", "SCB") not in state_dict.keys(): raise ValueError("Missing quantization component `SCB`") diff --git a/tests/models/esm/test_modeling_esm.py b/tests/models/esm/test_modeling_esm.py index d09326df6..7e99f86bb 100644 --- a/tests/models/esm/test_modeling_esm.py +++ b/tests/models/esm/test_modeling_esm.py @@ -18,7 +18,7 @@ import unittest from transformers import EsmConfig, is_torch_available -from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device +from transformers.testing_utils import TestCasePlus, require_bitsandbytes, require_torch, slow, torch_device from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask @@ -303,9 +303,9 @@ class EsmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): pass +@slow @require_torch class EsmModelIntegrationTest(TestCasePlus): - @slow def test_inference_masked_lm(self): with torch.no_grad(): model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D") @@ -323,7 +323,6 @@ class EsmModelIntegrationTest(TestCasePlus): ) self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) - @slow def test_inference_no_head(self): with torch.no_grad(): model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D") @@ -336,3 +335,18 @@ class EsmModelIntegrationTest(TestCasePlus): [[[0.1444, 0.5413, 0.3248], [0.3034, 0.0053, 0.3108], [0.3228, -0.2499, 0.3415]]] ) self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) + + @require_bitsandbytes + def test_inference_bitsandbytes(self): + model = EsmForMaskedLM.from_pretrained("facebook/esm2_t36_3B_UR50D", load_in_8bit=True) + + input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]) + # Just test if inference works + with torch.no_grad(): + _ = model(input_ids)[0] + + model = EsmForMaskedLM.from_pretrained("facebook/esm2_t36_3B_UR50D", load_in_4bit=True) + + input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]) + # Just test if inference works + _ = model(input_ids)[0]