mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
Fix compiler warnings treated as errors in GistEncodeDecode. (#7568)
* Fix compiler warning in GistEncodeDecode. * Fix other use of member variable. * Make `compression_type_` const. * Change floor to floorf in CUDA code. * Statically cast size_t to int in GIST CUDA kernels * Add explicit cast to `long` in gist.cc Co-authored-by: Derek Murray <demurra@microsoft.com>
This commit is contained in:
parent
91985ab03d
commit
94c97ac8c2
4 changed files with 11 additions and 10 deletions
|
|
@ -188,9 +188,9 @@ std::vector<std::string> GistEncodeDecode::TargetOpTypes() const noexcept {
|
|||
|
||||
Status GistEncodeDecode::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const {
|
||||
if (node.Description() != "Backward pass") {
|
||||
if (GistEncodeDecode::AddEncodeDecode(graph, node, compression_type, logger)) {
|
||||
if (GistEncodeDecode::AddEncodeDecode(graph, node, compression_type_, logger)) {
|
||||
LOGS(logger, INFO) << "Gist applied to node name - " << node.Name() << ", node type - "
|
||||
<< node.OpType() << ", of compr type - " << compression_type;
|
||||
<< node.OpType() << ", of compr type - " << compression_type_;
|
||||
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ namespace onnxruntime {
|
|||
class GistEncodeDecode : public RewriteRule {
|
||||
public:
|
||||
int operator_type;
|
||||
std::string compression_type;
|
||||
|
||||
static constexpr const char* GIST_PAIR_NODE_NAME_BASE = "gist";
|
||||
|
||||
|
|
@ -34,7 +33,7 @@ class GistEncodeDecode : public RewriteRule {
|
|||
{"Relu", {"ReluGrad", "Shape", "Reshape"}}};
|
||||
|
||||
GistEncodeDecode() noexcept : RewriteRule("GistEncodeDecode") {}
|
||||
GistEncodeDecode(int op_type, std::string compr_type) noexcept : RewriteRule("GistEncodeDecode"), operator_type(op_type), compression_type(std::move(compr_type)) {}
|
||||
GistEncodeDecode(int op_type, std::string compr_type) noexcept : RewriteRule("GistEncodeDecode"), operator_type(op_type), compression_type_(std::move(compr_type)) {}
|
||||
|
||||
private:
|
||||
int GenerateDecodePriority() const { return priority_generator_--; };
|
||||
|
|
@ -42,6 +41,8 @@ class GistEncodeDecode : public RewriteRule {
|
|||
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
|
||||
bool AddEncodeDecode(Graph& graph, Node& curr_node, std::string compression_type, const logging::Logger& logger) const;
|
||||
|
||||
const std::string compression_type_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ Status GistPack1EncoderOp<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
const Tensor* X = context->Input<Tensor>(0);
|
||||
ORT_RETURN_IF(X == nullptr, "X input is unavailable");
|
||||
|
||||
long int n = (X->Shape().Size() + GIST_PACK1_FACTOR - 1) / GIST_PACK1_FACTOR;
|
||||
long n = static_cast<long>((X->Shape().Size() + GIST_PACK1_FACTOR - 1) / GIST_PACK1_FACTOR);
|
||||
Tensor* Y = context->Output(0, TensorShape({n}));
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
GistPack1EncoderImpl<CudaT>(
|
||||
|
|
|
|||
|
|
@ -352,7 +352,7 @@ __global__ void _GistPackMsfp15DecoderKernel(
|
|||
output_data[in_i] = 0.0;
|
||||
} else {
|
||||
// Find leading 1
|
||||
uint8_t leading_bit_pos = floor(log2f(mantissa));
|
||||
uint8_t leading_bit_pos = floorf(log2f(mantissa));
|
||||
// Difference from shared exponent of this value
|
||||
int exp_diff = 5 - leading_bit_pos;
|
||||
// Adjust exponent
|
||||
|
|
@ -466,9 +466,9 @@ void GistPackMsfp15EncoderImpl(
|
|||
const size_t tile_size) {
|
||||
|
||||
assert(axis_size % tile_size == 0);
|
||||
const int num_tiles = axis_size / tile_size;
|
||||
const int num_tiles = static_cast<int>(axis_size / tile_size);
|
||||
|
||||
const int threads = pre_axis_size * num_tiles;
|
||||
const int threads = static_cast<int>(pre_axis_size * num_tiles);
|
||||
|
||||
int blocksPerGrid = (int)(ceil(static_cast<float>(threads) / GridDim::maxThreadsPerBlock));
|
||||
_GistPackMsfp15EncoderKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
|
|
@ -492,9 +492,9 @@ void GistPackMsfp15DecoderImpl(
|
|||
const size_t tile_size) {
|
||||
|
||||
assert(axis_size % tile_size == 0);
|
||||
const int num_tiles = axis_size / tile_size;
|
||||
const int num_tiles = static_cast<int>(axis_size / tile_size);
|
||||
|
||||
const int threads = pre_axis_size * num_tiles;
|
||||
const int threads = static_cast<int>(pre_axis_size * num_tiles);
|
||||
|
||||
int blocksPerGrid = (int)(ceil(static_cast<float>(threads) / GridDim::maxThreadsPerBlock));
|
||||
_GistPackMsfp15DecoderKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
|
|
|
|||
Loading…
Reference in a new issue