2024-12-30 nightly release (2ed4d65af0)

This commit is contained in:
pytorchbot 2024-12-30 07:34:20 +00:00
parent ad23baaa28
commit 40b0741e91
76 changed files with 569 additions and 187 deletions

View file

@ -7,7 +7,7 @@ import yaml
# Need to import modules that lie on an upward-relative path
sys.path.append(os.path.join(sys.path[0], ".."))
sys.path.append(os.path.dirname(sys.path[0]))
import cimodel.lib.miniyaml as miniyaml

View file

@ -22,7 +22,7 @@ TOKEN = os.environ["GITHUB_TOKEN"]
if not TOKEN:
raise Exception("GITHUB_TOKEN is not set") # noqa: TRY002
REPO_ROOT = Path(__file__).parent.parent.parent
REPO_ROOT = Path(__file__).parents[2]
# Query for all PRs instead of just closed/merged because it's faster
GRAPHQL_ALL_PRS_BY_UPDATED_AT = """

View file

@ -6,7 +6,7 @@ from pathlib import Path
import yaml
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
REPO_ROOT = Path(__file__).resolve().parents[2]
WORKFLOWS = REPO_ROOT / ".github" / "workflows"
EXPECTED_GROUP_PREFIX = (
"${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}"

View file

@ -94,7 +94,7 @@ def get_nccl_submodule_version() -> str:
from pathlib import Path
nccl_version_mk = (
Path(__file__).absolute().parent.parent.parent
Path(__file__).absolute().parents[2]
/ "third_party"
/ "nccl"
/ "nccl"

View file

@ -32,7 +32,7 @@ def get_git_remote_name() -> str:
def get_git_repo_dir() -> str:
from pathlib import Path
return os.getenv("GIT_REPO_DIR", str(Path(__file__).resolve().parent.parent.parent))
return os.getenv("GIT_REPO_DIR", str(Path(__file__).resolve().parents[2]))
def fuzzy_list_to_dict(items: List[Tuple[str, str]]) -> Dict[str, List[str]]:

View file

@ -26,7 +26,7 @@ def fn(base: str) -> str:
return str(base / Path("aten/src/ATen/native/native_functions.yaml"))
with open(Path(__file__).parent.parent.parent / fn(".")) as f:
with open(Path(__file__).parents[2] / fn(".")) as f:
contents = f.read()
yaml = ruamel.yaml.YAML() # type: ignore[attr-defined]

View file

@ -68,7 +68,7 @@ class TestRetriesDecorator(TestCase):
class TestGitRepo(TestCase):
def setUp(self) -> None:
repo_dir = BASE_DIR.parent.parent.absolute()
repo_dir = BASE_DIR.absolute().parent.parent
if not (repo_dir / ".git").is_dir():
raise SkipTest(
"Can't find git directory, make sure to run this test on real repo checkout"

View file

@ -71,7 +71,7 @@ ARTIFACTS_QUERY_URL = (
"c1cdfadc-6bb2-4a91-bbf9-3d19e1981cd4/run?format=JSON"
)
CSV_LINTER = str(
Path(__file__).absolute().parent.parent.parent.parent
Path(__file__).absolute().parents[3]
/ "tools/linter/adapters/no_merge_conflict_csv_linter.py"
)

View file

@ -7,9 +7,9 @@ import torch._prims as prims
from torchgen.gen import parse_native_yaml
ROOT = Path(__file__).absolute().parent.parent.parent.parent
NATIVE_FUNCTION_YAML_PATH = ROOT / Path("aten/src/ATen/native/native_functions.yaml")
TAGS_YAML_PATH = ROOT / Path("aten/src/ATen/native/tags.yaml")
ROOT = Path(__file__).absolute().parents[3]
NATIVE_FUNCTION_YAML_PATH = ROOT / "aten/src/ATen/native/native_functions.yaml"
TAGS_YAML_PATH = ROOT / "aten/src/ATen/native/tags.yaml"
BUILD_DIR = "build/ir"
ATEN_OPS_CSV_FILE = "aten_ops.csv"

View file

@ -15,7 +15,7 @@ from torch.ao.quantization.backend_config.utils import (
# Create a directory for the images, if it doesn't exist
QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH = os.path.join(
os.path.realpath(os.path.join(__file__, "..")), "quantization_backend_configs"
os.path.realpath(os.path.dirname(__file__)), "quantization_backend_configs"
)
if not os.path.exists(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH):

View file

@ -11,9 +11,9 @@ from torch.export import export
PWD = Path(__file__).absolute().parent
ROOT = Path(__file__).absolute().parent.parent.parent.parent
SOURCE = ROOT / Path("source")
EXPORTDB_SOURCE = SOURCE / Path("generated") / Path("exportdb")
ROOT = Path(__file__).absolute().parents[3]
SOURCE = ROOT / "source"
EXPORTDB_SOURCE = SOURCE / "generated" / "exportdb"
def generate_example_rst(example_case: ExportCase):

View file

@ -194,7 +194,7 @@ if __name__ == "__main__":
"filename",
nargs="?",
default=str(
Path(__file__).absolute().parent.parent.parent
Path(__file__).absolute().parents[2]
/ "torch/testing/_internal/dynamo_test_failures.py"
),
help="Optional path to dynamo_test_failures.py",
@ -203,7 +203,7 @@ if __name__ == "__main__":
parser.add_argument(
"test_dir",
nargs="?",
default=str(Path(__file__).absolute().parent.parent.parent / "test"),
default=str(Path(__file__).absolute().parents[2] / "test"),
help="Optional path to test folder",
)
parser.add_argument(

View file

@ -41,7 +41,7 @@ Inherits most tests from TestNNAPI, which loads Android NNAPI models
without the delegate API.
"""
# First skip is needed for IS_WINDOWS or IS_MACOS to skip the tests.
torch_root = Path(__file__).resolve().parent.parent.parent
torch_root = Path(__file__).resolve().parents[2]
lib_path = torch_root / "build" / "lib" / "libnnapi_backend.so"

View file

