pytorch/benchmarks/framework_overhead_benchmark/SimpleAddModule.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

20 lines
371 B
Python
Raw Normal View History

from utils import NUM_LOOP_ITERS
import torch
def add_tensors_loop(x, y):
z = torch.add(x, y)
for i in range(NUM_LOOP_ITERS):
z = torch.add(z, x)
return z
class SimpleAddModule(torch.nn.Module):
def __init__(self, add_op):
super().__init__()
self.add_op = add_op
def forward(self, x, y):
return self.add_op(x, y)