mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
2024-12-30 nightly release (2ed4d65af0)
This commit is contained in:
parent
ad23baaa28
commit
40b0741e91
76 changed files with 569 additions and 187 deletions
|
|
@ -7,7 +7,7 @@ import yaml
|
|||
|
||||
|
||||
# Need to import modules that lie on an upward-relative path
|
||||
sys.path.append(os.path.join(sys.path[0], ".."))
|
||||
sys.path.append(os.path.dirname(sys.path[0]))
|
||||
|
||||
import cimodel.lib.miniyaml as miniyaml
|
||||
|
||||
|
|
|
|||
2
.github/scripts/delete_old_branches.py
vendored
2
.github/scripts/delete_old_branches.py
vendored
|
|
@ -22,7 +22,7 @@ TOKEN = os.environ["GITHUB_TOKEN"]
|
|||
if not TOKEN:
|
||||
raise Exception("GITHUB_TOKEN is not set") # noqa: TRY002
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent
|
||||
REPO_ROOT = Path(__file__).parents[2]
|
||||
|
||||
# Query for all PRs instead of just closed/merged because it's faster
|
||||
GRAPHQL_ALL_PRS_BY_UPDATED_AT = """
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from pathlib import Path
|
|||
import yaml
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
WORKFLOWS = REPO_ROOT / ".github" / "workflows"
|
||||
EXPECTED_GROUP_PREFIX = (
|
||||
"${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}"
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ def get_nccl_submodule_version() -> str:
|
|||
from pathlib import Path
|
||||
|
||||
nccl_version_mk = (
|
||||
Path(__file__).absolute().parent.parent.parent
|
||||
Path(__file__).absolute().parents[2]
|
||||
/ "third_party"
|
||||
/ "nccl"
|
||||
/ "nccl"
|
||||
|
|
|
|||
2
.github/scripts/gitutils.py
vendored
2
.github/scripts/gitutils.py
vendored
|
|
@ -32,7 +32,7 @@ def get_git_remote_name() -> str:
|
|||
def get_git_repo_dir() -> str:
|
||||
from pathlib import Path
|
||||
|
||||
return os.getenv("GIT_REPO_DIR", str(Path(__file__).resolve().parent.parent.parent))
|
||||
return os.getenv("GIT_REPO_DIR", str(Path(__file__).resolve().parents[2]))
|
||||
|
||||
|
||||
def fuzzy_list_to_dict(items: List[Tuple[str, str]]) -> Dict[str, List[str]]:
|
||||
|
|
|
|||
2
.github/scripts/lint_native_functions.py
vendored
2
.github/scripts/lint_native_functions.py
vendored
|
|
@ -26,7 +26,7 @@ def fn(base: str) -> str:
|
|||
return str(base / Path("aten/src/ATen/native/native_functions.yaml"))
|
||||
|
||||
|
||||
with open(Path(__file__).parent.parent.parent / fn(".")) as f:
|
||||
with open(Path(__file__).parents[2] / fn(".")) as f:
|
||||
contents = f.read()
|
||||
|
||||
yaml = ruamel.yaml.YAML() # type: ignore[attr-defined]
|
||||
|
|
|
|||
2
.github/scripts/test_gitutils.py
vendored
2
.github/scripts/test_gitutils.py
vendored
|
|
@ -68,7 +68,7 @@ class TestRetriesDecorator(TestCase):
|
|||
|
||||
class TestGitRepo(TestCase):
|
||||
def setUp(self) -> None:
|
||||
repo_dir = BASE_DIR.parent.parent.absolute()
|
||||
repo_dir = BASE_DIR.absolute().parent.parent
|
||||
if not (repo_dir / ".git").is_dir():
|
||||
raise SkipTest(
|
||||
"Can't find git directory, make sure to run this test on real repo checkout"
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ ARTIFACTS_QUERY_URL = (
|
|||
"c1cdfadc-6bb2-4a91-bbf9-3d19e1981cd4/run?format=JSON"
|
||||
)
|
||||
CSV_LINTER = str(
|
||||
Path(__file__).absolute().parent.parent.parent.parent
|
||||
Path(__file__).absolute().parents[3]
|
||||
/ "tools/linter/adapters/no_merge_conflict_csv_linter.py"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,9 +7,9 @@ import torch._prims as prims
|
|||
from torchgen.gen import parse_native_yaml
|
||||
|
||||
|
||||
ROOT = Path(__file__).absolute().parent.parent.parent.parent
|
||||
NATIVE_FUNCTION_YAML_PATH = ROOT / Path("aten/src/ATen/native/native_functions.yaml")
|
||||
TAGS_YAML_PATH = ROOT / Path("aten/src/ATen/native/tags.yaml")
|
||||
ROOT = Path(__file__).absolute().parents[3]
|
||||
NATIVE_FUNCTION_YAML_PATH = ROOT / "aten/src/ATen/native/native_functions.yaml"
|
||||
TAGS_YAML_PATH = ROOT / "aten/src/ATen/native/tags.yaml"
|
||||
|
||||
BUILD_DIR = "build/ir"
|
||||
ATEN_OPS_CSV_FILE = "aten_ops.csv"
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from torch.ao.quantization.backend_config.utils import (
|
|||
|
||||
# Create a directory for the images, if it doesn't exist
|
||||
QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH = os.path.join(
|
||||
os.path.realpath(os.path.join(__file__, "..")), "quantization_backend_configs"
|
||||
os.path.realpath(os.path.dirname(__file__)), "quantization_backend_configs"
|
||||
)
|
||||
|
||||
if not os.path.exists(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH):
|
||||
|
|
|
|||
|
|
@ -11,9 +11,9 @@ from torch.export import export
|
|||
|
||||
|
||||
PWD = Path(__file__).absolute().parent
|
||||
ROOT = Path(__file__).absolute().parent.parent.parent.parent
|
||||
SOURCE = ROOT / Path("source")
|
||||
EXPORTDB_SOURCE = SOURCE / Path("generated") / Path("exportdb")
|
||||
ROOT = Path(__file__).absolute().parents[3]
|
||||
SOURCE = ROOT / "source"
|
||||
EXPORTDB_SOURCE = SOURCE / "generated" / "exportdb"
|
||||
|
||||
|
||||
def generate_example_rst(example_case: ExportCase):
|
||||
|
|
|
|||
|
|
@ -194,7 +194,7 @@ if __name__ == "__main__":
|
|||
"filename",
|
||||
nargs="?",
|
||||
default=str(
|
||||
Path(__file__).absolute().parent.parent.parent
|
||||
Path(__file__).absolute().parents[2]
|
||||
/ "torch/testing/_internal/dynamo_test_failures.py"
|
||||
),
|
||||
help="Optional path to dynamo_test_failures.py",
|
||||
|
|
@ -203,7 +203,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument(
|
||||
"test_dir",
|
||||
nargs="?",
|
||||
default=str(Path(__file__).absolute().parent.parent.parent / "test"),
|
||||
default=str(Path(__file__).absolute().parents[2] / "test"),
|
||||
help="Optional path to test folder",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ Inherits most tests from TestNNAPI, which loads Android NNAPI models
|
|||
without the delegate API.
|
||||
"""
|
||||
# First skip is needed for IS_WINDOWS or IS_MACOS to skip the tests.
|
||||
torch_root = Path(__file__).resolve().parent.parent.parent
|
||||
torch_root = Path(__file__).resolve().parents[2]
|
||||
lib_path = torch_root / "build" / "lib" / "libnnapi_backend.so"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import dataclasses
|
|||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import onnxruntime
|
||||
|
|
@ -24,7 +25,8 @@ from torch.testing._internal import common_utils
|
|||
from torch.testing._internal.common_utils import skipIfNNModuleInlined
|
||||
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(str(Path(__file__).absolute().parents[1]))
|
||||
|
||||
import onnx_test_common
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -45,8 +45,7 @@ _InputArgsType = Optional[
|
|||
_OutputsType = Sequence[_NumericType]
|
||||
|
||||
onnx_model_dir = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
os.pardir,
|
||||
os.path.dirname(os.path.dirname(os.path.realpath(__file__))),
|
||||
"repos",
|
||||
"onnx",
|
||||
"onnx",
|
||||
|
|
@ -54,11 +53,7 @@ onnx_model_dir = os.path.join(
|
|||
"test",
|
||||
"data",
|
||||
)
|
||||
|
||||
|
||||
pytorch_converted_dir = os.path.join(onnx_model_dir, "pytorch-converted")
|
||||
|
||||
|
||||
pytorch_operator_dir = os.path.join(onnx_model_dir, "pytorch-operator")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
# Owner(s): ["module: onnx"]
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.onnx
|
||||
|
|
@ -10,7 +10,8 @@ from torch.testing._internal import common_utils
|
|||
from torch.utils import _pytree as torch_pytree
|
||||
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(str(Path(__file__).absolute().parents[1]))
|
||||
|
||||
import onnx_test_common
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -76,12 +76,22 @@ def _multistep_backprop_diff_hyperparams_fn(
|
|||
|
||||
# This copy is necessary so the update on line 78 doesn't overwrite the original kwargs values
|
||||
kwargs = kwargs.copy()
|
||||
|
||||
# Have to pass in beta1 and beta2 separately
|
||||
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck
|
||||
if "beta1" in kwargs or "beta2" in kwargs:
|
||||
# Prevent just one beta kwarg from being passed in
|
||||
assert (
|
||||
"beta1" in kwargs and "beta2" in kwargs
|
||||
), "Both betas should be defined in kwargs"
|
||||
kwargs.update({"betas": (kwargs.pop("beta1"), kwargs.pop("beta2"))})
|
||||
|
||||
kwargs.update(
|
||||
{k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
)
|
||||
differentiable_kwargs = [
|
||||
v for v in kwargs.values() if isinstance(v, torch.Tensor) and v.requires_grad
|
||||
]
|
||||
] + (list(kwargs["betas"]) if "betas" in kwargs else [])
|
||||
|
||||
criterion = nn.MSELoss()
|
||||
|
||||
|
|
@ -104,6 +114,10 @@ def _multistep_backprop_diff_hyperparams_fn(
|
|||
meta_loss = loss
|
||||
meta_loss.backward(inputs=(*differentiable_kwargs,), create_graph=True)
|
||||
|
||||
# Extra check to make sure the test properly computed a gradient for all kwargs
|
||||
for kwarg in differentiable_kwargs:
|
||||
assert kwarg.grad is not None
|
||||
|
||||
return (
|
||||
(meta_loss,)
|
||||
+ tuple(
|
||||
|
|
@ -111,11 +125,7 @@ def _multistep_backprop_diff_hyperparams_fn(
|
|||
for v in optimizer.state[params].values()
|
||||
if isinstance(v, torch.Tensor) and v.requires_grad
|
||||
)
|
||||
+ tuple(
|
||||
v
|
||||
for v in kwargs.values()
|
||||
if isinstance(v, torch.Tensor) and v.requires_grad
|
||||
)
|
||||
+ tuple(differentiable_kwargs)
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -404,6 +414,276 @@ class TestDifferentiableOptimizer(TestCase):
|
|||
),
|
||||
)
|
||||
|
||||
def test_adam_differentiable_lr(self):
|
||||
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
|
||||
lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)
|
||||
|
||||
state = {}
|
||||
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
||||
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["max_exp_avg_sq"] = torch.rand(
|
||||
10, requires_grad=True, dtype=torch.float64
|
||||
)
|
||||
kwargs: dict[str, Any] = {"lr": lr, "differentiable": True}
|
||||
|
||||
gradcheck(
|
||||
_multistep_backprop_diff_hyperparams_fn,
|
||||
(
|
||||
params,
|
||||
grad,
|
||||
state,
|
||||
Adam,
|
||||
kwargs, # includes lr
|
||||
*state.values(),
|
||||
*kwargs.values(),
|
||||
),
|
||||
)
|
||||
|
||||
def test_adam_differentiable_weight_decay(self):
|
||||
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
|
||||
weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64)
|
||||
|
||||
state = {}
|
||||
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
||||
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["max_exp_avg_sq"] = torch.rand(
|
||||
10, requires_grad=True, dtype=torch.float64
|
||||
)
|
||||
kwargs: dict[str, Any] = {"weight_decay": weight_decay, "differentiable": True}
|
||||
|
||||
gradcheck(
|
||||
_multistep_backprop_diff_hyperparams_fn,
|
||||
(
|
||||
params,
|
||||
grad,
|
||||
state,
|
||||
Adam,
|
||||
kwargs, # includes weight_decay
|
||||
*state.values(),
|
||||
*kwargs.values(),
|
||||
),
|
||||
)
|
||||
|
||||
def test_adam_differentiable_betas(self):
|
||||
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
|
||||
|
||||
lr = torch.tensor([0.001], requires_grad=True, dtype=torch.float64)
|
||||
betas = (
|
||||
torch.tensor(0.9, requires_grad=True, dtype=torch.float64),
|
||||
torch.tensor(0.999, requires_grad=True, dtype=torch.float64),
|
||||
)
|
||||
state = {}
|
||||
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
||||
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["max_exp_avg_sq"] = torch.rand(
|
||||
10, requires_grad=True, dtype=torch.float64
|
||||
)
|
||||
|
||||
# Have to pass in beta1 and beta2 separately
|
||||
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck.
|
||||
# In the test, this is called: kwargs.update({betas: (beta1, beta2)})
|
||||
kwargs: dict[str, Any] = {
|
||||
"beta1": betas[0],
|
||||
"beta2": betas[1],
|
||||
"lr": lr,
|
||||
"differentiable": True,
|
||||
}
|
||||
|
||||
gradcheck(
|
||||
_multistep_backprop_diff_hyperparams_fn,
|
||||
(
|
||||
params,
|
||||
grad,
|
||||
state,
|
||||
Adam,
|
||||
kwargs, # includes betas
|
||||
*state.values(),
|
||||
*kwargs.values(),
|
||||
),
|
||||
)
|
||||
|
||||
def test_adam_differentiable_all_hyperparams(self):
|
||||
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
|
||||
|
||||
lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)
|
||||
weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64)
|
||||
betas = (
|
||||
torch.tensor(0.9, requires_grad=True, dtype=torch.float64),
|
||||
torch.tensor(0.999, requires_grad=True, dtype=torch.float64),
|
||||
)
|
||||
state = {}
|
||||
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
||||
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["max_exp_avg_sq"] = torch.rand(
|
||||
10, requires_grad=True, dtype=torch.float64
|
||||
)
|
||||
|
||||
# Have to pass in beta1 and beta2 separately
|
||||
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck.
|
||||
# In the test, this is called: kwargs.update({betas: (beta1, beta2)})
|
||||
kwargs: dict[str, Any] = {
|
||||
"lr": lr,
|
||||
"weight_decay": weight_decay,
|
||||
"beta1": betas[0],
|
||||
"beta2": betas[1],
|
||||
"differentiable": True,
|
||||
}
|
||||
|
||||
gradcheck(
|
||||
_multistep_backprop_diff_hyperparams_fn,
|
||||
(
|
||||
params,
|
||||
grad,
|
||||
state,
|
||||
Adam,
|
||||
kwargs, # includes betas
|
||||
*state.values(),
|
||||
*kwargs.values(),
|
||||
),
|
||||
)
|
||||
|
||||
def test_adamw_differentiable_lr(self):
|
||||
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
|
||||
lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)
|
||||
|
||||
state = {}
|
||||
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
||||
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["max_exp_avg_sq"] = torch.rand(
|
||||
10, requires_grad=True, dtype=torch.float64
|
||||
)
|
||||
kwargs: dict[str, Any] = {"lr": lr, "differentiable": True}
|
||||
|
||||
gradcheck(
|
||||
_multistep_backprop_diff_hyperparams_fn,
|
||||
(
|
||||
params,
|
||||
grad,
|
||||
state,
|
||||
AdamW,
|
||||
kwargs, # includes lr
|
||||
*state.values(),
|
||||
*kwargs.values(),
|
||||
),
|
||||
)
|
||||
|
||||
def test_adamw_differentiable_weight_decay(self):
|
||||
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
|
||||
weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64)
|
||||
|
||||
state = {}
|
||||
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
||||
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["max_exp_avg_sq"] = torch.rand(
|
||||
10, requires_grad=True, dtype=torch.float64
|
||||
)
|
||||
kwargs: dict[str, Any] = {"weight_decay": weight_decay, "differentiable": True}
|
||||
|
||||
gradcheck(
|
||||
_multistep_backprop_diff_hyperparams_fn,
|
||||
(
|
||||
params,
|
||||
grad,
|
||||
state,
|
||||
AdamW,
|
||||
kwargs, # includes weight_decay
|
||||
*state.values(),
|
||||
*kwargs.values(),
|
||||
),
|
||||
)
|
||||
|
||||
def test_adamw_differentiable_betas(self):
|
||||
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
|
||||
|
||||
betas = (
|
||||
torch.tensor(0.9, requires_grad=True, dtype=torch.float64),
|
||||
torch.tensor(0.999, requires_grad=True, dtype=torch.float64),
|
||||
)
|
||||
state = {}
|
||||
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
||||
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["max_exp_avg_sq"] = torch.rand(
|
||||
10, requires_grad=True, dtype=torch.float64
|
||||
)
|
||||
|
||||
# Have to pass in beta1 and beta2 separately
|
||||
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck.
|
||||
# In the test, this is called: kwargs.update({betas: (beta1, beta2)})
|
||||
kwargs: dict[str, Any] = {
|
||||
"beta1": betas[0],
|
||||
"beta2": betas[1],
|
||||
"differentiable": True,
|
||||
}
|
||||
|
||||
gradcheck(
|
||||
_multistep_backprop_diff_hyperparams_fn,
|
||||
(
|
||||
params,
|
||||
grad,
|
||||
state,
|
||||
AdamW,
|
||||
kwargs, # includes betas
|
||||
*state.values(),
|
||||
*kwargs.values(),
|
||||
),
|
||||
)
|
||||
|
||||
def test_adamw_differentiable_all_hyperparams(self):
|
||||
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
|
||||
|
||||
lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64)
|
||||
weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64)
|
||||
betas = (
|
||||
torch.tensor(0.9, requires_grad=True, dtype=torch.float64),
|
||||
torch.tensor(0.999, requires_grad=True, dtype=torch.float64),
|
||||
)
|
||||
state = {}
|
||||
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
||||
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
state["max_exp_avg_sq"] = torch.rand(
|
||||
10, requires_grad=True, dtype=torch.float64
|
||||
)
|
||||
|
||||
# Have to pass in beta1 and beta2 separately
|
||||
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck.
|
||||
# In the test, this is called: kwargs.update({betas: (beta1, beta2)})
|
||||
kwargs: dict[str, Any] = {
|
||||
"lr": lr,
|
||||
"weight_decay": weight_decay,
|
||||
"beta1": betas[0],
|
||||
"beta2": betas[1],
|
||||
"differentiable": True,
|
||||
}
|
||||
|
||||
gradcheck(
|
||||
_multistep_backprop_diff_hyperparams_fn,
|
||||
(
|
||||
params,
|
||||
grad,
|
||||
state,
|
||||
AdamW,
|
||||
kwargs, # includes betas
|
||||
*state.values(),
|
||||
*kwargs.values(),
|
||||
),
|
||||
)
|
||||
|
||||
def test_differentiable_lr(self):
|
||||
params = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
||||
grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64)
|
||||
|
|
|
|||
|
|
@ -2,8 +2,11 @@
|
|||
|
||||
import pickle
|
||||
from io import BytesIO
|
||||
from sys import version_info
|
||||
from textwrap import dedent
|
||||
from unittest import skipIf
|
||||
|
||||
import torch
|
||||
from torch.package import PackageExporter, PackageImporter, sys_importer
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
|
|
@ -265,6 +268,20 @@ class TestSaveLoad(PackageTestCase):
|
|||
exporter.intern("**")
|
||||
exporter.save_module("package_a.use_torch_package_importer")
|
||||
|
||||
@skipIf(version_info >= (3, 13), "https://github.com/pytorch/pytorch/issues/142170")
|
||||
def test_save_load_fp8(self):
|
||||
tensor = torch.rand(20, 20).to(torch.float8_e4m3fn)
|
||||
|
||||
buffer = BytesIO()
|
||||
with PackageExporter(buffer) as exporter:
|
||||
exporter.save_pickle("fp8_model", "model.pkl", tensor)
|
||||
|
||||
buffer.seek(0)
|
||||
|
||||
importer = PackageImporter(buffer)
|
||||
loaded_tensor = importer.load_pickle("fp8_model", "model.pkl")
|
||||
self.assertTrue(torch.equal(tensor, loaded_tensor))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ class TestQuantizationDocs(QuantizationTestCase):
|
|||
"been updated to have the correct relative path between "
|
||||
"test_docs.py and the docs."
|
||||
)
|
||||
pytorch_root = core_dir.parent.parent.parent
|
||||
pytorch_root = core_dir.parents[2]
|
||||
return pytorch_root / path_from_pytorch
|
||||
|
||||
path_to_file = get_correct_path(path_from_pytorch)
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "typing"))
|
|||
REVEAL_DIR = os.path.join(DATA_DIR, "reveal")
|
||||
PASS_DIR = os.path.join(DATA_DIR, "pass")
|
||||
FAIL_DIR = os.path.join(DATA_DIR, "fail")
|
||||
MYPY_INI = os.path.join(DATA_DIR, os.pardir, os.pardir, "mypy.ini")
|
||||
MYPY_INI = os.path.join(os.path.dirname(os.path.dirname(DATA_DIR)), "mypy.ini")
|
||||
CACHE_DIR = os.path.join(DATA_DIR, ".mypy_cache")
|
||||
|
||||
|
||||
|
|
|
|||
2
third_party/xpu.txt
vendored
2
third_party/xpu.txt
vendored
|
|
@ -1 +1 @@
|
|||
7ecb0b1a56b65dec63837a30972a8ba6f8432477
|
||||
214f33b9d969930a18656a82b5c5d8da53cdcb8e
|
||||
|
|
|
|||
|
|
@ -4,15 +4,16 @@
|
|||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
sys.path.append(
|
||||
os.path.realpath(
|
||||
os.path.join(
|
||||
__file__, os.path.pardir, os.path.pardir, os.path.pardir, "torch", "utils"
|
||||
)
|
||||
)
|
||||
)
|
||||
# NOTE: `tools/amd_build/build_amd.py` could be a symlink.
|
||||
# The behavior of `symlink / '..'` is different from `symlink.parent`.
|
||||
# Use `pardir` three times rather than using `path.parents[2]`.
|
||||
REPO_ROOT = (
|
||||
Path(__file__).absolute() / os.path.pardir / os.path.pardir / os.path.pardir
|
||||
).resolve()
|
||||
sys.path.append(str(REPO_ROOT / "torch" / "utils"))
|
||||
|
||||
from hipify import hipify_python # type: ignore[import]
|
||||
|
||||
|
|
@ -53,8 +54,9 @@ parser.add_argument(
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
# NOTE: `tools/amd_build/build_amd.py` could be a symlink.
|
||||
amd_build_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
proj_dir = os.path.join(os.path.dirname(os.path.dirname(amd_build_dir)))
|
||||
proj_dir = os.path.dirname(os.path.dirname(amd_build_dir))
|
||||
|
||||
if args.project_directory:
|
||||
proj_dir = args.project_directory
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
import argparse
|
||||
import sys
|
||||
from os.path import abspath, dirname
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# By appending pytorch_root to sys.path, this module can import other torch
|
||||
# By appending REPO_ROOT to sys.path, this module can import other torch
|
||||
# modules even when run as a standalone script. i.e., it's okay either you
|
||||
# do `python build_libtorch.py` or `python -m tools.build_libtorch`.
|
||||
pytorch_root = dirname(dirname(abspath(__file__)))
|
||||
sys.path.append(pytorch_root)
|
||||
REPO_ROOT = Path(__file__).absolute().parent.parent
|
||||
sys.path.append(str(REPO_ROOT))
|
||||
|
||||
from tools.build_pytorch_libs import build_pytorch
|
||||
from tools.setup_helpers.cmake import CMake
|
||||
|
|
|
|||
|
|
@ -43,9 +43,7 @@ def get_llvm_tool_path() -> str:
|
|||
def get_pytorch_folder() -> str:
|
||||
# TOOLS_FOLDER in oss: pytorch/tools/code_coverage
|
||||
return os.path.abspath(
|
||||
os.environ.get(
|
||||
"PYTORCH_FOLDER", os.path.join(TOOLS_FOLDER, os.path.pardir, os.path.pardir)
|
||||
)
|
||||
os.environ.get("PYTORCH_FOLDER", os.path.dirname(os.path.dirname(TOOLS_FOLDER)))
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,13 +2,12 @@ from __future__ import annotations
|
|||
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# <project folder>
|
||||
HOME_DIR = os.environ["HOME"]
|
||||
TOOLS_FOLDER = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), os.path.pardir, os.path.pardir
|
||||
)
|
||||
TOOLS_FOLDER = str(Path(__file__).resolve().parents[2])
|
||||
|
||||
|
||||
# <profile folder>
|
||||
|
|
|
|||
|
|
@ -10,24 +10,28 @@ import glob
|
|||
import io
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from itertools import product
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
import subprocess
|
||||
import sys
|
||||
import textwrap
|
||||
from dataclasses import dataclass
|
||||
from itertools import product
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from yaml.constructor import ConstructorError
|
||||
from yaml.nodes import MappingNode
|
||||
|
||||
|
||||
try:
|
||||
from yaml import CLoader as Loader
|
||||
except ImportError:
|
||||
from yaml import Loader # type: ignore[assignment, misc]
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).absolute().parent.parent
|
||||
sys.path.append(str(REPO_ROOT))
|
||||
|
||||
CPP_H_NAME = "spv.h"
|
||||
CPP_SRC_NAME = "spv.cpp"
|
||||
|
||||
|
|
|
|||
|
|
@ -26,10 +26,7 @@ try:
|
|||
PYTORCH_ROOT = result.stdout.decode("utf-8").strip()
|
||||
except subprocess.CalledProcessError:
|
||||
# If git is not installed, compute repo root as 3 folders up from this file
|
||||
path_ = os.path.abspath(__file__)
|
||||
for _ in range(4):
|
||||
path_ = os.path.dirname(path_)
|
||||
PYTORCH_ROOT = path_
|
||||
PYTORCH_ROOT = str(Path(__file__).absolute().parents[3])
|
||||
|
||||
DRY_RUN = False
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ def read_sub_write(path: str, prefix_pat: str, new_default: int) -> None:
|
|||
|
||||
|
||||
def main(args: Any) -> None:
|
||||
pytorch_dir = Path(__file__).parent.parent.parent.resolve()
|
||||
pytorch_dir = Path(__file__).parents[2].resolve()
|
||||
onnx_dir = pytorch_dir / "third_party" / "onnx"
|
||||
os.chdir(onnx_dir)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import platform
|
|||
import sys
|
||||
import sysconfig
|
||||
from distutils.version import LooseVersion
|
||||
from pathlib import Path
|
||||
from subprocess import CalledProcessError, check_call, check_output
|
||||
from typing import Any, cast
|
||||
|
||||
|
|
@ -173,9 +174,7 @@ class CMake:
|
|||
toolset_expr = ",".join([f"{k}={v}" for k, v in toolset_dict.items()])
|
||||
args.append("-T" + toolset_expr)
|
||||
|
||||
base_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
base_dir = str(Path(__file__).absolute().parents[2])
|
||||
install_dir = os.path.join(base_dir, "torch")
|
||||
|
||||
_mkdir_p(install_dir)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
# Little stub file to get BUILD.bazel to play along
|
||||
|
||||
import os.path
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.insert(0, root)
|
||||
REPO_ROOT = Path(__file__).absolute().parents[2]
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
import torchgen.gen
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
# Little stub file to get BUILD.bazel to play along
|
||||
|
||||
import os.path
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.insert(0, root)
|
||||
REPO_ROOT = Path(__file__).absolute().parents[2]
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
import tools.jit.gen_unboxing
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ try:
|
|||
except ImportError:
|
||||
from yaml import SafeLoader as YamlLoader # type: ignore[assignment, misc]
|
||||
|
||||
|
||||
NATIVE_FUNCTIONS_PATH = "aten/src/ATen/native/native_functions.yaml"
|
||||
TAGS_PATH = "aten/src/ATen/native/tags.yaml"
|
||||
|
||||
|
|
@ -110,8 +111,9 @@ def get_selector(
|
|||
operators_yaml_path: str | None,
|
||||
) -> Any:
|
||||
# cwrap depends on pyyaml, so we can't import it earlier
|
||||
root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.insert(0, root)
|
||||
REPO_ROOT = Path(__file__).absolute().parents[2]
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
assert not (
|
||||
|
|
|
|||
|
|
@ -2,8 +2,9 @@ import sys
|
|||
from pathlib import Path
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
sys.path.append(str(REPO_ROOT))
|
||||
|
||||
from tools.stats.import_test_stats import get_test_class_times, get_test_times
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from typing import Any, Callable, cast, Dict
|
|||
from urllib.request import urlopen
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
|
||||
|
||||
def get_disabled_issues() -> list[str]:
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from typing import Any
|
|||
from unittest import mock
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent
|
||||
REPO_ROOT = Path(__file__).resolve().parents[3]
|
||||
sys.path.append(str(REPO_ROOT))
|
||||
|
||||
from tools.test.heuristics.test_interface import TestTD
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from pathlib import Path
|
|||
from typing import Any
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent
|
||||
REPO_ROOT = Path(__file__).resolve().parents[3]
|
||||
sys.path.append(str(REPO_ROOT))
|
||||
|
||||
import tools.testing.target_determination.heuristics.interface as interface
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from pathlib import Path
|
|||
from typing import Any
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent
|
||||
REPO_ROOT = Path(__file__).resolve().parents[3]
|
||||
sys.path.append(str(REPO_ROOT))
|
||||
|
||||
import tools.testing.target_determination.heuristics.utils as utils
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
|
|
@ -12,10 +11,6 @@ from torchgen.gen import _GLOBAL_PARSE_NATIVE_YAML_CACHE # noqa: F401
|
|||
from torchgen.gen_backend_stubs import run
|
||||
|
||||
|
||||
path = os.path.dirname(os.path.realpath(__file__))
|
||||
gen_backend_stubs_path = os.path.join(path, "../torchgen/gen_backend_stubs.py")
|
||||
|
||||
|
||||
# gen_backend_stubs.py is an integration point that is called directly by external backends.
|
||||
# The tests here are to confirm that badly formed inputs result in reasonable error messages.
|
||||
class TestGenBackendStubs(expecttest.TestCase):
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import unittest
|
|||
from pathlib import Path
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
try:
|
||||
# using tools/ to optimize test run.
|
||||
sys.path.append(str(REPO_ROOT))
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from collections import defaultdict
|
|||
from pathlib import Path
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
try:
|
||||
# using tools/ to optimize test run.
|
||||
sys.path.append(str(REPO_ROOT))
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from typing import Any
|
|||
from unittest import mock
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
from tools.stats.upload_metrics import add_global_metric, emit_metric, global_metrics
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from pathlib import Path
|
|||
CPP_TEST_PREFIX = "cpp"
|
||||
CPP_TEST_PATH = "build/bin"
|
||||
CPP_TESTS_DIR = os.path.abspath(os.getenv("CPP_TESTS_DIR", default=CPP_TEST_PATH))
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
|
||||
|
||||
def parse_test_module(test: str) -> str:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import sys
|
|||
from pathlib import Path
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
from tools.stats.import_test_stats import (
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from typing import Any
|
|||
import yaml
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent
|
||||
REPO_ROOT = Path(__file__).parents[2]
|
||||
CONFIG_YML = REPO_ROOT / ".circleci" / "config.yml"
|
||||
WORKFLOWS_DIR = REPO_ROOT / ".github" / "workflows"
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from pathlib import Path
|
|||
from typing import Any
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
|
||||
# These tests are slow enough that it's worth calculating whether the patch
|
||||
# touched any related files first. This list was manually generated, but for every
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from pathlib import Path
|
|||
from typing import Any
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent
|
||||
REPO_ROOT = Path(__file__).resolve().parents[3]
|
||||
|
||||
|
||||
def gen_ci_artifact(included: list[Any], excluded: list[Any]) -> None:
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from tools.testing.target_determination.heuristics.utils import (
|
|||
from tools.testing.test_run import TestRun
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
REPO_ROOT = Path(__file__).parents[3]
|
||||
|
||||
keyword_synonyms: dict[str, list[str]] = {
|
||||
"amp": ["mixed_precision"],
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from tools.testing.target_determination.heuristics.utils import normalize_rating
|
|||
from tools.testing.test_run import TestRun
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent.parent
|
||||
REPO_ROOT = Path(__file__).resolve().parents[4]
|
||||
|
||||
|
||||
class LLM(HeuristicInterface):
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from tools.testing.target_determination.heuristics.utils import (
|
|||
from tools.testing.test_run import TestRun
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent.parent
|
||||
REPO_ROOT = Path(__file__).resolve().parents[4]
|
||||
|
||||
|
||||
class PreviouslyFailedInPR(HeuristicInterface):
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ from warnings import warn
|
|||
if TYPE_CHECKING:
|
||||
from tools.testing.test_run import TestRun
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent.parent
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[4]
|
||||
|
||||
|
||||
def python_test_file_to_test_name(tests: set[str]) -> set[str]:
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
|||
from collections.abc import Sequence
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
|
||||
IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"
|
||||
BUILD_ENVIRONMENT = os.getenv("BUILD_ENVIRONMENT", "")
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@
|
|||
|
||||
import functools
|
||||
import warnings
|
||||
from typing import Any, Callable, List, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import deprecated
|
||||
from typing import Any, Callable, List, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import deprecated, ParamSpec
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
|
@ -14,6 +14,9 @@ try:
|
|||
except ModuleNotFoundError:
|
||||
np = None # type: ignore[assignment]
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# TorchScript does not support `@deprecated`
|
||||
# This is a workaround to avoid breaking TorchScript
|
||||
|
|
@ -35,13 +38,13 @@ else:
|
|||
return torch.compiler.is_compiling()
|
||||
|
||||
|
||||
def wrap_inline(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def wrap_inline(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
"""
|
||||
Create an extra frame around fn that is not in skipfiles.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def inner(*args: Any, **kwargs: Any) -> Any:
|
||||
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
|
@ -61,7 +64,7 @@ def call_hook(
|
|||
return result
|
||||
|
||||
|
||||
def wrap_numpy(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def wrap_numpy(f: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function
|
||||
from ``torch.Tensor``s to ``torch.Tensor``s.
|
||||
"""
|
||||
|
|
@ -69,7 +72,7 @@ def wrap_numpy(f: Callable[..., Any]) -> Callable[..., Any]:
|
|||
return f
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrap(*args: Any, **kwargs: Any) -> Any:
|
||||
def wrap(*args: _P.args, **kwargs: _P.kwargs) -> pytree.PyTree:
|
||||
args, kwargs = pytree.tree_map_only(
|
||||
torch.Tensor, lambda x: x.numpy(), (args, kwargs)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from typing import (
|
|||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import ParamSpec
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
|
@ -51,6 +52,8 @@ three = 3
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
def clone_me(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
||||
if x is None:
|
||||
|
|
@ -407,9 +410,9 @@ def check_dynamic_shape_capture() -> bool:
|
|||
return not config.assume_static_by_default
|
||||
|
||||
|
||||
def _make_fn_with_patches(fn: Callable[..., _T], *patches: Any) -> Callable[..., _T]:
|
||||
def _make_fn_with_patches(fn: Callable[_P, _T], *patches: Any) -> Callable[_P, _T]:
|
||||
@functools.wraps(fn)
|
||||
def _fn(*args: Any, **kwargs: Any) -> _T:
|
||||
def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||
with contextlib.ExitStack() as stack:
|
||||
for module, attr, val in patches:
|
||||
stack.enter_context(patch.object(module, attr, val))
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import functools
|
|||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
|
|
@ -51,15 +52,13 @@ def _reload_python_module(key, path):
|
|||
def _set_triton_ptxas_path() -> None:
|
||||
if os.environ.get("TRITON_PTXAS_PATH") is not None:
|
||||
return
|
||||
ptxas_path = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "bin", "ptxas")
|
||||
)
|
||||
if not os.path.exists(ptxas_path):
|
||||
ptxas = Path(__file__).absolute().parents[1] / "bin" / "ptxas"
|
||||
if not ptxas.exists():
|
||||
return
|
||||
if os.path.isfile(ptxas_path) and os.access(ptxas_path, os.X_OK):
|
||||
os.environ["TRITON_PTXAS_PATH"] = ptxas_path
|
||||
if ptxas.is_file() and os.access(ptxas, os.X_OK):
|
||||
os.environ["TRITON_PTXAS_PATH"] = str(ptxas)
|
||||
else:
|
||||
warnings.warn(f"{ptxas_path} exists but is not an executable")
|
||||
warnings.warn(f"{ptxas} exists but is not an executable")
|
||||
|
||||
|
||||
def _worker_compile_triton(load_kernel: Callable[[], Any], extra_env: Dict[str, str]):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@ import subprocess
|
|||
import time
|
||||
from threading import Lock
|
||||
from timeit import default_timer as timer
|
||||
from typing import Any, List, Optional, Sequence
|
||||
from typing import Any, Callable, List, Optional, Sequence, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
|
||||
logger = logging.getLogger("strobelight_function_profiler")
|
||||
|
|
@ -23,6 +24,9 @@ logger.addHandler(console_handler)
|
|||
logger.setLevel(logging.INFO)
|
||||
logger.propagate = False
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
class StrobelightCLIProfilerError(Exception):
|
||||
"""
|
||||
|
|
@ -250,7 +254,9 @@ class StrobelightCLIFunctionProfiler:
|
|||
self._stop_strobelight_no_throw(collect_results=False)
|
||||
return False
|
||||
|
||||
def profile(self, work_function: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
def profile(
|
||||
self, work_function: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs
|
||||
) -> Optional[_R]:
|
||||
self.current_run_id = None
|
||||
self.profile_result = None
|
||||
|
||||
|
|
@ -288,6 +294,7 @@ class StrobelightCLIFunctionProfiler:
|
|||
self._stop_strobelight_no_throw(collect_results=False)
|
||||
StrobelightCLIFunctionProfiler._lock.release()
|
||||
raise error
|
||||
return None
|
||||
|
||||
|
||||
# A function decorator that wraps profile, if no profiler is provided one with
|
||||
|
|
@ -297,13 +304,15 @@ class StrobelightCLIFunctionProfiler:
|
|||
# @strobelight(stop_at_error=True,...)
|
||||
def strobelight(
|
||||
profiler: Optional[StrobelightCLIFunctionProfiler] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
) -> Callable[[Callable[_P, _R]], Callable[_P, Optional[_R]]]:
|
||||
if not profiler:
|
||||
profiler = StrobelightCLIFunctionProfiler(**kwargs)
|
||||
|
||||
def strobelight_inner(work_function: Any) -> Any:
|
||||
def strobelight_inner(
|
||||
work_function: Callable[_P, _R]
|
||||
) -> Callable[_P, Optional[_R]]:
|
||||
@functools.wraps(work_function)
|
||||
def wrapper_function(*args: Any, **kwargs: Any) -> Any:
|
||||
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:
|
||||
return profiler.profile(work_function, *args, **kwargs)
|
||||
|
||||
return wrapper_function
|
||||
|
|
|
|||
|
|
@ -542,7 +542,7 @@ class FakeTensorConfig:
|
|||
#
|
||||
# Making this a descriptor may seem overly fancy, but actually it's the most
|
||||
# convenient way to ensure access to FakeTensor during access, which is
|
||||
# required for testing version counter and epoch validity.
|
||||
# required for testing version counter and epoch validity.
|
||||
class SymNumberMemoDescriptor:
|
||||
_name: str
|
||||
|
||||
|
|
@ -763,7 +763,7 @@ class FakeTensor(Tensor):
|
|||
|
||||
@classmethod
|
||||
@count
|
||||
def __torch_dispatch__(
|
||||
def __torch_dispatch__( # type: ignore[override] # TODO
|
||||
cls,
|
||||
func: OpOverload,
|
||||
types: Sequence[Type],
|
||||
|
|
|
|||
|
|
@ -1,7 +1,19 @@
|
|||
from copy import deepcopy
|
||||
from datetime import timedelta
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import ParamSpec, TypeVarTuple, Unpack
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -26,6 +38,10 @@ _TOTAL_KEY = "Total"
|
|||
|
||||
__all__ = ["FSDPMemTracker"]
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
_Ts = TypeVarTuple("_Ts")
|
||||
|
||||
|
||||
class _FSDPRefType(_RefType):
|
||||
"""
|
||||
|
|
@ -185,8 +201,8 @@ class FSDPMemTracker(MemTracker):
|
|||
def _fsdp_state_pre_forward(
|
||||
self,
|
||||
fsdp_mod: FSDPModule,
|
||||
orig_fsdp_state_pre_fw: Callable,
|
||||
) -> Callable:
|
||||
orig_fsdp_state_pre_fw: Callable[_P, Tuple[Tuple[Unpack[_Ts]], Dict[str, Any]]],
|
||||
) -> Callable[_P, Tuple[Tuple[Unpack[_Ts]], Dict[str, Any]]]:
|
||||
# We capture memory snapshots before and after ``FSDPState._pre_forward`` to attribute the `unsharded` params
|
||||
# and `all_gather` buffers. There are three cases:
|
||||
# Case 1: If the module is not in the ``memory_tracking`` dictionary, create a new ``_FSDPModMemStats``
|
||||
|
|
@ -201,7 +217,9 @@ class FSDPMemTracker(MemTracker):
|
|||
# For Case 1 and 3, we also initialiaze the ``local_peak`` and ``PEAK_FW`` snapshot for the module.
|
||||
# For Case 2 we only capture 1 snapshot after ``FSDPState._pre_forward`` runs because it is a no-op.
|
||||
@wraps(orig_fsdp_state_pre_fw)
|
||||
def inner(*args: Any, **kwargs: Any) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
|
||||
def inner(
|
||||
*args: _P.args, **kwargs: _P.kwargs
|
||||
) -> Tuple[Tuple[Unpack[_Ts]], Dict[str, Any]]:
|
||||
mod_fqn = self._mod_tracker.get_known_fqn(fsdp_mod)
|
||||
assert mod_fqn is not None
|
||||
if fsdp_mod not in self.memory_tracking:
|
||||
|
|
@ -251,15 +269,15 @@ class FSDPMemTracker(MemTracker):
|
|||
def _fsdp_state_post_forward(
|
||||
self,
|
||||
fsdp_mod: FSDPModule,
|
||||
orig_fsdp_state_post_fw: Callable,
|
||||
) -> Callable:
|
||||
orig_fsdp_state_post_fw: Callable[_P, _R],
|
||||
) -> Callable[_P, _R]:
|
||||
# We capture memory snapshots before and after ``FSDPState._post_forward`` to capture the resharded state
|
||||
# if ``reshard_after_forward`` is not ``False``. There are two cases:
|
||||
# Case 1: This is called in backward, which means we are in the AC region. If this is the top most module
|
||||
# in the AC region, we set the flag ``_in_ac`` to False.
|
||||
# Case 2: This is called in forward.
|
||||
@wraps(orig_fsdp_state_post_fw)
|
||||
def inner(*args: Any, **kwargs: Any) -> Any:
|
||||
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
mod_stat = self.memory_tracking[fsdp_mod]
|
||||
if self._mod_tracker.is_bw:
|
||||
state = _FSDPModState.POST_FW_AC
|
||||
|
|
@ -283,12 +301,12 @@ class FSDPMemTracker(MemTracker):
|
|||
def _fsdp_param_group_pre_backward(
|
||||
self,
|
||||
fsdp_mod: FSDPModule,
|
||||
orig_fsdp_param_group_pre_backward: Callable,
|
||||
) -> Callable:
|
||||
orig_fsdp_param_group_pre_backward: Callable[_P, Any],
|
||||
) -> Callable[_P, None]:
|
||||
# We capture memory snapshots before and after ``FSDPParamGroup.pre_backward`` to capture the pre-fetching
|
||||
# and unsharding of params. We also initialize ``local_peak`` and ``PEAK_BW`` snapshot for the module.
|
||||
@wraps(orig_fsdp_param_group_pre_backward)
|
||||
def inner(*args: Any, **kwargs: Any) -> None:
|
||||
def inner(*args: _P.args, **kwargs: _P.kwargs) -> None:
|
||||
mod_stat = self.memory_tracking[fsdp_mod]
|
||||
snapshot = self.get_tracker_snapshot()
|
||||
mod_stat.local_peak = {
|
||||
|
|
@ -309,13 +327,13 @@ class FSDPMemTracker(MemTracker):
|
|||
def _fsdp_param_group_post_backward(
|
||||
self,
|
||||
fsdp_mod: FSDPModule,
|
||||
orig_fsdp_param_group_post_backward: Callable,
|
||||
) -> Callable:
|
||||
orig_fsdp_param_group_post_backward: Callable[_P, Any],
|
||||
) -> Callable[_P, None]:
|
||||
# We capture the memory snapshots before and after ``FSDPParamGroup.post_backward`` to track and attribute
|
||||
# the `unsharded` grads before the post backward and then `sharded` grads and `reduce_scatter` buffers
|
||||
# after the post backward.
|
||||
@wraps(orig_fsdp_param_group_post_backward)
|
||||
def inner(*args: Any, **kwargs: Any) -> None:
|
||||
def inner(*args: _P.args, **kwargs: _P.kwargs) -> None:
|
||||
fsdp_state = fsdp_mod._get_fsdp_state()
|
||||
if fsdp_param_group := fsdp_state._fsdp_param_group:
|
||||
for fsdp_param in fsdp_param_group.fsdp_params:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
"""Utility to lazily import modules."""
|
||||
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
|
|
@ -17,7 +16,7 @@ class _LazyModule:
|
|||
def __repr__(self) -> str:
|
||||
return f"<lazy module '{self._name}'>"
|
||||
|
||||
def __getattr__(self, attr):
|
||||
def __getattr__(self, attr: str) -> object:
|
||||
if self._module is None:
|
||||
self._module = importlib.import_module(".", self._name)
|
||||
return getattr(self._module, attr)
|
||||
|
|
|
|||
|
|
@ -402,7 +402,14 @@ def _single_tensor_adam(
|
|||
# Perform stepweight decay
|
||||
param.mul_(1 - lr * weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
# Nested if is necessary to bypass jitscript rules
|
||||
if differentiable and isinstance(weight_decay, Tensor):
|
||||
if weight_decay.requires_grad:
|
||||
grad = grad.addcmul_(param.clone(), weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
|
||||
if torch.is_complex(param):
|
||||
grad = torch.view_as_real(grad)
|
||||
|
|
@ -429,13 +436,43 @@ def _single_tensor_adam(
|
|||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.lerp_(grad, 1 - device_beta1)
|
||||
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
||||
# Nested if is necessary to bypass jitscript rules
|
||||
if differentiable and isinstance(beta2, Tensor):
|
||||
if beta2.requires_grad:
|
||||
# Using lerp to only use 2 operations bc addcmul's value cannot be a tensor
|
||||
# Showing equivalence of differentiable path and nondifferentiable path
|
||||
# expavg * b2 + grad^2 * (1-b2)
|
||||
# add expavg * (1-b2) - expavg * (1-b2) = 0
|
||||
# expavg * b2 + expavg * (1-b2) - expavg * (1-b2) + grad^2 * (1-b2)
|
||||
# expavg - expavg * (1-b2) + grad^2 * (1-b2)
|
||||
# expavg + (grad^2 - expavg) * (1-b2)
|
||||
# expavg.lerp(grad^2, 1-beta2)
|
||||
exp_avg_sq.lerp_(torch.square(grad), weight=1 - beta2)
|
||||
else:
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
else:
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
|
||||
if capturable or differentiable:
|
||||
step = step_t
|
||||
|
||||
bias_correction1 = 1 - beta1**step
|
||||
bias_correction2 = 1 - beta2**step
|
||||
# Nested if is necessary to bypass jitscript rules
|
||||
if differentiable and isinstance(beta1, Tensor):
|
||||
if beta1.requires_grad:
|
||||
bias_correction1 = 1 - beta1 ** step.clone()
|
||||
else:
|
||||
bias_correction1 = 1 - beta1**step
|
||||
else:
|
||||
bias_correction1 = 1 - beta1**step
|
||||
|
||||
# Nested if is necessary to bypass jitscript rules
|
||||
if differentiable and isinstance(beta2, Tensor):
|
||||
if beta2.requires_grad:
|
||||
bias_correction2 = 1 - beta2 ** step.clone()
|
||||
else:
|
||||
bias_correction2 = 1 - beta2**step
|
||||
else:
|
||||
bias_correction2 = 1 - beta2**step
|
||||
|
||||
step_size = lr / bias_correction1
|
||||
step_size_neg = step_size.neg()
|
||||
|
|
@ -462,7 +499,10 @@ def _single_tensor_adam(
|
|||
exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)
|
||||
).add_(eps / step_size_neg)
|
||||
|
||||
param.addcdiv_(exp_avg, denom)
|
||||
if differentiable:
|
||||
param.addcdiv_(exp_avg.clone(), denom)
|
||||
else:
|
||||
param.addcdiv_(exp_avg, denom)
|
||||
else:
|
||||
step = _get_value(step_t)
|
||||
|
||||
|
|
|
|||
|
|
@ -260,7 +260,10 @@ class PackageImporter(Importer):
|
|||
|
||||
if typename == "storage":
|
||||
storage_type, key, location, size = data
|
||||
dtype = storage_type.dtype
|
||||
if storage_type is torch.UntypedStorage:
|
||||
dtype = torch.uint8
|
||||
else:
|
||||
dtype = storage_type.dtype
|
||||
|
||||
if key not in loaded_storages:
|
||||
load_tensor(
|
||||
|
|
|
|||
|
|
@ -5008,7 +5008,7 @@ def find_library_location(lib_name: str) -> Path:
|
|||
path = torch_root / 'lib' / lib_name
|
||||
if os.path.exists(path):
|
||||
return path
|
||||
torch_root = Path(__file__).resolve().parent.parent.parent
|
||||
torch_root = Path(__file__).resolve().parents[2]
|
||||
return torch_root / 'build' / 'lib' / lib_name
|
||||
|
||||
def skip_but_pass_in_sandcastle(reason):
|
||||
|
|
|
|||
|
|
@ -3,17 +3,23 @@
|
|||
# AND SCRUB AWAY TORCH NOTIONS THERE.
|
||||
import collections
|
||||
import functools
|
||||
from typing import Any, Callable, OrderedDict
|
||||
from typing import Callable, OrderedDict, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
|
||||
simple_call_counter: OrderedDict[str, int] = collections.OrderedDict()
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
def count_label(label: str) -> None:
|
||||
prev = simple_call_counter.setdefault(label, 0)
|
||||
simple_call_counter[label] = prev + 1
|
||||
|
||||
def count(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def count(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
if fn.__qualname__ not in simple_call_counter:
|
||||
simple_call_counter[fn.__qualname__] = 0
|
||||
simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1
|
||||
|
|
|
|||
|
|
@ -7,7 +7,8 @@ import re
|
|||
import subprocess
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Any, List, Optional, Sequence
|
||||
from typing import Any, Callable, List, Optional, Sequence, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
|
||||
logger = logging.getLogger("strobelight_function_profiler")
|
||||
|
|
@ -22,6 +23,9 @@ logger.addHandler(console_handler)
|
|||
logger.setLevel(logging.INFO)
|
||||
logger.propagate = False
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
class StrobelightCLIProfilerError(Exception):
|
||||
"""
|
||||
|
|
@ -246,7 +250,9 @@ class StrobelightCLIFunctionProfiler:
|
|||
self._stop_strobelight_no_throw(collect_results=False)
|
||||
return False
|
||||
|
||||
def profile(self, work_function: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
def profile(
|
||||
self, work_function: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs
|
||||
) -> Optional[_R]:
|
||||
self.current_run_id = None
|
||||
|
||||
if locked := StrobelightCLIFunctionProfiler._lock.acquire(False):
|
||||
|
|
@ -279,6 +285,7 @@ class StrobelightCLIFunctionProfiler:
|
|||
self._stop_strobelight_no_throw(collect_results=False)
|
||||
StrobelightCLIFunctionProfiler._lock.release()
|
||||
raise error
|
||||
return None
|
||||
|
||||
|
||||
# A function decorator that wraps profile, if no profiler is provided one with
|
||||
|
|
@ -288,13 +295,15 @@ class StrobelightCLIFunctionProfiler:
|
|||
# @strobelight(stop_at_error=True,...)
|
||||
def strobelight(
|
||||
profiler: Optional[StrobelightCLIFunctionProfiler] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
) -> Callable[[Callable[_P, _R]], Callable[_P, Optional[_R]]]:
|
||||
if not profiler:
|
||||
profiler = StrobelightCLIFunctionProfiler(**kwargs)
|
||||
|
||||
def strobelight_inner(work_function: Any) -> Any:
|
||||
def strobelight_inner(
|
||||
work_function: Callable[_P, _R]
|
||||
) -> Callable[_P, Optional[_R]]:
|
||||
@functools.wraps(work_function)
|
||||
def wrapper_function(*args: Any, **kwargs: Any) -> Any:
|
||||
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:
|
||||
return profiler.profile(work_function, *args, **kwargs)
|
||||
|
||||
return wrapper_function
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import math
|
|||
import operator
|
||||
import sys
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Iterable,
|
||||
List,
|
||||
|
|
@ -14,6 +13,7 @@ from typing import (
|
|||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import TypeVarTuple, Unpack
|
||||
|
||||
import sympy
|
||||
from sympy import S
|
||||
|
|
@ -32,6 +32,7 @@ from .numbers import int_oo
|
|||
|
||||
|
||||
_T = TypeVar("_T", bound=SupportsFloat)
|
||||
_Ts = TypeVarTuple("_Ts")
|
||||
|
||||
# Portions of this file are adapted from the Sympy codebase, which was
|
||||
# licensed as follows:
|
||||
|
|
@ -101,9 +102,11 @@ def _is_symbols_binary_summation(expr: sympy.Expr) -> bool:
|
|||
)
|
||||
|
||||
|
||||
def _keep_float(f: Callable[..., _T]) -> Callable[..., Union[_T, sympy.Float]]:
|
||||
def _keep_float(
|
||||
f: Callable[[Unpack[_Ts]], _T]
|
||||
) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]:
|
||||
@functools.wraps(f)
|
||||
def inner(*args: Any) -> Union[_T, sympy.Float]:
|
||||
def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]:
|
||||
r: Union[_T, sympy.Float] = f(*args)
|
||||
if any(isinstance(a, sympy.Float) for a in args) and not isinstance(
|
||||
r, sympy.Float
|
||||
|
|
|
|||
|
|
@ -1,13 +1,12 @@
|
|||
# mypy: ignore-errors
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).absolute().parents[1]))
|
||||
|
||||
from benchmark_runner import BenchmarkRunner # type: ignore[import-not-found]
|
||||
from benchmark_utils import ( # type: ignore[import-not-found]
|
||||
fits_in_memory,
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from pathlib import Path
|
||||
|
||||
from expecttest import TestCase
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).absolute().parents[1]))
|
||||
|
||||
from test_utils import read_file_to_string, run_bash # type: ignore[import-not-found]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
# mypy: ignore-errors
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(str(Path(__file__).absolute().parents[1]))
|
||||
|
||||
from train_decision import AHTrainDecisionTree
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,12 @@
|
|||
import itertools
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).absolute().parents[1]))
|
||||
|
||||
from benchmark_runner import BenchmarkRunner # type: ignore[import-not-found]
|
||||
from benchmark_utils import ( # type: ignore[import-not-found]
|
||||
fits_in_memory,
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
# mypy: ignore-errors
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd # type: ignore[import-untyped]
|
||||
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(str(Path(__file__).absolute().parents[1]))
|
||||
|
||||
from train_decision import AHTrainDecisionTree
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).absolute().parents[1]))
|
||||
|
||||
from benchmark_runner import BenchmarkRunner # type: ignore[import-not-found]
|
||||
from benchmark_utils import ( # type: ignore[import-not-found]
|
||||
fits_in_memory,
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from pathlib import Path
|
||||
|
||||
from expecttest import TestCase
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).absolute().parents[1]))
|
||||
|
||||
from test_utils import read_file_to_string, run_bash # type: ignore[import-not-found]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
# mypy: ignore-errors
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(str(Path(__file__).absolute().parents[1]))
|
||||
|
||||
from train_decision import AHTrainDecisionTree
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
# mypy: ignore-errors
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(str(Path(__file__).absolute().parents[1]))
|
||||
|
||||
from train_regression import AHTrainRegressionTree
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
# mypy: ignore-errors
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(str(Path(__file__).absolute().parents[1]))
|
||||
|
||||
from train_regression import AHTrainRegressionTree
|
||||
|
||||
|
|
|
|||
|
|
@ -533,7 +533,7 @@ def run(
|
|||
source_yaml: str, output_dir: str, dry_run: bool, impl_path: str | None = None
|
||||
) -> None:
|
||||
# Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
|
||||
pytorch_root = Path(__file__).parent.parent.absolute()
|
||||
pytorch_root = Path(__file__).absolute().parent.parent
|
||||
template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")
|
||||
|
||||
def make_file_manager(install_dir: str) -> FileManager:
|
||||
|
|
|
|||
|
|
@ -256,7 +256,7 @@ def main() -> None:
|
|||
options = parser.parse_args()
|
||||
|
||||
# Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
|
||||
torch_root = Path(__file__).parent.parent.parent.absolute()
|
||||
torch_root = Path(__file__).absolute().parents[2]
|
||||
aten_path = str(torch_root / "aten" / "src" / "ATen")
|
||||
lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator
|
||||
if options.gen_ts_lowerings:
|
||||
|
|
|
|||
Loading…
Reference in a new issue