diff --git a/orttraining/orttraining/core/optimizer/gist_encode_decode.cc b/orttraining/orttraining/core/optimizer/gist_encode_decode.cc index 7eed7486b6..8baea5a1ad 100644 --- a/orttraining/orttraining/core/optimizer/gist_encode_decode.cc +++ b/orttraining/orttraining/core/optimizer/gist_encode_decode.cc @@ -188,9 +188,9 @@ std::vector GistEncodeDecode::TargetOpTypes() const noexcept { Status GistEncodeDecode::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const { if (node.Description() != "Backward pass") { - if (GistEncodeDecode::AddEncodeDecode(graph, node, compression_type, logger)) { + if (GistEncodeDecode::AddEncodeDecode(graph, node, compression_type_, logger)) { LOGS(logger, INFO) << "Gist applied to node name - " << node.Name() << ", node type - " - << node.OpType() << ", of compr type - " << compression_type; + << node.OpType() << ", of compr type - " << compression_type_; rule_effect = RewriteRuleEffect::kModifiedRestOfGraph; } } diff --git a/orttraining/orttraining/core/optimizer/gist_encode_decode.h b/orttraining/orttraining/core/optimizer/gist_encode_decode.h index 84c4ccd8cf..8b087c151d 100644 --- a/orttraining/orttraining/core/optimizer/gist_encode_decode.h +++ b/orttraining/orttraining/core/optimizer/gist_encode_decode.h @@ -13,7 +13,6 @@ namespace onnxruntime { class GistEncodeDecode : public RewriteRule { public: int operator_type; - std::string compression_type; static constexpr const char* GIST_PAIR_NODE_NAME_BASE = "gist"; @@ -34,7 +33,7 @@ class GistEncodeDecode : public RewriteRule { {"Relu", {"ReluGrad", "Shape", "Reshape"}}}; GistEncodeDecode() noexcept : RewriteRule("GistEncodeDecode") {} - GistEncodeDecode(int op_type, std::string compr_type) noexcept : RewriteRule("GistEncodeDecode"), operator_type(op_type), compression_type(std::move(compr_type)) {} + GistEncodeDecode(int op_type, std::string compr_type) noexcept : RewriteRule("GistEncodeDecode"), operator_type(op_type), compression_type_(std::move(compr_type)) {} private: int GenerateDecodePriority() const { return priority_generator_--; }; @@ -42,6 +41,8 @@ class GistEncodeDecode : public RewriteRule { bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; bool AddEncodeDecode(Graph& graph, Node& curr_node, std::string compression_type, const logging::Logger& logger) const; + + const std::string compression_type_; }; } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/gist/gist.cc b/orttraining/orttraining/training_ops/cuda/gist/gist.cc index 557487aef2..1b53081792 100644 --- a/orttraining/orttraining/training_ops/cuda/gist/gist.cc +++ b/orttraining/orttraining/training_ops/cuda/gist/gist.cc @@ -87,7 +87,7 @@ Status GistPack1EncoderOp::ComputeInternal(OpKernelContext* context) const { const Tensor* X = context->Input(0); ORT_RETURN_IF(X == nullptr, "X input is unavailable"); - long int n = (X->Shape().Size() + GIST_PACK1_FACTOR - 1) / GIST_PACK1_FACTOR; + long n = static_cast((X->Shape().Size() + GIST_PACK1_FACTOR - 1) / GIST_PACK1_FACTOR); Tensor* Y = context->Output(0, TensorShape({n})); typedef typename ToCudaType::MappedType CudaT; GistPack1EncoderImpl( diff --git a/orttraining/orttraining/training_ops/cuda/gist/gist_impl.cu b/orttraining/orttraining/training_ops/cuda/gist/gist_impl.cu index 108e490fa2..6c301347a2 100644 --- a/orttraining/orttraining/training_ops/cuda/gist/gist_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/gist/gist_impl.cu @@ -352,7 +352,7 @@ __global__ void _GistPackMsfp15DecoderKernel( output_data[in_i] = 0.0; } else { // Find leading 1 - uint8_t leading_bit_pos = floor(log2f(mantissa)); + uint8_t leading_bit_pos = floorf(log2f(mantissa)); // Difference from shared exponent of this value int exp_diff = 5 - leading_bit_pos; // Adjust exponent @@ -466,9 +466,9 @@ void GistPackMsfp15EncoderImpl( const size_t tile_size) { assert(axis_size % tile_size == 0); - const int num_tiles = axis_size / tile_size; + const int num_tiles = static_cast(axis_size / tile_size); - const int threads = pre_axis_size * num_tiles; + const int threads = static_cast(pre_axis_size * num_tiles); int blocksPerGrid = (int)(ceil(static_cast(threads) / GridDim::maxThreadsPerBlock)); _GistPackMsfp15EncoderKernel<<>>( @@ -492,9 +492,9 @@ void GistPackMsfp15DecoderImpl( const size_t tile_size) { assert(axis_size % tile_size == 0); - const int num_tiles = axis_size / tile_size; + const int num_tiles = static_cast(axis_size / tile_size); - const int threads = pre_axis_size * num_tiles; + const int threads = static_cast(pre_axis_size * num_tiles); int blocksPerGrid = (int)(ceil(static_cast(threads) / GridDim::maxThreadsPerBlock)); _GistPackMsfp15DecoderKernel<<>>(