mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
[Bug Fix] Include python training apis when enable_training is enabled (#14485)
This commit is contained in:
parent
d1533c27eb
commit
d06ad9462b
2 changed files with 43 additions and 31 deletions
69
setup.py
69
setup.py
|
|
@ -520,41 +520,48 @@ classifiers = [
|
|||
if not enable_training:
|
||||
classifiers.extend(["Operating System :: Microsoft :: Windows", "Operating System :: MacOS"])
|
||||
|
||||
if enable_training:
|
||||
if enable_training or enable_training_apis:
|
||||
packages.append("onnxruntime.training")
|
||||
if enable_training:
|
||||
packages.extend(
|
||||
[
|
||||
"onnxruntime.training.amp",
|
||||
"onnxruntime.training.experimental",
|
||||
"onnxruntime.training.experimental.gradient_graph",
|
||||
"onnxruntime.training.optim",
|
||||
"onnxruntime.training.torchdynamo",
|
||||
"onnxruntime.training.ortmodule",
|
||||
"onnxruntime.training.ortmodule.experimental",
|
||||
"onnxruntime.training.ortmodule.experimental.json_config",
|
||||
"onnxruntime.training.ortmodule.experimental.hierarchical_ortmodule",
|
||||
"onnxruntime.training.ortmodule.torch_cpp_extensions",
|
||||
"onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor",
|
||||
"onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils",
|
||||
"onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator",
|
||||
"onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops",
|
||||
"onnxruntime.training.utils.data",
|
||||
]
|
||||
)
|
||||
|
||||
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor"] = ["*.cc"]
|
||||
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils"] = ["*.cc"]
|
||||
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator"] = ["*.cc"]
|
||||
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops"] = [
|
||||
"*.cpp",
|
||||
"*.cu",
|
||||
"*.cuh",
|
||||
"*.h",
|
||||
]
|
||||
|
||||
packages.extend(
|
||||
[
|
||||
"onnxruntime.training",
|
||||
"onnxruntime.training.amp",
|
||||
"onnxruntime.training.experimental",
|
||||
"onnxruntime.training.experimental.gradient_graph",
|
||||
"onnxruntime.training.optim",
|
||||
"onnxruntime.training.torchdynamo",
|
||||
"onnxruntime.training.ortmodule",
|
||||
"onnxruntime.training.ortmodule.experimental",
|
||||
"onnxruntime.training.ortmodule.experimental.json_config",
|
||||
"onnxruntime.training.ortmodule.experimental.hierarchical_ortmodule",
|
||||
"onnxruntime.training.ortmodule.torch_cpp_extensions",
|
||||
"onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor",
|
||||
"onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils",
|
||||
"onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator",
|
||||
"onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops",
|
||||
"onnxruntime.training.utils.data",
|
||||
"onnxruntime.training.api",
|
||||
"onnxruntime.training.onnxblock",
|
||||
"onnxruntime.training.onnxblock.loss",
|
||||
"onnxruntime.training.onnxblock.optim",
|
||||
]
|
||||
)
|
||||
if enable_training_apis:
|
||||
packages.append("onnxruntime.training.api")
|
||||
packages.append("onnxruntime.training.onnxblock")
|
||||
packages.append("onnxruntime.training.onnxblock.loss")
|
||||
packages.append("onnxruntime.training.onnxblock.optim")
|
||||
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor"] = ["*.cc"]
|
||||
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils"] = ["*.cc"]
|
||||
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator"] = ["*.cc"]
|
||||
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops"] = [
|
||||
"*.cpp",
|
||||
"*.cu",
|
||||
"*.cuh",
|
||||
"*.h",
|
||||
]
|
||||
|
||||
requirements_file = "requirements-training.txt"
|
||||
# with training, we want to follow this naming convention:
|
||||
# stable:
|
||||
|
|
|
|||
|
|
@ -2375,6 +2375,11 @@ def main():
|
|||
if args.use_gdk:
|
||||
args.test = False
|
||||
|
||||
# enable_training is a higher level flag that enables all training functionality.
|
||||
if args.enable_training:
|
||||
args.enable_training_apis = True
|
||||
args.enable_training_ops = True
|
||||
|
||||
configs = set(args.config)
|
||||
|
||||
# setup paths and directories
|
||||
|
|
|
|||
Loading…
Reference in a new issue