mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Enable jit tracing to parametrization and add jit tests (#60969)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60969 This PR fixes the tracing in the parametrizations. The current resolution is that when tracing is performed while caching is enabled, we throw an error. Without caching, the tracing should work properly (tests added). Currently, the parametrizations don't support scripting. This PR introduces the same logic as with the tracing (throw error if caching). However, the scripting itself cannot enabled due to the use of the generator expressions in the parametrizations. Added TODO to fix it. Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D29462887 Pulled By: z-a-f fbshipit-source-id: 49721d3059be58f36055d1c374080df41a748d66
This commit is contained in:
parent
4e181dfc35
commit
c1499a9933
3 changed files with 86 additions and 7 deletions
67
test/jit/test_parametrization.py
Normal file
67
test/jit/test_parametrization.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.utils.parametrize as parametrize
|
||||
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
|
||||
class TestParametrization(JitTestCase):
|
||||
# Define some parametrization
|
||||
class Symmetric(nn.Module):
|
||||
def forward(self, X):
|
||||
return X.triu() + X.triu(1).transpose(-1, -2)
|
||||
|
||||
def test_traceable(self):
|
||||
r"""Test the jit scripting and tracing of a parametrized model."""
|
||||
model = nn.Linear(5, 5)
|
||||
parametrize.register_parametrization(model, "weight", self.Symmetric())
|
||||
|
||||
x = torch.randn(3, 5)
|
||||
y = model(x)
|
||||
|
||||
# Check the tracing works. Because traced functions cannot be called
|
||||
# directly, we run the comparison on the activations.
|
||||
traced_model = torch.jit.trace_module(model, {'forward': x})
|
||||
y_hat = traced_model(x)
|
||||
self.assertEqual(y, y_hat)
|
||||
|
||||
# Check traced model works with caching
|
||||
with parametrize.cached():
|
||||
y_hat = traced_model(x)
|
||||
self.assertEqual(y, y_hat)
|
||||
|
||||
# Check the tracing throws an error when caching
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
'Cannot trace a model while caching'):
|
||||
with parametrize.cached():
|
||||
traced_model = torch.jit.trace_module(model, {'forward': x})
|
||||
|
||||
def test_scriptable(self):
|
||||
# TODO: Need to fix the scripting in parametrizations
|
||||
# Currently, all the tests below will throw UnsupportedNodeError
|
||||
model = nn.Linear(5, 5)
|
||||
parametrize.register_parametrization(model, "weight", self.Symmetric())
|
||||
|
||||
x = torch.randn(3, 5)
|
||||
y = model(x)
|
||||
|
||||
with self.assertRaises(torch.jit.frontend.UnsupportedNodeError):
|
||||
# Check scripting works
|
||||
scripted_model = torch.jit.script(model)
|
||||
y_hat = scripted_model(x)
|
||||
self.assertEqual(y, y_hat)
|
||||
|
||||
with parametrize.cached():
|
||||
# Check scripted model works when caching
|
||||
y_hat = scripted_model(x)
|
||||
self.assertEqual(y, y_hat)
|
||||
|
||||
# Check the scripting process throws an error when caching
|
||||
with self.assertRaisesRegex(RuntimeError, 'Caching is not implemented'):
|
||||
scripted_model = torch.jit.trace_module(model)
|
||||
|
|
@ -57,6 +57,7 @@ from jit.test_tensor_creation_ops import TestTensorCreationOps # noqa: F401
|
|||
from jit.test_module_apis import TestModuleAPIs # noqa: F401
|
||||
from jit.test_script_profile import TestScriptProfile # noqa: F401
|
||||
from jit.test_convert_activation import TestFunctionalToInplaceActivation, TestInplaceToFunctionalActivation # noqa: F401
|
||||
from jit.test_parametrization import TestParametrization # noqa: F401
|
||||
|
||||
# Torch
|
||||
from torch import Tensor
|
||||
|
|
|
|||
|
|
@ -307,17 +307,28 @@ def _inject_property(module: Module, tensor_name: str) -> None:
|
|||
# This should never fire if register_parametrization is correctly implemented
|
||||
assert not hasattr(module, tensor_name)
|
||||
|
||||
def get_parametrized(self) -> Tensor:
|
||||
@torch.jit.unused
|
||||
def get_cached_parametrization(parametrization) -> Tensor:
|
||||
global _cache
|
||||
key = (id(module), tensor_name)
|
||||
tensor = _cache.get(key)
|
||||
if tensor is None:
|
||||
tensor = parametrization()
|
||||
_cache[key] = tensor
|
||||
return tensor
|
||||
|
||||
def get_parametrized(self) -> Tensor:
|
||||
parametrization = self.parametrizations[tensor_name]
|
||||
if _cache_enabled:
|
||||
key = (id(module), tensor_name)
|
||||
tensor = _cache.get(key)
|
||||
if tensor is None:
|
||||
tensor = parametrization()
|
||||
_cache[key] = tensor
|
||||
return tensor
|
||||
if torch.jit.is_scripting():
|
||||
# Scripting
|
||||
raise RuntimeError('Caching is not implemented for scripting. '
|
||||
'Either disable caching or avoid scripting.')
|
||||
elif torch._C._get_tracing_state() is not None:
|
||||
# Tracing
|
||||
raise RuntimeError('Cannot trace a model while caching parametrizations.')
|
||||
else:
|
||||
return get_cached_parametrization(parametrization)
|
||||
else:
|
||||
# If caching is not active, this function just evaluates the parametrization
|
||||
return parametrization()
|
||||
|
|
|
|||
Loading…
Reference in a new issue