onnxruntime/orttraining/orttraining/python/training/utils/ptable.py
pengwa ccf3b2054b
Allow layer-wise recompute (#18566)
### Allow layer-wise recompute 

Early, we need users/developers to specify the subgraphs to recompute,
now we introduced a more user-friendly way to enable recompute for all
detected stashed activation recomputation subgraphs. This scarifies
getting the best configs while makes it easier to support user
requirements when they switches from PyTorch per-layer gradient
checkpoint to ORTModule.

`ORTMODULE_MEMORY_OPT_LEVEL` is introduced to control the usage, by
default, it is 0, e.g. `USER_SPECIFIED`, all subgraphs definedin
`ORTMODULE_MEMORY_OPT_CONFIG` will be recomputed. So this is compatible
to existing recompute usage in ORTModule integrated models.

Using `ORTMODULE_MEMORY_OPT_LEVEL=1`, we will enable all recompute plans
detected, so those configs in `ORTMODULE_MEMORY_OPT_CONFIG` will not be
respected any more.


Add Unit Tests using 3 layer blooms. 



https://github.com/microsoft/onnxruntime/blob/pengwa/add_aggresive_recompute/docs/Memory_Optimizer.md
2023-12-12 08:44:05 +08:00

73 lines
2.7 KiB
Python

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from typing import List
class Row:
"""A row in a PTable"""
def __init__(self, columns: List[str]) -> None:
self._columns: List[str] = columns # List of strings
self._annotation_table = None # Optional PTable used for displaying detailed information about the feature row.
def append_annotation_table(self, ptable) -> None:
self._annotation_table = ptable
class PTable:
"""A table that can be printed to the console."""
def __init__(self, sortable=False) -> None:
self._rows: List[Row] = []
self._column_count = None
self._sortable = sortable # allow the rows to be sorted by the first column
def add_row(self, columns: List[str]) -> Row:
"""Add a row to the table. The number of columns must match the number of columns in the table."""
if self._column_count is None:
self._column_count = len(columns)
assert self._column_count == len(columns)
row = Row(columns)
self._rows.append(row)
return row
def get_string(self, first_column_width=None, second_column_width=None) -> str:
"""Serialize the table to a string."""
if len(self._rows) == 0:
return ""
# Collect the max width of each column
column_widths = []
for row in self._rows:
if column_widths:
assert len(column_widths) == len(row._columns)
else:
column_widths = [0] * len(row._columns)
for i, column in enumerate(row._columns):
column_widths[i] = max(column_widths[i], len(str(column)))
if first_column_width:
column_widths[0] = max(first_column_width, column_widths[0])
if second_column_width:
column_widths[2] = max(second_column_width, column_widths[2])
serialized_table = ""
if self._sortable:
sorted_rows = sorted(self._rows, key=lambda row: row._columns[0])
else:
sorted_rows = self._rows
for row in sorted_rows:
for i, column in enumerate(row._columns):
serialized_table += f"{str(column).ljust(column_widths[i] + 2)}"
serialized_table += "\n"
if row._annotation_table:
serialized_table += row._annotation_table.get_string(
first_column_width=column_widths[0], second_column_width=column_widths[2]
)
return serialized_table