Update PT/Flax weight conversion after #24030 (#24556)

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-06-28 19:44:31 +02:00 committed by GitHub
parent 33b5ef5cdf
commit faae8d8255
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -120,6 +120,16 @@ def rename_key_and_reshape_tensor(
if pt_tuple_key[-1] == "beta":
return renamed_pt_tuple_key, pt_tensor
# New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
name = None
if pt_tuple_key[-3::2] == ("parametrizations", "original0"):
name = pt_tuple_key[-2] + "_g"
elif pt_tuple_key[-3::2] == ("parametrizations", "original1"):
name = pt_tuple_key[-2] + "_v"
if name is not None:
renamed_pt_tuple_key = pt_tuple_key[:-3] + (name,)
return renamed_pt_tuple_key, pt_tensor
return pt_tuple_key, pt_tensor
@ -372,6 +382,24 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
else:
flax_key = ".".join(flax_key_tuple)
# We also need to look at `pt_model_dict` and see if there are keys requiring further transformation.
special_pt_names = {}
# New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
for key in pt_model_dict:
key_components = key.split(".")
name = None
if key_components[-3::2] == ["parametrizations", "original0"]:
name = key_components[-2] + "_g"
elif key_components[-3::2] == ["parametrizations", "original1"]:
name = key_components[-2] + "_v"
if name is not None:
key_components = key_components[:-3] + [name]
key_to_check = ".".join(key_components)
special_pt_names[key_to_check] = key
if flax_key in special_pt_names:
flax_key = special_pt_names[flax_key]
if flax_key in pt_model_dict:
if flax_tensor.shape != pt_model_dict[flax_key].shape:
raise ValueError(