New perf metric - e2e throughput (#4085)

* new metric

* on comments

* tab to spaces

Co-authored-by: Ethan Tao <ettao@microsoft.com>
This commit is contained in:
ytaous 2020-06-01 12:11:34 -07:00 committed by GitHub
parent 70d91a8550
commit 72d508b7a0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 202 additions and 200 deletions

View file

@ -923,6 +923,15 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad
++epoch;
}
const double e2e_throughput = [&]() {
if (end_to_end_perf_start_step >= params_.num_train_steps) return 0.0;
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> duration_seconds = end - end_to_end_start;
const double total_e2e_time = duration_seconds.count();
const size_t end_to_end_step_count = params_.num_train_steps - std::max(step_start, end_to_end_perf_start_step);
return params_.batch_size * end_to_end_step_count / total_e2e_time;
}();
const size_t number_of_batches = step_ - step_start;
const size_t weight_update_steps = weight_update_step_count_ - weight_update_step_count_start;
const double avg_time_per_batch = total_time / (step_ - step_start) * 1000;
@ -937,19 +946,11 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad
ORT_RETURN_IF_ERROR(Env::Default().CreateFolder(params_.perf_output_dir));
// saving json file
ORT_RETURN_IF_ERROR(SavePerfMetrics(number_of_batches, gradient_accumulation_step_count, weight_update_steps,
total_time, avg_time_per_batch, throughput, stabilized_throughput, mapped_dimensions,
total_time, avg_time_per_batch, throughput, stabilized_throughput,
e2e_throughput, mapped_dimensions,
average_cpu_usage, peak_workingset_size));
}
double e2e_throughput{0};
if (end_to_end_perf_start_step < params_.num_train_steps) {
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> duration_seconds = end - end_to_end_start;
const double total_e2e_time = duration_seconds.count();
const size_t end_to_end_step_count = params_.num_train_steps - std::max(step_start, end_to_end_perf_start_step);
e2e_throughput = params_.batch_size * end_to_end_step_count / total_e2e_time;
}
std::cout << "Round: " << round_ << "\n"
<< "Batch size: " << params_.batch_size << "\n"
<< "Number of Batches: " << number_of_batches << "\n"
@ -967,7 +968,7 @@ Status TrainingRunner::TrainingLoop(IDataLoader& training_data_loader, IDataLoad
Status TrainingRunner::SavePerfMetrics(const size_t number_of_batches, const size_t gradient_accumulation_steps,
const size_t weight_update_steps, const double total_time,
const double avg_time_per_batch, const double throughput, const double stabilized_throughput,
const MapStringToString& mapped_dimensions,
const double e2e_throughput, const MapStringToString& mapped_dimensions,
const short average_cpu_usage, const size_t peak_workingset_size) {
// populate metrics for reporting
json perf_metrics;
@ -991,6 +992,7 @@ Status TrainingRunner::SavePerfMetrics(const size_t number_of_batches, const siz
perf_metrics["AvgTimePerBatch"] = avg_time_per_batch;
perf_metrics["Throughput"] = throughput;
perf_metrics["StabilizedThroughput"] = stabilized_throughput;
perf_metrics["EndToEndThroughput"] = e2e_throughput;
perf_metrics["UseMixedPrecision"] = params_.use_mixed_precision;
std::string optimizer = params_.training_optimizer_name;

View file

@ -216,7 +216,7 @@ class TrainingRunner {
Status SavePerfMetrics(const size_t number_of_batches, const size_t gradient_accumulation_steps,
const size_t weight_update_steps, const double total_time,
const double avg_time_per_batch, const double throughput, const double stabilized_throughput,
const MapStringToString& mapped_dimensions,
const double e2e_throughput, const MapStringToString& mapped_dimensions,
const short average_cpu_usage, const size_t peak_workingset_size);
size_t step_;

View file

@ -46,7 +46,7 @@ def main():
"--train_batch_size", str(c.batch_size),
"--mode", "train",
"--max_seq_length", str(c.max_seq_length),
"--num_train_steps", "100",
"--num_train_steps", "640",
"--display_loss_steps", "5",
"--optimizer", "Lamb",
"--learning_rate", "3e-3",

View file

@ -44,7 +44,7 @@ def main():
"--train_batch_size", str(c.batch_size),
"--mode", "train",
"--max_seq_length", str(c.max_seq_length),
"--num_train_steps", "200",
"--num_train_steps", "640",
"--gradient_accumulation_steps", "1",
"--perf_output_dir", os.path.join(SCRIPT_DIR, "results"),
]

View file

@ -1,56 +1,56 @@
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.msft</groupId>
<artifactId>send_perf_metrics</artifactId>
<version>0.0.1-SNAPSHOT</version>
<packaging>jar</packaging>
<groupId>com.msft</groupId>
<artifactId>send_perf_metrics</artifactId>
<version>0.0.1-SNAPSHOT</version>
<packaging>jar</packaging>
<name>send_perf_metrics</name>
<url>http://maven.apache.org</url>
<build>
<name>send_perf_metrics</name>
<url>http://maven.apache.org</url>
<build>
<plugins>
<plugin>
<artifactId>maven-assembly-plugin</artifactId>
<version>3.1.1</version>
<configuration>
<descriptorRefs>
<descriptorRef>jar-with-dependencies</descriptorRef>
</descriptorRefs>
</configuration>
<executions>
<execution>
<id>make-assembly</id> <!-- this is used for inheritance merges -->
<phase>package</phase> <!-- bind to the packaging phase -->
<goals>
<goal>single</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
</properties>
<plugins>
<plugin>
<artifactId>maven-assembly-plugin</artifactId>
<version>3.1.1</version>
<configuration>
<descriptorRefs>
<descriptorRef>jar-with-dependencies</descriptorRef>
</descriptorRefs>
</configuration>
<executions>
<execution>
<id>make-assembly</id> <!-- this is used for inheritance merges -->
<phase>package</phase> <!-- bind to the packaging phase -->
<goals>
<goal>single</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
</properties>
<dependencies>
<!-- https://mvnrepository.com/artifact/com.googlecode.json-simple/json-simple -->
<dependency>
<groupId>com.googlecode.json-simple</groupId>
<artifactId>json-simple</artifactId>
<version>1.1.1</version>
</dependency>
<dependencies>
<!-- https://mvnrepository.com/artifact/com.googlecode.json-simple/json-simple -->
<dependency>
<groupId>com.googlecode.json-simple</groupId>
<artifactId>json-simple</artifactId>
<version>1.1.1</version>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>8.0.15</version>
</dependency>
</dependencies>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>8.0.15</version>
</dependency>
</dependencies>
</project>

View file

@ -18,151 +18,151 @@ import java.util.*;
public class App {
static String exec_command(Path source_dir, String... commands) throws Exception {
ProcessBuilder sb = new ProcessBuilder(commands).directory(source_dir.toFile()).redirectErrorStream(true);
Process p = sb.start();
if (p.waitFor() != 0)
throw new RuntimeException("execute " + String.join(" ", commands) + " failed");
try (BufferedReader r = new BufferedReader(new InputStreamReader(p.getInputStream()))) {
return r.readLine();
}
}
static String exec_command(Path source_dir, String... commands) throws Exception {
ProcessBuilder sb = new ProcessBuilder(commands).directory(source_dir.toFile()).redirectErrorStream(true);
Process p = sb.start();
if (p.waitFor() != 0)
throw new RuntimeException("execute " + String.join(" ", commands) + " failed");
try (BufferedReader r = new BufferedReader(new InputStreamReader(p.getInputStream()))) {
return r.readLine();
}
}
public static void main(String[] args) throws Exception {
public static void main(String[] args) throws Exception {
final Path source_dir = Paths.get(args[0]);
final List<Path> perf_metrics = new ArrayList<Path>();
Files.walkFileTree(source_dir, new SimpleFileVisitor<Path>() {
final Path source_dir = Paths.get(args[0]);
final List<Path> perf_metrics = new ArrayList<Path>();
Files.walkFileTree(source_dir, new SimpleFileVisitor<Path>() {
@Override
public FileVisitResult preVisitDirectory(Path dir, BasicFileAttributes attrs) throws IOException {
String dirname = dir.getFileName().toString();
if (dirname != "." && dirname.startsWith("."))
return FileVisitResult.SKIP_SUBTREE;
return FileVisitResult.CONTINUE;
}
@Override
public FileVisitResult preVisitDirectory(Path dir, BasicFileAttributes attrs) throws IOException {
String dirname = dir.getFileName().toString();
if (dirname != "." && dirname.startsWith("."))
return FileVisitResult.SKIP_SUBTREE;
return FileVisitResult.CONTINUE;
}
@Override
public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException {
String filename = file.getFileName().toString();
@Override
public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException {
String filename = file.getFileName().toString();
if (!filename.startsWith(".") && filename.endsWith(".json")) {
perf_metrics.add(file);
System.out.println(filename);
}
return FileVisitResult.CONTINUE;
}
if (!filename.startsWith(".") && filename.endsWith(".json")) {
perf_metrics.add(file);
System.out.println(filename);
}
return FileVisitResult.CONTINUE;
}
});
});
final Path cwd_dir = Paths.get(System.getProperty("user.dir"));
// git rev-parse HEAD
String commit_id = exec_command(cwd_dir, "git", "rev-parse", "HEAD");
String date = exec_command(cwd_dir, "git", "show", "-s", "--format=%ci", commit_id);
final SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss Z");
java.util.Date commitDate = sdf.parse(date);
final SimpleDateFormat simple_date_format = new SimpleDateFormat("yyyy-MM-dd");
String batch_id = simple_date_format.format(commitDate);
System.out.println(String.format("Commit change date: %s", batch_id));
final Path cwd_dir = Paths.get(System.getProperty("user.dir"));
// git rev-parse HEAD
String commit_id = exec_command(cwd_dir, "git", "rev-parse", "HEAD");
String date = exec_command(cwd_dir, "git", "show", "-s", "--format=%ci", commit_id);
final SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss Z");
java.util.Date commitDate = sdf.parse(date);
final SimpleDateFormat simple_date_format = new SimpleDateFormat("yyyy-MM-dd");
String batch_id = simple_date_format.format(commitDate);
System.out.println(String.format("Commit change date: %s", batch_id));
// collect all json files list
processPerfMetrics(perf_metrics, commit_id, batch_id);
// collect all json files list
processPerfMetrics(perf_metrics, commit_id, batch_id);
// TODO - add e2e tests later, run it w/ process command
}
// TODO - add e2e tests later, run it w/ process command
}
private static void processPerfMetrics(final List<Path> perf_metrics, String commit_id,
String batch_id) throws Exception {
try {
Connection conn = JdbcUtil.GetConn();
System.out.println("MySQL DB connection established.\n");
// go thru each json file
JSONParser jsonParser = new JSONParser();
for (Path metrics_json : perf_metrics) {
try (FileReader reader = new FileReader(metrics_json.toAbsolutePath().toString())) {
// Read JSON file
Object obj = jsonParser.parse(reader);
loadMetricsIntoMySQL(conn, commit_id, batch_id, (JSONObject) obj);
}
}
} catch (Exception e) {
e.printStackTrace();
throw e;
}
}
private static void processPerfMetrics(final List<Path> perf_metrics, String commit_id,
String batch_id) throws Exception {
try {
Connection conn = JdbcUtil.GetConn();
System.out.println("MySQL DB connection established.\n");
// go thru each json file
JSONParser jsonParser = new JSONParser();
for (Path metrics_json : perf_metrics) {
try (FileReader reader = new FileReader(metrics_json.toAbsolutePath().toString())) {
// Read JSON file
Object obj = jsonParser.parse(reader);
loadMetricsIntoMySQL(conn, commit_id, batch_id, (JSONObject) obj);
}
}
} catch (Exception e) {
e.printStackTrace();
throw e;
}
}
static private void loadMetricsIntoMySQL(java.sql.Connection conn, String commit_id, String batch_id,
JSONObject json_object) throws Exception {
static private void loadMetricsIntoMySQL(java.sql.Connection conn, String commit_id, String batch_id,
JSONObject json_object) throws Exception {
// field name -> json value
Map<String, Object> field_mapping = new LinkedHashMap();
Set<String> update_on_duplicate_fields =
new LinkedHashSet<> (Arrays.asList("AvgTimePerBatch", "Throughput", "StabilizedThroughput", "TotalTime", "AvgCPU", "Memory"));
// field name -> json value
Map<String, Object> field_mapping = new LinkedHashMap();
Set<String> update_on_duplicate_fields =
new LinkedHashSet<> (Arrays.asList("AvgTimePerBatch", "Throughput", "StabilizedThroughput", "EndToEndThroughput", "TotalTime", "AvgCPU", "Memory"));
field_mapping.put("BatchId", batch_id);
field_mapping.put("CommitId", commit_id.substring(0, 8));
json_object.forEach((key, value) -> {
if (key.equals("DerivedProperties")) {
JSONObject properties = (JSONObject) json_object.get("DerivedProperties");
properties.forEach((sub_key, sub_value) -> {
field_mapping.put((String)sub_key, sub_value);
});
} else {
field_mapping.put((String)key, value);
}
});
field_mapping.put("BatchId", batch_id);
field_mapping.put("CommitId", commit_id.substring(0, 8));
json_object.forEach((key, value) -> {
if (key.equals("DerivedProperties")) {
JSONObject properties = (JSONObject) json_object.get("DerivedProperties");
properties.forEach((sub_key, sub_value) -> {
field_mapping.put((String)sub_key, sub_value);
});
} else {
field_mapping.put((String)key, value);
}
});
// building sql statement
StringBuilder sb = new StringBuilder("INSERT INTO perf_test_training_data (");
field_mapping.forEach((key, value) -> {
sb.append(key).append(",");
});
sb.append("Time) values (");
for(int i = 0; i < field_mapping.size(); i++) {
sb.append("?,");
}
sb.append("Now()) ON DUPLICATE KEY UPDATE ");
update_on_duplicate_fields.forEach((key) -> {
if(field_mapping.get(key) != null) {
sb.append(key).append("=?,");
}
});
// building sql statement
StringBuilder sb = new StringBuilder("INSERT INTO perf_test_training_data (");
field_mapping.forEach((key, value) -> {
sb.append(key).append(",");
});
sb.append("Time) values (");
for(int i = 0; i < field_mapping.size(); i++) {
sb.append("?,");
}
sb.append("Now()) ON DUPLICATE KEY UPDATE ");
update_on_duplicate_fields.forEach((key) -> {
if(field_mapping.get(key) != null) {
sb.append(key).append("=?,");
}
});
try (java.sql.PreparedStatement st = conn.prepareStatement(sb.substring(0, sb.length() - 1))) {
int i = 0; // param index
for (Map.Entry<String, Object> entry : field_mapping.entrySet()) {
setSqlParam(++i, st, entry.getValue());
}
try (java.sql.PreparedStatement st = conn.prepareStatement(sb.substring(0, sb.length() - 1))) {
int i = 0; // param index
for (Map.Entry<String, Object> entry : field_mapping.entrySet()) {
setSqlParam(++i, st, entry.getValue());
}
// update section
for(String key : update_on_duplicate_fields) {
Object value = field_mapping.get(key);
if(value != null) {
setSqlParam(++i, st, value);
}
}
// update section
for(String key : update_on_duplicate_fields) {
Object value = field_mapping.get(key);
if(value != null) {
setSqlParam(++i, st, value);
}
}
st.executeUpdate();
} catch (Exception e) {
e.printStackTrace();
throw e;
}
st.executeUpdate();
} catch (Exception e) {
e.printStackTrace();
throw e;
}
}
}
static void setSqlParam(int param_index, PreparedStatement st, Object value) throws Exception {
if (value instanceof String) {
st.setString(param_index, (String) value);
} else if (value instanceof Long) {
st.setInt(param_index, (int) (long) value);
} else if (value instanceof Double) {
st.setFloat(param_index, (float) (double) value);
} else if (value instanceof Boolean) {
st.setBoolean(param_index, (Boolean) value);
} else {
throw new Exception("Unsupported data type:" + value.getClass().getName());
}
}
static void setSqlParam(int param_index, PreparedStatement st, Object value) throws Exception {
if (value instanceof String) {
st.setString(param_index, (String) value);
} else if (value instanceof Long) {
st.setInt(param_index, (int) (long) value);
} else if (value instanceof Double) {
st.setFloat(param_index, (float) (double) value);
} else if (value instanceof Boolean) {
st.setBoolean(param_index, (Boolean) value);
} else {
throw new Exception("Unsupported data type:" + value.getClass().getName());
}
}
}

View file

@ -5,15 +5,15 @@ import java.util.Map;
import java.util.Properties;
public class JdbcUtil {
static java.sql.Connection GetConn() throws Exception {
try (java.io.InputStream in = App.class.getResourceAsStream("/jdbc.properties")) {
if (in == null)
throw new RuntimeException("Error reading jdbc properties");
Properties props = new Properties();
props.load(in);
// loading password via env variable
return DriverManager.getConnection(props.getProperty("url"), props.getProperty("user"),
System.getenv(props.getProperty("password_env")));
}
}
static java.sql.Connection GetConn() throws Exception {
try (java.io.InputStream in = App.class.getResourceAsStream("/jdbc.properties")) {
if (in == null)
throw new RuntimeException("Error reading jdbc properties");
Properties props = new Properties();
props.load(in);
// loading password via env variable
return DriverManager.getConnection(props.getProperty("url"), props.getProperty("user"),
System.getenv(props.getProperty("password_env")));
}
}
}