pytorch/caffe2/operators/ctc_beam_search_decoder_op.cc
Duke Vijitbenjaronk d684112ec9 Output sequence probability with CTC beam search, optional multiple output sequences (#21927)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21927

Add `OUTPUT_PROB` output to CTCBeamSearchDecoderOp to return a probability for each sequence.

Add argument to output top-k instead of top-1 decoded sequences.

Reviewed By: SuperIRabbit

Differential Revision: D15797371

fbshipit-source-id: 737ca5cc4f90a0bcc3660ac9f58519a175977b69
2019-07-02 17:29:13 -07:00

188 lines
6.4 KiB
C++

#include "caffe2/operators/ctc_beam_search_decoder_op.h"
namespace caffe2 {
namespace {
const float* getTensorDataPtr(const Tensor& tensor, int t, int n) {
const auto dims = tensor.sizes();
CAFFE_ENFORCE_EQ(dims.size(), 3);
int offset = (t * dims[1] + n) * dims[2];
CAFFE_ENFORCE_LT(offset, tensor.numel());
return tensor.template data<float>() + offset;
}
} // namespace
template <>
bool CTCBeamSearchDecoderOp<CPUContext>::RunOnDevice() {
// shape: max_activation_length x batch_size x alphabet_size
auto& inputs = Input(INPUTS);
// shape: batch_size
// shape: sum over all decoded_length
const auto inputs_dims = inputs.sizes();
int32_t max_activation_length = inputs_dims[0];
int32_t batch_size = inputs_dims[1];
int32_t alphabet_size = inputs_dims[2];
// [batch_size]
const int* seq_len_data =
(InputSize() == 2) ? Input(SEQ_LEN).data<int>() : nullptr;
vector<int32_t> values_cache;
const int total_candidates = batch_size * num_candidates_;
auto* output_len =
Output(OUTPUT_LEN, vector<int64_t>{total_candidates}, at::dtype<int>());
int* output_len_data = output_len->mutable_data<int>();
memset(output_len_data, 0, total_candidates * sizeof(int));
auto* output_prob = Output(
OUTPUT_PROB, vector<int64_t>{total_candidates}, at::dtype<float>());
float* output_prob_data = output_prob->mutable_data<float>();
memset(output_prob_data, 0, total_candidates * sizeof(float));
for (int32_t i = 0; i < batch_size; ++i) {
const int32_t activation_length =
(seq_len_data) ? seq_len_data[i] : max_activation_length;
std::multimap<float, vector<int32_t>, std::greater<float>> A_next_inv;
// For a given time step, Pb maps prefixes to the probability of all
// candidate sequences that end in a blank and Pnb maps prefixes to the
// probability of all candidate sequences that don't end in a blank.
vector<std::map<vector<int32_t>, float>> Pb(
activation_length + 1, std::map<vector<int32_t>, float>());
vector<std::map<vector<int32_t>, float>> Pnb(
activation_length + 1, std::map<vector<int32_t>, float>());
set<vector<int32_t>> A_prev;
Pb[0][vector<int32_t>()] = 1;
Pnb[0][vector<int32_t>()] = 0;
A_prev.insert(vector<int32_t>());
for (int t = 0; t < activation_length; t++) {
const float* ctc = getTensorDataPtr(inputs, t, i);
vector<int32_t> pruned_alpha;
for (int32_t c = 0; c < alphabet_size; c++) {
if (ctc[c] > prune_threshold_) {
pruned_alpha.push_back(c);
}
}
// If the pruned alphabet is empty, don't use pruning.
if (pruned_alpha.size() == 0) {
pruned_alpha = vector<int32_t>(alphabet_size);
std::iota(pruned_alpha.begin(), pruned_alpha.end(), 0);
}
for (auto const& l : A_prev) {
// We skip the code handling the end character from the article since
// our system does not support an end character.
for (auto const c : pruned_alpha) {
// Assumption: blank character always mapped to index 0
if (c == 0) {
Pb[t + 1][l] += ctc[c] * (Pb[t][l] + Pnb[t][l]);
} else {
vector<int32_t> l_plus = vector<int32_t>(l);
l_plus.push_back(c);
if (l.size() > 0 && c == l.back()) {
Pnb[t + 1][l_plus] += ctc[c] * Pb[t][l];
Pnb[t + 1][l] += ctc[c] * Pnb[t][l];
} else {
Pnb[t + 1][l_plus] += ctc[c] * (Pb[t][l] + Pnb[t][l]);
}
if (A_prev.find(l_plus) == A_prev.end()) {
Pb[t + 1][l_plus] += ctc[0] * (Pb[t][l_plus] + Pnb[t][l_plus]);
Pnb[t + 1][l_plus] += ctc[c] * Pnb[t][l_plus];
}
}
}
}
std::map<vector<int32_t>, float> A_next(Pb[t + 1]);
for (const auto& it : Pnb[t + 1]) {
A_next[it.first] += it.second;
}
A_next_inv.clear();
for (const auto& it : A_next) {
A_next_inv.insert({it.second, it.first});
}
A_prev.clear();
auto it = A_next_inv.begin();
for (int j = 0; j < beam_width_; j++) {
if (it == A_next_inv.end()) {
break;
}
A_prev.insert(it->second);
it++;
}
}
auto it = A_next_inv.begin();
for (int index = 0; index < num_candidates_; index++, it++) {
if (it == A_next_inv.end()) {
break;
}
auto& candidate = it->second;
output_len_data[i * num_candidates_ + index] = candidate.size();
output_prob_data[i * num_candidates_ + index] =
Pb.back()[candidate] + Pnb.back()[candidate];
values_cache.insert(
values_cache.end(), candidate.begin(), candidate.end());
}
}
int32_t values_cache_size = values_cache.size();
auto* values =
Output(VALUES, vector<int64_t>{values_cache_size}, at::dtype<int>());
int* values_data = values->mutable_data<int>();
for (int i = 0; i < values_cache.size(); ++i) {
values_data[i] = values_cache.at(i);
}
values_cache.clear();
return true;
}
REGISTER_CPU_OPERATOR(CTCBeamSearchDecoder, CTCBeamSearchDecoderOp<CPUContext>);
OPERATOR_SCHEMA(CTCBeamSearchDecoder)
.NumInputs(1, 2)
.NumOutputs(2, 3)
.SetDoc(
"Prefix beam search decoder for connectionist temporal classification.")
.Arg(
"beam_width",
"Maximum number of candidates to carry over to next activation step.")
.Arg(
"prune_threshold",
"Probability threshold below which outputs are ignored.")
.Input(
0,
"INPUTS",
"3D float Tensor sized [max_activation_length, batch_size, alphabet_size] "
"of network logits (before softmax application).")
.Input(
1,
"SEQ_LEN",
"(optional) 1D int vector containing sequence lengths, "
"having size [batch_size] "
"seq_len will be set to max_time if not provided.")
.Output(
0,
"OUTPUT_LEN",
"Output_len matrix size (batch_size * num_candidates). "
"Each index stores lengths of candidates for its corresponding batch item.")
.Output(
1,
"VALUES",
"Values vector, size (total_decoded_outputs). "
"The flattened vector of final output sequences, in batch order.")
.Output(
2,
"OUTPUT_PROB",
"Probability vector, size (total_decoded_outputs). "
"Each index stores final output probability of its corresponding batch item.")
.InheritOnnxSchema();
SHOULD_NOT_DO_GRADIENT(CTCBeamSearchDecoder);
} // namespace caffe2