mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
[bnb] Let's make serialization of int8 models possible (#22177)
* make serialization of int8 models possible * make fixup * add docs * add ability to push to hub and save pretrained * fixes * more addition * more tests * fix issues * change variable * clearer message * adapt from suggestions * few fixes * remove unused function * Update src/transformers/utils/quantization_config.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * address last comments * last warning * clarify doc * protect import * Update src/transformers/modeling_utils.py * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
523ca4e016
commit
370f0ca18c
6 changed files with 274 additions and 19 deletions
|
|
@ -52,6 +52,37 @@ Note that once a model has been loaded in 8-bit it is currently not possible to
|
|||
|
||||
</Tip>
|
||||
|
||||
### Push quantized models on the 🤗 Hub
|
||||
|
||||
You can push a quantized model on the Hub by naively using `push_to_hub` method. This will first push the quantization configuration file, then push the quantized model weights.
|
||||
Make sure to use `bitsandbytes>0.37.2` (at this time of writing, we tested it on `bitsandbytes==0.38.0.post1`) to be able to use this feature.
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", device_map="auto", load_in_8bit=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
|
||||
|
||||
model.push_to_hub("bloom-560m-8bit")
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Pushing 8bit models on the Hub is strongely encouraged for large models. This will allow the community to benefit from the memory footprint reduction and loading for example large models on a Google Colab.
|
||||
|
||||
</Tip>
|
||||
|
||||
### Load a quantized model from the 🤗 Hub
|
||||
|
||||
You can load a quantized model from the Hub by using `from_pretrained` method. Make sure that the pushed weights are quantized, by checking that the attribute `quantization_config` is present in the model configuration object.
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("{your_username}/bloom-560m-8bit")
|
||||
```
|
||||
Note that in this case, you don't need to specify the arguments `load_in_8bit=True` and `device_map="auto"`, but you need to make sure that `bitsandbytes` and `accelerate` are installed.
|
||||
|
||||
### Advanced usecases
|
||||
|
||||
This section is intended to advanced users, that want to explore what it is possible to do beyond loading and running 8-bit models.
|
||||
|
|
|
|||
|
|
@ -801,6 +801,13 @@ class PretrainedConfig(PushToHubMixin):
|
|||
# Transformers version when serializing the model
|
||||
output["transformers_version"] = __version__
|
||||
|
||||
if hasattr(self, "quantization_config"):
|
||||
output["quantization_config"] = (
|
||||
self.quantization_config.to_dict()
|
||||
if not isinstance(self.quantization_config, dict)
|
||||
else self.quantization_config
|
||||
)
|
||||
|
||||
self.dict_torch_dtype_to_str(output)
|
||||
|
||||
return output
|
||||
|
|
|
|||
|
|
@ -697,7 +697,15 @@ def _load_state_dict_into_meta_model(
|
|||
# For backward compatibility with older versions of `accelerate`
|
||||
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
|
||||
else:
|
||||
set_module_8bit_tensor_to_device(model, param_name, param_device, value=param)
|
||||
if param.dtype == torch.int8 and param_name.replace("weight", "SCB") in state_dict.keys():
|
||||
fp16_statistics = state_dict[param_name.replace("weight", "SCB")]
|
||||
else:
|
||||
fp16_statistics = None
|
||||
|
||||
if "SCB" not in param_name:
|
||||
set_module_8bit_tensor_to_device(
|
||||
model, param_name, param_device, value=param, fp16_statistics=fp16_statistics
|
||||
)
|
||||
|
||||
return error_msgs, offload_index, state_dict_index
|
||||
|
||||
|
|
@ -1700,10 +1708,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
# Checks if the model has been loaded in 8-bit
|
||||
if getattr(self, "is_loaded_in_8bit", False):
|
||||
if getattr(self, "is_loaded_in_8bit", False) and getattr(self, "is_8bit_serializable", False):
|
||||
warnings.warn(
|
||||
"You are calling `save_pretrained` to a 8-bit converted model you may likely encounter unexepected"
|
||||
" behaviors. ",
|
||||
" behaviors. If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
|
|
@ -2165,6 +2173,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
||||
|
||||
if is_bitsandbytes_available():
|
||||
is_8bit_serializable = version.parse(importlib_metadata.version("bitsandbytes")) > version.parse("0.37.2")
|
||||
else:
|
||||
is_8bit_serializable = False
|
||||
|
||||
if trust_remote_code is True:
|
||||
logger.warning(
|
||||
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
|
||||
|
|
@ -2207,6 +2220,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
"`quantization_config` argument at the same time."
|
||||
)
|
||||
|
||||
# in the case a user loads an 8bit model from the Hub and assigns a new quantization_config
|
||||
if device_map is None:
|
||||
device_map = "auto"
|
||||
if low_cpu_mem_usage is None:
|
||||
low_cpu_mem_usage = True
|
||||
|
||||
if load_in_8bit:
|
||||
if not (is_accelerate_available() and is_bitsandbytes_available()):
|
||||
raise ImportError(
|
||||
|
|
@ -2265,6 +2284,43 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
else:
|
||||
model_kwargs = kwargs
|
||||
|
||||
if is_8bit_serializable and quantization_config is not None and load_in_8bit:
|
||||
if hasattr(config, "quantization_config"):
|
||||
logger.warning(
|
||||
"You passed `quantization_config` to `from_pretrained` but the model you're loading already has a"
|
||||
" `quantization_config` attribute. The `quantization_config` attribute will be overwritten with the"
|
||||
" one you passed to `from_pretrained`."
|
||||
)
|
||||
config.quantization_config = quantization_config
|
||||
elif is_8bit_serializable and not load_in_8bit and hasattr(config, "quantization_config"):
|
||||
quantization_config = config.quantization_config
|
||||
if isinstance(quantization_config, dict):
|
||||
quantization_config = BitsAndBytesConfig.from_dict(quantization_config, return_unused_kwargs=False)
|
||||
elif isinstance(quantization_config, BitsAndBytesConfig):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid type for `quantization_config`: {type(quantization_config)}. Should be a `dict` or a"
|
||||
" `BitsAndBytesConfig` instance."
|
||||
)
|
||||
|
||||
load_in_8bit = quantization_config.load_in_8bit
|
||||
|
||||
if load_in_8bit:
|
||||
torch_dtype = torch.float16
|
||||
|
||||
if device_map is None:
|
||||
device_map = "auto"
|
||||
|
||||
if low_cpu_mem_usage is None:
|
||||
low_cpu_mem_usage = True
|
||||
elif not is_8bit_serializable and not load_in_8bit and hasattr(config, "quantization_config"):
|
||||
logger.warning(
|
||||
"Detected the presence of a `quantization_config` attribute in the model's configuration but you don't have the correct"
|
||||
" `bitsandbytes` version to support int8 serialization. Please install the latest version of `bitsandbytes` with "
|
||||
" `pip install --upgrade bitsandbytes`."
|
||||
)
|
||||
|
||||
if commit_hash is None:
|
||||
commit_hash = getattr(config, "_commit_hash", None)
|
||||
|
||||
|
|
@ -2621,6 +2677,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
importlib_metadata.version("bitsandbytes")
|
||||
) >= version.parse("0.37.0")
|
||||
|
||||
model.config.quantization_config = quantization_config
|
||||
model.is_8bit_serializable = is_8bit_serializable
|
||||
|
||||
if isinstance(device_map, str):
|
||||
special_dtypes = {}
|
||||
if load_in_8bit:
|
||||
|
|
@ -3113,6 +3172,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
)
|
||||
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
||||
|
||||
if load_in_8bit:
|
||||
unexpected_keys = [elem for elem in unexpected_keys if "SCB" not in elem]
|
||||
missing_keys = [elem for elem in missing_keys if "SCB" not in elem]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from copy import deepcopy
|
||||
|
||||
from .import_utils import is_accelerate_available, is_bitsandbytes_available
|
||||
from packaging import version
|
||||
|
||||
from .import_utils import importlib_metadata, is_accelerate_available, is_bitsandbytes_available
|
||||
|
||||
|
||||
if is_bitsandbytes_available():
|
||||
|
|
@ -13,7 +15,7 @@ if is_accelerate_available():
|
|||
from accelerate.utils import find_tied_parameters
|
||||
|
||||
|
||||
def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None):
|
||||
def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None):
|
||||
"""
|
||||
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
|
||||
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The
|
||||
|
|
@ -29,6 +31,8 @@ def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None):
|
|||
The device on which to set the tensor.
|
||||
value (`torch.Tensor`, *optional*):
|
||||
The value of the tensor (useful when going from the meta device to any other device).
|
||||
fp16_statistics (`torch.HalfTensor`, *optional*):
|
||||
The list of fp16 statistics to set on the module, used for serialization.
|
||||
"""
|
||||
# Recurse if needed
|
||||
if "." in tensor_name:
|
||||
|
|
@ -61,14 +65,21 @@ def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None):
|
|||
elif isinstance(value, torch.Tensor):
|
||||
new_value = value.to("cpu")
|
||||
if value.dtype == torch.int8:
|
||||
raise ValueError(
|
||||
"You cannot load weights that are saved in int8 using `load_in_8bit=True`, make sure you are",
|
||||
" using `load_in_8bit=True` on float32/float16/bfloat16 weights.",
|
||||
is_8bit_serializable = version.parse(importlib_metadata.version("bitsandbytes")) > version.parse(
|
||||
"0.37.2"
|
||||
)
|
||||
if not is_8bit_serializable:
|
||||
raise ValueError(
|
||||
"Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
|
||||
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
|
||||
)
|
||||
else:
|
||||
new_value = torch.tensor(value, device="cpu")
|
||||
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, has_fp16_weights=has_fp16_weights).to(device)
|
||||
module._parameters[tensor_name] = new_value
|
||||
|
||||
if fp16_statistics is not None:
|
||||
setattr(module.weight, "SCB", fp16_statistics.to(device))
|
||||
else:
|
||||
if value is None:
|
||||
new_value = old_value.to(device)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,16 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -49,6 +58,8 @@ class BitsAndBytesConfig:
|
|||
your model in different parts and run some parts in int8 on GPU and some parts in fp32 on CPU, you can use
|
||||
this flag. This is useful for offloading large models such as `google/flan-t5-xxl`. Note that the int8
|
||||
operations will not be run on CPU.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional parameters from which to initialize the configuration object.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -57,6 +68,7 @@ class BitsAndBytesConfig:
|
|||
llm_int8_threshold=6.0,
|
||||
llm_int8_skip_modules=None,
|
||||
llm_int8_enable_fp32_cpu_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.load_in_8bit = load_in_8bit
|
||||
self.llm_int8_threshold = llm_int8_threshold
|
||||
|
|
@ -81,17 +93,19 @@ class BitsAndBytesConfig:
|
|||
@classmethod
|
||||
def from_dict(cls, config_dict, return_unused_kwargs, **kwargs):
|
||||
"""
|
||||
Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters.
|
||||
Instantiates a [`BitsAndBytesConfig`] from a Python dictionary of parameters.
|
||||
|
||||
Args:
|
||||
config_dict (`Dict[str, Any]`):
|
||||
Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
|
||||
retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method.
|
||||
Dictionary that will be used to instantiate the configuration object.
|
||||
return_unused_kwargs (`bool`):
|
||||
Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in
|
||||
`PreTrainedModel`.
|
||||
kwargs (`Dict[str, Any]`):
|
||||
Additional parameters from which to initialize the configuration object.
|
||||
|
||||
Returns:
|
||||
[`PretrainedConfig`]: The configuration object instantiated from those parameters.
|
||||
[`BitsAndBytesConfig`]: The configuration object instantiated from those parameters.
|
||||
"""
|
||||
config = cls(**config_dict)
|
||||
|
||||
|
|
@ -107,3 +121,28 @@ class BitsAndBytesConfig:
|
|||
return config, kwargs
|
||||
else:
|
||||
return config
|
||||
|
||||
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
||||
"""
|
||||
Save this instance to a JSON file.
|
||||
|
||||
Args:
|
||||
json_file_path (`str` or `os.PathLike`):
|
||||
Path to the JSON file in which this configuration instance's parameters will be saved.
|
||||
use_diff (`bool`, *optional*, defaults to `True`):
|
||||
If set to `True`, only the difference between the config instance and the default
|
||||
`BitsAndBytesConfig()` is serialized to JSON file.
|
||||
"""
|
||||
with open(json_file_path, "w", encoding="utf-8") as writer:
|
||||
config_dict = self.to_dict()
|
||||
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
writer.write(json_string)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Returns:
|
||||
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
return output
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import unittest
|
|||
from packaging import version
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
|
|
@ -150,6 +151,13 @@ class MixedInt8Test(BaseMixedInt8Test):
|
|||
|
||||
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_warns_save_pretrained(self):
|
||||
r"""
|
||||
Test whether trying to save a model after converting it in 8-bit will throw a warning.
|
||||
"""
|
||||
with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.model_8bit.save_pretrained(tmpdirname)
|
||||
|
||||
def test_raise_if_config_and_load_in_8bit(self):
|
||||
r"""
|
||||
Test that loading the model with the config and `load_in_8bit` raises an error
|
||||
|
|
@ -165,13 +173,6 @@ class MixedInt8Test(BaseMixedInt8Test):
|
|||
llm_int8_enable_fp32_cpu_offload=True,
|
||||
)
|
||||
|
||||
def test_warns_save_pretrained(self):
|
||||
r"""
|
||||
Test whether trying to save a model after converting it in 8-bit will throw a warning.
|
||||
"""
|
||||
with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.model_8bit.save_pretrained(tmpdirname)
|
||||
|
||||
def test_device_and_dtype_assignment(self):
|
||||
r"""
|
||||
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
|
||||
|
|
@ -219,6 +220,77 @@ class MixedInt8Test(BaseMixedInt8Test):
|
|||
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", load_in_8bit=True, device_map="auto")
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
|
||||
|
||||
def test_int8_serialization(self):
|
||||
r"""
|
||||
Test whether it is possible to serialize a model in 8-bit.
|
||||
"""
|
||||
from bitsandbytes.nn import Int8Params
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.model_8bit.save_pretrained(tmpdirname)
|
||||
|
||||
# check that the file `quantization_config` is present
|
||||
config = AutoConfig.from_pretrained(tmpdirname)
|
||||
self.assertTrue(hasattr(config, "quantization_config"))
|
||||
|
||||
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto")
|
||||
|
||||
self.assertTrue(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
|
||||
self.assertTrue(hasattr(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
|
||||
|
||||
# generate
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = model_from_saved.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
|
||||
|
||||
self.assertEqual(
|
||||
self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT
|
||||
)
|
||||
|
||||
def test_int8_serialization_sharded(self):
|
||||
r"""
|
||||
Test whether it is possible to serialize a model in 8-bit - sharded version.
|
||||
"""
|
||||
from bitsandbytes.nn import Int8Params
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.model_8bit.save_pretrained(tmpdirname, max_shard_size="200MB")
|
||||
|
||||
# check that the file `quantization_config` is present
|
||||
config = AutoConfig.from_pretrained(tmpdirname)
|
||||
self.assertTrue(hasattr(config, "quantization_config"))
|
||||
|
||||
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname)
|
||||
|
||||
self.assertTrue(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
|
||||
self.assertTrue(hasattr(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
|
||||
|
||||
# generate
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = model_from_saved.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
|
||||
|
||||
self.assertEqual(
|
||||
self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT
|
||||
)
|
||||
|
||||
def test_int8_from_pretrained(self):
|
||||
r"""
|
||||
Test whether loading a 8bit model from the Hub works as expected
|
||||
"""
|
||||
from bitsandbytes.nn import Int8Params
|
||||
|
||||
model_id = "ybelkada/bloom-1b7-8bit"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
|
||||
self.assertTrue(model.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
|
||||
self.assertTrue(hasattr(model.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
|
||||
|
||||
# generate
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
|
||||
|
||||
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
@require_bitsandbytes
|
||||
@require_accelerate
|
||||
|
|
@ -289,6 +361,38 @@ class MixedInt8T5Test(unittest.TestCase):
|
|||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
|
||||
_ = model.generate(**encoded_input)
|
||||
|
||||
def test_inference_with_keep_in_fp32_serialized(self):
|
||||
r"""
|
||||
Test whether it is possible to mix both `int8` and `fp32` weights when using `keep_in_fp32_modules` correctly on
|
||||
a serialized model.
|
||||
`flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
|
||||
both cases.
|
||||
"""
|
||||
import bitsandbytes as bnb
|
||||
|
||||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
# test with `t5-small`
|
||||
model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
model = T5ForConditionalGeneration.from_pretrained(tmp_dir)
|
||||
|
||||
# there was a bug with decoders - this test checks that it is fixed
|
||||
self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear8bitLt))
|
||||
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
|
||||
_ = model.generate(**encoded_input)
|
||||
|
||||
# test with `flan-t5-small`
|
||||
model = T5ForConditionalGeneration.from_pretrained(
|
||||
self.dense_act_model_name, load_in_8bit=True, device_map="auto"
|
||||
)
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
|
||||
_ = model.generate(**encoded_input)
|
||||
|
||||
|
||||
class MixedInt8ModelClassesTest(BaseMixedInt8Test):
|
||||
def setUp(self):
|
||||
|
|
|
|||
Loading…
Reference in a new issue