onnxruntime/orttraining/tools/scripts/watch_experiment.py
Justin Chu fdce4fa6af
Format all python files under onnxruntime with black and isort (#11324)
Description: Format all python files under onnxruntime with black and isort.

After checking in, we can use .git-blame-ignore-revs to ignore the formatting PR in git blame.

#11315, #11316
2022-04-26 09:35:16 -07:00

86 lines
3.2 KiB
Python

import argparse
import sys
import os
from concurrent.futures import ThreadPoolExecutor
from requests import Session
from threading import Event, Thread
from azureml.core import Workspace, Experiment, Run
from azureml._run_impl.run_watcher import RunWatcher
parser = argparse.ArgumentParser()
parser.add_argument(
"--subscription", type=str, default="ea482afa-3a32-437c-aa10-7de928a9e793"
) # AI Platform GPU - MLPerf
parser.add_argument(
"--resource_group", type=str, default="onnx_training", help="Azure resource group containing the AzureML Workspace"
)
parser.add_argument(
"--workspace", type=str, default="ort_training_dev", help="AzureML Workspace to run the Experiment in"
)
parser.add_argument("--experiment", type=str, default="BERT-ONNX", help="Name of the AzureML Experiment")
parser.add_argument("--run", type=str, default=None, help="The Experiment run to watch (defaults to the latest run)")
parser.add_argument("--remote_dir", type=str, default=None, help="Specify a remote directory to sync (read) from")
parser.add_argument("--local_dir", type=str, default=None, help="Specify a local directory to sync (write) to")
args = parser.parse_args()
# Validate
if (args.remote_dir and not args.local_dir) or (not args.remote_dir and args.local_dir):
print("Must specify both remote_dir and local_dir to sync files from Experiment")
sys.exit()
# Get the AzureML Workspace the Experiment is running in
ws = Workspace.get(name=args.workspace, subscription_id=args.subscription, resource_group=args.resource_group)
# Find the Experiment
experiment = Experiment(workspace=ws, name=args.experiment)
# Find the Run
runs = [r for r in experiment.get_runs()]
if len(runs) == 0:
print("No runs found in Experiment '{}'".format(args.experiment))
sys.exit()
run = runs[0]
if args.run is not None:
try:
run = next(r for r in runs if r.id == args.run)
except StopIteration:
print("Run id '{}' not found in Experiment '{}'".format(args.run, args.experiment))
sys.exit()
# Optionally start synchronizing files from Run
if args.remote_dir and args.local_dir:
local_root = os.path.normpath(args.local_dir)
remote_root = args.remote_dir
if run.get_status() in ["Completed", "Failed", "Canceled"]:
print(
"Downloading Experiment files from remote directory: '{}' to local directory: '{}'".format(
remote_root, local_root
)
)
files = [f for f in run.get_file_names() if f.startswith(remote_root)]
for remote_path in files:
local_path = os.path.join(local_root, os.path.basename(remote_path))
run.download_file(remote_path, local_path)
else:
executor = ThreadPoolExecutor()
event = Event()
session = Session()
print(
"Streaming Experiment files from remote directory: '{}' to local directory: '{}'".format(
remote_root, local_root
)
)
watcher = RunWatcher(
run, local_root=local_root, remote_root=remote_root, executor=executor, event=event, session=session
)
executor.submit(watcher.refresh_requeue)
# Block until run completes, to keep updating the files (if streaming)
run.wait_for_completion(show_output=True)