Add test for FC/BC for torchscript file.

Summary:
title

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75136
Approved by: https://github.com/gmagogsfm
This commit is contained in:
Han Qi 2022-04-13 23:23:13 +00:00 committed by PyTorch MergeBot
parent 7545e2a4d6
commit d65414d145
3 changed files with 47 additions and 0 deletions

View file

@ -450,6 +450,8 @@ test_xla() {
# nightly version.
test_forward_backward_compatibility() {
set -x
# create a dummy ts model at this version
python test/create_dummy_torchscript_model.py /tmp/model_new.pt
pushd test/forward_backward_compatibility
python -m venv venv
# shellcheck disable=SC1091
@ -457,10 +459,21 @@ test_forward_backward_compatibility() {
pip_install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pip show torch
python dump_all_function_schemas.py --filename nightly_schemas.txt
# FC: verify newmodel can be load with old code.
if ! python ../load_torchscript_model.py /tmp/model_new.pt; then
echo "FC check failed: new model cannot be load in old code"
return 1
fi
python ../create_dummy_torchscript_model.py /tmp/model_old.pt
deactivate
rm -r venv
pip show torch
python check_forward_backward_compatibility.py --existing-schemas nightly_schemas.txt
# BC: verify old model can be load with new code
if ! python ../load_torchscript_model.py /tmp/model_old.pt; then
echo "BC check failed: old model cannot be load in new code"
return 1
fi
popd
set +x
assert_git_not_dirty

View file

@ -0,0 +1,28 @@
# Usage: python create_dummy_model.py <name_of_the_file>
import sys
import torch
from torch import nn
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
if __name__ == '__main__':
jit_module = torch.jit.script(NeuralNetwork())
torch.jit.save(jit_module, sys.argv[1])

View file

@ -0,0 +1,6 @@
import sys
import torch
if __name__ == '__main__':
print(torch.jit.load(sys.argv[1]))
sys.exit(0)