From d65414d1450b52740d83b10283e40d4130ed3f32 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Wed, 13 Apr 2022 23:23:13 +0000 Subject: [PATCH] Add test for FC/BC for torchscript file. Summary: title Test Plan: CI Reviewers: Subscribers: Tasks: Tags: Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/75136 Approved by: https://github.com/gmagogsfm --- .jenkins/pytorch/test.sh | 13 ++++++++++++ test/create_dummy_torchscript_model.py | 28 ++++++++++++++++++++++++++ test/load_torchscript_model.py | 6 ++++++ 3 files changed, 47 insertions(+) create mode 100644 test/create_dummy_torchscript_model.py create mode 100644 test/load_torchscript_model.py diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index 1e4a9a52066..a74a3e81580 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -450,6 +450,8 @@ test_xla() { # nightly version. test_forward_backward_compatibility() { set -x + # create a dummy ts model at this version + python test/create_dummy_torchscript_model.py /tmp/model_new.pt pushd test/forward_backward_compatibility python -m venv venv # shellcheck disable=SC1091 @@ -457,10 +459,21 @@ test_forward_backward_compatibility() { pip_install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html pip show torch python dump_all_function_schemas.py --filename nightly_schemas.txt + # FC: verify newmodel can be load with old code. + if ! python ../load_torchscript_model.py /tmp/model_new.pt; then + echo "FC check failed: new model cannot be load in old code" + return 1 + fi + python ../create_dummy_torchscript_model.py /tmp/model_old.pt deactivate rm -r venv pip show torch python check_forward_backward_compatibility.py --existing-schemas nightly_schemas.txt + # BC: verify old model can be load with new code + if ! python ../load_torchscript_model.py /tmp/model_old.pt; then + echo "BC check failed: old model cannot be load in new code" + return 1 + fi popd set +x assert_git_not_dirty diff --git a/test/create_dummy_torchscript_model.py b/test/create_dummy_torchscript_model.py new file mode 100644 index 00000000000..ffd869e27f0 --- /dev/null +++ b/test/create_dummy_torchscript_model.py @@ -0,0 +1,28 @@ +# Usage: python create_dummy_model.py +import sys +import torch +from torch import nn + + +class NeuralNetwork(nn.Module): + + def __init__(self): + super(NeuralNetwork, self).__init__() + self.flatten = nn.Flatten() + self.linear_relu_stack = nn.Sequential( + nn.Linear(28 * 28, 512), + nn.ReLU(), + nn.Linear(512, 512), + nn.ReLU(), + nn.Linear(512, 10), + ) + + def forward(self, x): + x = self.flatten(x) + logits = self.linear_relu_stack(x) + return logits + + +if __name__ == '__main__': + jit_module = torch.jit.script(NeuralNetwork()) + torch.jit.save(jit_module, sys.argv[1]) diff --git a/test/load_torchscript_model.py b/test/load_torchscript_model.py new file mode 100644 index 00000000000..dc8d4159d7f --- /dev/null +++ b/test/load_torchscript_model.py @@ -0,0 +1,6 @@ +import sys +import torch + +if __name__ == '__main__': + print(torch.jit.load(sys.argv[1])) + sys.exit(0)