mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
[fsmt tokenizer] support lowercase tokenizer (#8389)
* support lowercase tokenizer * fix arg pos
This commit is contained in:
parent
1e2acd0dcf
commit
78d706f3ae
3 changed files with 23 additions and 1 deletions
|
|
@ -133,6 +133,14 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder
|
|||
with open(src_vocab_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(src_vocab, ensure_ascii=False, indent=json_indent))
|
||||
|
||||
# detect whether this is a do_lower_case situation, which can be derived by checking whether we
|
||||
# have at least one upcase letter in the source vocab
|
||||
do_lower_case = True
|
||||
for k in src_vocab.keys():
|
||||
if not k.islower():
|
||||
do_lower_case = False
|
||||
break
|
||||
|
||||
tgt_dict = Dictionary.load(tgt_dict_file)
|
||||
tgt_vocab = rewrite_dict_keys(tgt_dict.indices)
|
||||
tgt_vocab_size = len(tgt_vocab)
|
||||
|
|
@ -207,6 +215,7 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder
|
|||
tokenizer_conf = {
|
||||
"langs": [src_lang, tgt_lang],
|
||||
"model_max_length": 1024,
|
||||
"do_lower_case": do_lower_case,
|
||||
}
|
||||
|
||||
print(f"Generating {fsmt_tokenizer_config_file}")
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
|
|||
File containing the vocabulary for the target language.
|
||||
merges_file (:obj:`str`):
|
||||
File containing the merges.
|
||||
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to lowercase the input when tokenizing.
|
||||
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
|
|
@ -186,6 +186,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
|
|||
src_vocab_file=None,
|
||||
tgt_vocab_file=None,
|
||||
merges_file=None,
|
||||
do_lower_case=False,
|
||||
unk_token="<unk>",
|
||||
bos_token="<s>",
|
||||
sep_token="</s>",
|
||||
|
|
@ -197,6 +198,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
|
|||
src_vocab_file=src_vocab_file,
|
||||
tgt_vocab_file=tgt_vocab_file,
|
||||
merges_file=merges_file,
|
||||
do_lower_case=do_lower_case,
|
||||
unk_token=unk_token,
|
||||
bos_token=bos_token,
|
||||
sep_token=sep_token,
|
||||
|
|
@ -207,6 +209,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
|
|||
self.src_vocab_file = src_vocab_file
|
||||
self.tgt_vocab_file = tgt_vocab_file
|
||||
self.merges_file = merges_file
|
||||
self.do_lower_case = do_lower_case
|
||||
|
||||
# cache of sm.MosesPunctNormalizer instance
|
||||
self.cache_moses_punct_normalizer = dict()
|
||||
|
|
@ -351,6 +354,9 @@ class FSMTTokenizer(PreTrainedTokenizer):
|
|||
# raise ValueError(f"Expected lang={self.src_lang}, but got {lang}")
|
||||
lang = self.src_lang
|
||||
|
||||
if self.do_lower_case:
|
||||
text = text.lower()
|
||||
|
||||
if bypass_tokenizer:
|
||||
text = text.split()
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -151,6 +151,13 @@ class FSMTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
decoded_text = tokenizer_dec.decode(encoded_ids, skip_special_tokens=True)
|
||||
self.assertEqual(decoded_text, src_text)
|
||||
|
||||
@slow
|
||||
def test_tokenizer_lower(self):
|
||||
tokenizer = FSMTTokenizer.from_pretrained("facebook/wmt19-ru-en", do_lower_case=True)
|
||||
tokens = tokenizer.tokenize("USA is United States of America")
|
||||
expected = ["us", "a</w>", "is</w>", "un", "i", "ted</w>", "st", "ates</w>", "of</w>", "am", "er", "ica</w>"]
|
||||
self.assertListEqual(tokens, expected)
|
||||
|
||||
@unittest.skip("FSMTConfig.__init__ requires non-optional args")
|
||||
def test_torch_encode_plus_sent_to_model(self):
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in a new issue