mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[ONNX][bench] Deepcopy model to another device before export to avoid OOM (#118710)
Prior to onnx export, the model is deepcopied to avoid modifications that may affect later performance profiling. However this increases the memory requirement on the device. This PR modifies the script to deepcopy and export the model on another device when possible. Pull Request resolved: https://github.com/pytorch/pytorch/pull/118710 Approved by: https://github.com/thiagocrepaldi
This commit is contained in:
parent
21ce53b9c5
commit
30f43e3d89
1 changed files with 37 additions and 2 deletions
|
|
@ -1222,6 +1222,31 @@ class OnnxModel(abc.ABC):
|
|||
self.model_dir / f"{model_name}_{self._COMPILER_NAME}.onnx"
|
||||
)
|
||||
|
||||
def _determine_deepcopy_target_device(self):
|
||||
if current_device == "cpu":
|
||||
target_device = "cpu"
|
||||
else:
|
||||
if torch.cuda.device_count() > 1:
|
||||
# Copy to another cuda device to avoid OOM.
|
||||
target_device = "cuda:1"
|
||||
else:
|
||||
target_device = "cuda"
|
||||
return target_device
|
||||
|
||||
def deepcopy_model_and_inputs_to_device(self, model, example_inputs, target_device):
|
||||
# Deepcopy model before export to avoid modification to baseline model.
|
||||
# To avoid OOM, the model is first moved to CPU. Both models are then moved to device.
|
||||
model_device = next(model.parameters()).device
|
||||
model.to("cpu")
|
||||
model_copy = copy.deepcopy(model).to(target_device)
|
||||
model.to(model_device)
|
||||
|
||||
target_device_example_inputs = tree_map_only(
|
||||
torch.Tensor, lambda x: x.to(device=target_device), example_inputs
|
||||
)
|
||||
|
||||
return model_copy, target_device_example_inputs
|
||||
|
||||
@classmethod
|
||||
def _generate_onnx_model_directory(
|
||||
cls, output_directory: str, compiler_name: str, model_name: str
|
||||
|
|
@ -1404,7 +1429,9 @@ class OnnxModelFromTorchScript(OnnxModel):
|
|||
def _export(self, model, example_inputs, output_path: str, /, **kwargs) -> None:
|
||||
if self.copy_before_export:
|
||||
# Deepcopy model before export to avoid modification to baseline model.
|
||||
model = copy.deepcopy(model)
|
||||
model, example_inputs = self.deepcopy_model_and_inputs_to_device(
|
||||
model, example_inputs, self._determine_deepcopy_target_device()
|
||||
)
|
||||
|
||||
# Hack for huggingface models (kwargs only).
|
||||
if isinstance(example_inputs, dict):
|
||||
|
|
@ -1486,7 +1513,9 @@ class OnnxModelFromDynamo(OnnxModel):
|
|||
) -> torch.onnx.ONNXProgram:
|
||||
if self.copy_before_export:
|
||||
# Deepcopy model before export to avoid modification to baseline model.
|
||||
model = copy.deepcopy(model)
|
||||
model, example_inputs = self.deepcopy_model_and_inputs_to_device(
|
||||
model, example_inputs, self._determine_deepcopy_target_device()
|
||||
)
|
||||
|
||||
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
|
||||
options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes)
|
||||
|
|
@ -1513,6 +1542,12 @@ class OnnxModelFromDynamoAotInline(OnnxModelFromDynamo):
|
|||
def _export(
|
||||
self, model, example_inputs, output_path: str
|
||||
) -> torch.onnx.ONNXProgram:
|
||||
if self.copy_before_export:
|
||||
# Deepcopy model before export to avoid modification to baseline model.
|
||||
model, example_inputs = self.deepcopy_model_and_inputs_to_device(
|
||||
model, example_inputs, self._determine_deepcopy_target_device()
|
||||
)
|
||||
|
||||
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
|
||||
options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes)
|
||||
onnx_program = torch.onnx.dynamo_export(
|
||||
|
|
|
|||
Loading…
Reference in a new issue