pqc-accelerate/HLS_Codes_Dilithium/pm_test.cpp

93 lines
3 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// pm_test_dilithium.cpp C-sim / C-synth testbench for Dilithium polymult IP
#include <iostream>
#include <hls_stream.h>
#include "ntt.h"
#include "test_case.h"
int main()
{
// Top-level AXI4-Stream ports for the DUT
hls::stream<coeff_axis_t> in_data;
hls::stream<coeff_axis_t> out_data;
coeff_axis_t local_stream;
coeff_t actual_outputs[DILITHIUM_N];
int i;
// -------------------------------------------------------------------------
// Write stimulus into input AXI4-Stream
//
// Protocol for poly_mult_dilithium():
// - first DILITHIUM_N words: a[0..N-1]
// - next DILITHIUM_N words: b[0..N-1]
// TLAST = 1 only on the very last word (b[N-1])
// -------------------------------------------------------------------------
// Send polynomial a
for (i = 0; i < DILITHIUM_N; i++) {
coeff_t val = input1_vals[i];
local_stream.data = (ap_int<32>)val;
local_stream.keep = -1;
local_stream.strb = -1;
local_stream.last = 0; // not last yet
in_data.write(local_stream);
}
// Send polynomial b
for (i = 0; i < DILITHIUM_N; i++) {
coeff_t val = input2_vals[i];
local_stream.data = (ap_int<32>)val;
local_stream.keep = -1;
local_stream.strb = -1;
local_stream.last = (i == DILITHIUM_N - 1) ? 1 : 0;
in_data.write(local_stream);
}
// -------------------------------------------------------------------------
// Call DUT
// -------------------------------------------------------------------------
poly_mult_dilithium(in_data, out_data);
// -------------------------------------------------------------------------
// Read result from output AXI4-Stream
//
// The core outputs exactly DILITHIUM_N coefficients; TLAST is 1 on the
// final beat. We read all N coefficients into actual_outputs[].
// -------------------------------------------------------------------------
for (i = 0; i < DILITHIUM_N; i++) {
coeff_axis_t out_word = out_data.read();
actual_outputs[i] = (coeff_t)out_word.data;
// Optional: you can sanity-check TLAST here:
// if (i == DILITHIUM_N - 1 && out_word.last != 1) ...
}
// -------------------------------------------------------------------------
// Compare against golden output
// -------------------------------------------------------------------------
int ret_val = 0;
for (i = 0; i < DILITHIUM_N; i++) {
if (output_vals[i] != actual_outputs[i]) {
ret_val++;
std::cout << "Mismatch at index " << i
<< ": got " << (long long)actual_outputs[i]
<< ", expected " << (long long)output_vals[i]
<< std::endl;
break; // stop at first mismatch (like your Kyber testbench)
}
}
if (ret_val == 0) {
std::cout << "All " << DILITHIUM_N
<< " coefficients match golden output." << std::endl;
}
return ret_val;
}