From 72d508b7a088ea942e1c8d98a70a665763b900de Mon Sep 17 00:00:00 2001 From: ytaous <4484531+ytaous@users.noreply.github.com> Date: Mon, 1 Jun 2020 12:11:34 -0700 Subject: [PATCH] New perf metric - e2e throughput (#4085) * new metric * on comments * tab to spaces Co-authored-by: Ethan Tao --- .../models/runner/training_runner.cc | 24 +- .../models/runner/training_runner.h | 2 +- .../tools/ci_test/run_bert_perf_test.py | 2 +- .../tools/ci_test/run_gpt2_perf_test.py | 2 +- tools/perf_util/pom.xml | 98 +++---- .../java/com/msft/send_perf_metrics/App.java | 252 +++++++++--------- .../com/msft/send_perf_metrics/JdbcUtil.java | 22 +- 7 files changed, 202 insertions(+), 200 deletions(-) diff --git a/orttraining/orttraining/models/runner/training_runner.cc b/orttraining/orttraining/models/runner/training_runner.cc index ea8864e10a..e7c19d8008 100644 --- a/orttraining/orttraining/models/runner/training_runner.cc +++ b/orttraining/orttraining/models/runner/training_runner.cc @@ -923,6 +923,15 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad ++epoch; } + const double e2e_throughput = [&]() { + if (end_to_end_perf_start_step >= params_.num_train_steps) return 0.0; + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration duration_seconds = end - end_to_end_start; + const double total_e2e_time = duration_seconds.count(); + const size_t end_to_end_step_count = params_.num_train_steps - std::max(step_start, end_to_end_perf_start_step); + return params_.batch_size * end_to_end_step_count / total_e2e_time; + }(); + const size_t number_of_batches = step_ - step_start; const size_t weight_update_steps = weight_update_step_count_ - weight_update_step_count_start; const double avg_time_per_batch = total_time / (step_ - step_start) * 1000; @@ -937,19 +946,11 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad ORT_RETURN_IF_ERROR(Env::Default().CreateFolder(params_.perf_output_dir)); // saving json file ORT_RETURN_IF_ERROR(SavePerfMetrics(number_of_batches, gradient_accumulation_step_count, weight_update_steps, - total_time, avg_time_per_batch, throughput, stabilized_throughput, mapped_dimensions, + total_time, avg_time_per_batch, throughput, stabilized_throughput, + e2e_throughput, mapped_dimensions, average_cpu_usage, peak_workingset_size)); } - double e2e_throughput{0}; - if (end_to_end_perf_start_step < params_.num_train_steps) { - auto end = std::chrono::high_resolution_clock::now(); - std::chrono::duration duration_seconds = end - end_to_end_start; - const double total_e2e_time = duration_seconds.count(); - const size_t end_to_end_step_count = params_.num_train_steps - std::max(step_start, end_to_end_perf_start_step); - e2e_throughput = params_.batch_size * end_to_end_step_count / total_e2e_time; - } - std::cout << "Round: " << round_ << "\n" << "Batch size: " << params_.batch_size << "\n" << "Number of Batches: " << number_of_batches << "\n" @@ -967,7 +968,7 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad Status TrainingRunner::SavePerfMetrics(const size_t number_of_batches, const size_t gradient_accumulation_steps, const size_t weight_update_steps, const double total_time, const double avg_time_per_batch, const double throughput, const double stabilized_throughput, - const MapStringToString& mapped_dimensions, + const double e2e_throughput, const MapStringToString& mapped_dimensions, const short average_cpu_usage, const size_t peak_workingset_size) { // populate metrics for reporting json perf_metrics; @@ -991,6 +992,7 @@ Status TrainingRunner::SavePerfMetrics(const size_t number_of_batches, const siz perf_metrics["AvgTimePerBatch"] = avg_time_per_batch; perf_metrics["Throughput"] = throughput; perf_metrics["StabilizedThroughput"] = stabilized_throughput; + perf_metrics["EndToEndThroughput"] = e2e_throughput; perf_metrics["UseMixedPrecision"] = params_.use_mixed_precision; std::string optimizer = params_.training_optimizer_name; diff --git a/orttraining/orttraining/models/runner/training_runner.h b/orttraining/orttraining/models/runner/training_runner.h index 6f861d099c..fc3a4d7fed 100644 --- a/orttraining/orttraining/models/runner/training_runner.h +++ b/orttraining/orttraining/models/runner/training_runner.h @@ -216,7 +216,7 @@ class TrainingRunner { Status SavePerfMetrics(const size_t number_of_batches, const size_t gradient_accumulation_steps, const size_t weight_update_steps, const double total_time, const double avg_time_per_batch, const double throughput, const double stabilized_throughput, - const MapStringToString& mapped_dimensions, + const double e2e_throughput, const MapStringToString& mapped_dimensions, const short average_cpu_usage, const size_t peak_workingset_size); size_t step_; diff --git a/orttraining/tools/ci_test/run_bert_perf_test.py b/orttraining/tools/ci_test/run_bert_perf_test.py index 4290d9edc9..b11aad9127 100644 --- a/orttraining/tools/ci_test/run_bert_perf_test.py +++ b/orttraining/tools/ci_test/run_bert_perf_test.py @@ -46,7 +46,7 @@ def main(): "--train_batch_size", str(c.batch_size), "--mode", "train", "--max_seq_length", str(c.max_seq_length), - "--num_train_steps", "100", + "--num_train_steps", "640", "--display_loss_steps", "5", "--optimizer", "Lamb", "--learning_rate", "3e-3", diff --git a/orttraining/tools/ci_test/run_gpt2_perf_test.py b/orttraining/tools/ci_test/run_gpt2_perf_test.py index 75de8a83a5..8f4594a4d1 100644 --- a/orttraining/tools/ci_test/run_gpt2_perf_test.py +++ b/orttraining/tools/ci_test/run_gpt2_perf_test.py @@ -44,7 +44,7 @@ def main(): "--train_batch_size", str(c.batch_size), "--mode", "train", "--max_seq_length", str(c.max_seq_length), - "--num_train_steps", "200", + "--num_train_steps", "640", "--gradient_accumulation_steps", "1", "--perf_output_dir", os.path.join(SCRIPT_DIR, "results"), ] diff --git a/tools/perf_util/pom.xml b/tools/perf_util/pom.xml index 1a2d79079b..1f8b7badbe 100644 --- a/tools/perf_util/pom.xml +++ b/tools/perf_util/pom.xml @@ -1,56 +1,56 @@ - 4.0.0 + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + 4.0.0 - com.msft - send_perf_metrics - 0.0.1-SNAPSHOT - jar + com.msft + send_perf_metrics + 0.0.1-SNAPSHOT + jar - send_perf_metrics - http://maven.apache.org - + send_perf_metrics + http://maven.apache.org + - - - maven-assembly-plugin - 3.1.1 - - - jar-with-dependencies - - - - - make-assembly - package - - single - - - - - - - - UTF-8 - 1.8 - 1.8 - + + + maven-assembly-plugin + 3.1.1 + + + jar-with-dependencies + + + + + make-assembly + package + + single + + + + + + + + UTF-8 + 1.8 + 1.8 + - - - - com.googlecode.json-simple - json-simple - 1.1.1 - + + + + com.googlecode.json-simple + json-simple + 1.1.1 + - - mysql - mysql-connector-java - 8.0.15 - - + + mysql + mysql-connector-java + 8.0.15 + + diff --git a/tools/perf_util/src/main/java/com/msft/send_perf_metrics/App.java b/tools/perf_util/src/main/java/com/msft/send_perf_metrics/App.java index 665501bc56..a0a04211f6 100644 --- a/tools/perf_util/src/main/java/com/msft/send_perf_metrics/App.java +++ b/tools/perf_util/src/main/java/com/msft/send_perf_metrics/App.java @@ -18,151 +18,151 @@ import java.util.*; public class App { - static String exec_command(Path source_dir, String... commands) throws Exception { - ProcessBuilder sb = new ProcessBuilder(commands).directory(source_dir.toFile()).redirectErrorStream(true); - Process p = sb.start(); - if (p.waitFor() != 0) - throw new RuntimeException("execute " + String.join(" ", commands) + " failed"); - try (BufferedReader r = new BufferedReader(new InputStreamReader(p.getInputStream()))) { - return r.readLine(); - } - } + static String exec_command(Path source_dir, String... commands) throws Exception { + ProcessBuilder sb = new ProcessBuilder(commands).directory(source_dir.toFile()).redirectErrorStream(true); + Process p = sb.start(); + if (p.waitFor() != 0) + throw new RuntimeException("execute " + String.join(" ", commands) + " failed"); + try (BufferedReader r = new BufferedReader(new InputStreamReader(p.getInputStream()))) { + return r.readLine(); + } + } - public static void main(String[] args) throws Exception { + public static void main(String[] args) throws Exception { - final Path source_dir = Paths.get(args[0]); - final List perf_metrics = new ArrayList(); - Files.walkFileTree(source_dir, new SimpleFileVisitor() { + final Path source_dir = Paths.get(args[0]); + final List perf_metrics = new ArrayList(); + Files.walkFileTree(source_dir, new SimpleFileVisitor() { - @Override - public FileVisitResult preVisitDirectory(Path dir, BasicFileAttributes attrs) throws IOException { - String dirname = dir.getFileName().toString(); - if (dirname != "." && dirname.startsWith(".")) - return FileVisitResult.SKIP_SUBTREE; - return FileVisitResult.CONTINUE; - } + @Override + public FileVisitResult preVisitDirectory(Path dir, BasicFileAttributes attrs) throws IOException { + String dirname = dir.getFileName().toString(); + if (dirname != "." && dirname.startsWith(".")) + return FileVisitResult.SKIP_SUBTREE; + return FileVisitResult.CONTINUE; + } - @Override - public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException { - String filename = file.getFileName().toString(); + @Override + public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException { + String filename = file.getFileName().toString(); - if (!filename.startsWith(".") && filename.endsWith(".json")) { - perf_metrics.add(file); - System.out.println(filename); - } - return FileVisitResult.CONTINUE; - } + if (!filename.startsWith(".") && filename.endsWith(".json")) { + perf_metrics.add(file); + System.out.println(filename); + } + return FileVisitResult.CONTINUE; + } - }); + }); - final Path cwd_dir = Paths.get(System.getProperty("user.dir")); - // git rev-parse HEAD - String commit_id = exec_command(cwd_dir, "git", "rev-parse", "HEAD"); - String date = exec_command(cwd_dir, "git", "show", "-s", "--format=%ci", commit_id); - final SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss Z"); - java.util.Date commitDate = sdf.parse(date); - final SimpleDateFormat simple_date_format = new SimpleDateFormat("yyyy-MM-dd"); - String batch_id = simple_date_format.format(commitDate); - System.out.println(String.format("Commit change date: %s", batch_id)); + final Path cwd_dir = Paths.get(System.getProperty("user.dir")); + // git rev-parse HEAD + String commit_id = exec_command(cwd_dir, "git", "rev-parse", "HEAD"); + String date = exec_command(cwd_dir, "git", "show", "-s", "--format=%ci", commit_id); + final SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss Z"); + java.util.Date commitDate = sdf.parse(date); + final SimpleDateFormat simple_date_format = new SimpleDateFormat("yyyy-MM-dd"); + String batch_id = simple_date_format.format(commitDate); + System.out.println(String.format("Commit change date: %s", batch_id)); - // collect all json files list - processPerfMetrics(perf_metrics, commit_id, batch_id); + // collect all json files list + processPerfMetrics(perf_metrics, commit_id, batch_id); - // TODO - add e2e tests later, run it w/ process command - } + // TODO - add e2e tests later, run it w/ process command + } - private static void processPerfMetrics(final List perf_metrics, String commit_id, - String batch_id) throws Exception { - try { - Connection conn = JdbcUtil.GetConn(); - System.out.println("MySQL DB connection established.\n"); - // go thru each json file - JSONParser jsonParser = new JSONParser(); - for (Path metrics_json : perf_metrics) { - try (FileReader reader = new FileReader(metrics_json.toAbsolutePath().toString())) { - // Read JSON file - Object obj = jsonParser.parse(reader); - loadMetricsIntoMySQL(conn, commit_id, batch_id, (JSONObject) obj); - } - } - } catch (Exception e) { - e.printStackTrace(); - throw e; - } - } + private static void processPerfMetrics(final List perf_metrics, String commit_id, + String batch_id) throws Exception { + try { + Connection conn = JdbcUtil.GetConn(); + System.out.println("MySQL DB connection established.\n"); + // go thru each json file + JSONParser jsonParser = new JSONParser(); + for (Path metrics_json : perf_metrics) { + try (FileReader reader = new FileReader(metrics_json.toAbsolutePath().toString())) { + // Read JSON file + Object obj = jsonParser.parse(reader); + loadMetricsIntoMySQL(conn, commit_id, batch_id, (JSONObject) obj); + } + } + } catch (Exception e) { + e.printStackTrace(); + throw e; + } + } - static private void loadMetricsIntoMySQL(java.sql.Connection conn, String commit_id, String batch_id, - JSONObject json_object) throws Exception { + static private void loadMetricsIntoMySQL(java.sql.Connection conn, String commit_id, String batch_id, + JSONObject json_object) throws Exception { - // field name -> json value - Map field_mapping = new LinkedHashMap(); - Set update_on_duplicate_fields = - new LinkedHashSet<> (Arrays.asList("AvgTimePerBatch", "Throughput", "StabilizedThroughput", "TotalTime", "AvgCPU", "Memory")); + // field name -> json value + Map field_mapping = new LinkedHashMap(); + Set update_on_duplicate_fields = + new LinkedHashSet<> (Arrays.asList("AvgTimePerBatch", "Throughput", "StabilizedThroughput", "EndToEndThroughput", "TotalTime", "AvgCPU", "Memory")); - field_mapping.put("BatchId", batch_id); - field_mapping.put("CommitId", commit_id.substring(0, 8)); - json_object.forEach((key, value) -> { - if (key.equals("DerivedProperties")) { - JSONObject properties = (JSONObject) json_object.get("DerivedProperties"); - properties.forEach((sub_key, sub_value) -> { - field_mapping.put((String)sub_key, sub_value); - }); - } else { - field_mapping.put((String)key, value); - } - }); + field_mapping.put("BatchId", batch_id); + field_mapping.put("CommitId", commit_id.substring(0, 8)); + json_object.forEach((key, value) -> { + if (key.equals("DerivedProperties")) { + JSONObject properties = (JSONObject) json_object.get("DerivedProperties"); + properties.forEach((sub_key, sub_value) -> { + field_mapping.put((String)sub_key, sub_value); + }); + } else { + field_mapping.put((String)key, value); + } + }); - // building sql statement - StringBuilder sb = new StringBuilder("INSERT INTO perf_test_training_data ("); - field_mapping.forEach((key, value) -> { - sb.append(key).append(","); - }); - sb.append("Time) values ("); - for(int i = 0; i < field_mapping.size(); i++) { - sb.append("?,"); - } - sb.append("Now()) ON DUPLICATE KEY UPDATE "); - update_on_duplicate_fields.forEach((key) -> { - if(field_mapping.get(key) != null) { - sb.append(key).append("=?,"); - } - }); + // building sql statement + StringBuilder sb = new StringBuilder("INSERT INTO perf_test_training_data ("); + field_mapping.forEach((key, value) -> { + sb.append(key).append(","); + }); + sb.append("Time) values ("); + for(int i = 0; i < field_mapping.size(); i++) { + sb.append("?,"); + } + sb.append("Now()) ON DUPLICATE KEY UPDATE "); + update_on_duplicate_fields.forEach((key) -> { + if(field_mapping.get(key) != null) { + sb.append(key).append("=?,"); + } + }); - try (java.sql.PreparedStatement st = conn.prepareStatement(sb.substring(0, sb.length() - 1))) { - int i = 0; // param index - for (Map.Entry entry : field_mapping.entrySet()) { - setSqlParam(++i, st, entry.getValue()); - } + try (java.sql.PreparedStatement st = conn.prepareStatement(sb.substring(0, sb.length() - 1))) { + int i = 0; // param index + for (Map.Entry entry : field_mapping.entrySet()) { + setSqlParam(++i, st, entry.getValue()); + } - // update section - for(String key : update_on_duplicate_fields) { - Object value = field_mapping.get(key); - if(value != null) { - setSqlParam(++i, st, value); - } - } + // update section + for(String key : update_on_duplicate_fields) { + Object value = field_mapping.get(key); + if(value != null) { + setSqlParam(++i, st, value); + } + } - st.executeUpdate(); - } catch (Exception e) { - e.printStackTrace(); - throw e; - } + st.executeUpdate(); + } catch (Exception e) { + e.printStackTrace(); + throw e; + } - } + } - static void setSqlParam(int param_index, PreparedStatement st, Object value) throws Exception { - if (value instanceof String) { - st.setString(param_index, (String) value); - } else if (value instanceof Long) { - st.setInt(param_index, (int) (long) value); - } else if (value instanceof Double) { - st.setFloat(param_index, (float) (double) value); - } else if (value instanceof Boolean) { - st.setBoolean(param_index, (Boolean) value); - } else { - throw new Exception("Unsupported data type:" + value.getClass().getName()); - } - } + static void setSqlParam(int param_index, PreparedStatement st, Object value) throws Exception { + if (value instanceof String) { + st.setString(param_index, (String) value); + } else if (value instanceof Long) { + st.setInt(param_index, (int) (long) value); + } else if (value instanceof Double) { + st.setFloat(param_index, (float) (double) value); + } else if (value instanceof Boolean) { + st.setBoolean(param_index, (Boolean) value); + } else { + throw new Exception("Unsupported data type:" + value.getClass().getName()); + } + } } diff --git a/tools/perf_util/src/main/java/com/msft/send_perf_metrics/JdbcUtil.java b/tools/perf_util/src/main/java/com/msft/send_perf_metrics/JdbcUtil.java index 58cedbeee5..d7132f7f26 100644 --- a/tools/perf_util/src/main/java/com/msft/send_perf_metrics/JdbcUtil.java +++ b/tools/perf_util/src/main/java/com/msft/send_perf_metrics/JdbcUtil.java @@ -5,15 +5,15 @@ import java.util.Map; import java.util.Properties; public class JdbcUtil { - static java.sql.Connection GetConn() throws Exception { - try (java.io.InputStream in = App.class.getResourceAsStream("/jdbc.properties")) { - if (in == null) - throw new RuntimeException("Error reading jdbc properties"); - Properties props = new Properties(); - props.load(in); - // loading password via env variable - return DriverManager.getConnection(props.getProperty("url"), props.getProperty("user"), - System.getenv(props.getProperty("password_env"))); - } - } + static java.sql.Connection GetConn() throws Exception { + try (java.io.InputStream in = App.class.getResourceAsStream("/jdbc.properties")) { + if (in == null) + throw new RuntimeException("Error reading jdbc properties"); + Properties props = new Properties(); + props.load(in); + // loading password via env variable + return DriverManager.getConnection(props.getProperty("url"), props.getProperty("user"), + System.getenv(props.getProperty("password_env"))); + } + } }