Fix compiler warnings treated as errors in GistEncodeDecode. (#7568)

* Fix compiler warning in GistEncodeDecode.

* Fix other use of member variable.

* Make `compression_type_` const.

* Change floor to floorf in CUDA code.

* Statically cast size_t to int in GIST CUDA kernels

* Add explicit cast to `long` in gist.cc

Co-authored-by: Derek Murray <demurra@microsoft.com>
This commit is contained in:
Derek Murray 2021-05-05 09:05:11 -07:00 committed by GitHub
parent 91985ab03d
commit 94c97ac8c2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 11 additions and 10 deletions

View file

@ -188,9 +188,9 @@ std::vector<std::string> GistEncodeDecode::TargetOpTypes() const noexcept {
Status GistEncodeDecode::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const {
if (node.Description() != "Backward pass") {
if (GistEncodeDecode::AddEncodeDecode(graph, node, compression_type, logger)) {
if (GistEncodeDecode::AddEncodeDecode(graph, node, compression_type_, logger)) {
LOGS(logger, INFO) << "Gist applied to node name - " << node.Name() << ", node type - "
<< node.OpType() << ", of compr type - " << compression_type;
<< node.OpType() << ", of compr type - " << compression_type_;
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;
}
}

View file

@ -13,7 +13,6 @@ namespace onnxruntime {
class GistEncodeDecode : public RewriteRule {
public:
int operator_type;
std::string compression_type;
static constexpr const char* GIST_PAIR_NODE_NAME_BASE = "gist";
@ -34,7 +33,7 @@ class GistEncodeDecode : public RewriteRule {
{"Relu", {"ReluGrad", "Shape", "Reshape"}}};
GistEncodeDecode() noexcept : RewriteRule("GistEncodeDecode") {}
GistEncodeDecode(int op_type, std::string compr_type) noexcept : RewriteRule("GistEncodeDecode"), operator_type(op_type), compression_type(std::move(compr_type)) {}
GistEncodeDecode(int op_type, std::string compr_type) noexcept : RewriteRule("GistEncodeDecode"), operator_type(op_type), compression_type_(std::move(compr_type)) {}
private:
int GenerateDecodePriority() const { return priority_generator_--; };
@ -42,6 +41,8 @@ class GistEncodeDecode : public RewriteRule {
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
bool AddEncodeDecode(Graph& graph, Node& curr_node, std::string compression_type, const logging::Logger& logger) const;
const std::string compression_type_;
};
} // namespace onnxruntime

View file

@ -87,7 +87,7 @@ Status GistPack1EncoderOp<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* X = context->Input<Tensor>(0);
ORT_RETURN_IF(X == nullptr, "X input is unavailable");
long int n = (X->Shape().Size() + GIST_PACK1_FACTOR - 1) / GIST_PACK1_FACTOR;
long n = static_cast<long>((X->Shape().Size() + GIST_PACK1_FACTOR - 1) / GIST_PACK1_FACTOR);
Tensor* Y = context->Output(0, TensorShape({n}));
typedef typename ToCudaType<T>::MappedType CudaT;
GistPack1EncoderImpl<CudaT>(

View file

@ -352,7 +352,7 @@ __global__ void _GistPackMsfp15DecoderKernel(
output_data[in_i] = 0.0;
} else {
// Find leading 1
uint8_t leading_bit_pos = floor(log2f(mantissa));
uint8_t leading_bit_pos = floorf(log2f(mantissa));
// Difference from shared exponent of this value
int exp_diff = 5 - leading_bit_pos;
// Adjust exponent
@ -466,9 +466,9 @@ void GistPackMsfp15EncoderImpl(
const size_t tile_size) {
assert(axis_size % tile_size == 0);
const int num_tiles = axis_size / tile_size;
const int num_tiles = static_cast<int>(axis_size / tile_size);
const int threads = pre_axis_size * num_tiles;
const int threads = static_cast<int>(pre_axis_size * num_tiles);
int blocksPerGrid = (int)(ceil(static_cast<float>(threads) / GridDim::maxThreadsPerBlock));
_GistPackMsfp15EncoderKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
@ -492,9 +492,9 @@ void GistPackMsfp15DecoderImpl(
const size_t tile_size) {
assert(axis_size % tile_size == 0);
const int num_tiles = axis_size / tile_size;
const int num_tiles = static_cast<int>(axis_size / tile_size);
const int threads = pre_axis_size * num_tiles;
const int threads = static_cast<int>(pre_axis_size * num_tiles);
int blocksPerGrid = (int)(ceil(static_cast<float>(threads) / GridDim::maxThreadsPerBlock));
_GistPackMsfp15DecoderKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(