diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index 4e5fe8370..44adbc6b5 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -14,6 +14,7 @@ # limitations under the License. from __future__ import absolute_import, division, print_function, unicode_literals +import os import sys from io import open import tempfile @@ -49,15 +50,18 @@ def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, * def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs): tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs) + tester.assertIsNotNone(tokenizer) text = u"Munich and Berlin are nice cities" - filename = u"/tmp/tokenizer.bin" - subwords = tokenizer.tokenize(text) - pickle.dump(tokenizer, open(filename, "wb")) + with TemporaryDirectory() as tmpdirname: + + filename = os.path.join(tmpdirname, u"tokenizer.bin") + pickle.dump(tokenizer, open(filename, "wb")) + + tokenizer_new = pickle.load(open(filename, "rb")) - tokenizer_new = pickle.load(open(filename, "rb")) subwords_loaded = tokenizer_new.tokenize(text) tester.assertListEqual(subwords, subwords_loaded)