[Bug Fix] Include python training apis when enable_training is enabled (#14485)

This commit is contained in:
Baiju Meswani 2023-01-31 17:17:26 -08:00 committed by GitHub
parent d1533c27eb
commit d06ad9462b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 31 deletions

View file

@ -520,41 +520,48 @@ classifiers = [
if not enable_training:
classifiers.extend(["Operating System :: Microsoft :: Windows", "Operating System :: MacOS"])
if enable_training:
if enable_training or enable_training_apis:
packages.append("onnxruntime.training")
if enable_training:
packages.extend(
[
"onnxruntime.training.amp",
"onnxruntime.training.experimental",
"onnxruntime.training.experimental.gradient_graph",
"onnxruntime.training.optim",
"onnxruntime.training.torchdynamo",
"onnxruntime.training.ortmodule",
"onnxruntime.training.ortmodule.experimental",
"onnxruntime.training.ortmodule.experimental.json_config",
"onnxruntime.training.ortmodule.experimental.hierarchical_ortmodule",
"onnxruntime.training.ortmodule.torch_cpp_extensions",
"onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor",
"onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils",
"onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator",
"onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops",
"onnxruntime.training.utils.data",
]
)
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor"] = ["*.cc"]
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils"] = ["*.cc"]
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator"] = ["*.cc"]
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops"] = [
"*.cpp",
"*.cu",
"*.cuh",
"*.h",
]
packages.extend(
[
"onnxruntime.training",
"onnxruntime.training.amp",
"onnxruntime.training.experimental",
"onnxruntime.training.experimental.gradient_graph",
"onnxruntime.training.optim",
"onnxruntime.training.torchdynamo",
"onnxruntime.training.ortmodule",
"onnxruntime.training.ortmodule.experimental",
"onnxruntime.training.ortmodule.experimental.json_config",
"onnxruntime.training.ortmodule.experimental.hierarchical_ortmodule",
"onnxruntime.training.ortmodule.torch_cpp_extensions",
"onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor",
"onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils",
"onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator",
"onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops",
"onnxruntime.training.utils.data",
"onnxruntime.training.api",
"onnxruntime.training.onnxblock",
"onnxruntime.training.onnxblock.loss",
"onnxruntime.training.onnxblock.optim",
]
)
if enable_training_apis:
packages.append("onnxruntime.training.api")
packages.append("onnxruntime.training.onnxblock")
packages.append("onnxruntime.training.onnxblock.loss")
packages.append("onnxruntime.training.onnxblock.optim")
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor"] = ["*.cc"]
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils"] = ["*.cc"]
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator"] = ["*.cc"]
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops"] = [
"*.cpp",
"*.cu",
"*.cuh",
"*.h",
]
requirements_file = "requirements-training.txt"
# with training, we want to follow this naming convention:
# stable:

View file

@ -2375,6 +2375,11 @@ def main():
if args.use_gdk:
args.test = False
# enable_training is a higher level flag that enables all training functionality.
if args.enable_training:
args.enable_training_apis = True
args.enable_training_ops = True
configs = set(args.config)
# setup paths and directories