mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Adding Llama FastTokenizer support. (#22264)
* Adding Llama FastTokenizer support. - Requires https://github.com/huggingface/tokenizers/pull/1183 version - Only support byte_fallback for llama, raise otherwise (safety net). - Lots of questions are special tokens How to test: ```python from transformers.convert_slow_tokenizer import convert_slow_tokenizer from transformers import AutoTokenizer from tokenizers import Tokenizer tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b") if False: new_tokenizer = Tokenizer.from_file("tok.json") else: new_tokenizer = convert_slow_tokenizer(tokenizer) new_tokenizer.save("tok.json") strings = [ "This is a test", "生活的真谛是", "生活的真谛是[MASK]。", # XXX: This one is problematic because of special tokens # "<s> Something something", ] for string in strings: encoded = tokenizer(string)["input_ids"] encoded2 = new_tokenizer.encode(string).ids assert encoded == encoded2, f"{encoded} != {encoded2}" decoded = tokenizer.decode(encoded) decoded2 = new_tokenizer.decode(encoded2) assert decoded.strip() == decoded2, f"{repr(decoded)} != {repr(decoded2)}" ``` The converter + some test script. The test script. Tmp save. Adding Fast tokenizer + tests. Adding the tokenization tests. Correct combination. Small fix. Fixing tests. Fixing with latest update. Rebased. fix copies + normalized added tokens + copies. Adding doc. TMP. Doc + split files. Doc. Versions + try import. Fix Camembert + warnings -> Error. Fix by ArthurZucker. Not a decorator. * Fixing comments. * Adding more to docstring. * Doc rewriting.
This commit is contained in:
parent
1564189298
commit
1670be4bde
11 changed files with 267 additions and 25 deletions
|
|
@ -336,7 +336,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||
| LED | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| LeViT | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| LiLT | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| LLaMA | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| LLaMA | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| Longformer | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| LongT5 | ❌ | ❌ | ✅ | ❌ | ✅ |
|
||||
| LUKE | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
|
|
|
|||
|
|
@ -59,6 +59,14 @@ This model was contributed by [zphang](https://huggingface.co/zphang) with contr
|
|||
- create_token_type_ids_from_sequences
|
||||
- save_vocabulary
|
||||
|
||||
## LlamaTokenizerFast
|
||||
|
||||
[[autodoc]] LlamaTokenizerFast
|
||||
- build_inputs_with_special_tokens
|
||||
- get_special_tokens_mask
|
||||
- create_token_type_ids_from_sequences
|
||||
- save_vocabulary
|
||||
|
||||
## LlamaModel
|
||||
|
||||
[[autodoc]] LlamaModel
|
||||
|
|
|
|||
3
setup.py
3
setup.py
|
|
@ -78,7 +78,7 @@ import re
|
|||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from setuptools import setup, Command
|
||||
from setuptools import Command, setup
|
||||
|
||||
|
||||
# Remove stale transformers.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
|
||||
|
|
@ -251,6 +251,7 @@ class DepsTableUpdateCommand(Command):
|
|||
with open(target, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write("\n".join(content))
|
||||
|
||||
|
||||
extras = {}
|
||||
|
||||
extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic", "sudachipy", "sudachidict_core", "rhoknp")
|
||||
|
|
|
|||
|
|
@ -740,6 +740,7 @@ else:
|
|||
_import_structure["models.layoutlmv3"].append("LayoutLMv3TokenizerFast")
|
||||
_import_structure["models.layoutxlm"].append("LayoutXLMTokenizerFast")
|
||||
_import_structure["models.led"].append("LEDTokenizerFast")
|
||||
_import_structure["models.llama"].append("LlamaTokenizerFast")
|
||||
_import_structure["models.longformer"].append("LongformerTokenizerFast")
|
||||
_import_structure["models.lxmert"].append("LxmertTokenizerFast")
|
||||
_import_structure["models.markuplm"].append("MarkupLMTokenizerFast")
|
||||
|
|
@ -4388,6 +4389,7 @@ if TYPE_CHECKING:
|
|||
from .models.layoutlmv3 import LayoutLMv3TokenizerFast
|
||||
from .models.layoutxlm import LayoutXLMTokenizerFast
|
||||
from .models.led import LEDTokenizerFast
|
||||
from .models.llama import LlamaTokenizerFast
|
||||
from .models.longformer import LongformerTokenizerFast
|
||||
from .models.lxmert import LxmertTokenizerFast
|
||||
from .models.markuplm import MarkupLMTokenizerFast
|
||||
|
|
|
|||
|
|
@ -19,10 +19,9 @@ All the conversions are grouped here to gather SentencePiece dependencies outsid
|
|||
allow to make our dependency on SentencePiece optional.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
||||
from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
||||
from tokenizers.models import BPE, Unigram, WordPiece
|
||||
|
||||
from .utils import requires_backends
|
||||
|
|
@ -450,12 +449,13 @@ class SpmConverter(Converter):
|
|||
self.proto = m
|
||||
|
||||
if self.proto.trainer_spec.byte_fallback:
|
||||
warnings.warn(
|
||||
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
|
||||
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
|
||||
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
|
||||
"unknown tokens into a sequence of byte tokens matching the original piece of text."
|
||||
)
|
||||
if not getattr(self, "handle_byte_fallback", None):
|
||||
raise RuntimeError(
|
||||
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
|
||||
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
|
||||
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
|
||||
"unknown tokens into a sequence of byte tokens matching the original piece of text."
|
||||
)
|
||||
|
||||
def vocab(self, proto):
|
||||
return [(piece.piece, piece.score) for piece in proto.pieces]
|
||||
|
|
@ -1094,6 +1094,78 @@ class XGLMConverter(SpmConverter):
|
|||
)
|
||||
|
||||
|
||||
class LlamaConverter(SpmConverter):
|
||||
handle_byte_fallback = True
|
||||
|
||||
def vocab(self, proto):
|
||||
vocab = [
|
||||
("<unk>", 0.0),
|
||||
("<s>", 0.0),
|
||||
("</s>", 0.0),
|
||||
]
|
||||
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
|
||||
return vocab
|
||||
|
||||
def unk_id(self, proto):
|
||||
unk_id = 0
|
||||
return unk_id
|
||||
|
||||
def decoder(self, replacement, add_prefix_space):
|
||||
return decoders.Sequence(
|
||||
[
|
||||
decoders.Replace("▁", " "),
|
||||
decoders.ByteFallback(),
|
||||
decoders.Fuse(),
|
||||
decoders.Strip(content=" ", left=1),
|
||||
]
|
||||
)
|
||||
|
||||
def tokenizer(self, proto):
|
||||
model_type = proto.trainer_spec.model_type
|
||||
vocab_scores = self.vocab(proto)
|
||||
if model_type == 1:
|
||||
raise RuntimeError("Llama is supposed to be a BPE model!")
|
||||
elif model_type == 2:
|
||||
_, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
|
||||
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
|
||||
tokenizer = Tokenizer(
|
||||
BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
|
||||
)
|
||||
tokenizer.add_special_tokens(
|
||||
[
|
||||
AddedToken("<unk>", normalized=True),
|
||||
AddedToken("<s>", normalized=True),
|
||||
AddedToken("</s>", normalized=True),
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
|
||||
)
|
||||
|
||||
return tokenizer
|
||||
|
||||
def normalizer(self, proto):
|
||||
return normalizers.Sequence(
|
||||
[
|
||||
normalizers.Prepend(prepend="▁"),
|
||||
normalizers.Replace(pattern=" ", content="▁"),
|
||||
]
|
||||
)
|
||||
|
||||
def pre_tokenizer(self, replacement, add_prefix_space):
|
||||
return None
|
||||
|
||||
def post_processor(self):
|
||||
return processors.TemplateProcessing(
|
||||
single="<s> $A",
|
||||
pair="<s> $A $B",
|
||||
special_tokens=[
|
||||
("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class MarkupLMConverter(Converter):
|
||||
def converted(self) -> Tokenizer:
|
||||
ot = self.original_tokenizer
|
||||
|
|
@ -1183,6 +1255,7 @@ SLOW_TO_FAST_CONVERTERS = {
|
|||
"XLNetTokenizer": XLNetConverter,
|
||||
"SplinterTokenizer": SplinterConverter,
|
||||
"XGLMTokenizer": XGLMConverter,
|
||||
"LlamaTokenizer": LlamaConverter,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -172,7 +172,13 @@ else:
|
|||
("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("lilt", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("llama", ("LlamaTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
(
|
||||
"llama",
|
||||
(
|
||||
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
||||
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"longt5",
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from ...utils import (
|
|||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_sentencepiece_available,
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
|
|
@ -33,6 +34,14 @@ except OptionalDependencyNotAvailable:
|
|||
else:
|
||||
_import_structure["tokenization_llama"] = ["LlamaTokenizer"]
|
||||
|
||||
try:
|
||||
if not is_tokenizers_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["tokenization_llama_fast"] = ["LlamaTokenizerFast"]
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
|
@ -58,6 +67,14 @@ if TYPE_CHECKING:
|
|||
else:
|
||||
from .tokenization_llama import LlamaTokenizer
|
||||
|
||||
try:
|
||||
if not is_tokenizers_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .tokenization_llama_fast import LlamaTokenizerFast
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
|
|
|||
82
src/transformers/models/llama/tokenization_llama_fast.py
Normal file
82
src/transformers/models/llama/tokenization_llama_fast.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ...utils.versions import require_version
|
||||
|
||||
|
||||
require_version("tokenizers>=0.13.3")
|
||||
|
||||
|
||||
class LlamaTokenizerFast(PreTrainedTokenizerFast):
|
||||
"""
|
||||
Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
|
||||
|
||||
This uses notably ByteFallback and no normalization.
|
||||
|
||||
```
|
||||
from transformers import LlamaTokenizerFast
|
||||
|
||||
tokenizer = LlaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
tokenizer.encode("Hello this is a test")
|
||||
>>> [1, 15043, 445, 338, 263, 1243]
|
||||
```
|
||||
|
||||
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
||||
refer to this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`):
|
||||
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
|
||||
contains the vocabulary necessary to instantiate a tokenizer.
|
||||
tokenizer_file (`str`):
|
||||
[tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
|
||||
contains everything needed to load the tokenizer.
|
||||
|
||||
clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
|
||||
Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
|
||||
spaces.
|
||||
|
||||
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
||||
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
||||
|
||||
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
||||
The end of sequence token.
|
||||
|
||||
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
"""
|
||||
|
||||
padding_side = "left"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file=None,
|
||||
tokenizer_file=None,
|
||||
clean_up_tokenization_spaces=False,
|
||||
unk_token="<unk>",
|
||||
bos_token="<s>",
|
||||
eos_token="</s>",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
vocab_file=vocab_file,
|
||||
tokenizer_file=tokenizer_file,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
unk_token=unk_token,
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
@ -219,6 +219,13 @@ class LEDTokenizerFast(metaclass=DummyObject):
|
|||
requires_backends(self, ["tokenizers"])
|
||||
|
||||
|
||||
class LlamaTokenizerFast(metaclass=DummyObject):
|
||||
_backends = ["tokenizers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tokenizers"])
|
||||
|
||||
|
||||
class LongformerTokenizerFast(metaclass=DummyObject):
|
||||
_backends = ["tokenizers"]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
|
@ -23,8 +24,10 @@ from transformers import (
|
|||
SPIECE_UNDERLINE,
|
||||
AddedToken,
|
||||
LlamaTokenizer,
|
||||
LlamaTokenizerFast,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.convert_slow_tokenizer import convert_slow_tokenizer
|
||||
from transformers.testing_utils import (
|
||||
get_tests_dir,
|
||||
nested_simplify,
|
||||
|
|
@ -287,13 +290,11 @@ class LlamaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class LlamaIntegrationTest(unittest.TestCase):
|
||||
checkpoint_name = "hf-internal-testing/llama-tokenizer"
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.tokenizer: LlamaTokenizer = LlamaTokenizer.from_pretrained(cls.checkpoint_name)
|
||||
cls.rust_tokenizer = cls.tokenizer # TODO @narsil replace with the rust one
|
||||
cls.pad_token_id = 1
|
||||
checkpoint_name = "hf-internal-testing/llama-tokenizer"
|
||||
cls.tokenizer: LlamaTokenizer = LlamaTokenizer.from_pretrained(checkpoint_name)
|
||||
cls.rust_tokenizer = LlamaTokenizerFast.from_pretrained(checkpoint_name)
|
||||
return cls
|
||||
|
||||
@require_torch
|
||||
|
|
@ -314,6 +315,27 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||
},
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_conversion(self):
|
||||
# This is excruciatingly slow since it has to recreate the entire merge
|
||||
# list from the original vocabulary in spm
|
||||
self.rust_tokenizer.save_pretrained("./out")
|
||||
with tempfile.TemporaryDirectory() as dirname:
|
||||
self.rust_tokenizer.save_pretrained(dirname)
|
||||
|
||||
with open(os.path.join(dirname, "tokenizer.json"), "r") as f:
|
||||
old_serialized = f.read()
|
||||
|
||||
new_tokenizer = convert_slow_tokenizer(self.tokenizer)
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
new_tokenizer.save(f.name)
|
||||
# Re-opening since `f` is in bytes.
|
||||
new_serialized = open(f.name, "r").read()
|
||||
with open("out_tokenizer.json", "w") as g:
|
||||
g.write(new_serialized)
|
||||
|
||||
self.assertEqual(old_serialized, new_serialized)
|
||||
|
||||
def test_simple_encode_decode(self):
|
||||
pyth_tokenizer = self.tokenizer
|
||||
rust_tokenizer = self.rust_tokenizer
|
||||
|
|
@ -362,11 +384,27 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||
self.assertEqual(pyth_tokenizer.encode(" Hello"), [1, 29871, 15043])
|
||||
self.assertEqual(rust_tokenizer.encode(" Hello"), [1, 29871, 15043])
|
||||
|
||||
def test_no_differences_showcase(self):
|
||||
pyth_tokenizer = self.tokenizer
|
||||
rust_tokenizer = self.rust_tokenizer
|
||||
self.assertEqual(pyth_tokenizer.encode(""), [1])
|
||||
self.assertEqual(rust_tokenizer.encode(""), [1])
|
||||
|
||||
self.assertEqual(pyth_tokenizer.encode(" "), [1, 259])
|
||||
self.assertEqual(rust_tokenizer.encode(" "), [1, 259])
|
||||
|
||||
self.assertEqual(pyth_tokenizer.encode(" "), [1, 1678])
|
||||
self.assertEqual(rust_tokenizer.encode(" "), [1, 1678])
|
||||
|
||||
self.assertEqual(pyth_tokenizer.encode(" Hello"), [1, 29871, 15043])
|
||||
self.assertEqual(rust_tokenizer.encode(" Hello"), [1, 29871, 15043])
|
||||
|
||||
self.assertEqual(pyth_tokenizer.encode("<s>"), [1, 1])
|
||||
self.assertEqual(rust_tokenizer.encode("<s>"), [1, 1])
|
||||
|
||||
self.assertEqual(pyth_tokenizer.encode(""), [1])
|
||||
self.assertEqual(rust_tokenizer.encode(""), [1])
|
||||
def test_no_differences_decode(self):
|
||||
pyth_tokenizer = self.tokenizer
|
||||
rust_tokenizer = self.rust_tokenizer
|
||||
|
||||
self.assertEqual(pyth_tokenizer.decode([869]), ".")
|
||||
self.assertEqual(rust_tokenizer.decode([869]), ".")
|
||||
|
|
@ -374,6 +412,15 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||
self.assertEqual(pyth_tokenizer.decode([30112, 869]), "ا .")
|
||||
self.assertEqual(rust_tokenizer.decode([30112, 869]), "ا .")
|
||||
|
||||
def test_no_differences_special_tokens(self):
|
||||
pyth_tokenizer = self.tokenizer
|
||||
rust_tokenizer = self.rust_tokenizer
|
||||
self.assertEqual(pyth_tokenizer.encode(""), [1])
|
||||
self.assertEqual(rust_tokenizer.encode(""), [1])
|
||||
|
||||
self.assertEqual(pyth_tokenizer.encode("<s>"), [1, 1])
|
||||
self.assertEqual(rust_tokenizer.encode("<s>"), [1, 1])
|
||||
|
||||
@unittest.skipIf(
|
||||
os.getenv("RUN_TOKENIZER_INTEGRATION", "0") == "0",
|
||||
"RUN_TOKENIZER_INTEGRATION=1 to run tokenizer integration tests",
|
||||
|
|
@ -392,8 +439,8 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||
|
||||
self.assertEqual(encoded1, encoded2)
|
||||
|
||||
decoded1 = pyth_tokenizer.decode(encoded1)
|
||||
decoded2 = rust_tokenizer.decode(encoded2)
|
||||
decoded1 = pyth_tokenizer.decode(encoded1, skip_special_tokens=True)
|
||||
decoded2 = rust_tokenizer.decode(encoded2, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(decoded1, decoded2)
|
||||
|
||||
|
|
@ -406,7 +453,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||
|
||||
self.assertEqual(encoded1, encoded2)
|
||||
|
||||
decoded1 = pyth_tokenizer.decode(encoded1)
|
||||
decoded2 = rust_tokenizer.decode(encoded2)
|
||||
decoded1 = pyth_tokenizer.decode(encoded1, skip_special_tokens=True)
|
||||
decoded2 = rust_tokenizer.decode(encoded2, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(decoded1, decoded2)
|
||||
|
|
|
|||
|
|
@ -24,11 +24,10 @@ class ConvertSlowTokenizerTest(unittest.TestCase):
|
|||
|
||||
original_tokenizer_with_bytefallback = FakeOriginalTokenizer(vocab_file=spm_model_file_with_bytefallback)
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
_ = SpmConverter(original_tokenizer_with_bytefallback)
|
||||
self.assertEqual(len(w), 1)
|
||||
self.assertIn(
|
||||
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
|
||||
" which is not implemented in the fast tokenizers.",
|
||||
str(w[0].message),
|
||||
str(cm.exception),
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue