From c8e1038eab6c07392fb0891d0fec45e43f1a7487 Mon Sep 17 00:00:00 2001 From: pengwa Date: Fri, 3 Nov 2023 00:46:11 +0800 Subject: [PATCH] Optimize 4bit Qlora training (#18131) ### Optimize 4bit Qlora training Extent existing `MatmulBnb4bit` to its usage in training scenarios. The PR includes following changes: 1. Add special `torch.autograd.Function` export logic for `bitsandbytes.autograd._functions.MatMul4Bit` that is preferred before common PythonOp exporter. 2. Add `training_mode` optional attribute for op `MatmulBnb4bit`, which help skip some inference specific logic in implementation. 3. Add `transB` optional attribute, which is by default be 1; setting it to be 0 is needed by backward usage. Changing from `PythonOp` to this `MatmulBnb4bit` brings roughly ~2.9% throughput gains. The reason is: `bitsandbytes.autograd._functions.MatMul4Bit` has logic `ctx.save_for_backward`, which would need an additional copy in PythonOp, otherwise, the tensor might be released by ORT, while backward op still references it. Removing the clones also reduce the peak memory consumptions because `bitsandbytes.autograd._functions.MatMul4Bit` saved tensors that are not needed in backward compute. --- docs/ContribOperators.md | 30 +++- .../cpu/quantization/matmul_bnb4.cc | 6 +- .../cuda/quantization/matmul_bnb4.cc | 38 +++-- .../core/graph/contrib_ops/contrib_defs.cc | 46 ++++- .../core/framework/gradient_graph_builder.h | 1 + .../core/graph/gradient_builder.cc | 36 ++++ .../orttraining/core/graph/gradient_builder.h | 1 + .../core/graph/gradient_builder_registry.cc | 1 + .../_custom_autograd_function_exporter.py | 158 +++++++++++++++--- .../ortmodule/_custom_op_symbolic_registry.py | 4 +- .../ortmodule/_graph_execution_manager.py | 6 +- .../ortmodule/_zero_stage3_compatibility.py | 4 +- .../python/training/ortmodule/options.py | 2 +- .../python/training/utils/__init__.py | 11 +- .../utils/hooks/_zero_offload_subscriber.py | 4 +- .../python/training/utils/torch_type_map.py | 16 +- 16 files changed, 297 insertions(+), 67 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 8e86862a62..646465ef8b 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2580,8 +2580,30 @@ This version of the operator has been available since version 1 of the 'com.micr And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. 3. Input B's quantization constants or scales are specified by input 'absmax'. - Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. - Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. + Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + + + 1. (Default value) transB=True (Majorly used for forward pass) + Shape of A: [D0, D1, ..., Dn, K] + Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features]. + + The computation math: + dequant_B = dequant(B, absmax, quant_type, block_size) + transposed_dequant_B = dequant_B^T + output = A @ transposed_dequant_B + + Shape of output: [D0, D1, ..., Dn, N] + + 2. transB=False (Majorly used for backward pass) + Shape of A: [D0, D1, ..., Dn, N] + Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features]. + + The computation math: + dequant_B = dequant(B, absmax, quant_type, block_size) + output = A @ dequant_B + + Shape of output: [D0, D1, ..., Dn, K] #### Version @@ -2599,6 +2621,10 @@ This version of the operator has been available since version 1 of the 'com.micr
number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.
quant_type : int (required)
quantization data type. 0 for FP4, 1 for NF4.
+
training_mode : int
+
Indicate if the ops run in training_mode, by default, False.
+
transB : int
+
Whether B should be transposed on the last two dimensions before doing multiplication. Default to be 1.
#### Inputs diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc index 2f3ede49c3..b898c956b6 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc @@ -21,6 +21,8 @@ class MatMulBnb4 final : public OpKernel { ORT_ENFORCE( quant_type_ == FP4 || quant_type_ == NF4, "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + is_training_mode_ = static_cast(info.GetAttrOrDefault("training_mode", static_cast(0))); + transB_ = static_cast(info.GetAttrOrDefault("transB", static_cast(1))); } Status Compute(OpKernelContext* context) const override; @@ -30,6 +32,8 @@ class MatMulBnb4 final : public OpKernel { int64_t N_; int64_t block_size_; int64_t quant_type_; + bool is_training_mode_; + bool transB_; }; Status MatMulBnb4::Compute(OpKernelContext* ctx) const { @@ -58,7 +62,7 @@ Status MatMulBnb4::Compute(OpKernelContext* ctx) const { thread_pool); constexpr bool transa = false; - constexpr bool transb = true; + const bool transb = transB_; TensorShape b_shape({N_, K_}); MatMulComputeHelper helper; ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, transa, transb)); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc index bd5b6e0a8a..ecf332715d 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc @@ -25,6 +25,9 @@ class MatMulBnb4 final : public CudaKernel { ORT_ENFORCE( quant_type_ == FP4 || quant_type_ == NF4, "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + + is_training_mode_ = static_cast(info.GetAttrOrDefault("training_mode", static_cast(0))); + transB_ = static_cast(info.GetAttrOrDefault("transB", static_cast(1))); } Status ComputeInternal(OpKernelContext* context) const override; @@ -34,6 +37,8 @@ class MatMulBnb4 final : public CudaKernel { int64_t N_; int64_t block_size_; int64_t quant_type_; + bool is_training_mode_; + bool transB_; }; template @@ -59,7 +64,7 @@ Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { static_cast(ctx->GetComputeStream()->GetHandle()))); constexpr bool transa = false; - constexpr bool transb = true; + const bool transb = transB_; MatMulComputeHelper helper; TensorShape b_shape({N_, K_}); ORT_RETURN_IF_ERROR( @@ -69,17 +74,18 @@ Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { // Bail out early if the output is going to be empty if (Y->Shape().Size() == 0) return Status::OK(); - bool is_4bit_done = TryMatMulBnb4( - reinterpret_cast(quant_map_buffer_data), - reinterpret_cast(Y->MutableData()), - reinterpret_cast(a_data), - b_quant_data, - reinterpret_cast(absmax_data), - SafeInt(helper.M()), - SafeInt(helper.N()), - SafeInt(helper.K()), - SafeInt(block_size_), - static_cast(ctx->GetComputeStream()->GetHandle())); + bool is_4bit_done = !is_training_mode_ // skip inference specific handle if in training mode + && TryMatMulBnb4( + reinterpret_cast(quant_map_buffer_data), + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + b_quant_data, + reinterpret_cast(absmax_data), + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle())); if (!is_4bit_done) { IAllocatorUniquePtr b_dequant_ptr = GetScratchBuffer(N_ * K_, ctx->GetComputeStream()); @@ -98,16 +104,16 @@ Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( GetCublasHandle(ctx), - CUBLAS_OP_T, - CUBLAS_OP_N, + transb ? CUBLAS_OP_T : CUBLAS_OP_N, // transB + CUBLAS_OP_N, // transA SafeInt(helper.N()), SafeInt(helper.M()), SafeInt(helper.K()), &alpha, reinterpret_cast(b_dequant_data), - SafeInt(K_), + helper.Ldb(transb), // ldb reinterpret_cast(a_data), - helper.Lda(transa), + helper.Lda(transa), // lda &zero, reinterpret_cast(Y->MutableData()), helper.Ldc(), diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index e757e39130..39449bea63 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -2693,7 +2693,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GemmFloat8, 1, static void MatmulWithQuantWeightShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int64_t K, - int64_t N) { + int64_t N, + bool transB) { int input_a_idx = 0; if (!hasInputShape(ctx, input_a_idx)) { return; @@ -2707,15 +2708,15 @@ static void MatmulWithQuantWeightShapeInference(ONNX_NAMESPACE::InferenceContext // TODO: check B shape const auto& dim_last = a_shape.dim(a_shape.dim_size() - 1); - if (dim_last.has_dim_value() && dim_last.dim_value() != K) { + ONNX_NAMESPACE::TensorShapeProto resultShape; + if (dim_last.has_dim_value() && dim_last.dim_value() != (transB ? K : N)) { fail_shape_inference("Incompatible dimensions for matrix multiplication"); } - ONNX_NAMESPACE::TensorShapeProto resultShape; for (int i = 0; i < a_shape.dim_size() - 1; ++i) { *resultShape.add_dim() = a_shape.dim(i); } - resultShape.add_dim()->set_dim_value(N); + resultShape.add_dim()->set_dim_value(transB ? N : K); *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape() = resultShape; } @@ -3354,7 +3355,7 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored // Shape inference int64_t in_features = getAttribute(ctx, "K", -1); int64_t out_features = getAttribute(ctx, "N", -1); - MatmulWithQuantWeightShapeInference(ctx, in_features, out_features); + MatmulWithQuantWeightShapeInference(ctx, in_features, out_features, true); }); static const char* MatMulBnb4_ver1_doc = R"DOC( @@ -3364,8 +3365,30 @@ MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. 3. Input B's quantization constants or scales are specified by input 'absmax'. -Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. -Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. + Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + + + 1. (Default value) transB=True (Majorly used for forward pass) + Shape of A: [D0, D1, ..., Dn, K] + Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features]. + + The computation math: + dequant_B = dequant(B, absmax, quant_type, block_size) + transposed_dequant_B = dequant_B^T + output = A @ transposed_dequant_B + + Shape of output: [D0, D1, ..., Dn, N] + + 2. transB=False (Majorly used for backward pass) + Shape of A: [D0, D1, ..., Dn, N] + Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features]. + + The computation math: + dequant_B = dequant(B, absmax, quant_type, block_size) + output = A @ dequant_B + + Shape of output: [D0, D1, ..., Dn, K] )DOC"; @@ -3377,6 +3400,12 @@ Input absmax is stored in same type as original type of B(float32, float16) with .Attr("N", "size of each output feature", AttributeProto::INT) .Attr("block_size", "number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT) .Attr("quant_type", "quantization data type. 0 for FP4, 1 for NF4.", AttributeProto::INT) + .Attr("training_mode", + "Indicate if the ops run in training_mode, by default, False.", + AttributeProto::INT, + static_cast(0)) + .Attr("transB", "Whether B should be transposed on the last two dimensions before doing multiplication. Default to be 1.", + AttributeProto::INT, static_cast(1)) .Input(0, "A", "The input tensor, not quantized", "T1") .Input(1, "B", "1-dimensional quantized data for weight", "T2") .Input(2, "absmax", "quantization constants", "T1") @@ -3389,7 +3418,8 @@ Input absmax is stored in same type as original type of B(float32, float16) with // Shape inference int64_t in_features = getAttribute(ctx, "K", -1); int64_t out_features = getAttribute(ctx, "N", -1); - MatmulWithQuantWeightShapeInference(ctx, in_features, out_features); + bool transB = getAttribute(ctx, "transB", 1) != 0; + MatmulWithQuantWeightShapeInference(ctx, in_features, out_features, transB); }); #ifdef ENABLE_ATEN diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.h b/orttraining/orttraining/core/framework/gradient_graph_builder.h index 8068d4825c..93ba836b53 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.h +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.h @@ -70,6 +70,7 @@ static std::unordered_map> {"Split", {1}}, {"Clip", {1, 2}}, {"Pad", {1, 2}}, + {"MatMulBnb4", {1, 2}}, // quantified weight (non float) and absmax constant don't need gradients. {"Multinomial", {0}}, {"RandomNormalLike", {0}}, {"RandomUniformLike", {0}}, diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 6547f53a3c..7100cedaf7 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -494,6 +494,42 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) { return result; } +IMPLEMENT_GRADIENT_BUILDER(GetMatmulBnb4Gradient) { + auto attributes = SrcNodeAttributes(); + std::vector attrs; + bool find_transB = false; + for (auto& attr : attributes) { + if (attr.first == "transB") { + int64_t transB_value = attr.second.i(); + transB_value = (transB_value + 1) % 2; // revert the transpose + attrs.push_back(MakeAttribute("transB", transB_value)); + find_transB = true; + } else { + attrs.push_back(attr.second); + } + } + + if (!find_transB) { + attrs.push_back(MakeAttribute("transB", int64_t(0))); // default is 1, so we need to set it to 0 + } + + std::vector result; + // Y = A * B + // dA = dY * B', dB = A' * dY + if (IsGradientRequiredForSrcNodeInput(0)) { + // B is 1-D, so don't need transpose here. + result.push_back(NodeDef(OpDef{"MatMulBnb4", kMSDomain, 1}, + {GO(0), I(1), I(2)}, + {GI(0)}, + attrs)); + } + + ORT_ENFORCE(!IsGradientRequiredForSrcNodeInput(1), "Gradient propagation to B is not supported yet."); + ORT_ENFORCE(!IsGradientRequiredForSrcNodeInput(2), "Gradient propagation to absmax is not supported yet."); + + return result; +} + IMPLEMENT_GRADIENT_BUILDER(GetSplitGradient) { std::vector result = {}; std::vector input_args; diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index 28a316261e..08987a86eb 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -54,6 +54,7 @@ DECLARE_GRADIENT_BUILDER(GetSoftmaxCrossEntropyLossGradient) DECLARE_GRADIENT_BUILDER(GetSoftmaxCrossEntropyLossInternalGradient) DECLARE_GRADIENT_BUILDER(GetGlobalAveragePoolGradient) DECLARE_GRADIENT_BUILDER(GetGemmGradient) +DECLARE_GRADIENT_BUILDER(GetMatmulBnb4Gradient) DECLARE_GRADIENT_BUILDER(GetDropoutGradient) DECLARE_GRADIENT_BUILDER(GetGatherNDGradient) DECLARE_GRADIENT_BUILDER(GetGatherElementsGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 4b8c68aef0..f280a02cb4 100755 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -68,6 +68,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("Reshape", GetReshapeGradient); REGISTER_GRADIENT_BUILDER("Transpose", GetTransposeGradient); REGISTER_GRADIENT_BUILDER("Gemm", GetGemmGradient); + REGISTER_GRADIENT_BUILDER("MatMulBnb4", GetMatmulBnb4Gradient); REGISTER_GRADIENT_BUILDER("MaxPool", GetMaxPoolGradient); REGISTER_GRADIENT_BUILDER("Gather", GetGatherGradient); REGISTER_GRADIENT_BUILDER("PadAndUnflatten", GetPadAndUnflattenGradient); diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 8c5469740d..de2c04d262 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -3,7 +3,10 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +from __future__ import annotations + import sys +from typing import ClassVar import torch import torch.utils.checkpoint @@ -18,13 +21,55 @@ from onnxruntime.capi._pybind_state import ( register_torch_autograd_function, ) from onnxruntime.training import ortmodule -from onnxruntime.training.utils import pytorch_dtype_to_onnx +from onnxruntime.training.utils import pytorch_scalar_type_to_pytorch_dtype, pytorch_type_to_onnx_dtype from ._custom_op_symbolic_registry import wrap_custom_export_function from ._fallback import ORTModuleONNXModelException, wrap_exception from ._utils import get_fully_qualified_class_name, get_runtime_pytorch_version +class _SpecialCustomFunctionHandler: + """A class to handle high priority export of torch.autograd.Function. + `register_high_priority_handler` can be used as function decorator to register a handler for a torch.autograd.Function. + """ + + _HIGH_PRIORITY_EXPORT_HANDLER_MAP: ClassVar[dict[str, callable]] = {} + + @staticmethod + def add_handler(func_name: str, handler: callable) -> None: + """Add a handler for a function name. + + Args: + func_name (str): The function name. + handler (callable): The handler. + + """ + _SpecialCustomFunctionHandler._HIGH_PRIORITY_EXPORT_HANDLER_MAP[func_name] = handler + + @staticmethod + def get_handler(func_name: str) -> callable | None: + """Get the handler for a function name. + + Args: + func_name (str): The function name. + + Returns: + callable | None: The handler. + + """ + return _SpecialCustomFunctionHandler._HIGH_PRIORITY_EXPORT_HANDLER_MAP.get(func_name, None) + + +def register_high_priority_handler(func_name): + """Register a handler for a torch.autograd.Function using its full qualified class name.""" + + def symbolic_wrapper(fn): + _SpecialCustomFunctionHandler.add_handler(func_name, fn) + return fn + + return symbolic_wrapper + + def register_custom_function_schema_supplementary(kclass: torch.autograd.Function) -> None: """Register a shape inference function for a torch.autograd.Function if there is staticmethod "infer_shape" defined. @@ -96,6 +141,30 @@ _UNSUPPORTED_CKPT_FUNC_NAMES = frozenset( ) +def _get_training_mode() -> bool: + # TODO move to public API once the exporter team exposes that + training_mode = None + if get_runtime_pytorch_version() >= version.parse("1.12"): + # FIXME: using private modules + from torch.onnx import _globals + + # before https://github.com/pytorch/pytorch/commit/c8b9b6266b505328e503b12f6a42fd88c56374f9, + # training_mode is still a bool type + if isinstance(_globals.GLOBALS.training_mode, bool): + training_mode = _globals.GLOBALS.training_mode + else: + if _globals.GLOBALS.training_mode not in [ + torch.onnx.TrainingMode.EVAL, + torch.onnx.TrainingMode.TRAINING, + ]: + raise Exception(f"Unexpected training mode {_globals.GLOBALS.training_mode}") + training_mode = _globals.GLOBALS.training_mode == torch.onnx.TrainingMode.TRAINING + else: + training_mode = symbolic_helper._training_mode + + return bool(training_mode) + + def _export_pt_1_10(g, n, *args, **kwargs): """Export torch.autograd.Function in ORT PythonOp. @@ -114,6 +183,15 @@ def _export_pt_1_10(g, n, *args, **kwargs): func_class = n.pyobj().__self__ func_full_qual_name = get_fully_qualified_class_name(func_class) + # Check if the function is handled by high priority exporter. + hi_pri_handler = _SpecialCustomFunctionHandler.get_handler(func_full_qual_name) + if hi_pri_handler: + try_export = hi_pri_handler(g, n, *args, **kwargs) + if try_export is not None: + return try_export + + # Fall back to common exporter if not handled by high priority exporter. + # Check if the checkpointing activation is allowed. is_ckpt_activation_allowed = ortmodule._defined_from_envvar("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", 0) == 1 if is_ckpt_activation_allowed is False and func_full_qual_name in _UNSUPPORTED_CKPT_FUNC_NAMES: @@ -123,26 +201,6 @@ def _export_pt_1_10(g, n, *args, **kwargs): "wrap exportable sub-nn.Module's as ORTModule." ) - # TODO move to public API once the exporter team exposes that - training_mode = None - if get_runtime_pytorch_version() >= version.parse("1.12"): - # FIXME: using private modules - from torch.onnx import _globals - - # before https://github.com/pytorch/pytorch/commit/c8b9b6266b505328e503b12f6a42fd88c56374f9, - # training_mode is still a bool type - if isinstance(_globals.GLOBALS.training_mode, bool): - training_mode = _globals.GLOBALS.training_mode - else: - if _globals.GLOBALS.training_mode not in [ - torch.onnx.TrainingMode.EVAL, - torch.onnx.TrainingMode.TRAINING, - ]: - raise Exception(f"Unexpected training mode {_globals.GLOBALS.training_mode}") - training_mode = _globals.GLOBALS.training_mode == torch.onnx.TrainingMode.TRAINING - else: - training_mode = symbolic_helper._training_mode - cconv = n.cconv() input_tensor_types = [] @@ -179,7 +237,7 @@ def _export_pt_1_10(g, n, *args, **kwargs): if call_type == "d": # Got a tensor variable. tensor_args.append(arg) - scalar_type = pytorch_dtype_to_onnx(arg.type().scalarType()) + scalar_type = pytorch_type_to_onnx_dtype(arg.type().scalarType()) input_tensor_types.append(scalar_type) input_tensor_ranks.append(arg.type().dim()) continue @@ -258,7 +316,7 @@ def _export_pt_1_10(g, n, *args, **kwargs): output_tensor_ranks = [] for arg in n.outputs(): # Type of tensor's elements. - scalar_type = pytorch_dtype_to_onnx(arg.type().scalarType()) + scalar_type = pytorch_type_to_onnx_dtype(arg.type().scalarType()) output_tensor_types.append(scalar_type) output_tensor_ranks.append(arg.type().dim()) @@ -270,7 +328,7 @@ def _export_pt_1_10(g, n, *args, **kwargs): "input_tensor_ranks_i": input_tensor_ranks, "output_tensor_types_i": output_tensor_types, "output_tensor_ranks_i": output_tensor_ranks, - "training_mode_i": 1 if training_mode else 0, + "training_mode_i": 1 if _get_training_mode() else 0, "comment_s": debug_comment, } @@ -336,3 +394,55 @@ def post_process_enabling_autograd_function(exported_model: ModelProto) -> Model index += 1 return exported_model + + +@register_high_priority_handler("bitsandbytes.autograd._functions.MatMul4Bit") +def _matmul4bit_export(g, n, *args, **kwargs): + cconv = n.cconv() + can_converted = cconv[0] == "d" and cconv[1] == "d" and cconv[2] == "c" and cconv[3] == "c" and cconv[4] == "c" + can_converted = can_converted and (args[2] is None and args[3] is None and args[4] is not None) + if not can_converted: + return None + + quant_state = args[4] + absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state + + # MatMulBnb4's blocksize needs to be a power of 2 and not smaller than 16 + if blocksize < 16 or blocksize & (blocksize - 1) != 0: + return None + + # MatMulBnb4 does not support double de-quantization (e.g. absmax is int, needs to be dequantized too) + if compressed_stats is not None: + return None + + # The PyTorch linear weight shape is [out_feature, in_feature] + in_feature = shape[1] + out_feature = shape[0] + if quant_type == "fp4": + quant_type = 0 + elif quant_type == "nf4": + quant_type = 1 + else: + return None + attrs = { + "K_i": in_feature, + "N_i": out_feature, + "block_size_i": blocksize, + "quant_type_i": quant_type, + "training_mode_i": 1 if _get_training_mode() else 0, + } + + # Make sure the quant weight can be flatten to 1D tensor safely, which com.microsoft::MatMulBnb4 requires. + found_dim1 = any(v == 1 for v in args[1].type().sizes()) + if not found_dim1: + return None + + absmax = g.op( + "Constant", + value_t=torch.tensor(absmax, dtype=pytorch_scalar_type_to_pytorch_dtype(args[0].type().scalarType())), + ) + quant_weight = g.op( + "Reshape", args[1], g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) + ) # flatten to 1D + tensor_args = [args[0], quant_weight, absmax] + return g.op("com.microsoft::MatMulBnb4", *tensor_args, **attrs) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 6e694dcdf2..99e8851b6a 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -12,7 +12,7 @@ from packaging.version import Version from torch.onnx import register_custom_op_symbolic from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes, parse_args -from onnxruntime.training.utils import pytorch_dtype_to_onnx +from onnxruntime.training.utils import pytorch_type_to_onnx_dtype from ._utils import get_runtime_pytorch_version @@ -145,7 +145,7 @@ def cross_entropy_loss(g, node, logits, target, weight, reduction, ignore_index, weight_casted, ignore_index, reduction_s=reduction, - output_type_i=pytorch_dtype_to_onnx(output_type.scalarType()), + output_type_i=pytorch_type_to_onnx_dtype(output_type.scalarType()), outputs=2, ) output.setType(output_type) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 5e8805bfdd..ba61b2e4c8 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -19,7 +19,7 @@ from torch.utils.cpp_extension import ROCM_HOME import onnxruntime from onnxruntime.capi import _pybind_state as C from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference -from onnxruntime.training.utils import ORTModelInputOutputSchemaType, onnx_dtype_to_pytorch +from onnxruntime.training.utils import ORTModelInputOutputSchemaType, onnx_dtype_to_pytorch_dtype from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils @@ -345,7 +345,7 @@ class GraphExecutionManager(GraphExecutionInterface): cache_dir, f"{hash_fn(str(self._flattened_module).encode()).hexdigest()}_{get_rank()}.onnx" ) if os.path.exists(cache_dir) and os.path.isfile(filename): - self._logger.info( + self._logger.warning( f"Cached model detected! Cached model will be used to save export and initialization time." f"If you want the model to be re-exported then DELETE {filename}." ) @@ -627,7 +627,7 @@ class GraphExecutionManager(GraphExecutionInterface): kwargs[STAGE3_PULL_WEIGHT_TRIGGER_NAME] = torch.zeros( STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, - dtype=onnx_dtype_to_pytorch(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE), + dtype=onnx_dtype_to_pytorch_dtype(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE), device=device, ).requires_grad_() diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py index d0dea66fda..3a5d7da926 100644 --- a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py +++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py @@ -14,7 +14,7 @@ from onnxruntime.capi._pybind_state import ( register_shape_inference_function, register_torch_autograd_function, ) -from onnxruntime.training.utils import pytorch_dtype_to_onnx +from onnxruntime.training.utils import pytorch_type_to_onnx_dtype from ._custom_autograd_function_exporter import register_custom_function_schema_supplementary from ._utils import get_fully_qualified_class_name @@ -149,7 +149,7 @@ def post_processing_enable_zero_stage3_compat( c, graph_input.name, len(zero_stage3_named_params[graph_input.name].ds_shape), # new rank - pytorch_dtype_to_onnx(zero_stage3_named_params[graph_input.name].dtype), # new data type + pytorch_type_to_onnx_dtype(zero_stage3_named_params[graph_input.name].dtype), # new data type ) # Delete exported_model.graph.input diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index 0eb6790d7a..ff0cde3719 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -366,7 +366,7 @@ class _RuntimeOptions: # Cache exported model if "ORTMODULE_CACHE_DIR" in os.environ: - self._logger.info("ORTModule cache optimization is ON.") + self._logger.warning("ORTModule optimization for caching exported model is ON.") self.ortmodule_cache_dir = os.getenv("ORTMODULE_CACHE_DIR") # Experimental features. diff --git a/orttraining/orttraining/python/training/utils/__init__.py b/orttraining/orttraining/python/training/utils/__init__.py index fa7c9f2750..d40a6ddf7d 100644 --- a/orttraining/orttraining/python/training/utils/__init__.py +++ b/orttraining/orttraining/python/training/utils/__init__.py @@ -9,7 +9,11 @@ from onnxruntime.training.utils.torch_io_helper import ( extract_data_and_schema, unflatten_data_using_schema, ) -from onnxruntime.training.utils.torch_type_map import onnx_dtype_to_pytorch, pytorch_dtype_to_onnx +from onnxruntime.training.utils.torch_type_map import ( + onnx_dtype_to_pytorch_dtype, + pytorch_scalar_type_to_pytorch_dtype, + pytorch_type_to_onnx_dtype, +) __all__ = [ "PrimitiveType", @@ -17,6 +21,7 @@ __all__ = [ "ORTModelInputOutputSchemaType", "extract_data_and_schema", "unflatten_data_using_schema", - "pytorch_dtype_to_onnx", - "onnx_dtype_to_pytorch", + "pytorch_type_to_onnx_dtype", + "onnx_dtype_to_pytorch_dtype", + "pytorch_scalar_type_to_pytorch_dtype", ] diff --git a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py index b1cb5c19e8..0d268a7a4a 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py @@ -17,7 +17,7 @@ import torch from onnxruntime.training.utils import ( ORTModelInputOutputType, extract_data_and_schema, - pytorch_dtype_to_onnx, + pytorch_type_to_onnx_dtype, unflatten_data_using_schema, ) @@ -324,7 +324,7 @@ class ORTZeROOffloadPreForwardFunction(torch.autograd.Function): start_offset = len(tensor_input_shapes) - len(partitioned_params) for index, param in enumerate(partitioned_params): tensor_output_shapes[start_offset + index] = list(param.ds_shape) - tensor_output_dtypes[start_offset + index] = int(pytorch_dtype_to_onnx(param.dtype)) + tensor_output_dtypes[start_offset + index] = int(pytorch_type_to_onnx_dtype(param.dtype)) assert len(tensor_output_shapes) == len(tensor_input_shapes) assert len(tensor_output_dtypes) == len(tensor_input_dtypes) diff --git a/orttraining/orttraining/python/training/utils/torch_type_map.py b/orttraining/orttraining/python/training/utils/torch_type_map.py index bdacab8ad0..2b429f3fd4 100644 --- a/orttraining/orttraining/python/training/utils/torch_type_map.py +++ b/orttraining/orttraining/python/training/utils/torch_type_map.py @@ -36,8 +36,10 @@ _DTYPE_TO_ONNX = {torch_dtype: onnx_dtype for k, (onnx_dtype, torch_dtype) in _C _ONNX_TO_DTYPE = {onnx_dtype: torch_dtype for torch_dtype, onnx_dtype in _DTYPE_TO_ONNX.items()} -def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torch.onnx.TensorProtoDataType: - """Converts a pytorch dtype or scalar type string to an onnx dtype.""" +def pytorch_type_to_onnx_dtype(dtype_or_scalar_type: Union[torch.dtype, str]) -> torch.onnx.TensorProtoDataType: + """Converts a pytorch dtype or scalar type string to an onnx dtype. + PyTorch type can be either a dtype or a scalar type string. + """ dtype = dtype_or_scalar_type if isinstance(dtype, str): if dtype not in _CAST_PYTORCH_TO_ONNX: @@ -49,7 +51,15 @@ def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torc return _DTYPE_TO_ONNX[dtype] -def onnx_dtype_to_pytorch(dtype: torch.onnx.TensorProtoDataType) -> torch.dtype: +def pytorch_scalar_type_to_pytorch_dtype(dtype: str) -> torch.dtype: + """Converts a pytorch scalar type string to a pytorch dtype.""" + assert isinstance(dtype, str) + if dtype not in _CAST_PYTORCH_TO_ONNX: + raise RuntimeError(f"Unsupported dtype {dtype}") + return _CAST_PYTORCH_TO_ONNX[dtype][1] + + +def onnx_dtype_to_pytorch_dtype(dtype: torch.onnx.TensorProtoDataType) -> torch.dtype: """Converts an onnx dtype to a pytorch dtype.""" if dtype not in _ONNX_TO_DTYPE: raise RuntimeError(f"Unsupported dtype {dtype}")