Enable jit tracing to parametrization and add jit tests (#60969)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60969

This PR fixes the tracing in the parametrizations.
The current resolution is that when tracing is performed while caching is enabled, we throw an error.
Without caching, the tracing should work properly (tests added).

Currently, the parametrizations don't support scripting.
This PR introduces the same logic as with the tracing (throw error if caching).
However, the scripting itself cannot enabled due to the use of the generator expressions in the parametrizations.
Added TODO to fix it.

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D29462887

Pulled By: z-a-f

fbshipit-source-id: 49721d3059be58f36055d1c374080df41a748d66
This commit is contained in:
Zafar 2021-06-30 23:52:31 -07:00 committed by Facebook GitHub Bot
parent 4e181dfc35
commit c1499a9933
3 changed files with 86 additions and 7 deletions

View file

@ -0,0 +1,67 @@
import torch
from torch import nn
import torch.nn.utils.parametrize as parametrize
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestParametrization(JitTestCase):
# Define some parametrization
class Symmetric(nn.Module):
def forward(self, X):
return X.triu() + X.triu(1).transpose(-1, -2)
def test_traceable(self):
r"""Test the jit scripting and tracing of a parametrized model."""
model = nn.Linear(5, 5)
parametrize.register_parametrization(model, "weight", self.Symmetric())
x = torch.randn(3, 5)
y = model(x)
# Check the tracing works. Because traced functions cannot be called
# directly, we run the comparison on the activations.
traced_model = torch.jit.trace_module(model, {'forward': x})
y_hat = traced_model(x)
self.assertEqual(y, y_hat)
# Check traced model works with caching
with parametrize.cached():
y_hat = traced_model(x)
self.assertEqual(y, y_hat)
# Check the tracing throws an error when caching
with self.assertRaisesRegex(RuntimeError,
'Cannot trace a model while caching'):
with parametrize.cached():
traced_model = torch.jit.trace_module(model, {'forward': x})
def test_scriptable(self):
# TODO: Need to fix the scripting in parametrizations
# Currently, all the tests below will throw UnsupportedNodeError
model = nn.Linear(5, 5)
parametrize.register_parametrization(model, "weight", self.Symmetric())
x = torch.randn(3, 5)
y = model(x)
with self.assertRaises(torch.jit.frontend.UnsupportedNodeError):
# Check scripting works
scripted_model = torch.jit.script(model)
y_hat = scripted_model(x)
self.assertEqual(y, y_hat)
with parametrize.cached():
# Check scripted model works when caching
y_hat = scripted_model(x)
self.assertEqual(y, y_hat)
# Check the scripting process throws an error when caching
with self.assertRaisesRegex(RuntimeError, 'Caching is not implemented'):
scripted_model = torch.jit.trace_module(model)

View file

@ -57,6 +57,7 @@ from jit.test_tensor_creation_ops import TestTensorCreationOps # noqa: F401
from jit.test_module_apis import TestModuleAPIs # noqa: F401
from jit.test_script_profile import TestScriptProfile # noqa: F401
from jit.test_convert_activation import TestFunctionalToInplaceActivation, TestInplaceToFunctionalActivation # noqa: F401
from jit.test_parametrization import TestParametrization # noqa: F401
# Torch
from torch import Tensor

View file

@ -307,17 +307,28 @@ def _inject_property(module: Module, tensor_name: str) -> None:
# This should never fire if register_parametrization is correctly implemented
assert not hasattr(module, tensor_name)
def get_parametrized(self) -> Tensor:
@torch.jit.unused
def get_cached_parametrization(parametrization) -> Tensor:
global _cache
key = (id(module), tensor_name)
tensor = _cache.get(key)
if tensor is None:
tensor = parametrization()
_cache[key] = tensor
return tensor
def get_parametrized(self) -> Tensor:
parametrization = self.parametrizations[tensor_name]
if _cache_enabled:
key = (id(module), tensor_name)
tensor = _cache.get(key)
if tensor is None:
tensor = parametrization()
_cache[key] = tensor
return tensor
if torch.jit.is_scripting():
# Scripting
raise RuntimeError('Caching is not implemented for scripting. '
'Either disable caching or avoid scripting.')
elif torch._C._get_tracing_state() is not None:
# Tracing
raise RuntimeError('Cannot trace a model while caching parametrizations.')
else:
return get_cached_parametrization(parametrization)
else:
# If caching is not active, this function just evaluates the parametrization
return parametrization()