mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12021 TestPilot runs stress tests in parallel. These fail for serialized tests because extracting (and subsequent deletion) of binary data during the process isn't threadsafe. Extract zips into tempfile to avoid this problem. Also remove some accidentally checked in zips of a test that we didn't end up including for now. Reviewed By: houseroad Differential Revision: D10013682 fbshipit-source-id: 6e13b850b38dee4106d3c10a9372747d17b67c5a
256 lines
8.1 KiB
Python
256 lines
8.1 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import argparse
|
|
from caffe2.proto import caffe2_pb2
|
|
from caffe2.python import gradient_checker
|
|
import caffe2.python.hypothesis_test_util as hu
|
|
import hypothesis as hy
|
|
import inspect
|
|
import numpy as np
|
|
import os
|
|
import shutil
|
|
import sys
|
|
import tempfile
|
|
import threading
|
|
from zipfile import ZipFile
|
|
|
|
operator_test_type = 'operator_test'
|
|
TOP_DIR = os.path.dirname(os.path.realpath(__file__))
|
|
DATA_SUFFIX = 'data'
|
|
DATA_DIR = os.path.join(TOP_DIR, DATA_SUFFIX)
|
|
_output_context = threading.local()
|
|
|
|
|
|
def given(*given_args, **given_kwargs):
|
|
def wrapper(f):
|
|
hyp_func = hy.given(*given_args, **given_kwargs)(f)
|
|
fixed_seed_func = hy.seed(0)(hy.settings(max_examples=1)(hy.given(
|
|
*given_args, **given_kwargs)(f)))
|
|
|
|
def func(self, *args, **kwargs):
|
|
self.should_serialize = True
|
|
fixed_seed_func(self, *args, **kwargs)
|
|
self.should_serialize = False
|
|
hyp_func(self, *args, **kwargs)
|
|
return func
|
|
return wrapper
|
|
|
|
|
|
def _getGradientOrNone(op_proto):
|
|
try:
|
|
grad_ops, _ = gradient_checker.getGradientForOp(op_proto)
|
|
return grad_ops
|
|
except Exception:
|
|
return []
|
|
|
|
|
|
# necessary to support converting jagged lists into numpy arrays
|
|
def _transformList(l):
|
|
ret = np.empty(len(l), dtype=np.object)
|
|
for (i, arr) in enumerate(l):
|
|
ret[i] = arr
|
|
return ret
|
|
|
|
|
|
def _prepare_dir(path):
|
|
if os.path.exists(path):
|
|
shutil.rmtree(path)
|
|
os.makedirs(path)
|
|
|
|
|
|
class SerializedTestCase(hu.HypothesisTestCase):
|
|
|
|
should_serialize = False
|
|
|
|
def get_output_dir(self):
|
|
output_dir_arg = getattr(_output_context, 'output_dir', DATA_DIR)
|
|
output_dir = os.path.join(
|
|
output_dir_arg, operator_test_type)
|
|
|
|
if os.path.exists(output_dir):
|
|
return output_dir
|
|
|
|
# fall back to pwd
|
|
cwd = os.getcwd()
|
|
serialized_util_module_components = __name__.split('.')
|
|
serialized_util_module_components.pop()
|
|
serialized_dir = '/'.join(serialized_util_module_components)
|
|
output_dir_fallback = os.path.join(cwd, serialized_dir, DATA_SUFFIX)
|
|
output_dir = os.path.join(
|
|
output_dir_fallback,
|
|
operator_test_type)
|
|
|
|
return output_dir
|
|
|
|
def get_output_filename(self):
|
|
class_path = inspect.getfile(self.__class__)
|
|
file_name_components = os.path.basename(class_path).split('.')
|
|
test_file = file_name_components[0]
|
|
|
|
function_name_components = self.id().split('.')
|
|
test_function = function_name_components[-1]
|
|
|
|
return test_file + '.' + test_function
|
|
|
|
def serialize_test(self, inputs, outputs, grad_ops, op, device_option):
|
|
output_dir = self.get_output_dir()
|
|
test_name = self.get_output_filename()
|
|
full_dir = os.path.join(output_dir, test_name)
|
|
_prepare_dir(full_dir)
|
|
|
|
inputs = _transformList(inputs)
|
|
outputs = _transformList(outputs)
|
|
device_type = int(device_option.device_type)
|
|
|
|
op_path = os.path.join(full_dir, 'op.pb')
|
|
grad_paths = []
|
|
inout_path = os.path.join(full_dir, 'inout')
|
|
|
|
with open(op_path, 'wb') as f:
|
|
f.write(op.SerializeToString())
|
|
for (i, grad) in enumerate(grad_ops):
|
|
grad_path = os.path.join(full_dir, 'grad_{}.pb'.format(i))
|
|
grad_paths.append(grad_path)
|
|
with open(grad_path, 'wb') as f:
|
|
f.write(grad.SerializeToString())
|
|
|
|
np.savez_compressed(
|
|
inout_path,
|
|
inputs=inputs,
|
|
outputs=outputs,
|
|
device_type=device_type)
|
|
|
|
with ZipFile(os.path.join(output_dir, test_name + '.zip'), 'w') as z:
|
|
z.write(op_path, 'op.pb')
|
|
z.write(inout_path + '.npz', 'inout.npz')
|
|
for path in grad_paths:
|
|
z.write(path, os.path.basename(path))
|
|
|
|
shutil.rmtree(full_dir)
|
|
|
|
def compare_test(self, inputs, outputs, grad_ops, atol=1e-7, rtol=1e-7):
|
|
|
|
def parse_proto(x):
|
|
proto = caffe2_pb2.OperatorDef()
|
|
proto.ParseFromString(x)
|
|
return proto
|
|
|
|
source_dir = self.get_output_dir()
|
|
test_name = self.get_output_filename()
|
|
temp_dir = tempfile.mkdtemp()
|
|
with ZipFile(os.path.join(source_dir, test_name + '.zip')) as z:
|
|
z.extractall(temp_dir)
|
|
|
|
op_path = os.path.join(temp_dir, 'op.pb')
|
|
inout_path = os.path.join(temp_dir, 'inout.npz')
|
|
|
|
# load serialized input and output
|
|
loaded = np.load(inout_path, encoding='bytes')
|
|
loaded_inputs = loaded['inputs'].tolist()
|
|
inputs_equal = True
|
|
for (x, y) in zip(inputs, loaded_inputs):
|
|
if not np.array_equal(x, y):
|
|
inputs_equal = False
|
|
loaded_outputs = loaded['outputs'].tolist()
|
|
|
|
# if inputs are not the same, run serialized input through serialized op
|
|
if not inputs_equal:
|
|
# load operator
|
|
with open(op_path, 'rb') as f:
|
|
loaded_op = f.read()
|
|
|
|
op_proto = parse_proto(loaded_op)
|
|
device_type = loaded['device_type']
|
|
device_option = caffe2_pb2.DeviceOption(device_type=int(device_type))
|
|
|
|
outputs = hu.runOpOnInput(device_option, op_proto, loaded_inputs)
|
|
grad_ops = _getGradientOrNone(op_proto)
|
|
|
|
# assert outputs are equal
|
|
for (x, y) in zip(outputs, loaded_outputs):
|
|
np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
|
|
|
|
# assert gradient op is equal
|
|
for i in range(len(grad_ops)):
|
|
grad_path = os.path.join(temp_dir, 'grad_{}.pb'.format(i))
|
|
with open(grad_path, 'rb') as f:
|
|
loaded_grad = f.read()
|
|
grad_proto = parse_proto(loaded_grad)
|
|
self.assertTrue(grad_proto == grad_ops[i])
|
|
|
|
shutil.rmtree(temp_dir)
|
|
|
|
def assertSerializedOperatorChecks(
|
|
self,
|
|
inputs,
|
|
outputs,
|
|
gradient_operator,
|
|
op,
|
|
device_option,
|
|
):
|
|
if self.should_serialize:
|
|
if getattr(_output_context, 'should_generate_output', False):
|
|
self.serialize_test(
|
|
inputs, outputs, gradient_operator, op, device_option)
|
|
else:
|
|
self.compare_test(inputs, outputs, gradient_operator)
|
|
|
|
def assertReferenceChecks(
|
|
self,
|
|
device_option,
|
|
op,
|
|
inputs,
|
|
reference,
|
|
input_device_options=None,
|
|
threshold=1e-4,
|
|
output_to_grad=None,
|
|
grad_reference=None,
|
|
atol=None,
|
|
outputs_to_check=None,
|
|
):
|
|
outs = super(SerializedTestCase, self).assertReferenceChecks(
|
|
device_option,
|
|
op,
|
|
inputs,
|
|
reference,
|
|
input_device_options,
|
|
threshold,
|
|
output_to_grad,
|
|
grad_reference,
|
|
atol,
|
|
outputs_to_check,
|
|
)
|
|
if not getattr(_output_context, 'disable_serialized_check', False):
|
|
grad_ops = _getGradientOrNone(op)
|
|
self.assertSerializedOperatorChecks(
|
|
inputs,
|
|
outs,
|
|
grad_ops,
|
|
op,
|
|
device_option,
|
|
)
|
|
|
|
|
|
def testWithArgs():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
'-G', '--generate-serialized', action='store_true', dest='generate',
|
|
help='generate output files (default=false, compares to current files)')
|
|
parser.add_argument(
|
|
'-O', '--output', default=DATA_DIR,
|
|
help='output directory (default: %(default)s)')
|
|
parser.add_argument(
|
|
'-D', '--disable-serialized_check', action='store_true', dest='disable',
|
|
help='disable checking serialized tests')
|
|
parser.add_argument('unittest_args', nargs='*')
|
|
args = parser.parse_args()
|
|
sys.argv[1:] = args.unittest_args
|
|
_output_context.__setattr__('should_generate_output', args.generate)
|
|
_output_context.__setattr__('output_dir', args.output)
|
|
_output_context.__setattr__('disable_serialized_check', args.disable)
|
|
|
|
import unittest
|
|
unittest.main()
|