fix null dereference warning (#6437)

This commit is contained in:
Yufeng Li 2021-01-25 16:50:32 -08:00 committed by GitHub
parent f3a0344f9a
commit 7e42840298
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 44 additions and 49 deletions

View file

@ -227,30 +227,16 @@ Status Attention<T>::PrePack(const Tensor& weights, int input_idx, bool& is_pack
template <typename T>
Status Attention<T>::Compute(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
const Tensor* weights;
const Tensor* weights = packed_weights_ ? nullptr : context->Input<Tensor>(1);
const Tensor* bias = context->Input<Tensor>(2);
const Tensor* mask_index = context->Input<Tensor>(3);
const Tensor* past = context->Input<Tensor>(4);
if (packed_weights_) {
weights = nullptr;
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
weight_shape_,
bias->Shape(),
mask_index,
past));
} else {
weights = context->Input<Tensor>(1);
//Normally we don't check if an input is NULL, but this one is needed to make VC++
//static analyzer happy
if (weights == nullptr)
return Status(common::ONNXRUNTIME, common::FAIL, "the second input cannot be NULL");
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
weights->Shape(),
bias->Shape(),
mask_index,
past));
}
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
weights ? weights->Shape() : weight_shape_,
bias->Shape(),
mask_index,
past));
const auto& shape = input->Shape().GetDims();
const int batch_size = static_cast<int>(shape[0]);
@ -279,7 +265,7 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
{
const int loop_len = 3 * batch_size * num_heads_;
const auto* input_data = input->template Data<T>();
const auto* weights_data = weights == nullptr ? nullptr : weights->template Data<T>();
const auto* weights_data = weights ? weights->template Data<T>() : nullptr;
const auto* bias_data = bias->template Data<T>();
const double cost =

View file

@ -35,7 +35,7 @@ Status MatMulInteger::Compute(OpKernelContext* ctx) const {
const Tensor* b = packed_b_ ? nullptr : ctx->Input<Tensor>(1);
MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), packed_b_ ? b_shape_ : b->Shape()));
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b ? b->Shape() : b_shape_));
Tensor* y = ctx->Output(0, helper.OutputShape());
// Bail out early if the output is going to be empty
@ -61,9 +61,9 @@ Status MatMulInteger::Compute(OpKernelContext* ctx) const {
const auto* a_data = a->template Data<uint8_t>();
auto* y_data = y->template MutableData<int32_t>();
for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
if (packed_b_) {
if (packed_b_) {
for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
MlasGemm(static_cast<size_t>(helper.M()),
static_cast<size_t>(helper.N()),
static_cast<size_t>(helper.K()),
@ -76,25 +76,33 @@ Status MatMulInteger::Compute(OpKernelContext* ctx) const {
y_data + helper.OutputOffsets()[i],
static_cast<size_t>(helper.N()),
thread_pool);
continue;
}
#endif
const auto* b_data = static_cast<const uint8_t*>(b->DataRaw());
const bool b_is_signed = b->IsDataType<int8_t>();
MlasGemm(static_cast<size_t>(helper.M()),
static_cast<size_t>(helper.N()),
static_cast<size_t>(helper.K()),
a_data + helper.LeftOffsets()[i],
static_cast<size_t>(helper.K()),
a_offset,
b_data + helper.RightOffsets()[i],
static_cast<size_t>(helper.N()),
b_offset,
b_is_signed,
y_data + helper.OutputOffsets()[i],
static_cast<size_t>(helper.N()),
thread_pool);
return Status::OK();
}
#endif
if (b != nullptr) {
for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
const auto* b_data = static_cast<const uint8_t*>(b->DataRaw());
const bool b_is_signed = b->IsDataType<int8_t>();
MlasGemm(static_cast<size_t>(helper.M()),
static_cast<size_t>(helper.N()),
static_cast<size_t>(helper.K()),
a_data + helper.LeftOffsets()[i],
static_cast<size_t>(helper.K()),
a_offset,
b_data + helper.RightOffsets()[i],
static_cast<size_t>(helper.N()),
b_offset,
b_is_signed,
y_data + helper.OutputOffsets()[i],
static_cast<size_t>(helper.N()),
thread_pool);
}
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input B should not be null.");
}
return Status::OK();
}

