pytorch/caffe2/python/serialized_test/serialized_test_util.py
Ansha Yu 8ff435c8f6 Use tempfile during serialized test comparison (#12021)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12021

TestPilot runs stress tests in parallel. These fail for serialized tests because extracting (and subsequent deletion) of binary data during the process isn't threadsafe. Extract zips into tempfile to avoid this problem.

Also remove some accidentally checked in zips of a test that we didn't end up including for now.

Reviewed By: houseroad

Differential Revision: D10013682

fbshipit-source-id: 6e13b850b38dee4106d3c10a9372747d17b67c5a
2018-09-25 20:55:45 -07:00

256 lines
8.1 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
from caffe2.proto import caffe2_pb2
from caffe2.python import gradient_checker
import caffe2.python.hypothesis_test_util as hu
import hypothesis as hy
import inspect
import numpy as np
import os
import shutil
import sys
import tempfile
import threading
from zipfile import ZipFile
operator_test_type = 'operator_test'
TOP_DIR = os.path.dirname(os.path.realpath(__file__))
DATA_SUFFIX = 'data'
DATA_DIR = os.path.join(TOP_DIR, DATA_SUFFIX)
_output_context = threading.local()
def given(*given_args, **given_kwargs):
def wrapper(f):
hyp_func = hy.given(*given_args, **given_kwargs)(f)
fixed_seed_func = hy.seed(0)(hy.settings(max_examples=1)(hy.given(
*given_args, **given_kwargs)(f)))
def func(self, *args, **kwargs):
self.should_serialize = True
fixed_seed_func(self, *args, **kwargs)
self.should_serialize = False
hyp_func(self, *args, **kwargs)
return func
return wrapper
def _getGradientOrNone(op_proto):
try:
grad_ops, _ = gradient_checker.getGradientForOp(op_proto)
return grad_ops
except Exception:
return []
# necessary to support converting jagged lists into numpy arrays
def _transformList(l):
ret = np.empty(len(l), dtype=np.object)
for (i, arr) in enumerate(l):
ret[i] = arr
return ret
def _prepare_dir(path):
if os.path.exists(path):
shutil.rmtree(path)
os.makedirs(path)
class SerializedTestCase(hu.HypothesisTestCase):
should_serialize = False
def get_output_dir(self):
output_dir_arg = getattr(_output_context, 'output_dir', DATA_DIR)
output_dir = os.path.join(
output_dir_arg, operator_test_type)
if os.path.exists(output_dir):
return output_dir
# fall back to pwd
cwd = os.getcwd()
serialized_util_module_components = __name__.split('.')
serialized_util_module_components.pop()
serialized_dir = '/'.join(serialized_util_module_components)
output_dir_fallback = os.path.join(cwd, serialized_dir, DATA_SUFFIX)
output_dir = os.path.join(
output_dir_fallback,
operator_test_type)
return output_dir
def get_output_filename(self):
class_path = inspect.getfile(self.__class__)
file_name_components = os.path.basename(class_path).split('.')
test_file = file_name_components[0]
function_name_components = self.id().split('.')
test_function = function_name_components[-1]
return test_file + '.' + test_function
def serialize_test(self, inputs, outputs, grad_ops, op, device_option):
output_dir = self.get_output_dir()
test_name = self.get_output_filename()
full_dir = os.path.join(output_dir, test_name)
_prepare_dir(full_dir)
inputs = _transformList(inputs)
outputs = _transformList(outputs)
device_type = int(device_option.device_type)
op_path = os.path.join(full_dir, 'op.pb')
grad_paths = []
inout_path = os.path.join(full_dir, 'inout')
with open(op_path, 'wb') as f:
f.write(op.SerializeToString())
for (i, grad) in enumerate(grad_ops):
grad_path = os.path.join(full_dir, 'grad_{}.pb'.format(i))
grad_paths.append(grad_path)
with open(grad_path, 'wb') as f:
f.write(grad.SerializeToString())
np.savez_compressed(
inout_path,
inputs=inputs,
outputs=outputs,
device_type=device_type)
with ZipFile(os.path.join(output_dir, test_name + '.zip'), 'w') as z:
z.write(op_path, 'op.pb')
z.write(inout_path + '.npz', 'inout.npz')
for path in grad_paths:
z.write(path, os.path.basename(path))
shutil.rmtree(full_dir)
def compare_test(self, inputs, outputs, grad_ops, atol=1e-7, rtol=1e-7):
def parse_proto(x):
proto = caffe2_pb2.OperatorDef()
proto.ParseFromString(x)
return proto
source_dir = self.get_output_dir()
test_name = self.get_output_filename()
temp_dir = tempfile.mkdtemp()
with ZipFile(os.path.join(source_dir, test_name + '.zip')) as z:
z.extractall(temp_dir)
op_path = os.path.join(temp_dir, 'op.pb')
inout_path = os.path.join(temp_dir, 'inout.npz')
# load serialized input and output
loaded = np.load(inout_path, encoding='bytes')
loaded_inputs = loaded['inputs'].tolist()
inputs_equal = True
for (x, y) in zip(inputs, loaded_inputs):
if not np.array_equal(x, y):
inputs_equal = False
loaded_outputs = loaded['outputs'].tolist()
# if inputs are not the same, run serialized input through serialized op
if not inputs_equal:
# load operator
with open(op_path, 'rb') as f:
loaded_op = f.read()
op_proto = parse_proto(loaded_op)
device_type = loaded['device_type']
device_option = caffe2_pb2.DeviceOption(device_type=int(device_type))
outputs = hu.runOpOnInput(device_option, op_proto, loaded_inputs)
grad_ops = _getGradientOrNone(op_proto)
# assert outputs are equal
for (x, y) in zip(outputs, loaded_outputs):
np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
# assert gradient op is equal
for i in range(len(grad_ops)):
grad_path = os.path.join(temp_dir, 'grad_{}.pb'.format(i))
with open(grad_path, 'rb') as f:
loaded_grad = f.read()
grad_proto = parse_proto(loaded_grad)
self.assertTrue(grad_proto == grad_ops[i])
shutil.rmtree(temp_dir)
def assertSerializedOperatorChecks(
self,
inputs,
outputs,
gradient_operator,
op,
device_option,
):
if self.should_serialize:
if getattr(_output_context, 'should_generate_output', False):
self.serialize_test(
inputs, outputs, gradient_operator, op, device_option)
else:
self.compare_test(inputs, outputs, gradient_operator)
def assertReferenceChecks(
self,
device_option,
op,
inputs,
reference,
input_device_options=None,
threshold=1e-4,
output_to_grad=None,
grad_reference=None,
atol=None,
outputs_to_check=None,
):
outs = super(SerializedTestCase, self).assertReferenceChecks(
device_option,
op,
inputs,
reference,
input_device_options,
threshold,
output_to_grad,
grad_reference,
atol,
outputs_to_check,
)
if not getattr(_output_context, 'disable_serialized_check', False):
grad_ops = _getGradientOrNone(op)
self.assertSerializedOperatorChecks(
inputs,
outs,
grad_ops,
op,
device_option,
)
def testWithArgs():
parser = argparse.ArgumentParser()
parser.add_argument(
'-G', '--generate-serialized', action='store_true', dest='generate',
help='generate output files (default=false, compares to current files)')
parser.add_argument(
'-O', '--output', default=DATA_DIR,
help='output directory (default: %(default)s)')
parser.add_argument(
'-D', '--disable-serialized_check', action='store_true', dest='disable',
help='disable checking serialized tests')
parser.add_argument('unittest_args', nargs='*')
args = parser.parse_args()
sys.argv[1:] = args.unittest_args
_output_context.__setattr__('should_generate_output', args.generate)
_output_context.__setattr__('output_dir', args.output)
_output_context.__setattr__('disable_serialized_check', args.disable)
import unittest
unittest.main()