diff --git a/tests/test_onnx.py b/tests/test_onnx.py index d397a149c..6308bc523 100644 --- a/tests/test_onnx.py +++ b/tests/test_onnx.py @@ -1,7 +1,5 @@ import unittest -from os.path import dirname, exists from pathlib import Path -from shutil import rmtree from tempfile import NamedTemporaryFile, TemporaryDirectory from transformers import BertConfig, BertTokenizerFast, FeatureExtractionPipeline @@ -72,7 +70,7 @@ class OnnxExportTestCase(unittest.TestCase): def test_quantize_pytorch(self): for model in OnnxExportTestCase.MODEL_TO_TEST: path = self._test_export(model, "pt", 12) - quantized_path = quantize(Path(path)) + quantized_path = quantize(path) # Ensure the actual quantized model is not bigger than the original one if quantized_path.stat().st_size >= Path(path).stat().st_size: @@ -82,16 +80,16 @@ class OnnxExportTestCase(unittest.TestCase): try: # Compute path with TemporaryDirectory() as tempdir: - path = tempdir + "/model.onnx" + path = Path(tempdir).joinpath("model.onnx") # Remove folder if exists - if exists(dirname(path)): - rmtree(dirname(path)) + if path.parent.exists(): + path.parent.rmdir() - # Export - convert(framework, model, path, opset, tokenizer) + # Export + convert(framework, model, path, opset, tokenizer) - return path + return path except Exception as e: self.fail(e)