mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Changes to support TNLRV3 fine-tuning (#4639)
* added reducesumlogexp gradient added test fixed type mismatch when calling cudnnreduce kernel fixed python frontend to remove redundant states to match pytorch state dict
This commit is contained in:
parent
d8f3e46d45
commit
f90a2d46ae
7 changed files with 128 additions and 136 deletions
|
|
@ -395,10 +395,12 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr
|
|||
}
|
||||
|
||||
CudnnReduceDescriptor reduce_desc;
|
||||
if (std::is_same<T, MLFloat16>::value)
|
||||
if (std::is_same<T, MLFloat16>::value) {
|
||||
ORT_RETURN_IF_ERROR(reduce_desc.Set(cudnn_reduce_op, CudnnTensor::GetDataType<float>(), ReduceTensorIndices));
|
||||
else
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(reduce_desc.Set(cudnn_reduce_op, cudnn_type_X, ReduceTensorIndices));
|
||||
}
|
||||
|
||||
const auto one = Consts<CudaT>::One;
|
||||
const auto zero = Consts<CudaT>::Zero;
|
||||
CudnnTensor input_tensor;
|
||||
|
|
@ -437,7 +439,11 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr
|
|||
} else {
|
||||
// Reduce max -- Max/Min will output indices data
|
||||
CudnnReduceDescriptor reduce_max_desc;
|
||||
ORT_RETURN_IF_ERROR(reduce_max_desc.Set(CUDNN_REDUCE_TENSOR_MAX, cudnn_type_X, CUDNN_REDUCE_TENSOR_NO_INDICES));
|
||||
cudnnDataType_t cudnn_reduce_max_type = cudnn_type_X;
|
||||
if((std::is_same<T, MLFloat16>::value)) {
|
||||
cudnn_reduce_max_type = CUDNN_DATA_FLOAT;
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(reduce_max_desc.Set(CUDNN_REDUCE_TENSOR_MAX, cudnn_reduce_max_type, CUDNN_REDUCE_TENSOR_NO_INDICES));
|
||||
size_t indices_bytes_max = 0;
|
||||
CUDNN_RETURN_IF_ERROR(cudnnGetReductionIndicesSize(cuda_ep.PerThreadCudnnHandle(), reduce_max_desc,
|
||||
input_tensor, output_tensor, &indices_bytes_max));
|
||||
|
|
|
|||
|
|
@ -905,6 +905,40 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) {
|
|||
return result;
|
||||
}
|
||||
|
||||
// Reference computation is pytorch's logsumexp_backward
|
||||
// dx_i = exp(xi) / reduceSum(exp(xi))
|
||||
// O(0) = log(reduceSum(exp(xi)))
|
||||
// Self_Sub_Result = I(0) - O(0) = xi - log(sum(exp(xi))) = log( xi / reduceSum(exp(xi)))
|
||||
// Gradient computation is re-using output and input from forward op, can be a recomputation candidate.
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetReduceLogSumExpGradient) {
|
||||
std::vector<NodeDef> result;
|
||||
auto attributes = SrcNodeAttributes();
|
||||
bool keepdims = true;
|
||||
if (attributes.find("keepdims") != attributes.end() &&
|
||||
attributes.at("keepdims").has_i()) {
|
||||
keepdims = static_cast<bool>(attributes.at("keepdims").i());
|
||||
}
|
||||
|
||||
ArgDef grad = GO(0);
|
||||
if (!keepdims && attributes.find("axes") != attributes.end()) {
|
||||
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
|
||||
grad = IA("Unsqueezed_Grad");
|
||||
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));
|
||||
|
||||
result.push_back(NodeDef("Unsqueeze", {O(0)}, {IA("Unsqueezed_Output")}, {MakeAttribute("axes", axes_values)}));
|
||||
result.push_back(NodeDef("Sub", {I(0), IA("Unsqueezed_Output")}, {IA("Self_Sub_Result")}));
|
||||
}
|
||||
else {
|
||||
result.push_back(NodeDef("Sub", {I(0), O(0)}, {IA("Self_Sub_Result")}));
|
||||
}
|
||||
|
||||
result.push_back(NodeDef("Exp", {IA("Self_Sub_Result")}, {IA("Self_Sub_Result_Exp")}));
|
||||
|
||||
result.push_back(NodeDef("Mul", {IA("Self_Sub_Result_Exp"), grad}, {GI(0)}));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetReduceSumGradient) {
|
||||
std::vector<NodeDef> result;
|
||||
auto attributes = SrcNodeAttributes();
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ DECLARE_GRADIENT_BUILDER(GetMulGradient)
|
|||
DECLARE_GRADIENT_BUILDER(GetDivGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetReduceMeanGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetReduceSumGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetReduceLogSumExpGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetPowGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetConcatGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetReshapeGradient)
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
|
|||
REGISTER_GRADIENT_BUILDER("Pow", GetPowGradient);
|
||||
REGISTER_GRADIENT_BUILDER("ReduceMean", GetReduceMeanGradient);
|
||||
REGISTER_GRADIENT_BUILDER("ReduceSum", GetReduceSumGradient);
|
||||
REGISTER_GRADIENT_BUILDER("ReduceLogSumExp", GetReduceLogSumExpGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Add", GetAddSubGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Sub", GetAddSubGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Mul", GetMulGradient);
|
||||
|
|
|
|||
|
|
@ -629,6 +629,8 @@ class ORTTrainer():
|
|||
self.world_size = world_size
|
||||
self.use_mixed_precision = use_mixed_precision
|
||||
|
||||
self.original_model_state_keys = list(model.state_dict().keys()) if hasattr(model, 'state_dict') else []
|
||||
|
||||
self.session = None
|
||||
self.device_ = device
|
||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||
|
|
@ -773,7 +775,11 @@ class ORTTrainer():
|
|||
if n.name not in torch_state:
|
||||
torch_state[n.name] = torch.from_numpy(numpy_helper.to_array(n))
|
||||
|
||||
return torch_state
|
||||
# Need to remove redundant initializers and name suffices to map back to original torch state names
|
||||
torch_state_to_return = {key: torch_state[key] for key in self.original_model_state_keys if key in torch_state} \
|
||||
if self.original_model_state_keys \
|
||||
else torch_state
|
||||
return torch_state_to_return
|
||||
|
||||
def load_state_dict(self, state_dict, strict=False):
|
||||
# Note: It may happen ONNX model has not yet been initialized
|
||||
|
|
|
|||
|
|
@ -7,6 +7,9 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
using TestDataVector = std::tuple<std::vector<std::vector<TensorInfo>>, // Input data
|
||||
std::vector<std::vector<TensorInfo>>, // output data
|
||||
std::vector<std::vector<onnx::AttributeProto>>>; //attribute
|
||||
|
||||
class GradientOpTester : public OpTester {
|
||||
public:
|
||||
|
|
@ -39,3 +42,4 @@ class GradientOpTester : public OpTester {
|
|||
};
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
||||
|
|
|
|||
|
|
@ -38,6 +38,70 @@ static bool IsErrorWithinTolerance(float error, float tolerance) {
|
|||
#define EXPECT_IS_TINY(max_error) \
|
||||
EXPECT_IS_TINIER_THAN(max_error, 1.5e-2f)
|
||||
|
||||
static void RunReductionTests(const OpDef& op_def) {
|
||||
|
||||
TestDataVector test_data(
|
||||
// Input X
|
||||
{
|
||||
{{4, 3, 2}},
|
||||
{{4, 3, 2}},
|
||||
{{4, 3, 2}},
|
||||
{{4, 3, 2}},
|
||||
{{4, 3, 2}},
|
||||
{{4, 3, 2}},
|
||||
{{4, 3, 2}},
|
||||
{{4, 3, 2}},
|
||||
},
|
||||
// Input Y
|
||||
{
|
||||
{{1, 1, 1}},
|
||||
{{}},
|
||||
{{1, 3, 1}},
|
||||
{{2}},
|
||||
{{4, 1, 2}},
|
||||
{{4, 3}},
|
||||
{{4, 1, 2}},
|
||||
{{4}}
|
||||
},
|
||||
// Attributes
|
||||
{
|
||||
// default
|
||||
{},
|
||||
// axes = [0, 1, 2], keepdims = 0
|
||||
{MakeAttribute("axes", std::vector<int64_t>{0, 1, 2}),
|
||||
MakeAttribute("keepdims", int64_t(0))},
|
||||
// axes = [0, 2], keepdims = 1
|
||||
{MakeAttribute("axes", std::vector<int64_t>{0, 2})},
|
||||
// axes = [0, 1], keepdims = 0
|
||||
{MakeAttribute("axes", std::vector<int64_t>{0, 1}),
|
||||
MakeAttribute("keepdims", int64_t(0))},
|
||||
// axes = [1], keepdims = 1
|
||||
{MakeAttribute("axes", std::vector<int64_t>{1}),
|
||||
MakeAttribute("keepdims", int64_t(1))},
|
||||
// axes = [2], keepdims = 0
|
||||
{MakeAttribute("axes", std::vector<int64_t>{2}),
|
||||
MakeAttribute("keepdims", int64_t(0))},
|
||||
// axes = [-2], keepdims = 1
|
||||
{MakeAttribute("axes", std::vector<int64_t>{-2}),
|
||||
MakeAttribute("keepdims", int64_t(1))},
|
||||
// axes = [-2, -1], keepdims = 0
|
||||
{MakeAttribute("axes", std::vector<int64_t>{-2, -1}),
|
||||
MakeAttribute("keepdims", int64_t(0))}
|
||||
});
|
||||
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
|
||||
float max_error;
|
||||
|
||||
for (size_t i = 0; i < std::get<0>(test_data).size(); i++) {
|
||||
max_error = 0;
|
||||
gradient_checker.ComputeGradientError(op_def, std::get<0>(test_data)[i],
|
||||
std::get<1>(test_data)[i], &max_error,
|
||||
std::get<2>(test_data)[i]);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void GenerateRandomDataWithOneHot(
|
||||
std::vector<std::vector<float>>& x_datas,
|
||||
|
|
@ -426,149 +490,24 @@ TEST(GradientCheckerTest, GemmGrad) {
|
|||
}
|
||||
|
||||
TEST(GradientCheckerTest, ReduceMeanGrad) {
|
||||
float max_error;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
// Attribute axes supports negative values from opset 11.
|
||||
OpDef op_def{"ReduceMean", kOnnxDomain, 11};
|
||||
|
||||
// default
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{1, 1, 1}}, &max_error);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
// TODO: Fix forward kernel behavior for default axes
|
||||
// default axes, keepdims = 0
|
||||
/*
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{}}, &max_error,
|
||||
{MakeAttribute("keepdims", int64_t(0))});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
*/
|
||||
|
||||
// axes = [0, 1, 2], keepdims = 0
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{}}, &max_error,
|
||||
{MakeAttribute("axes", std::vector<int64_t>{0, 1, 2}),
|
||||
MakeAttribute("keepdims", int64_t(0))});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
// axes = [0, 2], keepdims = 1
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{1, 3, 1}}, &max_error,
|
||||
{MakeAttribute("axes", std::vector<int64_t>{0, 2})});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
// axes = [0, 1], keepdims = 0
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{2}}, &max_error,
|
||||
{MakeAttribute("axes", std::vector<int64_t>{0, 1}),
|
||||
MakeAttribute("keepdims", int64_t(0))});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
// axes = [1], keepdims = 1
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 1, 2}}, &max_error,
|
||||
{MakeAttribute("axes", std::vector<int64_t>{1}),
|
||||
MakeAttribute("keepdims", int64_t(1))});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
// axes = [2], keepdims = 0
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 3}}, &max_error,
|
||||
{MakeAttribute("axes", std::vector<int64_t>{2}),
|
||||
MakeAttribute("keepdims", int64_t(0))});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
// axes = [-2], keepdims = 1
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 1, 2}}, &max_error,
|
||||
{MakeAttribute("axes", std::vector<int64_t>{-2}),
|
||||
MakeAttribute("keepdims", int64_t(1))});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
// axes = [-2, -1], keepdims = 0
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4}}, &max_error,
|
||||
{MakeAttribute("axes", std::vector<int64_t>{-2, -1}),
|
||||
MakeAttribute("keepdims", int64_t(0))});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
RunReductionTests(op_def);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, ReduceSumGrad) {
|
||||
float max_error;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
// Attribute axes supports negative values from opset 11.
|
||||
OpDef op_def{"ReduceSum", kOnnxDomain, 11};
|
||||
|
||||
// default
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{1, 1, 1}}, &max_error);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
RunReductionTests(op_def);
|
||||
}
|
||||
|
||||
// axes = [0, 1, 2], keepdims = 0
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{}}, &max_error,
|
||||
{MakeAttribute("axes", std::vector<int64_t>{0, 1, 2}),
|
||||
MakeAttribute("keepdims", int64_t(0))});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
TEST(GradientCheckerTest, ReduceLogSumExpGrad) {
|
||||
// Attribute axes supports negative values from opset 11.
|
||||
OpDef op_def{"ReduceLogSumExp", kOnnxDomain, 11};
|
||||
|
||||
// axes = [0, 2], keepdims = 1
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{1, 3, 1}}, &max_error,
|
||||
{MakeAttribute("axes", std::vector<int64_t>{0, 2})});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
// axes = [0, 1], keepdims = 0
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{2}}, &max_error,
|
||||
{MakeAttribute("axes", std::vector<int64_t>{0, 1}),
|
||||
MakeAttribute("keepdims", int64_t(0))});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
// axes = [1], keepdims = 1
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 1, 2}}, &max_error,
|
||||
{MakeAttribute("axes", std::vector<int64_t>{1}),
|
||||
MakeAttribute("keepdims", int64_t(1))});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
// axes = [2], keepdims = 0
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 3}}, &max_error,
|
||||
{MakeAttribute("axes", std::vector<int64_t>{2}),
|
||||
MakeAttribute("keepdims", int64_t(0))});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
// axes = [-2], keepdims = 1
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 1, 2}}, &max_error,
|
||||
{MakeAttribute("axes", std::vector<int64_t>{-2}),
|
||||
MakeAttribute("keepdims", int64_t(1))});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
// axes = [-1, -3], keepdims = 0
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{3}}, &max_error,
|
||||
{MakeAttribute("axes", std::vector<int64_t>{-1, -3}),
|
||||
MakeAttribute("keepdims", int64_t(0))});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
RunReductionTests(op_def);
|
||||
}
|
||||
|
||||
#ifndef USE_CUDA
|
||||
|
|
@ -1960,3 +1899,4 @@ TEST(GradientCheckerTest, ExpandGrad) {
|
|||
} // namespace onnxruntime
|
||||
|
||||
#endif // NDEBUG
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue