diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 8e86862a62..646465ef8b 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2580,8 +2580,30 @@ This version of the operator has been available since version 1 of the 'com.micr
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
3. Input B's quantization constants or scales are specified by input 'absmax'.
- Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
- Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].
+ Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
+ Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].
+
+
+ 1. (Default value) transB=True (Majorly used for forward pass)
+ Shape of A: [D0, D1, ..., Dn, K]
+ Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features].
+
+ The computation math:
+ dequant_B = dequant(B, absmax, quant_type, block_size)
+ transposed_dequant_B = dequant_B^T
+ output = A @ transposed_dequant_B
+
+ Shape of output: [D0, D1, ..., Dn, N]
+
+ 2. transB=False (Majorly used for backward pass)
+ Shape of A: [D0, D1, ..., Dn, N]
+ Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features].
+
+ The computation math:
+ dequant_B = dequant(B, absmax, quant_type, block_size)
+ output = A @ dequant_B
+
+ Shape of output: [D0, D1, ..., Dn, K]
#### Version
@@ -2599,6 +2621,10 @@ This version of the operator has been available since version 1 of the 'com.micr
number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.
quant_type : int (required)
quantization data type. 0 for FP4, 1 for NF4.
+training_mode : int
+Indicate if the ops run in training_mode, by default, False.
+transB : int
+Whether B should be transposed on the last two dimensions before doing multiplication. Default to be 1.
#### Inputs
diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc
index 2f3ede49c3..b898c956b6 100644
--- a/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc
+++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc
@@ -21,6 +21,8 @@ class MatMulBnb4 final : public OpKernel {
ORT_ENFORCE(
quant_type_ == FP4 || quant_type_ == NF4,
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
+ is_training_mode_ = static_cast(info.GetAttrOrDefault("training_mode", static_cast(0)));
+ transB_ = static_cast(info.GetAttrOrDefault("transB", static_cast(1)));
}
Status Compute(OpKernelContext* context) const override;
@@ -30,6 +32,8 @@ class MatMulBnb4 final : public OpKernel {
int64_t N_;
int64_t block_size_;
int64_t quant_type_;
+ bool is_training_mode_;
+ bool transB_;
};
Status MatMulBnb4::Compute(OpKernelContext* ctx) const {
@@ -58,7 +62,7 @@ Status MatMulBnb4::Compute(OpKernelContext* ctx) const {
thread_pool);
constexpr bool transa = false;
- constexpr bool transb = true;
+ const bool transb = transB_;
TensorShape b_shape({N_, K_});
MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, transa, transb));
diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc
index bd5b6e0a8a..ecf332715d 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc
+++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc
@@ -25,6 +25,9 @@ class MatMulBnb4 final : public CudaKernel {
ORT_ENFORCE(
quant_type_ == FP4 || quant_type_ == NF4,
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
+
+ is_training_mode_ = static_cast(info.GetAttrOrDefault("training_mode", static_cast(0)));
+ transB_ = static_cast(info.GetAttrOrDefault("transB", static_cast(1)));
}
Status ComputeInternal(OpKernelContext* context) const override;
@@ -34,6 +37,8 @@ class MatMulBnb4 final : public CudaKernel {
int64_t N_;
int64_t block_size_;
int64_t quant_type_;
+ bool is_training_mode_;
+ bool transB_;
};
template
@@ -59,7 +64,7 @@ Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const {
static_cast(ctx->GetComputeStream()->GetHandle())));
constexpr bool transa = false;
- constexpr bool transb = true;
+ const bool transb = transB_;
MatMulComputeHelper helper;
TensorShape b_shape({N_, K_});
ORT_RETURN_IF_ERROR(
@@ -69,17 +74,18 @@ Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const {
// Bail out early if the output is going to be empty
if (Y->Shape().Size() == 0) return Status::OK();
- bool is_4bit_done = TryMatMulBnb4(
- reinterpret_cast(quant_map_buffer_data),
- reinterpret_cast(Y->MutableData()),
- reinterpret_cast(a_data),
- b_quant_data,
- reinterpret_cast(absmax_data),
- SafeInt(helper.M()),
- SafeInt(helper.N()),
- SafeInt(helper.K()),
- SafeInt(block_size_),
- static_cast(ctx->GetComputeStream()->GetHandle()));
+ bool is_4bit_done = !is_training_mode_ // skip inference specific handle if in training mode
+ && TryMatMulBnb4(
+ reinterpret_cast(quant_map_buffer_data),
+ reinterpret_cast(Y->MutableData()),
+ reinterpret_cast(a_data),
+ b_quant_data,
+ reinterpret_cast(absmax_data),
+ SafeInt(helper.M()),
+ SafeInt(helper.N()),
+ SafeInt(helper.K()),
+ SafeInt(block_size_),
+ static_cast(ctx->GetComputeStream()->GetHandle()));
if (!is_4bit_done) {
IAllocatorUniquePtr b_dequant_ptr = GetScratchBuffer(N_ * K_, ctx->GetComputeStream());
@@ -98,16 +104,16 @@ Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const {
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
GetCublasHandle(ctx),
- CUBLAS_OP_T,
- CUBLAS_OP_N,
+ transb ? CUBLAS_OP_T : CUBLAS_OP_N, // transB
+ CUBLAS_OP_N, // transA
SafeInt(helper.N()),
SafeInt(helper.M()),
SafeInt(helper.K()),
&alpha,
reinterpret_cast(b_dequant_data),
- SafeInt(K_),
+ helper.Ldb(transb), // ldb
reinterpret_cast(a_data),
- helper.Lda(transa),
+ helper.Lda(transa), // lda
&zero,
reinterpret_cast(Y->MutableData()),
helper.Ldc(),
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index e757e39130..39449bea63 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -2693,7 +2693,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GemmFloat8, 1,
static void MatmulWithQuantWeightShapeInference(ONNX_NAMESPACE::InferenceContext& ctx,
int64_t K,
- int64_t N) {
+ int64_t N,
+ bool transB) {
int input_a_idx = 0;
if (!hasInputShape(ctx, input_a_idx)) {
return;
@@ -2707,15 +2708,15 @@ static void MatmulWithQuantWeightShapeInference(ONNX_NAMESPACE::InferenceContext
// TODO: check B shape
const auto& dim_last = a_shape.dim(a_shape.dim_size() - 1);
- if (dim_last.has_dim_value() && dim_last.dim_value() != K) {
+ ONNX_NAMESPACE::TensorShapeProto resultShape;
+ if (dim_last.has_dim_value() && dim_last.dim_value() != (transB ? K : N)) {
fail_shape_inference("Incompatible dimensions for matrix multiplication");
}
- ONNX_NAMESPACE::TensorShapeProto resultShape;
for (int i = 0; i < a_shape.dim_size() - 1; ++i) {
*resultShape.add_dim() = a_shape.dim(i);
}
- resultShape.add_dim()->set_dim_value(N);
+ resultShape.add_dim()->set_dim_value(transB ? N : K);
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape() = resultShape;
}
@@ -3354,7 +3355,7 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored
// Shape inference
int64_t in_features = getAttribute(ctx, "K", -1);
int64_t out_features = getAttribute(ctx, "N", -1);
- MatmulWithQuantWeightShapeInference(ctx, in_features, out_features);
+ MatmulWithQuantWeightShapeInference(ctx, in_features, out_features, true);
});
static const char* MatMulBnb4_ver1_doc = R"DOC(
@@ -3364,8 +3365,30 @@ MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
3. Input B's quantization constants or scales are specified by input 'absmax'.
-Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
-Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].
+ Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
+ Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].
+
+
+ 1. (Default value) transB=True (Majorly used for forward pass)
+ Shape of A: [D0, D1, ..., Dn, K]
+ Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features].
+
+ The computation math:
+ dequant_B = dequant(B, absmax, quant_type, block_size)
+ transposed_dequant_B = dequant_B^T
+ output = A @ transposed_dequant_B
+
+ Shape of output: [D0, D1, ..., Dn, N]
+
+ 2. transB=False (Majorly used for backward pass)
+ Shape of A: [D0, D1, ..., Dn, N]
+ Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features].
+
+ The computation math:
+ dequant_B = dequant(B, absmax, quant_type, block_size)
+ output = A @ dequant_B
+
+ Shape of output: [D0, D1, ..., Dn, K]
)DOC";
@@ -3377,6 +3400,12 @@ Input absmax is stored in same type as original type of B(float32, float16) with
.Attr("N", "size of each output feature", AttributeProto::INT)
.Attr("block_size", "number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT)
.Attr("quant_type", "quantization data type. 0 for FP4, 1 for NF4.", AttributeProto::INT)
+ .Attr("training_mode",
+ "Indicate if the ops run in training_mode, by default, False.",
+ AttributeProto::INT,
+ static_cast(0))
+ .Attr("transB", "Whether B should be transposed on the last two dimensions before doing multiplication. Default to be 1.",
+ AttributeProto::INT, static_cast(1))
.Input(0, "A", "The input tensor, not quantized", "T1")
.Input(1, "B", "1-dimensional quantized data for weight", "T2")
.Input(2, "absmax", "quantization constants", "T1")
@@ -3389,7 +3418,8 @@ Input absmax is stored in same type as original type of B(float32, float16) with
// Shape inference
int64_t in_features = getAttribute(ctx, "K", -1);
int64_t out_features = getAttribute(ctx, "N", -1);
- MatmulWithQuantWeightShapeInference(ctx, in_features, out_features);
+ bool transB = getAttribute(ctx, "transB", 1) != 0;
+ MatmulWithQuantWeightShapeInference(ctx, in_features, out_features, transB);
});
#ifdef ENABLE_ATEN
diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.h b/orttraining/orttraining/core/framework/gradient_graph_builder.h
index 8068d4825c..93ba836b53 100644
--- a/orttraining/orttraining/core/framework/gradient_graph_builder.h
+++ b/orttraining/orttraining/core/framework/gradient_graph_builder.h
@@ -70,6 +70,7 @@ static std::unordered_map>
{"Split", {1}},
{"Clip", {1, 2}},
{"Pad", {1, 2}},
+ {"MatMulBnb4", {1, 2}}, // quantified weight (non float) and absmax constant don't need gradients.
{"Multinomial", {0}},
{"RandomNormalLike", {0}},
{"RandomUniformLike", {0}},
diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc
index 6547f53a3c..7100cedaf7 100755
--- a/orttraining/orttraining/core/graph/gradient_builder.cc
+++ b/orttraining/orttraining/core/graph/gradient_builder.cc
@@ -494,6 +494,42 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) {
return result;
}
+IMPLEMENT_GRADIENT_BUILDER(GetMatmulBnb4Gradient) {
+ auto attributes = SrcNodeAttributes();
+ std::vector attrs;
+ bool find_transB = false;
+ for (auto& attr : attributes) {
+ if (attr.first == "transB") {
+ int64_t transB_value = attr.second.i();
+ transB_value = (transB_value + 1) % 2; // revert the transpose
+ attrs.push_back(MakeAttribute("transB", transB_value));
+ find_transB = true;
+ } else {
+ attrs.push_back(attr.second);
+ }
+ }
+
+ if (!find_transB) {
+ attrs.push_back(MakeAttribute("transB", int64_t(0))); // default is 1, so we need to set it to 0
+ }
+
+ std::vector result;
+ // Y = A * B
+ // dA = dY * B', dB = A' * dY
+ if (IsGradientRequiredForSrcNodeInput(0)) {
+ // B is 1-D, so don't need transpose here.
+ result.push_back(NodeDef(OpDef{"MatMulBnb4", kMSDomain, 1},
+ {GO(0), I(1), I(2)},
+ {GI(0)},
+ attrs));
+ }
+
+ ORT_ENFORCE(!IsGradientRequiredForSrcNodeInput(1), "Gradient propagation to B is not supported yet.");
+ ORT_ENFORCE(!IsGradientRequiredForSrcNodeInput(2), "Gradient propagation to absmax is not supported yet.");
+
+ return result;
+}
+
IMPLEMENT_GRADIENT_BUILDER(GetSplitGradient) {
std::vector result = {};
std::vector input_args;
diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h
index 28a316261e..08987a86eb 100755
--- a/orttraining/orttraining/core/graph/gradient_builder.h
+++ b/orttraining/orttraining/core/graph/gradient_builder.h
@@ -54,6 +54,7 @@ DECLARE_GRADIENT_BUILDER(GetSoftmaxCrossEntropyLossGradient)
DECLARE_GRADIENT_BUILDER(GetSoftmaxCrossEntropyLossInternalGradient)
DECLARE_GRADIENT_BUILDER(GetGlobalAveragePoolGradient)
DECLARE_GRADIENT_BUILDER(GetGemmGradient)
+DECLARE_GRADIENT_BUILDER(GetMatmulBnb4Gradient)
DECLARE_GRADIENT_BUILDER(GetDropoutGradient)
DECLARE_GRADIENT_BUILDER(GetGatherNDGradient)
DECLARE_GRADIENT_BUILDER(GetGatherElementsGradient)
diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc
index 4b8c68aef0..f280a02cb4 100755
--- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc
+++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc
@@ -68,6 +68,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Reshape", GetReshapeGradient);
REGISTER_GRADIENT_BUILDER("Transpose", GetTransposeGradient);
REGISTER_GRADIENT_BUILDER("Gemm", GetGemmGradient);
+ REGISTER_GRADIENT_BUILDER("MatMulBnb4", GetMatmulBnb4Gradient);
REGISTER_GRADIENT_BUILDER("MaxPool", GetMaxPoolGradient);
REGISTER_GRADIENT_BUILDER("Gather", GetGatherGradient);
REGISTER_GRADIENT_BUILDER("PadAndUnflatten", GetPadAndUnflattenGradient);
diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py
index 8c5469740d..de2c04d262 100644
--- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py
+++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py
@@ -3,7 +3,10 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
+from __future__ import annotations
+
import sys
+from typing import ClassVar
import torch
import torch.utils.checkpoint
@@ -18,13 +21,55 @@ from onnxruntime.capi._pybind_state import (
register_torch_autograd_function,
)
from onnxruntime.training import ortmodule
-from onnxruntime.training.utils import pytorch_dtype_to_onnx
+from onnxruntime.training.utils import pytorch_scalar_type_to_pytorch_dtype, pytorch_type_to_onnx_dtype
from ._custom_op_symbolic_registry import wrap_custom_export_function
from ._fallback import ORTModuleONNXModelException, wrap_exception
from ._utils import get_fully_qualified_class_name, get_runtime_pytorch_version
+class _SpecialCustomFunctionHandler:
+ """A class to handle high priority export of torch.autograd.Function.
+ `register_high_priority_handler` can be used as function decorator to register a handler for a torch.autograd.Function.
+ """
+
+ _HIGH_PRIORITY_EXPORT_HANDLER_MAP: ClassVar[dict[str, callable]] = {}
+
+ @staticmethod
+ def add_handler(func_name: str, handler: callable) -> None:
+ """Add a handler for a function name.
+
+ Args:
+ func_name (str): The function name.
+ handler (callable): The handler.
+
+ """
+ _SpecialCustomFunctionHandler._HIGH_PRIORITY_EXPORT_HANDLER_MAP[func_name] = handler
+
+ @staticmethod
+ def get_handler(func_name: str) -> callable | None:
+ """Get the handler for a function name.
+
+ Args:
+ func_name (str): The function name.
+
+ Returns:
+ callable | None: The handler.
+
+ """
+ return _SpecialCustomFunctionHandler._HIGH_PRIORITY_EXPORT_HANDLER_MAP.get(func_name, None)
+
+
+def register_high_priority_handler(func_name):
+ """Register a handler for a torch.autograd.Function using its full qualified class name."""
+
+ def symbolic_wrapper(fn):
+ _SpecialCustomFunctionHandler.add_handler(func_name, fn)
+ return fn
+
+ return symbolic_wrapper
+
+
def register_custom_function_schema_supplementary(kclass: torch.autograd.Function) -> None:
"""Register a shape inference function for a torch.autograd.Function if there is staticmethod
"infer_shape" defined.
@@ -96,6 +141,30 @@ _UNSUPPORTED_CKPT_FUNC_NAMES = frozenset(
)
+def _get_training_mode() -> bool:
+ # TODO move to public API once the exporter team exposes that
+ training_mode = None
+ if get_runtime_pytorch_version() >= version.parse("1.12"):
+ # FIXME: using private modules
+ from torch.onnx import _globals
+
+ # before https://github.com/pytorch/pytorch/commit/c8b9b6266b505328e503b12f6a42fd88c56374f9,
+ # training_mode is still a bool type
+ if isinstance(_globals.GLOBALS.training_mode, bool):
+ training_mode = _globals.GLOBALS.training_mode
+ else:
+ if _globals.GLOBALS.training_mode not in [
+ torch.onnx.TrainingMode.EVAL,
+ torch.onnx.TrainingMode.TRAINING,
+ ]:
+ raise Exception(f"Unexpected training mode {_globals.GLOBALS.training_mode}")
+ training_mode = _globals.GLOBALS.training_mode == torch.onnx.TrainingMode.TRAINING
+ else:
+ training_mode = symbolic_helper._training_mode
+
+ return bool(training_mode)
+
+
def _export_pt_1_10(g, n, *args, **kwargs):
"""Export torch.autograd.Function in ORT PythonOp.
@@ -114,6 +183,15 @@ def _export_pt_1_10(g, n, *args, **kwargs):
func_class = n.pyobj().__self__
func_full_qual_name = get_fully_qualified_class_name(func_class)
+ # Check if the function is handled by high priority exporter.
+ hi_pri_handler = _SpecialCustomFunctionHandler.get_handler(func_full_qual_name)
+ if hi_pri_handler:
+ try_export = hi_pri_handler(g, n, *args, **kwargs)
+ if try_export is not None:
+ return try_export
+
+ # Fall back to common exporter if not handled by high priority exporter.
+
# Check if the checkpointing activation is allowed.
is_ckpt_activation_allowed = ortmodule._defined_from_envvar("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", 0) == 1
if is_ckpt_activation_allowed is False and func_full_qual_name in _UNSUPPORTED_CKPT_FUNC_NAMES:
@@ -123,26 +201,6 @@ def _export_pt_1_10(g, n, *args, **kwargs):
"wrap exportable sub-nn.Module's as ORTModule."
)
- # TODO move to public API once the exporter team exposes that
- training_mode = None
- if get_runtime_pytorch_version() >= version.parse("1.12"):
- # FIXME: using private modules
- from torch.onnx import _globals
-
- # before https://github.com/pytorch/pytorch/commit/c8b9b6266b505328e503b12f6a42fd88c56374f9,
- # training_mode is still a bool type
- if isinstance(_globals.GLOBALS.training_mode, bool):
- training_mode = _globals.GLOBALS.training_mode
- else:
- if _globals.GLOBALS.training_mode not in [
- torch.onnx.TrainingMode.EVAL,
- torch.onnx.TrainingMode.TRAINING,
- ]:
- raise Exception(f"Unexpected training mode {_globals.GLOBALS.training_mode}")
- training_mode = _globals.GLOBALS.training_mode == torch.onnx.TrainingMode.TRAINING
- else:
- training_mode = symbolic_helper._training_mode
-
cconv = n.cconv()
input_tensor_types = []
@@ -179,7 +237,7 @@ def _export_pt_1_10(g, n, *args, **kwargs):
if call_type == "d":
# Got a tensor variable.
tensor_args.append(arg)
- scalar_type = pytorch_dtype_to_onnx(arg.type().scalarType())
+ scalar_type = pytorch_type_to_onnx_dtype(arg.type().scalarType())
input_tensor_types.append(scalar_type)
input_tensor_ranks.append(arg.type().dim())
continue
@@ -258,7 +316,7 @@ def _export_pt_1_10(g, n, *args, **kwargs):
output_tensor_ranks = []
for arg in n.outputs():
# Type of tensor's elements.
- scalar_type = pytorch_dtype_to_onnx(arg.type().scalarType())
+ scalar_type = pytorch_type_to_onnx_dtype(arg.type().scalarType())
output_tensor_types.append(scalar_type)
output_tensor_ranks.append(arg.type().dim())
@@ -270,7 +328,7 @@ def _export_pt_1_10(g, n, *args, **kwargs):
"input_tensor_ranks_i": input_tensor_ranks,
"output_tensor_types_i": output_tensor_types,
"output_tensor_ranks_i": output_tensor_ranks,
- "training_mode_i": 1 if training_mode else 0,
+ "training_mode_i": 1 if _get_training_mode() else 0,
"comment_s": debug_comment,
}
@@ -336,3 +394,55 @@ def post_process_enabling_autograd_function(exported_model: ModelProto) -> Model
index += 1
return exported_model
+
+
+@register_high_priority_handler("bitsandbytes.autograd._functions.MatMul4Bit")
+def _matmul4bit_export(g, n, *args, **kwargs):
+ cconv = n.cconv()
+ can_converted = cconv[0] == "d" and cconv[1] == "d" and cconv[2] == "c" and cconv[3] == "c" and cconv[4] == "c"
+ can_converted = can_converted and (args[2] is None and args[3] is None and args[4] is not None)
+ if not can_converted:
+ return None
+
+ quant_state = args[4]
+ absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state
+
+ # MatMulBnb4's blocksize needs to be a power of 2 and not smaller than 16
+ if blocksize < 16 or blocksize & (blocksize - 1) != 0:
+ return None
+
+ # MatMulBnb4 does not support double de-quantization (e.g. absmax is int, needs to be dequantized too)
+ if compressed_stats is not None:
+ return None
+
+ # The PyTorch linear weight shape is [out_feature, in_feature]
+ in_feature = shape[1]
+ out_feature = shape[0]
+ if quant_type == "fp4":
+ quant_type = 0
+ elif quant_type == "nf4":
+ quant_type = 1
+ else:
+ return None
+ attrs = {
+ "K_i": in_feature,
+ "N_i": out_feature,
+ "block_size_i": blocksize,
+ "quant_type_i": quant_type,
+ "training_mode_i": 1 if _get_training_mode() else 0,
+ }
+
+ # Make sure the quant weight can be flatten to 1D tensor safely, which com.microsoft::MatMulBnb4 requires.
+ found_dim1 = any(v == 1 for v in args[1].type().sizes())
+ if not found_dim1:
+ return None
+
+ absmax = g.op(
+ "Constant",
+ value_t=torch.tensor(absmax, dtype=pytorch_scalar_type_to_pytorch_dtype(args[0].type().scalarType())),
+ )
+ quant_weight = g.op(
+ "Reshape", args[1], g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
+ ) # flatten to 1D
+ tensor_args = [args[0], quant_weight, absmax]
+ return g.op("com.microsoft::MatMulBnb4", *tensor_args, **attrs)
diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py
index 6e694dcdf2..99e8851b6a 100644
--- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py
+++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py
@@ -12,7 +12,7 @@ from packaging.version import Version
from torch.onnx import register_custom_op_symbolic
from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes, parse_args
-from onnxruntime.training.utils import pytorch_dtype_to_onnx
+from onnxruntime.training.utils import pytorch_type_to_onnx_dtype
from ._utils import get_runtime_pytorch_version
@@ -145,7 +145,7 @@ def cross_entropy_loss(g, node, logits, target, weight, reduction, ignore_index,
weight_casted,
ignore_index,
reduction_s=reduction,
- output_type_i=pytorch_dtype_to_onnx(output_type.scalarType()),
+ output_type_i=pytorch_type_to_onnx_dtype(output_type.scalarType()),
outputs=2,
)
output.setType(output_type)
diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py
index 5e8805bfdd..ba61b2e4c8 100755
--- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py
+++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py
@@ -19,7 +19,7 @@ from torch.utils.cpp_extension import ROCM_HOME
import onnxruntime
from onnxruntime.capi import _pybind_state as C
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
-from onnxruntime.training.utils import ORTModelInputOutputSchemaType, onnx_dtype_to_pytorch
+from onnxruntime.training.utils import ORTModelInputOutputSchemaType, onnx_dtype_to_pytorch_dtype
from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3
from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils
@@ -345,7 +345,7 @@ class GraphExecutionManager(GraphExecutionInterface):
cache_dir, f"{hash_fn(str(self._flattened_module).encode()).hexdigest()}_{get_rank()}.onnx"
)
if os.path.exists(cache_dir) and os.path.isfile(filename):
- self._logger.info(
+ self._logger.warning(
f"Cached model detected! Cached model will be used to save export and initialization time."
f"If you want the model to be re-exported then DELETE {filename}."
)
@@ -627,7 +627,7 @@ class GraphExecutionManager(GraphExecutionInterface):
kwargs[STAGE3_PULL_WEIGHT_TRIGGER_NAME] = torch.zeros(
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE,
- dtype=onnx_dtype_to_pytorch(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE),
+ dtype=onnx_dtype_to_pytorch_dtype(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE),
device=device,
).requires_grad_()
diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py
index d0dea66fda..3a5d7da926 100644
--- a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py
+++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py
@@ -14,7 +14,7 @@ from onnxruntime.capi._pybind_state import (
register_shape_inference_function,
register_torch_autograd_function,
)
-from onnxruntime.training.utils import pytorch_dtype_to_onnx
+from onnxruntime.training.utils import pytorch_type_to_onnx_dtype
from ._custom_autograd_function_exporter import register_custom_function_schema_supplementary
from ._utils import get_fully_qualified_class_name
@@ -149,7 +149,7 @@ def post_processing_enable_zero_stage3_compat(
c,
graph_input.name,
len(zero_stage3_named_params[graph_input.name].ds_shape), # new rank
- pytorch_dtype_to_onnx(zero_stage3_named_params[graph_input.name].dtype), # new data type
+ pytorch_type_to_onnx_dtype(zero_stage3_named_params[graph_input.name].dtype), # new data type
)
# Delete exported_model.graph.input
diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py
index 0eb6790d7a..ff0cde3719 100644
--- a/orttraining/orttraining/python/training/ortmodule/options.py
+++ b/orttraining/orttraining/python/training/ortmodule/options.py
@@ -366,7 +366,7 @@ class _RuntimeOptions:
# Cache exported model
if "ORTMODULE_CACHE_DIR" in os.environ:
- self._logger.info("ORTModule cache optimization is ON.")
+ self._logger.warning("ORTModule optimization for caching exported model is ON.")
self.ortmodule_cache_dir = os.getenv("ORTMODULE_CACHE_DIR")
# Experimental features.
diff --git a/orttraining/orttraining/python/training/utils/__init__.py b/orttraining/orttraining/python/training/utils/__init__.py
index fa7c9f2750..d40a6ddf7d 100644
--- a/orttraining/orttraining/python/training/utils/__init__.py
+++ b/orttraining/orttraining/python/training/utils/__init__.py
@@ -9,7 +9,11 @@ from onnxruntime.training.utils.torch_io_helper import (
extract_data_and_schema,
unflatten_data_using_schema,
)
-from onnxruntime.training.utils.torch_type_map import onnx_dtype_to_pytorch, pytorch_dtype_to_onnx
+from onnxruntime.training.utils.torch_type_map import (
+ onnx_dtype_to_pytorch_dtype,
+ pytorch_scalar_type_to_pytorch_dtype,
+ pytorch_type_to_onnx_dtype,
+)
__all__ = [
"PrimitiveType",
@@ -17,6 +21,7 @@ __all__ = [
"ORTModelInputOutputSchemaType",
"extract_data_and_schema",
"unflatten_data_using_schema",
- "pytorch_dtype_to_onnx",
- "onnx_dtype_to_pytorch",
+ "pytorch_type_to_onnx_dtype",
+ "onnx_dtype_to_pytorch_dtype",
+ "pytorch_scalar_type_to_pytorch_dtype",
]
diff --git a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py
index b1cb5c19e8..0d268a7a4a 100644
--- a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py
+++ b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py
@@ -17,7 +17,7 @@ import torch
from onnxruntime.training.utils import (
ORTModelInputOutputType,
extract_data_and_schema,
- pytorch_dtype_to_onnx,
+ pytorch_type_to_onnx_dtype,
unflatten_data_using_schema,
)
@@ -324,7 +324,7 @@ class ORTZeROOffloadPreForwardFunction(torch.autograd.Function):
start_offset = len(tensor_input_shapes) - len(partitioned_params)
for index, param in enumerate(partitioned_params):
tensor_output_shapes[start_offset + index] = list(param.ds_shape)
- tensor_output_dtypes[start_offset + index] = int(pytorch_dtype_to_onnx(param.dtype))
+ tensor_output_dtypes[start_offset + index] = int(pytorch_type_to_onnx_dtype(param.dtype))
assert len(tensor_output_shapes) == len(tensor_input_shapes)
assert len(tensor_output_dtypes) == len(tensor_input_dtypes)
diff --git a/orttraining/orttraining/python/training/utils/torch_type_map.py b/orttraining/orttraining/python/training/utils/torch_type_map.py
index bdacab8ad0..2b429f3fd4 100644
--- a/orttraining/orttraining/python/training/utils/torch_type_map.py
+++ b/orttraining/orttraining/python/training/utils/torch_type_map.py
@@ -36,8 +36,10 @@ _DTYPE_TO_ONNX = {torch_dtype: onnx_dtype for k, (onnx_dtype, torch_dtype) in _C
_ONNX_TO_DTYPE = {onnx_dtype: torch_dtype for torch_dtype, onnx_dtype in _DTYPE_TO_ONNX.items()}
-def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torch.onnx.TensorProtoDataType:
- """Converts a pytorch dtype or scalar type string to an onnx dtype."""
+def pytorch_type_to_onnx_dtype(dtype_or_scalar_type: Union[torch.dtype, str]) -> torch.onnx.TensorProtoDataType:
+ """Converts a pytorch dtype or scalar type string to an onnx dtype.
+ PyTorch type can be either a dtype or a scalar type string.
+ """
dtype = dtype_or_scalar_type
if isinstance(dtype, str):
if dtype not in _CAST_PYTORCH_TO_ONNX:
@@ -49,7 +51,15 @@ def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torc
return _DTYPE_TO_ONNX[dtype]
-def onnx_dtype_to_pytorch(dtype: torch.onnx.TensorProtoDataType) -> torch.dtype:
+def pytorch_scalar_type_to_pytorch_dtype(dtype: str) -> torch.dtype:
+ """Converts a pytorch scalar type string to a pytorch dtype."""
+ assert isinstance(dtype, str)
+ if dtype not in _CAST_PYTORCH_TO_ONNX:
+ raise RuntimeError(f"Unsupported dtype {dtype}")
+ return _CAST_PYTORCH_TO_ONNX[dtype][1]
+
+
+def onnx_dtype_to_pytorch_dtype(dtype: torch.onnx.TensorProtoDataType) -> torch.dtype:
"""Converts an onnx dtype to a pytorch dtype."""
if dtype not in _ONNX_TO_DTYPE:
raise RuntimeError(f"Unsupported dtype {dtype}")