diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 8e86862a62..646465ef8b 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2580,8 +2580,30 @@ This version of the operator has been available since version 1 of the 'com.micr And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. 3. Input B's quantization constants or scales are specified by input 'absmax'. - Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. - Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. + Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + + + 1. (Default value) transB=True (Majorly used for forward pass) + Shape of A: [D0, D1, ..., Dn, K] + Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features]. + + The computation math: + dequant_B = dequant(B, absmax, quant_type, block_size) + transposed_dequant_B = dequant_B^T + output = A @ transposed_dequant_B + + Shape of output: [D0, D1, ..., Dn, N] + + 2. transB=False (Majorly used for backward pass) + Shape of A: [D0, D1, ..., Dn, N] + Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features]. + + The computation math: + dequant_B = dequant(B, absmax, quant_type, block_size) + output = A @ dequant_B + + Shape of output: [D0, D1, ..., Dn, K] #### Version @@ -2599,6 +2621,10 @@ This version of the operator has been available since version 1 of the 'com.micr
number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.
quant_type : int (required)
quantization data type. 0 for FP4, 1 for NF4.
+
training_mode : int
+
Indicate if the ops run in training_mode, by default, False.
+
transB : int
+
Whether B should be transposed on the last two dimensions before doing multiplication. Default to be 1.
#### Inputs diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc index 2f3ede49c3..b898c956b6 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc @@ -21,6 +21,8 @@ class MatMulBnb4 final : public OpKernel { ORT_ENFORCE( quant_type_ == FP4 || quant_type_ == NF4, "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + is_training_mode_ = static_cast(info.GetAttrOrDefault("training_mode", static_cast(0))); + transB_ = static_cast(info.GetAttrOrDefault("transB", static_cast(1))); } Status Compute(OpKernelContext* context) const override; @@ -30,6 +32,8 @@ class MatMulBnb4 final : public OpKernel { int64_t N_; int64_t block_size_; int64_t quant_type_; + bool is_training_mode_; + bool transB_; }; Status MatMulBnb4::Compute(OpKernelContext* ctx) const { @@ -58,7 +62,7 @@ Status MatMulBnb4::Compute(OpKernelContext* ctx) const { thread_pool); constexpr bool transa = false; - constexpr bool transb = true; + const bool transb = transB_; TensorShape b_shape({N_, K_}); MatMulComputeHelper helper; ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, transa, transb)); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc index bd5b6e0a8a..ecf332715d 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc @@ -25,6 +25,9 @@ class MatMulBnb4 final : public CudaKernel { ORT_ENFORCE( quant_type_ == FP4 || quant_type_ == NF4, "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + + is_training_mode_ = static_cast(info.GetAttrOrDefault("training_mode", static_cast(0))); + transB_ = static_cast(info.GetAttrOrDefault("transB", static_cast(1))); } Status ComputeInternal(OpKernelContext* context) const override; @@ -34,6 +37,8 @@ class MatMulBnb4 final : public CudaKernel { int64_t N_; int64_t block_size_; int64_t quant_type_; + bool is_training_mode_; + bool transB_; }; template @@ -59,7 +64,7 @@ Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { static_cast(ctx->GetComputeStream()->GetHandle()))); constexpr bool transa = false; - constexpr bool transb = true; + const bool transb = transB_; MatMulComputeHelper helper; TensorShape b_shape({N_, K_}); ORT_RETURN_IF_ERROR( @@ -69,17 +74,18 @@ Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { // Bail out early if the output is going to be empty if (Y->Shape().Size() == 0) return Status::OK(); - bool is_4bit_done = TryMatMulBnb4( - reinterpret_cast(quant_map_buffer_data), - reinterpret_cast(Y->MutableData()), - reinterpret_cast(a_data), - b_quant_data, - reinterpret_cast(absmax_data), - SafeInt(helper.M()), - SafeInt(helper.N()), - SafeInt(helper.K()), - SafeInt(block_size_), - static_cast(ctx->GetComputeStream()->GetHandle())); + bool is_4bit_done = !is_training_mode_ // skip inference specific handle if in training mode + && TryMatMulBnb4( + reinterpret_cast(quant_map_buffer_data), + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + b_quant_data, + reinterpret_cast(absmax_data), + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle())); if (!is_4bit_done) { IAllocatorUniquePtr b_dequant_ptr = GetScratchBuffer(N_ * K_, ctx->GetComputeStream()); @@ -98,16 +104,16 @@ Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( GetCublasHandle(ctx), - CUBLAS_OP_T, - CUBLAS_OP_N, + transb ? CUBLAS_OP_T : CUBLAS_OP_N, // transB + CUBLAS_OP_N, // transA SafeInt(helper.N()), SafeInt(helper.M()), SafeInt(helper.K()), &alpha, reinterpret_cast(b_dequant_data), - SafeInt(K_), + helper.Ldb(transb), // ldb reinterpret_cast(a_data), - helper.Lda(transa), + helper.Lda(transa), // lda &zero, reinterpret_cast(Y->MutableData()), helper.Ldc(), diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index e757e39130..39449bea63 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -2693,7 +2693,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GemmFloat8, 1, static void MatmulWithQuantWeightShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int64_t K, - int64_t N) { + int64_t N, + bool transB) { int input_a_idx = 0; if (!hasInputShape(ctx, input_a_idx)) { return; @@ -2707,15 +2708,15 @@ static void MatmulWithQuantWeightShapeInference(ONNX_NAMESPACE::InferenceContext // TODO: check B shape const auto& dim_last = a_shape.dim(a_shape.dim_size() - 1); - if (dim_last.has_dim_value() && dim_last.dim_value() != K) { + ONNX_NAMESPACE::TensorShapeProto resultShape; + if (dim_last.has_dim_value() && dim_last.dim_value() != (transB ? K : N)) { fail_shape_inference("Incompatible dimensions for matrix multiplication"); } - ONNX_NAMESPACE::TensorShapeProto resultShape; for (int i = 0; i < a_shape.dim_size() - 1; ++i) { *resultShape.add_dim() = a_shape.dim(i); } - resultShape.add_dim()->set_dim_value(N); + resultShape.add_dim()->set_dim_value(transB ? N : K); *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape() = resultShape; } @@ -3354,7 +3355,7 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored // Shape inference int64_t in_features = getAttribute(ctx, "K", -1); int64_t out_features = getAttribute(ctx, "N", -1); - MatmulWithQuantWeightShapeInference(ctx, in_features, out_features); + MatmulWithQuantWeightShapeInference(ctx, in_features, out_features, true); }); static const char* MatMulBnb4_ver1_doc = R"DOC( @@ -3364,8 +3365,30 @@ MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. 3. Input B's quantization constants or scales are specified by input 'absmax'. -Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. -Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. + Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + + + 1. (Default value) transB=True (Majorly used for forward pass) + Shape of A: [D0, D1, ..., Dn, K] + Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features]. + + The computation math: + dequant_B = dequant(B, absmax, quant_type, block_size) + transposed_dequant_B = dequant_B^T + output = A @ transposed_dequant_B + + Shape of output: [D0, D1, ..., Dn, N] + + 2. transB=False (Majorly used for backward pass) + Shape of A: [D0, D1, ..., Dn, N] + Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features]. + + The computation math: + dequant_B = dequant(B, absmax, quant_type, block_size) + output = A @ dequant_B + + Shape of output: [D0, D1, ..., Dn, K] )DOC"; @@ -3377,6 +3400,12 @@ Input absmax is stored in same type as original type of B(float32, float16) with .Attr("N", "size of each output feature", AttributeProto::INT) .Attr("block_size", "number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT) .Attr("quant_type", "quantization data type. 0 for FP4, 1 for NF4.", AttributeProto::INT) + .Attr("training_mode", + "Indicate if the ops run in training_mode, by default, False.", + AttributeProto::INT, + static_cast(0)) + .Attr("transB", "Whether B should be transposed on the last two dimensions before doing multiplication. Default to be 1.", + AttributeProto::INT, static_cast(1)) .Input(0, "A", "The input tensor, not quantized", "T1") .Input(1, "B", "1-dimensional quantized data for weight", "T2") .Input(2, "absmax", "quantization constants", "T1") @@ -3389,7 +3418,8 @@ Input absmax is stored in same type as original type of B(float32, float16) with // Shape inference int64_t in_features = getAttribute(ctx, "K", -1); int64_t out_features = getAttribute(ctx, "N", -1); - MatmulWithQuantWeightShapeInference(ctx, in_features, out_features); + bool transB = getAttribute(ctx, "transB", 1) != 0; + MatmulWithQuantWeightShapeInference(ctx, in_features, out_features, transB); }); #ifdef ENABLE_ATEN diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.h b/orttraining/orttraining/core/framework/gradient_graph_builder.h index 8068d4825c..93ba836b53 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.h +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.h @@ -70,6 +70,7 @@ static std::unordered_map> {"Split", {1}}, {"Clip", {1, 2}}, {"Pad", {1, 2}}, + {"MatMulBnb4", {1, 2}}, // quantified weight (non float) and absmax constant don't need gradients. {"Multinomial", {0}}, {"RandomNormalLike", {0}}, {"RandomUniformLike", {0}}, diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 6547f53a3c..7100cedaf7 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -494,6 +494,42 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) { return result; } +IMPLEMENT_GRADIENT_BUILDER(GetMatmulBnb4Gradient) { + auto attributes = SrcNodeAttributes(); + std::vector attrs; + bool find_transB = false; + for (auto& attr : attributes) { + if (attr.first == "transB") { + int64_t transB_value = attr.second.i(); + transB_value = (transB_value + 1) % 2; // revert the transpose + attrs.push_back(MakeAttribute("transB", transB_value)); + find_transB = true; + } else { + attrs.push_back(attr.second); + } + } + + if (!find_transB) { + attrs.push_back(MakeAttribute("transB", int64_t(0))); // default is 1, so we need to set it to 0 + } + + std::vector result; + // Y = A * B + // dA = dY * B', dB = A' * dY + if (IsGradientRequiredForSrcNodeInput(0)) { + // B is 1-D, so don't need transpose here. + result.push_back(NodeDef(OpDef{"MatMulBnb4", kMSDomain, 1}, + {GO(0), I(1), I(2)}, + {GI(0)}, + attrs)); + } + + ORT_ENFORCE(!IsGradientRequiredForSrcNodeInput(1), "Gradient propagation to B is not supported yet."); + ORT_ENFORCE(!IsGradientRequiredForSrcNodeInput(2), "Gradient propagation to absmax is not supported yet."); + + return result; +} + IMPLEMENT_GRADIENT_BUILDER(GetSplitGradient) { std::vector result = {}; std::vector input_args; diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index 28a316261e..08987a86eb 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -54,6 +54,7 @@ DECLARE_GRADIENT_BUILDER(GetSoftmaxCrossEntropyLossGradient) DECLARE_GRADIENT_BUILDER(GetSoftmaxCrossEntropyLossInternalGradient) DECLARE_GRADIENT_BUILDER(GetGlobalAveragePoolGradient) DECLARE_GRADIENT_BUILDER(GetGemmGradient) +DECLARE_GRADIENT_BUILDER(GetMatmulBnb4Gradient) DECLARE_GRADIENT_BUILDER(GetDropoutGradient) DECLARE_GRADIENT_BUILDER(GetGatherNDGradient) DECLARE_GRADIENT_BUILDER(GetGatherElementsGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 4b8c68aef0..f280a02cb4 100755 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -68,6 +68,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("Reshape", GetReshapeGradient); REGISTER_GRADIENT_BUILDER("Transpose", GetTransposeGradient); REGISTER_GRADIENT_BUILDER("Gemm", GetGemmGradient); + REGISTER_GRADIENT_BUILDER("MatMulBnb4", GetMatmulBnb4Gradient); REGISTER_GRADIENT_BUILDER("MaxPool", GetMaxPoolGradient); REGISTER_GRADIENT_BUILDER("Gather", GetGatherGradient); REGISTER_GRADIENT_BUILDER("PadAndUnflatten", GetPadAndUnflattenGradient); diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 8c5469740d..de2c04d262 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -3,7 +3,10 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +from __future__ import annotations + import sys +from typing import ClassVar import torch import torch.utils.checkpoint @@ -18,13 +21,55 @@ from onnxruntime.capi._pybind_state import ( register_torch_autograd_function, ) from onnxruntime.training import ortmodule -from onnxruntime.training.utils import pytorch_dtype_to_onnx +from onnxruntime.training.utils import pytorch_scalar_type_to_pytorch_dtype, pytorch_type_to_onnx_dtype from ._custom_op_symbolic_registry import wrap_custom_export_function from ._fallback import ORTModuleONNXModelException, wrap_exception from ._utils import get_fully_qualified_class_name, get_runtime_pytorch_version +class _SpecialCustomFunctionHandler: + """A class to handle high priority export of torch.autograd.Function. + `register_high_priority_handler` can be used as function decorator to register a handler for a torch.autograd.Function. + """ + + _HIGH_PRIORITY_EXPORT_HANDLER_MAP: ClassVar[dict[str, callable]] = {} + + @staticmethod + def add_handler(func_name: str, handler: callable) -> None: + """Add a handler for a function name. + + Args: + func_name (str): The function name. + handler (callable): The handler. + + """ + _SpecialCustomFunctionHandler._HIGH_PRIORITY_EXPORT_HANDLER_MAP[func_name] = handler + + @staticmethod + def get_handler(func_name: str) -> callable | None: + """Get the handler for a function name. + + Args: + func_name (str): The function name. + + Returns: + callable | None: The handler. + + """ + return _SpecialCustomFunctionHandler._HIGH_PRIORITY_EXPORT_HANDLER_MAP.get(func_name, None) + + +def register_high_priority_handler(func_name): + """Register a handler for a torch.autograd.Function using its full qualified class name.""" + + def symbolic_wrapper(fn): + _SpecialCustomFunctionHandler.add_handler(func_name, fn) + return fn + + return symbolic_wrapper + + def register_custom_function_schema_supplementary(kclass: torch.autograd.Function) -> None: """Register a shape inference function for a torch.autograd.Function if there is staticmethod "infer_shape" defined. @@ -96,6 +141,30 @@ _UNSUPPORTED_CKPT_FUNC_NAMES = frozenset( ) +def _get_training_mode() -> bool: + # TODO move to public API once the exporter team exposes that + training_mode = None + if get_runtime_pytorch_version() >= version.parse("1.12"): + # FIXME: using private modules + from torch.onnx import _globals + + # before https://github.com/pytorch/pytorch/commit/c8b9b6266b505328e503b12f6a42fd88c56374f9, + # training_mode is still a bool type + if isinstance(_globals.GLOBALS.training_mode, bool): + training_mode = _globals.GLOBALS.training_mode + else: + if _globals.GLOBALS.training_mode not in [ + torch.onnx.TrainingMode.EVAL, + torch.onnx.TrainingMode.TRAINING, + ]: + raise Exception(f"Unexpected training mode {_globals.GLOBALS.training_mode}") + training_mode = _globals.GLOBALS.training_mode == torch.onnx.TrainingMode.TRAINING + else: + training_mode = symbolic_helper._training_mode + + return bool(training_mode) + + def _export_pt_1_10(g, n, *args, **kwargs): """Export torch.autograd.Function in ORT PythonOp. @@ -114,6 +183,15 @@ def _export_pt_1_10(g, n, *args, **kwargs): func_class = n.pyobj().__self__ func_full_qual_name = get_fully_qualified_class_name(func_class) + # Check if the function is handled by high priority exporter. + hi_pri_handler = _SpecialCustomFunctionHandler.get_handler(func_full_qual_name) + if hi_pri_handler: + try_export = hi_pri_handler(g, n, *args, **kwargs) + if try_export is not None: + return try_export + + # Fall back to common exporter if not handled by high priority exporter. + # Check if the checkpointing activation is allowed. is_ckpt_activation_allowed = ortmodule._defined_from_envvar("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", 0) == 1 if is_ckpt_activation_allowed is False and func_full_qual_name in _UNSUPPORTED_CKPT_FUNC_NAMES: @@ -123,26 +201,6 @@ def _export_pt_1_10(g, n, *args, **kwargs): "wrap exportable sub-nn.Module's as ORTModule." ) - # TODO move to public API once the exporter team exposes that - training_mode = None - if get_runtime_pytorch_version() >= version.parse("1.12"): - # FIXME: using private modules - from torch.onnx import _globals - - # before https://github.com/pytorch/pytorch/commit/c8b9b6266b505328e503b12f6a42fd88c56374f9, - # training_mode is still a bool type - if isinstance(_globals.GLOBALS.training_mode, bool): - training_mode = _globals.GLOBALS.training_mode - else: - if _globals.GLOBALS.training_mode not in [ - torch.onnx.TrainingMode.EVAL, - torch.onnx.TrainingMode.TRAINING, - ]: - raise Exception(f"Unexpected training mode {_globals.GLOBALS.training_mode}") - training_mode = _globals.GLOBALS.training_mode == torch.onnx.TrainingMode.TRAINING - else: - training_mode = symbolic_helper._training_mode - cconv = n.cconv() input_tensor_types = [] @@ -179,7 +237,7 @@ def _export_pt_1_10(g, n, *args, **kwargs): if call_type == "d": # Got a tensor variable. tensor_args.append(arg) - scalar_type = pytorch_dtype_to_onnx(arg.type().scalarType()) + scalar_type = pytorch_type_to_onnx_dtype(arg.type().scalarType()) input_tensor_types.append(scalar_type) input_tensor_ranks.append(arg.type().dim()) continue @@ -258,7 +316,7 @@ def _export_pt_1_10(g, n, *args, **kwargs): output_tensor_ranks = [] for arg in n.outputs(): # Type of tensor's elements. - scalar_type = pytorch_dtype_to_onnx(arg.type().scalarType()) + scalar_type = pytorch_type_to_onnx_dtype(arg.type().scalarType()) output_tensor_types.append(scalar_type) output_tensor_ranks.append(arg.type().dim()) @@ -270,7 +328,7 @@ def _export_pt_1_10(g, n, *args, **kwargs): "input_tensor_ranks_i": input_tensor_ranks, "output_tensor_types_i": output_tensor_types, "output_tensor_ranks_i": output_tensor_ranks, - "training_mode_i": 1 if training_mode else 0, + "training_mode_i": 1 if _get_training_mode() else 0, "comment_s": debug_comment, } @@ -336,3 +394,55 @@ def post_process_enabling_autograd_function(exported_model: ModelProto) -> Model index += 1 return exported_model + + +@register_high_priority_handler("bitsandbytes.autograd._functions.MatMul4Bit") +def _matmul4bit_export(g, n, *args, **kwargs): + cconv = n.cconv() + can_converted = cconv[0] == "d" and cconv[1] == "d" and cconv[2] == "c" and cconv[3] == "c" and cconv[4] == "c" + can_converted = can_converted and (args[2] is None and args[3] is None and args[4] is not None) + if not can_converted: + return None + + quant_state = args[4] + absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state + + # MatMulBnb4's blocksize needs to be a power of 2 and not smaller than 16 + if blocksize < 16 or blocksize & (blocksize - 1) != 0: + return None + + # MatMulBnb4 does not support double de-quantization (e.g. absmax is int, needs to be dequantized too) + if compressed_stats is not None: + return None + + # The PyTorch linear weight shape is [out_feature, in_feature] + in_feature = shape[1] + out_feature = shape[0] + if quant_type == "fp4": + quant_type = 0 + elif quant_type == "nf4": + quant_type = 1 + else: + return None + attrs = { + "K_i": in_feature, + "N_i": out_feature, + "block_size_i": blocksize, + "quant_type_i": quant_type, + "training_mode_i": 1 if _get_training_mode() else 0, + } + + # Make sure the quant weight can be flatten to 1D tensor safely, which com.microsoft::MatMulBnb4 requires. + found_dim1 = any(v == 1 for v in args[1].type().sizes()) + if not found_dim1: + return None + + absmax = g.op( + "Constant", + value_t=torch.tensor(absmax, dtype=pytorch_scalar_type_to_pytorch_dtype(args[0].type().scalarType())), + ) + quant_weight = g.op( + "Reshape", args[1], g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) + ) # flatten to 1D + tensor_args = [args[0], quant_weight, absmax] + return g.op("com.microsoft::MatMulBnb4", *tensor_args, **attrs) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 6e694dcdf2..99e8851b6a 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -12,7 +12,7 @@ from packaging.version import Version from torch.onnx import register_custom_op_symbolic from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes, parse_args -from onnxruntime.training.utils import pytorch_dtype_to_onnx +from onnxruntime.training.utils import pytorch_type_to_onnx_dtype from ._utils import get_runtime_pytorch_version @@ -145,7 +145,7 @@ def cross_entropy_loss(g, node, logits, target, weight, reduction, ignore_index, weight_casted, ignore_index, reduction_s=reduction, - output_type_i=pytorch_dtype_to_onnx(output_type.scalarType()), + output_type_i=pytorch_type_to_onnx_dtype(output_type.scalarType()), outputs=2, ) output.setType(output_type) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 5e8805bfdd..ba61b2e4c8 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -19,7 +19,7 @@ from torch.utils.cpp_extension import ROCM_HOME import onnxruntime from onnxruntime.capi import _pybind_state as C from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference -from onnxruntime.training.utils import ORTModelInputOutputSchemaType, onnx_dtype_to_pytorch +from onnxruntime.training.utils import ORTModelInputOutputSchemaType, onnx_dtype_to_pytorch_dtype from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils @@ -345,7 +345,7 @@ class GraphExecutionManager(GraphExecutionInterface): cache_dir, f"{hash_fn(str(self._flattened_module).encode()).hexdigest()}_{get_rank()}.onnx" ) if os.path.exists(cache_dir) and os.path.isfile(filename): - self._logger.info( + self._logger.warning( f"Cached model detected! Cached model will be used to save export and initialization time." f"If you want the model to be re-exported then DELETE {filename}." ) @@ -627,7 +627,7 @@ class GraphExecutionManager(GraphExecutionInterface): kwargs[STAGE3_PULL_WEIGHT_TRIGGER_NAME] = torch.zeros( STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, - dtype=onnx_dtype_to_pytorch(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE), + dtype=onnx_dtype_to_pytorch_dtype(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE), device=device, ).requires_grad_() diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py index d0dea66fda..3a5d7da926 100644 --- a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py +++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py @@ -14,7 +14,7 @@ from onnxruntime.capi._pybind_state import ( register_shape_inference_function, register_torch_autograd_function, ) -from onnxruntime.training.utils import pytorch_dtype_to_onnx +from onnxruntime.training.utils import pytorch_type_to_onnx_dtype from ._custom_autograd_function_exporter import register_custom_function_schema_supplementary from ._utils import get_fully_qualified_class_name @@ -149,7 +149,7 @@ def post_processing_enable_zero_stage3_compat( c, graph_input.name, len(zero_stage3_named_params[graph_input.name].ds_shape), # new rank - pytorch_dtype_to_onnx(zero_stage3_named_params[graph_input.name].dtype), # new data type + pytorch_type_to_onnx_dtype(zero_stage3_named_params[graph_input.name].dtype), # new data type ) # Delete exported_model.graph.input diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index 0eb6790d7a..ff0cde3719 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -366,7 +366,7 @@ class _RuntimeOptions: # Cache exported model if "ORTMODULE_CACHE_DIR" in os.environ: - self._logger.info("ORTModule cache optimization is ON.") + self._logger.warning("ORTModule optimization for caching exported model is ON.") self.ortmodule_cache_dir = os.getenv("ORTMODULE_CACHE_DIR") # Experimental features. diff --git a/orttraining/orttraining/python/training/utils/__init__.py b/orttraining/orttraining/python/training/utils/__init__.py index fa7c9f2750..d40a6ddf7d 100644 --- a/orttraining/orttraining/python/training/utils/__init__.py +++ b/orttraining/orttraining/python/training/utils/__init__.py @@ -9,7 +9,11 @@ from onnxruntime.training.utils.torch_io_helper import ( extract_data_and_schema, unflatten_data_using_schema, ) -from onnxruntime.training.utils.torch_type_map import onnx_dtype_to_pytorch, pytorch_dtype_to_onnx +from onnxruntime.training.utils.torch_type_map import ( + onnx_dtype_to_pytorch_dtype, + pytorch_scalar_type_to_pytorch_dtype, + pytorch_type_to_onnx_dtype, +) __all__ = [ "PrimitiveType", @@ -17,6 +21,7 @@ __all__ = [ "ORTModelInputOutputSchemaType", "extract_data_and_schema", "unflatten_data_using_schema", - "pytorch_dtype_to_onnx", - "onnx_dtype_to_pytorch", + "pytorch_type_to_onnx_dtype", + "onnx_dtype_to_pytorch_dtype", + "pytorch_scalar_type_to_pytorch_dtype", ] diff --git a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py index b1cb5c19e8..0d268a7a4a 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py @@ -17,7 +17,7 @@ import torch from onnxruntime.training.utils import ( ORTModelInputOutputType, extract_data_and_schema, - pytorch_dtype_to_onnx, + pytorch_type_to_onnx_dtype, unflatten_data_using_schema, ) @@ -324,7 +324,7 @@ class ORTZeROOffloadPreForwardFunction(torch.autograd.Function): start_offset = len(tensor_input_shapes) - len(partitioned_params) for index, param in enumerate(partitioned_params): tensor_output_shapes[start_offset + index] = list(param.ds_shape) - tensor_output_dtypes[start_offset + index] = int(pytorch_dtype_to_onnx(param.dtype)) + tensor_output_dtypes[start_offset + index] = int(pytorch_type_to_onnx_dtype(param.dtype)) assert len(tensor_output_shapes) == len(tensor_input_shapes) assert len(tensor_output_dtypes) == len(tensor_input_dtypes) diff --git a/orttraining/orttraining/python/training/utils/torch_type_map.py b/orttraining/orttraining/python/training/utils/torch_type_map.py index bdacab8ad0..2b429f3fd4 100644 --- a/orttraining/orttraining/python/training/utils/torch_type_map.py +++ b/orttraining/orttraining/python/training/utils/torch_type_map.py @@ -36,8 +36,10 @@ _DTYPE_TO_ONNX = {torch_dtype: onnx_dtype for k, (onnx_dtype, torch_dtype) in _C _ONNX_TO_DTYPE = {onnx_dtype: torch_dtype for torch_dtype, onnx_dtype in _DTYPE_TO_ONNX.items()} -def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torch.onnx.TensorProtoDataType: - """Converts a pytorch dtype or scalar type string to an onnx dtype.""" +def pytorch_type_to_onnx_dtype(dtype_or_scalar_type: Union[torch.dtype, str]) -> torch.onnx.TensorProtoDataType: + """Converts a pytorch dtype or scalar type string to an onnx dtype. + PyTorch type can be either a dtype or a scalar type string. + """ dtype = dtype_or_scalar_type if isinstance(dtype, str): if dtype not in _CAST_PYTORCH_TO_ONNX: @@ -49,7 +51,15 @@ def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torc return _DTYPE_TO_ONNX[dtype] -def onnx_dtype_to_pytorch(dtype: torch.onnx.TensorProtoDataType) -> torch.dtype: +def pytorch_scalar_type_to_pytorch_dtype(dtype: str) -> torch.dtype: + """Converts a pytorch scalar type string to a pytorch dtype.""" + assert isinstance(dtype, str) + if dtype not in _CAST_PYTORCH_TO_ONNX: + raise RuntimeError(f"Unsupported dtype {dtype}") + return _CAST_PYTORCH_TO_ONNX[dtype][1] + + +def onnx_dtype_to_pytorch_dtype(dtype: torch.onnx.TensorProtoDataType) -> torch.dtype: """Converts an onnx dtype to a pytorch dtype.""" if dtype not in _ONNX_TO_DTYPE: raise RuntimeError(f"Unsupported dtype {dtype}")