@ -7,6 +7,7 @@ import dataclasses
import os
import sys
import unittest
from pathlib import Path
from typing import Tuple
import onnxruntime
@ -24,7 +25,8 @@ from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import skipIfNNModuleInlined
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(str(Path(__file__).absolute().parents[1]))
import onnx_test_common

View file

@ -45,8 +45,7 @@ _InputArgsType = Optional[
_OutputsType = Sequence[_NumericType]
onnx_model_dir = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
os.pardir,
os.path.dirname(os.path.dirname(os.path.realpath(__file__))),
"repos",
"onnx",
"onnx",
@ -54,11 +53,7 @@ onnx_model_dir = os.path.join(
"test",
"data",
)
pytorch_converted_dir = os.path.join(onnx_model_dir, "pytorch-converted")
pytorch_operator_dir = os.path.join(onnx_model_dir, "pytorch-operator")

View file

@ -1,8 +1,8 @@
# Owner(s): ["module: onnx"]
from __future__ import annotations
import os
import sys
from pathlib import Path
import torch
import torch.onnx
@ -10,7 +10,8 @@ from torch.testing._internal import common_utils
from torch.utils import _pytree as torch_pytree
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(str(Path(__file__).absolute().parents[1]))
import onnx_test_common

View file

@ -76,12 +76,22 @@ def _multistep_backprop_diff_hyperparams_fn(
# This copy is necessary so the update on line 78 doesn't overwrite the original kwargs values
kwargs = kwargs.copy()
# Have to pass in beta1 and beta2 separately
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck
if "beta1" in kwargs or "beta2" in kwargs:
# Prevent just one beta kwarg from being passed in
assert (
"beta1" in kwargs and "beta2" in kwargs
), "Both betas should be defined in kwargs"
kwargs.update({"betas": (kwargs.pop("beta1"), kwargs.pop("beta2"))})
kwargs.update(
{k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
)
differentiable_kwargs = [
v for v in kwargs.values() if isinstance(v, torch.Tensor) and v.requires_grad
]
] + (list(kwargs["betas"]) if "betas" in kwargs else [])
criterion = nn.MSELoss()
@ -104,6 +114,10 @@ def _multistep_backprop_diff_hyperparams_fn(
meta_loss = loss
meta_loss.backward(inputs=(*differentiable_kwargs,), create_graph=True)
# Extra check to make sure the test properly computed a gradient for all kwargs
for kwarg in differentiable_kwargs:
assert kwarg.grad is not None
return (
(meta_loss,)
+ tuple(
@ -111,11 +125,7 @@ def _multistep_backprop_diff_hyperparams_fn(
for v in optimizer.state[params].values()
if isinstance(v, torch.Tensor) and v.requires_grad
)
+ tuple(
v
for v in kwargs.values()
if isinstance(v, torch.Tensor) and v.requires_grad
)
+ tuple(differentiable_kwargs)
)
@ -404,6 +414,276 @@ class TestDifferentiableOptimizer(TestCase):
),
)
def test_adam_differentiable_lr(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)
state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
kwargs: dict[str, Any] = {"lr": lr, "differentiable": True}
gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
Adam,
kwargs, # includes lr
*state.values(),
*kwargs.values(),
),
)
def test_adam_differentiable_weight_decay(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64)
state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
kwargs: dict[str, Any] = {"weight_decay": weight_decay, "differentiable": True}
gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
Adam,
kwargs, # includes weight_decay
*state.values(),
*kwargs.values(),
),
)
def test_adam_differentiable_betas(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
lr = torch.tensor([0.001], requires_grad=True, dtype=torch.float64)
betas = (
torch.tensor(0.9, requires_grad=True, dtype=torch.float64),
torch.tensor(0.999, requires_grad=True, dtype=torch.float64),
)
state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
# Have to pass in beta1 and beta2 separately
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck.
# In the test, this is called: kwargs.update({betas: (beta1, beta2)})
kwargs: dict[str, Any] = {
"beta1": betas[0],
"beta2": betas[1],
"lr": lr,
"differentiable": True,
}
gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
Adam,
kwargs, # includes betas
*state.values(),
*kwargs.values(),
),
)
def test_adam_differentiable_all_hyperparams(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)
weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64)
betas = (
torch.tensor(0.9, requires_grad=True, dtype=torch.float64),
torch.tensor(0.999, requires_grad=True, dtype=torch.float64),
)
state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
# Have to pass in beta1 and beta2 separately
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck.
# In the test, this is called: kwargs.update({betas: (beta1, beta2)})
kwargs: dict[str, Any] = {
"lr": lr,
"weight_decay": weight_decay,
"beta1": betas[0],
"beta2": betas[1],
"differentiable": True,
}
gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
Adam,
kwargs, # includes betas
*state.values(),
*kwargs.values(),
),
)
def test_adamw_differentiable_lr(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)
state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
kwargs: dict[str, Any] = {"lr": lr, "differentiable": True}
gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
AdamW,
kwargs, # includes lr
*state.values(),
*kwargs.values(),
),
)
def test_adamw_differentiable_weight_decay(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64)
state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
kwargs: dict[str, Any] = {"weight_decay": weight_decay, "differentiable": True}
gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
AdamW,
kwargs, # includes weight_decay
*state.values(),
*kwargs.values(),
),
)
def test_adamw_differentiable_betas(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
betas = (
torch.tensor(0.9, requires_grad=True, dtype=torch.float64),
torch.tensor(0.999, requires_grad=True, dtype=torch.float64),
)
state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
# Have to pass in beta1 and beta2 separately
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck.
# In the test, this is called: kwargs.update({betas: (beta1, beta2)})
kwargs: dict[str, Any] = {
"beta1": betas[0],
"beta2": betas[1],
"differentiable": True,
}
gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
AdamW,
kwargs, # includes betas
*state.values(),
*kwargs.values(),
),
)
def test_adamw_differentiable_all_hyperparams(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)
weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64)
betas = (
torch.tensor(0.9, requires_grad=True, dtype=torch.float64),
torch.tensor(0.999, requires_grad=True, dtype=torch.float64),
)
state = {}
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
# Have to pass in beta1 and beta2 separately
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck.
# In the test, this is called: kwargs.update({betas: (beta1, beta2)})
kwargs: dict[str, Any] = {
"lr": lr,
"weight_decay": weight_decay,
"beta1": betas[0],
"beta2": betas[1],
"differentiable": True,
}
gradcheck(
_multistep_backprop_diff_hyperparams_fn,
(
params,
grad,
state,
AdamW,
kwargs, # includes betas
*state.values(),
*kwargs.values(),
),
)
def test_differentiable_lr(self):
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)

