mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Fixes Pull Request resolved: https://github.com/pytorch/pytorch/pull/101585 Approved by: https://github.com/ngimel, https://github.com/desertfire
28 lines
707 B
Python
28 lines
707 B
Python
import shutil
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
import torch._inductor
|
|
|
|
|
|
class Net(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fc = torch.nn.Linear(64, 10)
|
|
|
|
def forward(self, x, y):
|
|
return self.fc(torch.sin(x) + torch.cos(y))
|
|
|
|
|
|
x = torch.randn((32, 64), device="cuda")
|
|
y = torch.randn((32, 64), device="cuda")
|
|
|
|
for dynamic in [True, False]:
|
|
torch._dynamo.config.dynamic_shapes = dynamic
|
|
torch._dynamo.reset()
|
|
|
|
with torch.no_grad():
|
|
module, _ = torch._dynamo.export(Net().cuda(), x, y)
|
|
lib_path = torch._inductor.aot_compile(module, [x, y])
|
|
|
|
shutil.copy(lib_path, f"libaot_inductor_output{'_dynamic' if dynamic else ''}.so")
|