[ONNX][bench] Deepcopy model to another device before export to avoid OOM (#118710)

Prior to onnx export, the model is deepcopied to avoid modifications that may affect later performance profiling. However this increases the memory requirement on the device.
This PR modifies the script to deepcopy and export the model on another device when possible.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118710
Approved by: https://github.com/thiagocrepaldi
This commit is contained in:
BowenBao 2024-01-30 17:37:07 -08:00 committed by PyTorch MergeBot
parent 21ce53b9c5
commit 30f43e3d89

View file

@ -1222,6 +1222,31 @@ class OnnxModel(abc.ABC):
self.model_dir / f"{model_name}_{self._COMPILER_NAME}.onnx"
)
def _determine_deepcopy_target_device(self):
if current_device == "cpu":
target_device = "cpu"
else:
if torch.cuda.device_count() > 1:
# Copy to another cuda device to avoid OOM.
target_device = "cuda:1"
else:
target_device = "cuda"
return target_device
def deepcopy_model_and_inputs_to_device(self, model, example_inputs, target_device):
# Deepcopy model before export to avoid modification to baseline model.
# To avoid OOM, the model is first moved to CPU. Both models are then moved to device.
model_device = next(model.parameters()).device
model.to("cpu")
model_copy = copy.deepcopy(model).to(target_device)
model.to(model_device)
target_device_example_inputs = tree_map_only(
torch.Tensor, lambda x: x.to(device=target_device), example_inputs
)
return model_copy, target_device_example_inputs
@classmethod
def _generate_onnx_model_directory(
cls, output_directory: str, compiler_name: str, model_name: str
@ -1404,7 +1429,9 @@ class OnnxModelFromTorchScript(OnnxModel):
def _export(self, model, example_inputs, output_path: str, /, **kwargs) -> None:
if self.copy_before_export:
# Deepcopy model before export to avoid modification to baseline model.
model = copy.deepcopy(model)
model, example_inputs = self.deepcopy_model_and_inputs_to_device(
model, example_inputs, self._determine_deepcopy_target_device()
)
# Hack for huggingface models (kwargs only).
if isinstance(example_inputs, dict):
@ -1486,7 +1513,9 @@ class OnnxModelFromDynamo(OnnxModel):
) -> torch.onnx.ONNXProgram:
if self.copy_before_export:
# Deepcopy model before export to avoid modification to baseline model.
model = copy.deepcopy(model)
model, example_inputs = self.deepcopy_model_and_inputs_to_device(
model, example_inputs, self._determine_deepcopy_target_device()
)
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes)
@ -1513,6 +1542,12 @@ class OnnxModelFromDynamoAotInline(OnnxModelFromDynamo):
def _export(
self, model, example_inputs, output_path: str
) -> torch.onnx.ONNXProgram:
if self.copy_before_export:
# Deepcopy model before export to avoid modification to baseline model.
model, example_inputs = self.deepcopy_model_and_inputs_to_device(
model, example_inputs, self._determine_deepcopy_target_device()
)
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes)
onnx_program = torch.onnx.dynamo_export(