pqc-accelerate/HLS_Codes_Dilithium/pm_test.cpp

94 lines
3 KiB
C++
Raw Normal View History

// pm_test_dilithium.cpp C-sim / C-synth testbench for Dilithium polymult IP
#include <iostream>
#include <hls_stream.h>
#include "ntt.h"
#include "test_case.h"
int main()
{
// Top-level AXI4-Stream ports for the DUT
hls::stream<coeff_axis_t> in_data;
hls::stream<coeff_axis_t> out_data;
coeff_axis_t local_stream;
coeff_t actual_outputs[DILITHIUM_N];
int i;
// -------------------------------------------------------------------------
// Write stimulus into input AXI4-Stream
//
// Protocol for poly_mult_dilithium():
// - first DILITHIUM_N words: a[0..N-1]
// - next DILITHIUM_N words: b[0..N-1]
// TLAST = 1 only on the very last word (b[N-1])
// -------------------------------------------------------------------------
// Send polynomial a
for (i = 0; i < DILITHIUM_N; i++) {
coeff_t val = input1_vals[i];
local_stream.data = (ap_int<32>)val;
local_stream.keep = -1;
local_stream.strb = -1;
local_stream.last = 0; // not last yet
in_data.write(local_stream);
}
// Send polynomial b
for (i = 0; i < DILITHIUM_N; i++) {
coeff_t val = input2_vals[i];
local_stream.data = (ap_int<32>)val;
local_stream.keep = -1;
local_stream.strb = -1;
local_stream.last = (i == DILITHIUM_N - 1) ? 1 : 0;
in_data.write(local_stream);
}
// -------------------------------------------------------------------------
// Call DUT
// -------------------------------------------------------------------------
poly_mult_dilithium(in_data, out_data);
// -------------------------------------------------------------------------
// Read result from output AXI4-Stream
//
// The core outputs exactly DILITHIUM_N coefficients; TLAST is 1 on the
// final beat. We read all N coefficients into actual_outputs[].
// -------------------------------------------------------------------------
for (i = 0; i < DILITHIUM_N; i++) {
coeff_axis_t out_word = out_data.read();
actual_outputs[i] = (coeff_t)out_word.data;
// Optional: you can sanity-check TLAST here:
// if (i == DILITHIUM_N - 1 && out_word.last != 1) ...
}
// -------------------------------------------------------------------------
// Compare against golden output
// -------------------------------------------------------------------------
int ret_val = 0;
for (i = 0; i < DILITHIUM_N; i++) {
if (output_vals[i] != actual_outputs[i]) {
ret_val++;
std::cout << "Mismatch at index " << i
<< ": got " << (long long)actual_outputs[i]
<< ", expected " << (long long)output_vals[i]
<< std::endl;
break; // stop at first mismatch (like your Kyber testbench)
}
}
if (ret_val == 0) {
std::cout << "All " << DILITHIUM_N
<< " coefficients match golden output." << std::endl;
}
return ret_val;
}