mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add test for FC/BC for torchscript file.
Summary: title Test Plan: CI Reviewers: Subscribers: Tasks: Tags: Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/75136 Approved by: https://github.com/gmagogsfm
This commit is contained in:
parent
7545e2a4d6
commit
d65414d145
3 changed files with 47 additions and 0 deletions
|
|
@ -450,6 +450,8 @@ test_xla() {
|
|||
# nightly version.
|
||||
test_forward_backward_compatibility() {
|
||||
set -x
|
||||
# create a dummy ts model at this version
|
||||
python test/create_dummy_torchscript_model.py /tmp/model_new.pt
|
||||
pushd test/forward_backward_compatibility
|
||||
python -m venv venv
|
||||
# shellcheck disable=SC1091
|
||||
|
|
@ -457,10 +459,21 @@ test_forward_backward_compatibility() {
|
|||
pip_install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||
pip show torch
|
||||
python dump_all_function_schemas.py --filename nightly_schemas.txt
|
||||
# FC: verify newmodel can be load with old code.
|
||||
if ! python ../load_torchscript_model.py /tmp/model_new.pt; then
|
||||
echo "FC check failed: new model cannot be load in old code"
|
||||
return 1
|
||||
fi
|
||||
python ../create_dummy_torchscript_model.py /tmp/model_old.pt
|
||||
deactivate
|
||||
rm -r venv
|
||||
pip show torch
|
||||
python check_forward_backward_compatibility.py --existing-schemas nightly_schemas.txt
|
||||
# BC: verify old model can be load with new code
|
||||
if ! python ../load_torchscript_model.py /tmp/model_old.pt; then
|
||||
echo "BC check failed: old model cannot be load in new code"
|
||||
return 1
|
||||
fi
|
||||
popd
|
||||
set +x
|
||||
assert_git_not_dirty
|
||||
|
|
|
|||
28
test/create_dummy_torchscript_model.py
Normal file
28
test/create_dummy_torchscript_model.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
# Usage: python create_dummy_model.py <name_of_the_file>
|
||||
import sys
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class NeuralNetwork(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(NeuralNetwork, self).__init__()
|
||||
self.flatten = nn.Flatten()
|
||||
self.linear_relu_stack = nn.Sequential(
|
||||
nn.Linear(28 * 28, 512),
|
||||
nn.ReLU(),
|
||||
nn.Linear(512, 512),
|
||||
nn.ReLU(),
|
||||
nn.Linear(512, 10),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.flatten(x)
|
||||
logits = self.linear_relu_stack(x)
|
||||
return logits
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
jit_module = torch.jit.script(NeuralNetwork())
|
||||
torch.jit.save(jit_module, sys.argv[1])
|
||||
6
test/load_torchscript_model.py
Normal file
6
test/load_torchscript_model.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
import sys
|
||||
import torch
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(torch.jit.load(sys.argv[1]))
|
||||
sys.exit(0)
|
||||
Loading…
Reference in a new issue