Fix use_parallel_residual and qkv_bias for StableLM GGUF config extraction (#34450)

* fix stablelm qkv_bias

* fix stablelm qkv_bias and use_parallel_residual

* remove original_model.config for stablelm gguf test
This commit is contained in:
Isotr0py 2024-11-06 01:26:20 +08:00 committed by GitHub
parent 9f28d0c5d0
commit e83aaaa86b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 11 additions and 8 deletions

View file

@ -106,6 +106,17 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
if "qwen2moe" in architecture:
updated_architecture = "qwen2_moe"
# For stablelm architecture, we need to set qkv_bias and use_parallel_residual from tensors
# If `qkv_bias=True`, qkv_proj with bias will be present in the tensors
# If `use_parallel_residual=False`, ffn_norm will be present in the tensors
if "stablelm" in architecture:
attn_bias_name = {"attn_q.bias", "attn_k.bias", "attn_v.bias"}
ffn_norm_name = "ffn_norm"
qkv_bias = any(bias_name in tensor.name for tensor in reader.tensors for bias_name in attn_bias_name)
use_parallel_residual = any(ffn_norm_name in tensor.name for tensor in reader.tensors)
parsed_parameters["config"]["qkv_bias"] = qkv_bias
parsed_parameters["config"]["use_parallel_residual"] = not use_parallel_residual
model_size = ""
# extract the number of params from file name as architectures can differ ;
# eg. for falcon : `...falcon-7b-...`

View file

@ -673,10 +673,6 @@ class GgufIntegrationTests(unittest.TestCase):
self.stablelm2_model_id,
gguf_file=self.fp16_stablelm2_model_id,
torch_dtype=torch.float16,
# for precise comparison it is required to use the original model config
# as quantized one is different in parameters: use_parallel_residual and use_qkv_bias
# and it highly influences on the output results
config=original_model.config,
)
tokenizer = AutoTokenizer.from_pretrained(self.stablelm2_model_id, gguf_file=self.fp16_stablelm2_model_id)
@ -703,10 +699,6 @@ class GgufIntegrationTests(unittest.TestCase):
gguf_file=self.fp16_stablelm2_model_id,
device_map="auto",
torch_dtype=torch.float16,
# for precise comparison it is required to use the original model config
# as quantized one is different in parameters: use_parallel_residual and use_qkv_bias
# and it highly influences on the output results
config=original_model.config,
)
converted_state_dict = converted_model.state_dict()