View file

@ -2,8 +2,11 @@
import pickle
from io import BytesIO
from sys import version_info
from textwrap import dedent
from unittest import skipIf
import torch
from torch.package import PackageExporter, PackageImporter, sys_importer
from torch.testing._internal.common_utils import run_tests
@ -265,6 +268,20 @@ class TestSaveLoad(PackageTestCase):
exporter.intern("**")
exporter.save_module("package_a.use_torch_package_importer")
@skipIf(version_info >= (3, 13), "https://github.com/pytorch/pytorch/issues/142170")
def test_save_load_fp8(self):
tensor = torch.rand(20, 20).to(torch.float8_e4m3fn)
buffer = BytesIO()
with PackageExporter(buffer) as exporter:
exporter.save_pickle("fp8_model", "model.pkl", tensor)
buffer.seek(0)
importer = PackageImporter(buffer)
loaded_tensor = importer.load_pickle("fp8_model", "model.pkl")
self.assertTrue(torch.equal(tensor, loaded_tensor))
if __name__ == "__main__":
run_tests()

View file

@ -51,7 +51,7 @@ class TestQuantizationDocs(QuantizationTestCase):
"been updated to have the correct relative path between "
"test_docs.py and the docs."
)
pytorch_root = core_dir.parent.parent.parent
pytorch_root = core_dir.parents[2]
return pytorch_root / path_from_pytorch
path_to_file = get_correct_path(path_from_pytorch)

View file

@ -30,7 +30,7 @@ DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "typing"))
REVEAL_DIR = os.path.join(DATA_DIR, "reveal")
PASS_DIR = os.path.join(DATA_DIR, "pass")
FAIL_DIR = os.path.join(DATA_DIR, "fail")
MYPY_INI = os.path.join(DATA_DIR, os.pardir, os.pardir, "mypy.ini")
MYPY_INI = os.path.join(os.path.dirname(os.path.dirname(DATA_DIR)), "mypy.ini")
CACHE_DIR = os.path.join(DATA_DIR, ".mypy_cache")

2
third_party/xpu.txt vendored
View file

@ -1 +1 @@
7ecb0b1a56b65dec63837a30972a8ba6f8432477
214f33b9d969930a18656a82b5c5d8da53cdcb8e

View file

@ -4,15 +4,16 @@
import argparse
import os
import sys
from pathlib import Path
sys.path.append(
os.path.realpath(
os.path.join(
__file__, os.path.pardir, os.path.pardir, os.path.pardir, "torch", "utils"
)
)
)
# NOTE: `tools/amd_build/build_amd.py` could be a symlink.
# The behavior of `symlink / '..'` is different from `symlink.parent`.
# Use `pardir` three times rather than using `path.parents[2]`.
REPO_ROOT = (
Path(__file__).absolute() / os.path.pardir / os.path.pardir / os.path.pardir
).resolve()
sys.path.append(str(REPO_ROOT / "torch" / "utils"))
from hipify import hipify_python # type: ignore[import]
@ -53,8 +54,9 @@ parser.add_argument(
args = parser.parse_args()
# NOTE: `tools/amd_build/build_amd.py` could be a symlink.
amd_build_dir = os.path.dirname(os.path.realpath(__file__))
proj_dir = os.path.join(os.path.dirname(os.path.dirname(amd_build_dir)))
proj_dir = os.path.dirname(os.path.dirname(amd_build_dir))
if args.project_directory:
proj_dir = args.project_directory

View file

@ -1,13 +1,13 @@
import argparse
import sys
from os.path import abspath, dirname
from pathlib import Path
# By appending pytorch_root to sys.path, this module can import other torch
# By appending REPO_ROOT to sys.path, this module can import other torch
# modules even when run as a standalone script. i.e., it's okay either you
# do `python build_libtorch.py` or `python -m tools.build_libtorch`.
pytorch_root = dirname(dirname(abspath(__file__)))
sys.path.append(pytorch_root)
REPO_ROOT = Path(__file__).absolute().parent.parent
sys.path.append(str(REPO_ROOT))
from tools.build_pytorch_libs import build_pytorch
from tools.setup_helpers.cmake import CMake

View file

@ -43,9 +43,7 @@ def get_llvm_tool_path() -> str:
def get_pytorch_folder() -> str:
# TOOLS_FOLDER in oss: pytorch/tools/code_coverage
return os.path.abspath(
os.environ.get(
"PYTORCH_FOLDER", os.path.join(TOOLS_FOLDER, os.path.pardir, os.path.pardir)
)
os.environ.get("PYTORCH_FOLDER", os.path.dirname(os.path.dirname(TOOLS_FOLDER)))
)

View file

@ -2,13 +2,12 @@ from __future__ import annotations
import os
from enum import Enum
from pathlib import Path
# <project folder>
HOME_DIR = os.environ["HOME"]
TOOLS_FOLDER = os.path.join(
os.path.dirname(os.path.realpath(__file__)), os.path.pardir, os.path.pardir
)
TOOLS_FOLDER = str(Path(__file__).resolve().parents[2])
# <profile folder>

View file

@ -10,24 +10,28 @@ import glob
import io
import os
import re
import sys
from itertools import product
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import subprocess
import sys
import textwrap
from dataclasses import dataclass
from itertools import product
from pathlib import Path
from typing import Any
import yaml
from yaml.constructor import ConstructorError
from yaml.nodes import MappingNode
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader # type: ignore[assignment, misc]
REPO_ROOT = Path(__file__).absolute().parent.parent
sys.path.append(str(REPO_ROOT))
CPP_H_NAME = "spv.h"
CPP_SRC_NAME = "spv.cpp"

View file

@ -26,10 +26,7 @@ try:
PYTORCH_ROOT = result.stdout.decode("utf-8").strip()
except subprocess.CalledProcessError:
# If git is not installed, compute repo root as 3 folders up from this file
path_ = os.path.abspath(__file__)
for _ in range(4):
path_ = os.path.dirname(path_)
PYTORCH_ROOT = path_
PYTORCH_ROOT = str(Path(__file__).absolute().parents[3])
DRY_RUN = False

View file

@ -30,7 +30,7 @@ def read_sub_write(path: str, prefix_pat: str, new_default: int) -> None:
def main(args: Any) -> None:
pytorch_dir = Path(__file__).parent.parent.parent.resolve()
pytorch_dir = Path(__file__).parents[2].resolve()
onnx_dir = pytorch_dir / "third_party" / "onnx"
os.chdir(onnx_dir)

View file

@ -8,6 +8,7 @@ import platform
import sys
import sysconfig
from distutils.version import LooseVersion
from pathlib import Path
from subprocess import CalledProcessError, check_call, check_output
from typing import Any, cast
@ -173,9 +174,7 @@ class CMake:
toolset_expr = ",".join([f"{k}={v}" for k, v in toolset_dict.items()])
args.append("-T" + toolset_expr)
base_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
base_dir = str(Path(__file__).absolute().parents[2])
install_dir = os.path.join(base_dir, "torch")
_mkdir_p(install_dir)

View file

@ -1,11 +1,11 @@
# Little stub file to get BUILD.bazel to play along
import os.path
import sys
from pathlib import Path
root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, root)
REPO_ROOT = Path(__file__).absolute().parents[2]
sys.path.insert(0, str(REPO_ROOT))
import torchgen.gen

