mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-23 02:38:28 +00:00
add debug code
This commit is contained in:
parent
c89a798b73
commit
3c3103e5df
5 changed files with 29 additions and 4 deletions
|
|
@ -8,7 +8,7 @@
|
|||
#include <gsl/gsl>
|
||||
#include "core/framework/allocator.h"
|
||||
#include "core/framework/ort_value.h"
|
||||
#include "contrib_ops/cpu/utils/debug_macros.h"
|
||||
#include "contrib_ops/cpu/utils/console_dumper.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
#pragma once
|
||||
#include "core/common/make_string.h"
|
||||
|
||||
// #define DEBUG_GENERATION 1 // uncomment it for debugging generation (like beam search etc)
|
||||
#define DEBUG_GENERATION 1 // uncomment it for debugging generation (like beam search etc)
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
#define DUMP_TENSOR_LEVEL 2
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
// cub.cuh includes device/dispatch_radix_sort.cuh which has assignment in conditional expressions
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 4706)
|
||||
#pragma warning(disable : 4706)
|
||||
#endif
|
||||
#include <cub/cub.cuh>
|
||||
#if defined(_MSC_VER)
|
||||
|
|
@ -406,6 +406,9 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu,
|
|||
beam_hyp.done_ = true;
|
||||
if (atomicAdd(&state.not_done_count_, -1) == 1)
|
||||
state_cpu.not_done_count_ = 0; // Update the CPU side
|
||||
|
||||
printf("\n --- BeamSearchScorer_Process updated cpu state for batch %d\n", threadIdx.x);
|
||||
state_cpu.Print();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include <stdint.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <curand_kernel.h>
|
||||
#include <cstdio>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -82,8 +83,21 @@ struct BeamScorerState {
|
|||
int eos_token_id_;
|
||||
bool early_stopping_;
|
||||
int not_done_count_; // When zero, every batch entry is done (starts at batch_size_)
|
||||
|
||||
int hypothesis_buffer_used_; // Offset of available buffer, or length of used buffer.
|
||||
|
||||
// Function to dump the struct data to stdout
|
||||
__host__ __device__ void Print() const {
|
||||
printf("BeamScorerState Dump:\n");
|
||||
printf(" batch_size_: %d\n", batch_size_);
|
||||
printf(" num_beams_: %d\n", num_beams_);
|
||||
printf(" max_length_: %d\n", max_length_);
|
||||
printf(" num_return_sequences_: %d\n", num_return_sequences_);
|
||||
printf(" pad_token_id_: %d\n", pad_token_id_);
|
||||
printf(" eos_token_id_: %d\n", eos_token_id_);
|
||||
printf(" early_stopping_: %s\n", early_stopping_ ? "true" : "false");
|
||||
printf(" not_done_count_: %d\n", not_done_count_);
|
||||
printf(" hypothesis_buffer_used_: %d\n", hypothesis_buffer_used_);
|
||||
}
|
||||
};
|
||||
|
||||
void LaunchInitializeBeamHypotheses(gsl::span<BeamHypotheses> beam_hyps, float length_penalty, gsl::span<HypothesisScore> beams, int num_beams, cudaStream_t stream);
|
||||
|
|
|
|||
|
|
@ -722,6 +722,9 @@ void CudaBeamSearchScorer::Process(transformers::ISequences& sequences,
|
|||
gsl::span<const float>& next_scores,
|
||||
gsl::span<const int32_t>& next_tokens,
|
||||
gsl::span<const int32_t>& next_indices) {
|
||||
printf("\n---Process ---\n");
|
||||
state_cpu_->Print();
|
||||
|
||||
cuda::LaunchBeamSearchScorer_Process(*state_cpu_,
|
||||
*state_gpu_,
|
||||
sequences.GetCurrentDeviceSequences(),
|
||||
|
|
@ -735,6 +738,7 @@ void CudaBeamSearchScorer::Process(transformers::ISequences& sequences,
|
|||
next_tokens,
|
||||
next_indices,
|
||||
stream_);
|
||||
|
||||
CUDA_CALL_THROW(cudaEventRecord(event_process_complete_.Get(), stream_));
|
||||
|
||||
cuda::LaunchBeamSearchScorer_AppendNextTokenToSequences(*state_cpu_,
|
||||
|
|
@ -749,6 +753,10 @@ void CudaBeamSearchScorer::Process(transformers::ISequences& sequences,
|
|||
|
||||
bool CudaBeamSearchScorer::IsDoneLater() const {
|
||||
CUDA_CALL_THROW(cudaEventSynchronize(event_process_complete_.Get()));
|
||||
|
||||
printf("\n---IsDoneLater ---\n");
|
||||
state_cpu_->Print();
|
||||
|
||||
return state_cpu_->not_done_count_ == 0;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue