diff --git a/benchmarks/operator_benchmark/benchmark_test_generator.py b/benchmarks/operator_benchmark/benchmark_test_generator.py index 9ad3e722606..5b6e948e1f9 100644 --- a/benchmarks/operator_benchmark/benchmark_test_generator.py +++ b/benchmarks/operator_benchmark/benchmark_test_generator.py @@ -8,6 +8,8 @@ from operator_benchmark.benchmark_caffe2 import Caffe2OperatorTestCase from operator_benchmark.benchmark_pytorch import PyTorchOperatorTestCase from operator_benchmark.benchmark_utils import * # noqa +import torch + def generate_test(configs, map_config, ops, OperatorTestCase): """