mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129764 Approved by: https://github.com/ezyang
26 lines
764 B
Python
26 lines
764 B
Python
# Owner(s): ["oncall: jit"]
|
|
# flake8: noqa
|
|
|
|
import sys
|
|
import unittest
|
|
from enum import Enum
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
from jit.myfunction_a import my_function_a
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
|
|
class TestDecorator(JitTestCase):
|
|
def test_decorator(self):
|
|
# Note: JitTestCase.checkScript() does not work with decorators
|
|
# self.checkScript(my_function_a, (1.0,))
|
|
# Error:
|
|
# RuntimeError: expected def but found '@' here:
|
|
# @my_decorator
|
|
# ~ <--- HERE
|
|
# def my_function_a(x: float) -> float:
|
|
# Do a simple torch.jit.script() test instead
|
|
fn = my_function_a
|
|
fx = torch.jit.script(fn)
|
|
self.assertEqual(fn(1.0), fx(1.0))
|