View file

@ -46,7 +46,7 @@ class MatMulIntegerBase : public OpKernel {
#endif
protected:
bool b_is_signed_;
bool b_is_signed_{true};
TensorShape b_shape_;
BufferUniquePtr packed_b_;
};

View file

@ -80,7 +80,7 @@ ONNX_OPERATOR_KERNEL_EX(
.TypeConstraint("T4", DataTypeImpl::GetTensorType<int32_t>()),
QLinearConv);
}
} // namespace contrib
#endif
@ -107,7 +107,7 @@ Status QLinearConv::PrePack(const Tensor& tensor, int input_idx, bool& is_packed
const size_t output_channels = static_cast<size_t>(shape[0]);
const size_t group_input_channels = static_cast<size_t>(shape[1]);
const size_t kernel_size =
static_cast<size_t>(std::accumulate(shape.data() + 2, shape.data() + rank, 1LL, std::multiplies<int64_t>()));
static_cast<size_t>(std::accumulate(shape.data() + 2, shape.data() + rank, 1LL, std::multiplies<int64_t>()));
const auto* Wdata = static_cast<const uint8_t*>(tensor.DataRaw());
W_shape_ = shape;
@ -165,7 +165,7 @@ Status QLinearConv::PrePack(const Tensor& tensor, int input_idx, bool& is_packed
Status QLinearConv::Compute(OpKernelContext* context) const {
const Tensor* X = context->Input<Tensor>(0);
const Tensor* W = is_W_packed_ ? nullptr : context->Input<Tensor>(3);
const auto& W_shape = is_W_packed_ ? W_shape_ : W->Shape();
const auto& W_shape = W ? W->Shape() : W_shape_;
const bool is_W_signed = (W != nullptr) ? W->IsDataType<int8_t>() : is_W_signed_;
const int64_t N = X->Shape()[0];
@ -285,7 +285,8 @@ Status QLinearConv::Compute(OpKernelContext* context) const {
}
#endif
if (use_reordered_W) {
if (reordered_W_buffer_) {
if (W == nullptr) {
// Weight was constant and reordered.
reordered_W = static_cast<uint8_t*>(reordered_W_buffer_.get());
} else {
// Weight tensor was not constant or prepacking is disabled.
@ -402,7 +403,7 @@ Status QLinearConv::Compute(OpKernelContext* context) const {
static_cast<int64_t>(kernel_rank),
static_cast<uint8_t*>(col_buffer.get()) + group_id * col_buffer_size,
X_zero_point_value);
}
}
}
}

View file

@ -714,7 +714,7 @@ private:
for (size_t f = 0; f < M * N; f++) {
if (C[f] != CReference[f]) {
printf("mismatch M=%zd, N=%zd, K=%zd, offa=%d, offb=%d!\n", M, N, K, offa, offb);
printf("mismatch M=%zd, N=%zd, K=%zd, offa=%d, offb=%d!\n", M, N, K, int(offa), int(offb));
break;
}
}
@ -921,7 +921,7 @@ private:
for (size_t f = 0; f < M * N; f++) {
// Sensitive to comparing positive/negative zero.
if (C[f] != CReference[f]) {
printf("mismatch M=%zd, N=%zd, K=%zd, offa=%d, offb=%d! %f %f\n", M, N, K, offa, offb, C[f], CReference[f]);
printf("mismatch M=%zd, N=%zd, K=%zd, offa=%d, offb=%d! %f %f\n", M, N, K, int(offa), int(offb), C[f], CReference[f]);
break;
}
}