View file

@ -1,11 +1,11 @@
# Little stub file to get BUILD.bazel to play along
import os.path
import sys
from pathlib import Path
root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, root)
REPO_ROOT = Path(__file__).absolute().parents[2]
sys.path.insert(0, str(REPO_ROOT))
import tools.jit.gen_unboxing

View file

@ -15,6 +15,7 @@ try:
except ImportError:
from yaml import SafeLoader as YamlLoader # type: ignore[assignment, misc]
NATIVE_FUNCTIONS_PATH = "aten/src/ATen/native/native_functions.yaml"
TAGS_PATH = "aten/src/ATen/native/tags.yaml"
@ -110,8 +111,9 @@ def get_selector(
operators_yaml_path: str | None,
) -> Any:
# cwrap depends on pyyaml, so we can't import it earlier
root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, root)
REPO_ROOT = Path(__file__).absolute().parents[2]
sys.path.insert(0, str(REPO_ROOT))
from torchgen.selective_build.selector import SelectiveBuilder
assert not (

View file

@ -2,8 +2,9 @@ import sys
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
REPO_ROOT = Path(__file__).resolve().parents[2]
sys.path.append(str(REPO_ROOT))
from tools.stats.import_test_stats import get_test_class_times, get_test_times

View file

@ -11,7 +11,7 @@ from typing import Any, Callable, cast, Dict
from urllib.request import urlopen
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
REPO_ROOT = Path(__file__).resolve().parents[2]
def get_disabled_issues() -> list[str]:

View file

@ -10,7 +10,7 @@ from typing import Any
from unittest import mock
REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent
REPO_ROOT = Path(__file__).resolve().parents[3]
sys.path.append(str(REPO_ROOT))
from tools.test.heuristics.test_interface import TestTD

View file

@ -6,7 +6,7 @@ from pathlib import Path
from typing import Any
REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent
REPO_ROOT = Path(__file__).resolve().parents[3]
sys.path.append(str(REPO_ROOT))
import tools.testing.target_determination.heuristics.interface as interface

View file

@ -6,7 +6,7 @@ from pathlib import Path
from typing import Any
REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent
REPO_ROOT = Path(__file__).resolve().parents[3]
sys.path.append(str(REPO_ROOT))
import tools.testing.target_determination.heuristics.utils as utils

View file

@ -2,7 +2,6 @@
from __future__ import annotations
import os
import tempfile
import unittest
@ -12,10 +11,6 @@ from torchgen.gen import _GLOBAL_PARSE_NATIVE_YAML_CACHE # noqa: F401
from torchgen.gen_backend_stubs import run
path = os.path.dirname(os.path.realpath(__file__))
gen_backend_stubs_path = os.path.join(path, "../torchgen/gen_backend_stubs.py")
# gen_backend_stubs.py is an integration point that is called directly by external backends.
# The tests here are to confirm that badly formed inputs result in reasonable error messages.
class TestGenBackendStubs(expecttest.TestCase):

View file

@ -3,7 +3,7 @@ import unittest
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
REPO_ROOT = Path(__file__).resolve().parents[2]
try:
# using tools/ to optimize test run.
sys.path.append(str(REPO_ROOT))

View file

@ -8,7 +8,7 @@ from collections import defaultdict
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
REPO_ROOT = Path(__file__).resolve().parents[2]
try:
# using tools/ to optimize test run.
sys.path.append(str(REPO_ROOT))

View file

@ -10,7 +10,7 @@ from typing import Any
from unittest import mock
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
REPO_ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(REPO_ROOT))
from tools.stats.upload_metrics import add_global_metric, emit_metric, global_metrics

View file

@ -9,7 +9,7 @@ from pathlib import Path
CPP_TEST_PREFIX = "cpp"
CPP_TEST_PATH = "build/bin"
CPP_TESTS_DIR = os.path.abspath(os.getenv("CPP_TESTS_DIR", default=CPP_TEST_PATH))
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
REPO_ROOT = Path(__file__).resolve().parents[2]
def parse_test_module(test: str) -> str:

View file

@ -4,7 +4,7 @@ import sys
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
REPO_ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(REPO_ROOT))
from tools.stats.import_test_stats import (

View file

@ -12,7 +12,7 @@ from typing import Any
import yaml
REPO_ROOT = Path(__file__).parent.parent.parent
REPO_ROOT = Path(__file__).parents[2]
CONFIG_YML = REPO_ROOT / ".circleci" / "config.yml"
WORKFLOWS_DIR = REPO_ROOT / ".github" / "workflows"

View file

@ -8,7 +8,7 @@ from pathlib import Path
from typing import Any
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
REPO_ROOT = Path(__file__).resolve().parents[2]
# These tests are slow enough that it's worth calculating whether the patch
# touched any related files first. This list was manually generated, but for every

View file

@ -6,7 +6,7 @@ from pathlib import Path
from typing import Any
REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent
REPO_ROOT = Path(__file__).resolve().parents[3]
def gen_ci_artifact(included: list[Any], excluded: list[Any]) -> None:

View file

@ -17,7 +17,7 @@ from tools.testing.target_determination.heuristics.utils import (
from tools.testing.test_run import TestRun
REPO_ROOT = Path(__file__).parent.parent.parent.parent
REPO_ROOT = Path(__file__).parents[3]
keyword_synonyms: dict[str, list[str]] = {
"amp": ["mixed_precision"],

View file

@ -16,7 +16,7 @@ from tools.testing.target_determination.heuristics.utils import normalize_rating
from tools.testing.test_run import TestRun
REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent.parent
REPO_ROOT = Path(__file__).resolve().parents[4]
class LLM(HeuristicInterface):

View file

@ -20,7 +20,7 @@ from tools.testing.target_determination.heuristics.utils import (
from tools.testing.test_run import TestRun
REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent.parent
REPO_ROOT = Path(__file__).resolve().parents[4]
class PreviouslyFailedInPR(HeuristicInterface):

View file

@ -15,7 +15,8 @@ from warnings import warn
if TYPE_CHECKING:
from tools.testing.test_run import TestRun
REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent.parent
REPO_ROOT = Path(__file__).resolve().parents[4]
def python_test_file_to_test_name(tests: set[str]) -> set[str]:

View file

@ -14,7 +14,7 @@ if TYPE_CHECKING:
from collections.abc import Sequence
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
REPO_ROOT = Path(__file__).resolve().parents[2]
IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"
BUILD_ENVIRONMENT = os.getenv("BUILD_ENVIRONMENT", "")

View file

@ -2,8 +2,8 @@
import functools
import warnings
from typing import Any, Callable, List, Optional, TYPE_CHECKING, Union
from typing_extensions import deprecated
from typing import Any, Callable, List, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import deprecated, ParamSpec
import torch
import torch.utils._pytree as pytree
@ -14,6 +14,9 @@ try:
except ModuleNotFoundError:
np = None # type: ignore[assignment]
_P = ParamSpec("_P")
_R = TypeVar("_R")
if TYPE_CHECKING:
# TorchScript does not support `@deprecated`
# This is a workaround to avoid breaking TorchScript
@ -35,13 +38,13 @@ else:
return torch.compiler.is_compiling()
def wrap_inline(fn: Callable[..., Any]) -> Callable[..., Any]:
def wrap_inline(fn: Callable[_P, _R]) -> Callable[_P, _R]:
"""
Create an extra frame around fn that is not in skipfiles.
"""
@functools.wraps(fn)
def inner(*args: Any, **kwargs: Any) -> Any:
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
return fn(*args, **kwargs)
return inner
@ -61,7 +64,7 @@ def call_hook(
return result
def wrap_numpy(f: Callable[..., Any]) -> Callable[..., Any]:
def wrap_numpy(f: Callable[_P, _R]) -> Callable[_P, _R]:
r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function
from ``torch.Tensor``s to ``torch.Tensor``s.
"""
@ -69,7 +72,7 @@ def wrap_numpy(f: Callable[..., Any]) -> Callable[..., Any]:
return f
@functools.wraps(f)
def wrap(*args: Any, **kwargs: Any) -> Any:
def wrap(*args: _P.args, **kwargs: _P.kwargs) -> pytree.PyTree:
args, kwargs = pytree.tree_map_only(
torch.Tensor, lambda x: x.numpy(), (args, kwargs)
)

View file

@ -20,6 +20,7 @@ from typing import (
TypeVar,
Union,
)
from typing_extensions import ParamSpec
from unittest.mock import patch
import torch
@ -51,6 +52,8 @@ three = 3
log = logging.getLogger(__name__)
_P = ParamSpec("_P")
def clone_me(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
if x is None:
@ -407,9 +410,9 @@ def check_dynamic_shape_capture() -> bool:
return not config.assume_static_by_default
def _make_fn_with_patches(fn: Callable[..., _T], *patches: Any) -> Callable[..., _T]:
def _make_fn_with_patches(fn: Callable[_P, _T], *patches: Any) -> Callable[_P, _T]:
@functools.wraps(fn)
def _fn(*args: Any, **kwargs: Any) -> _T:
def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T:
with contextlib.ExitStack() as stack:
for module, attr, val in patches:
stack.enter_context(patch.object(module, attr, val))

View file

@ -5,6 +5,7 @@ import functools
import os
import sys
import warnings
from pathlib import Path
from types import ModuleType
from typing import Any, Callable, Dict
@ -51,15 +52,13 @@ def _reload_python_module(key, path):
def _set_triton_ptxas_path() -> None:
if os.environ.get("TRITON_PTXAS_PATH") is not None:
return
ptxas_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "bin", "ptxas")
)
if not os.path.exists(ptxas_path):
ptxas = Path(__file__).absolute().parents[1] / "bin" / "ptxas"
if not ptxas.exists():
return
if os.path.isfile(ptxas_path) and os.access(ptxas_path, os.X_OK):
os.environ["TRITON_PTXAS_PATH"] = ptxas_path
if ptxas.is_file() and os.access(ptxas, os.X_OK):
os.environ["TRITON_PTXAS_PATH"] = str(ptxas)
else:
warnings.warn(f"{ptxas_path} exists but is not an executable")
warnings.warn(f"{ptxas} exists but is not an executable")
def _worker_compile_triton(load_kernel: Callable[[], Any], extra_env: Dict[str, str]):

View file

@ -8,7 +8,8 @@ import subprocess
import time
from threading import Lock
from timeit import default_timer as timer
from typing import Any, List, Optional, Sequence
from typing import Any, Callable, List, Optional, Sequence, TypeVar
from typing_extensions import ParamSpec
logger = logging.getLogger("strobelight_function_profiler")
@ -23,6 +24,9 @@ logger.addHandler(console_handler)
logger.setLevel(logging.INFO)
logger.propagate = False
_P = ParamSpec("_P")
_R = TypeVar("_R")
class StrobelightCLIProfilerError(Exception):
"""
@ -250,7 +254,9 @@ class StrobelightCLIFunctionProfiler:
self._stop_strobelight_no_throw(collect_results=False)
return False
def profile(self, work_function: Any, *args: Any, **kwargs: Any) -> Any:
def profile(
self, work_function: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs
) -> Optional[_R]:
self.current_run_id = None
self.profile_result = None
@ -288,6 +294,7 @@ class StrobelightCLIFunctionProfiler:
self._stop_strobelight_no_throw(collect_results=False)
StrobelightCLIFunctionProfiler._lock.release()
raise error
return None
# A function decorator that wraps profile, if no profiler is provided one with
@ -297,13 +304,15 @@ class StrobelightCLIFunctionProfiler:
# @strobelight(stop_at_error=True,...)
def strobelight(
profiler: Optional[StrobelightCLIFunctionProfiler] = None, **kwargs: Any
) -> Any:
) -> Callable[[Callable[_P, _R]], Callable[_P, Optional[_R]]]:
if not profiler:
profiler = StrobelightCLIFunctionProfiler(**kwargs)
def strobelight_inner(work_function: Any) -> Any:
def strobelight_inner(
work_function: Callable[_P, _R]
) -> Callable[_P, Optional[_R]]:
@functools.wraps(work_function)
def wrapper_function(*args: Any, **kwargs: Any) -> Any:
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:
return profiler.profile(work_function, *args, **kwargs)
return wrapper_function

View file

@ -542,7 +542,7 @@ class FakeTensorConfig:
#
# Making this a descriptor may seem overly fancy, but actually it's the most
# convenient way to ensure access to FakeTensor during access, which is
# required for testing version counter and epoch validity.
# required for testing version counter and epoch validity.
class SymNumberMemoDescriptor:
_name: str
@ -763,7 +763,7 @@ class FakeTensor(Tensor):
@classmethod
@count
def __torch_dispatch__(
def __torch_dispatch__( # type: ignore[override] # TODO
cls,
func: OpOverload,
types: Sequence[Type],

View file

@ -1,7 +1,19 @@
from copy import deepcopy
from datetime import timedelta
from functools import partial, wraps
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type, Union
from typing import (
Any,
Callable,
Dict,
List,
NamedTuple,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from typing_extensions import ParamSpec, TypeVarTuple, Unpack
import torch
import torch.distributed as dist
@ -26,6 +38,10 @@ _TOTAL_KEY = "Total"
__all__ = ["FSDPMemTracker"]
_P = ParamSpec("_P")
_R = TypeVar("_R")
_Ts = TypeVarTuple("_Ts")
class _FSDPRefType(_RefType):
"""
@ -185,8 +201,8 @@ class FSDPMemTracker(MemTracker):
def _fsdp_state_pre_forward(
self,
fsdp_mod: FSDPModule,
orig_fsdp_state_pre_fw: Callable,
) -> Callable:
orig_fsdp_state_pre_fw: Callable[_P, Tuple[Tuple[Unpack[_Ts]], Dict[str, Any]]],
) -> Callable[_P, Tuple[Tuple[Unpack[_Ts]], Dict[str, Any]]]:
# We capture memory snapshots before and after ``FSDPState._pre_forward`` to attribute the `unsharded` params
# and `all_gather` buffers. There are three cases:
# Case 1: If the module is not in the ``memory_tracking`` dictionary, create a new ``_FSDPModMemStats``
@ -201,7 +217,9 @@ class FSDPMemTracker(MemTracker):
# For Case 1 and 3, we also initialiaze the ``local_peak`` and ``PEAK_FW`` snapshot for the module.
# For Case 2 we only capture 1 snapshot after ``FSDPState._pre_forward`` runs because it is a no-op.
@wraps(orig_fsdp_state_pre_fw)
def inner(*args: Any, **kwargs: Any) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
def inner(
*args: _P.args, **kwargs: _P.kwargs
) -> Tuple[Tuple[Unpack[_Ts]], Dict[str, Any]]:
mod_fqn = self._mod_tracker.get_known_fqn(fsdp_mod)
assert mod_fqn is not None
if fsdp_mod not in self.memory_tracking:
@ -251,15 +269,15 @@ class FSDPMemTracker(MemTracker):
def _fsdp_state_post_forward(
self,
fsdp_mod: FSDPModule,
orig_fsdp_state_post_fw: Callable,
) -> Callable:
orig_fsdp_state_post_fw: Callable[_P, _R],
) -> Callable[_P, _R]:
# We capture memory snapshots before and after ``FSDPState._post_forward`` to capture the resharded state
# if ``reshard_after_forward`` is not ``False``. There are two cases:
# Case 1: This is called in backward, which means we are in the AC region. If this is the top most module
# in the AC region, we set the flag ``_in_ac`` to False.
# Case 2: This is called in forward.
@wraps(orig_fsdp_state_post_fw)
def inner(*args: Any, **kwargs: Any) -> Any:
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
mod_stat = self.memory_tracking[fsdp_mod]
if self._mod_tracker.is_bw:
state = _FSDPModState.POST_FW_AC
@ -283,12 +301,12 @@ class FSDPMemTracker(MemTracker):
def _fsdp_param_group_pre_backward(
self,
fsdp_mod: FSDPModule,
orig_fsdp_param_group_pre_backward: Callable,
) -> Callable:
orig_fsdp_param_group_pre_backward: Callable[_P, Any],
) -> Callable[_P, None]:
# We capture memory snapshots before and after ``FSDPParamGroup.pre_backward`` to capture the pre-fetching
# and unsharding of params. We also initialize ``local_peak`` and ``PEAK_BW`` snapshot for the module.
@wraps(orig_fsdp_param_group_pre_backward)
def inner(*args: Any, **kwargs: Any) -> None:
def inner(*args: _P.args, **kwargs: _P.kwargs) -> None:
mod_stat = self.memory_tracking[fsdp_mod]
snapshot = self.get_tracker_snapshot()
mod_stat.local_peak = {
@ -309,13 +327,13 @@ class FSDPMemTracker(MemTracker):
def _fsdp_param_group_post_backward(
self,
fsdp_mod: FSDPModule,
orig_fsdp_param_group_post_backward: Callable,
) -> Callable:
orig_fsdp_param_group_post_backward: Callable[_P, Any],
) -> Callable[_P, None]:
# We capture the memory snapshots before and after ``FSDPParamGroup.post_backward`` to track and attribute
# the `unsharded` grads before the post backward and then `sharded` grads and `reduce_scatter` buffers
# after the post backward.
@wraps(orig_fsdp_param_group_post_backward)
def inner(*args: Any, **kwargs: Any) -> None:
def inner(*args: _P.args, **kwargs: _P.kwargs) -> None:
fsdp_state = fsdp_mod._get_fsdp_state()
if fsdp_param_group := fsdp_state._fsdp_param_group:
for fsdp_param in fsdp_param_group.fsdp_params:

View file

@ -1,6 +1,5 @@
"""Utility to lazily import modules."""
# mypy: allow-untyped-defs
from __future__ import annotations
import importlib
@ -17,7 +16,7 @@ class _LazyModule:
def __repr__(self) -> str:
return f"<lazy module '{self._name}'>"
def __getattr__(self, attr):
def __getattr__(self, attr: str) -> object:
if self._module is None:
self._module = importlib.import_module(".", self._name)
return getattr(self._module, attr)

View file

@ -402,7 +402,14 @@ def _single_tensor_adam(
# Perform stepweight decay
param.mul_(1 - lr * weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
# Nested if is necessary to bypass jitscript rules
if differentiable and isinstance(weight_decay, Tensor):
if weight_decay.requires_grad:
grad = grad.addcmul_(param.clone(), weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
if torch.is_complex(param):
grad = torch.view_as_real(grad)
@ -429,13 +436,43 @@ def _single_tensor_adam(
# Decay the first and second moment running average coefficient
exp_avg.lerp_(grad, 1 - device_beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
# Nested if is necessary to bypass jitscript rules
if differentiable and isinstance(beta2, Tensor):
if beta2.requires_grad:
# Using lerp to only use 2 operations bc addcmul's value cannot be a tensor
# Showing equivalence of differentiable path and nondifferentiable path
# expavg * b2 + grad^2 * (1-b2)
# add expavg * (1-b2) - expavg * (1-b2) = 0
# expavg * b2 + expavg * (1-b2) - expavg * (1-b2) + grad^2 * (1-b2)
# expavg - expavg * (1-b2) + grad^2 * (1-b2)
# expavg + (grad^2 - expavg) * (1-b2)
# expavg.lerp(grad^2, 1-beta2)
exp_avg_sq.lerp_(torch.square(grad), weight=1 - beta2)
else:
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
else:
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if capturable or differentiable:
step = step_t
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
# Nested if is necessary to bypass jitscript rules
if differentiable and isinstance(beta1, Tensor):
if beta1.requires_grad:
bias_correction1 = 1 - beta1 ** step.clone()
else:
bias_correction1 = 1 - beta1**step
else:
bias_correction1 = 1 - beta1**step
# Nested if is necessary to bypass jitscript rules
if differentiable and isinstance(beta2, Tensor):
if beta2.requires_grad:
bias_correction2 = 1 - beta2 ** step.clone()
else:
bias_correction2 = 1 - beta2**step
else:
bias_correction2 = 1 - beta2**step
step_size = lr / bias_correction1
step_size_neg = step_size.neg()
@ -462,7 +499,10 @@ def _single_tensor_adam(
exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)
).add_(eps / step_size_neg)
param.addcdiv_(exp_avg, denom)
if differentiable:
param.addcdiv_(exp_avg.clone(), denom)
else:
param.addcdiv_(exp_avg, denom)
else:
step = _get_value(step_t)

View file

@ -260,7 +260,10 @@ class PackageImporter(Importer):
if typename == "storage":
storage_type, key, location, size = data
dtype = storage_type.dtype
if storage_type is torch.UntypedStorage:
dtype = torch.uint8
else:
dtype = storage_type.dtype
if key not in loaded_storages:
load_tensor(

View file

@ -5008,7 +5008,7 @@ def find_library_location(lib_name: str) -> Path:
path = torch_root / 'lib' / lib_name
if os.path.exists(path):
return path
torch_root = Path(__file__).resolve().parent.parent.parent
torch_root = Path(__file__).resolve().parents[2]
return torch_root / 'build' / 'lib' / lib_name
def skip_but_pass_in_sandcastle(reason):

View file

@ -3,17 +3,23 @@
# AND SCRUB AWAY TORCH NOTIONS THERE.
import collections
import functools
from typing import Any, Callable, OrderedDict
from typing import Callable, OrderedDict, TypeVar
from typing_extensions import ParamSpec
simple_call_counter: OrderedDict[str, int] = collections.OrderedDict()
_P = ParamSpec("_P")
_R = TypeVar("_R")
def count_label(label: str) -> None:
prev = simple_call_counter.setdefault(label, 0)
simple_call_counter[label] = prev + 1
def count(fn: Callable[..., Any]) -> Callable[..., Any]:
def count(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
if fn.__qualname__ not in simple_call_counter:
simple_call_counter[fn.__qualname__] = 0
simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1

View file

@ -7,7 +7,8 @@ import re
import subprocess
import time
from threading import Lock
from typing import Any, List, Optional, Sequence
from typing import Any, Callable, List, Optional, Sequence, TypeVar
from typing_extensions import ParamSpec
logger = logging.getLogger("strobelight_function_profiler")
@ -22,6 +23,9 @@ logger.addHandler(console_handler)
logger.setLevel(logging.INFO)
logger.propagate = False
_P = ParamSpec("_P")
_R = TypeVar("_R")
class StrobelightCLIProfilerError(Exception):
"""
@ -246,7 +250,9 @@ class StrobelightCLIFunctionProfiler:
self._stop_strobelight_no_throw(collect_results=False)
return False
def profile(self, work_function: Any, *args: Any, **kwargs: Any) -> Any:
def profile(
self, work_function: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs
) -> Optional[_R]:
self.current_run_id = None
if locked := StrobelightCLIFunctionProfiler._lock.acquire(False):
@ -279,6 +285,7 @@ class StrobelightCLIFunctionProfiler:
self._stop_strobelight_no_throw(collect_results=False)
StrobelightCLIFunctionProfiler._lock.release()
raise error
return None
# A function decorator that wraps profile, if no profiler is provided one with
@ -288,13 +295,15 @@ class StrobelightCLIFunctionProfiler:
# @strobelight(stop_at_error=True,...)
def strobelight(
profiler: Optional[StrobelightCLIFunctionProfiler] = None, **kwargs: Any
) -> Any:
) -> Callable[[Callable[_P, _R]], Callable[_P, Optional[_R]]]:
if not profiler:
profiler = StrobelightCLIFunctionProfiler(**kwargs)
def strobelight_inner(work_function: Any) -> Any:
def strobelight_inner(
work_function: Callable[_P, _R]
) -> Callable[_P, Optional[_R]]:
@functools.wraps(work_function)
def wrapper_function(*args: Any, **kwargs: Any) -> Any:
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:
return profiler.profile(work_function, *args, **kwargs)
return wrapper_function

View file

@ -4,7 +4,6 @@ import math
import operator
import sys
from typing import (
Any,
Callable,
Iterable,
List,
@ -14,6 +13,7 @@ from typing import (
TypeVar,
Union,
)
from typing_extensions import TypeVarTuple, Unpack
import sympy
from sympy import S
@ -32,6 +32,7 @@ from .numbers import int_oo
_T = TypeVar("_T", bound=SupportsFloat)
_Ts = TypeVarTuple("_Ts")
# Portions of this file are adapted from the Sympy codebase, which was
# licensed as follows:
@ -101,9 +102,11 @@ def _is_symbols_binary_summation(expr: sympy.Expr) -> bool:
)
def _keep_float(f: Callable[..., _T]) -> Callable[..., Union[_T, sympy.Float]]:
def _keep_float(
f: Callable[[Unpack[_Ts]], _T]
) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]:
@functools.wraps(f)
def inner(*args: Any) -> Union[_T, sympy.Float]:
def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]:
r: Union[_T, sympy.Float] = f(*args)
if any(isinstance(a, sympy.Float) for a in args) and not isinstance(
r, sympy.Float

View file

@ -1,13 +1,12 @@
# mypy: ignore-errors
import os
import random
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from pathlib import Path
from typing import Any
sys.path.append(str(Path(__file__).absolute().parents[1]))
from benchmark_runner import BenchmarkRunner # type: ignore[import-not-found]
from benchmark_utils import ( # type: ignore[import-not-found]
fits_in_memory,

View file

@ -1,12 +1,12 @@
import os
import sys
import unittest
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from pathlib import Path
from expecttest import TestCase
sys.path.append(str(Path(__file__).absolute().parents[1]))
from test_utils import read_file_to_string, run_bash # type: ignore[import-not-found]

View file

@ -1,9 +1,9 @@
# mypy: ignore-errors
import os
import sys
from pathlib import Path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(str(Path(__file__).absolute().parents[1]))
from train_decision import AHTrainDecisionTree

View file

@ -1,13 +1,12 @@
import itertools
import os
import random
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from pathlib import Path
from typing import Any
sys.path.append(str(Path(__file__).absolute().parents[1]))
from benchmark_runner import BenchmarkRunner # type: ignore[import-not-found]
from benchmark_utils import ( # type: ignore[import-not-found]
fits_in_memory,

View file

@ -1,11 +1,11 @@
# mypy: ignore-errors
import os
import sys
from pathlib import Path
import pandas as pd # type: ignore[import-untyped]
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(str(Path(__file__).absolute().parents[1]))
from train_decision import AHTrainDecisionTree

View file

@ -1,12 +1,11 @@
import os
import random
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from pathlib import Path
from typing import Any
sys.path.append(str(Path(__file__).absolute().parents[1]))
from benchmark_runner import BenchmarkRunner # type: ignore[import-not-found]
from benchmark_utils import ( # type: ignore[import-not-found]
fits_in_memory,

View file

@ -1,12 +1,12 @@
import os
import sys
import unittest
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from pathlib import Path
from expecttest import TestCase
sys.path.append(str(Path(__file__).absolute().parents[1]))
from test_utils import read_file_to_string, run_bash # type: ignore[import-not-found]

View file

@ -1,9 +1,9 @@
# mypy: ignore-errors
import os
import sys
from pathlib import Path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(str(Path(__file__).absolute().parents[1]))
from train_decision import AHTrainDecisionTree

View file

@ -1,9 +1,9 @@
# mypy: ignore-errors
import os
import sys
from pathlib import Path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(str(Path(__file__).absolute().parents[1]))
from train_regression import AHTrainRegressionTree

View file

@ -1,9 +1,9 @@
# mypy: ignore-errors
import os
import sys
from pathlib import Path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(str(Path(__file__).absolute().parents[1]))
from train_regression import AHTrainRegressionTree

View file

@ -533,7 +533,7 @@ def run(
source_yaml: str, output_dir: str, dry_run: bool, impl_path: str | None = None
) -> None:
# Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
pytorch_root = Path(__file__).parent.parent.absolute()
pytorch_root = Path(__file__).absolute().parent.parent
template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")
def make_file_manager(install_dir: str) -> FileManager:

View file

@ -256,7 +256,7 @@ def main() -> None:
options = parser.parse_args()
# Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
torch_root = Path(__file__).parent.parent.parent.absolute()
torch_root = Path(__file__).absolute().parents[2]
aten_path = str(torch_root / "aten" / "src" / "ATen")
lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator
if options.gen_ts_lowerings: