mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
Print tensor snippet in dumping node Inputs/Outputs to StdOut (#10707)
* dump tensor snippet
This commit is contained in:
parent
a7738b52c5
commit
2fb2dae42f
3 changed files with 42 additions and 68 deletions
|
|
@ -4,7 +4,7 @@
|
|||
#ifdef DEBUG_NODE_INPUTS_OUTPUTS
|
||||
|
||||
#include "core/framework/debug_node_inputs_outputs_utils.h"
|
||||
|
||||
#include "core/framework/print_tensor_utils.h"
|
||||
#include <iomanip>
|
||||
#include <cctype>
|
||||
#include <string>
|
||||
|
|
@ -57,64 +57,9 @@ bool FilterNode(const NodeDumpOptions& dump_options, const Node& node) {
|
|||
match_pattern(node.OpType(), dump_options.filter.op_type_pattern);
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const BFloat16& value) {
|
||||
return out << value.ToFloat();
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const MLFloat16& value) {
|
||||
return out << static_cast<float>(value);
|
||||
}
|
||||
|
||||
void PrintValue(const uint8_t value) {
|
||||
std::cout << static_cast<uint32_t>(value);
|
||||
}
|
||||
|
||||
void PrintValue(const int8_t value) {
|
||||
std::cout << static_cast<int32_t>(value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename std::enable_if<std::is_floating_point<T>::value, void>::type
|
||||
PrintValue(const T& value) {
|
||||
std::cout << std::setprecision(8) << value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename std::enable_if<!std::is_floating_point<T>::value, void>::type
|
||||
PrintValue(const T& value) {
|
||||
std::cout << value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DumpTensorToStdOut(const Tensor& tensor) {
|
||||
const auto& shape = tensor.Shape();
|
||||
auto num_items = shape.Size();
|
||||
|
||||
if (num_items == 0) {
|
||||
std::cout << "no data";
|
||||
return;
|
||||
}
|
||||
|
||||
size_t num_dims = shape.NumDimensions();
|
||||
size_t num_rows = 1;
|
||||
if (num_dims > 1) {
|
||||
num_rows = static_cast<size_t>(shape[0]);
|
||||
}
|
||||
|
||||
size_t row_size = num_items / num_rows;
|
||||
|
||||
auto data = tensor.DataAsSpan<T>();
|
||||
|
||||
for (size_t row = 0; row < num_rows; ++row) {
|
||||
PrintValue(data[row * row_size]);
|
||||
for (size_t i = 1; i < row_size; ++i) {
|
||||
std::cout << ", ";
|
||||
PrintValue(data[row * row_size + i]);
|
||||
}
|
||||
std::cout << "\n";
|
||||
}
|
||||
|
||||
std::cout << std::endl;
|
||||
void DumpTensorToStdOut(const Tensor& tensor, const NodeDumpOptions& dump_options) {
|
||||
onnxruntime::utils::PrintCpuTensor<T>(tensor, dump_options.snippet_threshold, dump_options.snippet_edge_items);
|
||||
}
|
||||
|
||||
PathString MakeTensorFileName(const std::string& tensor_name, const NodeDumpOptions& dump_options) {
|
||||
|
|
@ -349,7 +294,7 @@ void DumpCpuTensor(
|
|||
const Tensor& tensor, const TensorMetadata& tensor_metadata) {
|
||||
switch (dump_options.data_destination) {
|
||||
case NodeDumpOptions::DataDestination::StdOut: {
|
||||
DispatchOnTensorType(tensor.DataType(), DumpTensorToStdOut, tensor);
|
||||
DispatchOnTensorType(tensor.DataType(), DumpTensorToStdOut, tensor, dump_options);
|
||||
break;
|
||||
}
|
||||
case NodeDumpOptions::DataDestination::TensorProtoFiles: {
|
||||
|
|
@ -446,6 +391,10 @@ const NodeDumpOptions& NodeDumpOptionsFromEnvironmentVariables() {
|
|||
ORT_THROW("Unsupported data destination type: ", destination);
|
||||
}
|
||||
|
||||
// Snippet options for StdOut
|
||||
opts.snippet_threshold = ParseEnvironmentVariableWithDefault<int>(env_vars::kSnippetThreshold, kDefaultSnippetThreshold);
|
||||
opts.snippet_edge_items = ParseEnvironmentVariableWithDefault<int>(env_vars::kSnippetEdgeItems, kDefaultSnippetEdgeItems);
|
||||
|
||||
if (ParseEnvironmentVariableWithDefault<bool>(env_vars::kAppendRankToFileName, false)) {
|
||||
std::string rank = Env::Default().GetEnvironmentVar("OMPI_COMM_WORLD_RANK");
|
||||
if (rank.empty()) {
|
||||
|
|
|
|||
|
|
@ -53,6 +53,13 @@ constexpr const char* kSqliteDbPrefix = "ORT_DEBUG_NODE_IO_SQLITE_DB_PREFIX";
|
|||
// set to non-zero to confirm that dumping data files for all nodes is acceptable
|
||||
constexpr const char* kDumpingDataToFilesForAllNodesIsOk =
|
||||
"ORT_DEBUG_NODE_IO_DUMPING_DATA_TO_FILES_FOR_ALL_NODES_IS_OK";
|
||||
|
||||
// Total number of elements which trigger snippet rather than full dump (default 200). Value 0 disables snippet.
|
||||
constexpr const char* kSnippetThreshold = "ORT_DEBUG_NODE_IO_SNIPPET_THRESHOLD";
|
||||
// Number of array items in snippet at beginning and end of each dimension (default 3)
|
||||
constexpr const char* kSnippetEdgeItems = "ORT_DEBUG_NODE_IO_SNIPPET_EDGE_ITEMS";
|
||||
|
||||
|
||||
} // namespace debug_node_inputs_outputs_env_vars
|
||||
|
||||
constexpr char kFilterPatternDelimiter = ';';
|
||||
|
|
@ -102,6 +109,12 @@ struct NodeDumpOptions {
|
|||
Path output_dir;
|
||||
// the sqlite3 db to append dumped data
|
||||
Path sqlite_db_prefix;
|
||||
|
||||
// Total number of elements which trigger snippet rather than full array for Stdout. Value 0 disables snippet.
|
||||
int snippet_threshold;
|
||||
|
||||
// Number of array items in snippet at beginning and end of each dimension for Stdout.
|
||||
int snippet_edge_items;
|
||||
};
|
||||
|
||||
struct NodeDumpContext {
|
||||
|
|
@ -145,4 +158,4 @@ void DumpNodeOutputs(
|
|||
} // namespace utils
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
|
@ -30,19 +30,31 @@ constexpr int64_t kDefaultSnippetThreshold = 200;
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void PrintValue(const T& value) {
|
||||
inline void PrintValue(const T& value) {
|
||||
if (std::is_floating_point<T>::value)
|
||||
std::cout << std::setprecision(8) << value;
|
||||
else
|
||||
std::cout << value;
|
||||
}
|
||||
|
||||
// Explicit specialization for half
|
||||
template <> void PrintValue(const MLFloat16& value) {
|
||||
std::cout << std::setprecision(8) << (float)value;
|
||||
// Explicit specialization
|
||||
template <> inline void PrintValue(const MLFloat16& value) {
|
||||
std::cout << std::setprecision(8) << value.ToFloat();
|
||||
}
|
||||
|
||||
// Print 2D tensor snippet
|
||||
template <> inline void PrintValue(const BFloat16& value) {
|
||||
std::cout << std::setprecision(8) << value.ToFloat();
|
||||
}
|
||||
|
||||
template <> inline void PrintValue(const uint8_t& value) {
|
||||
std::cout << static_cast<uint32_t>(value);
|
||||
}
|
||||
|
||||
template <> inline void PrintValue(const int8_t& value) {
|
||||
std::cout << static_cast<int32_t>(value);
|
||||
}
|
||||
|
||||
// Print snippet of 2D tensor with shape (dim0, dim1)
|
||||
template <typename T>
|
||||
void PrintCpuTensorSnippet(const T* tensor, int64_t dim0, int64_t dim1, int64_t edge_items) {
|
||||
for (int64_t i = 0; i < dim0; i++) {
|
||||
|
|
@ -58,7 +70,7 @@ void PrintCpuTensorSnippet(const T* tensor, int64_t dim0, int64_t dim1, int64_t
|
|||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
// Print 3D tensor
|
||||
// Print snippet of 3D tensor with shape (dim0, dim1, dim2)
|
||||
template <typename T>
|
||||
void PrintCpuTensorSnippet(const T* tensor, int64_t dim0, int64_t dim1, int64_t dim2, int64_t edge_items) {
|
||||
for (int64_t i = 0; i < dim0; i++) {
|
||||
|
|
@ -66,7 +78,7 @@ void PrintCpuTensorSnippet(const T* tensor, int64_t dim0, int64_t dim1, int64_t
|
|||
for (int64_t j = 0; j < dim1; j++) {
|
||||
SKIP_NON_EDGE_ITEMS(dim1, j, edge_items);
|
||||
PrintValue(tensor[i * dim1 * dim2 + j * dim2]);
|
||||
for (int64_t k = 0; k < dim2; k++) {
|
||||
for (int64_t k = 1; k < dim2; k++) {
|
||||
SKIP_NON_EDGE_ITEMS_LAST_DIM(dim2, k, edge_items);
|
||||
std::cout << ", ";
|
||||
PrintValue(tensor[i * dim1 * dim2 + j * dim2 + k]);
|
||||
|
|
@ -98,7 +110,7 @@ void PrintCpuTensorFull(const T* tensor, int64_t dim0, int64_t dim1, int64_t dim
|
|||
for (int64_t i = 0; i < dim0; i++) {
|
||||
for (int64_t j = 0; j < dim1; j++) {
|
||||
PrintValue(tensor[i * dim1 * dim2 + j * dim2]);
|
||||
for (int64_t k = 0; k < dim2; k++) {
|
||||
for (int64_t k = 1; k < dim2; k++) {
|
||||
std::cout << ", ";
|
||||
PrintValue(tensor[i * dim1 * dim2 + j * dim2 + k]);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue