// pm_test_dilithium.cpp – C-sim / C-synth testbench for Dilithium polymult IP #include #include #include "ntt.h" #include "test_case.h" int main() { // Top-level AXI4-Stream ports for the DUT hls::stream in_data; hls::stream out_data; coeff_axis_t local_stream; coeff_t actual_outputs[DILITHIUM_N]; int i; // ------------------------------------------------------------------------- // Write stimulus into input AXI4-Stream // // Protocol for poly_mult_dilithium(): // - first DILITHIUM_N words: a[0..N-1] // - next DILITHIUM_N words: b[0..N-1] // TLAST = 1 only on the very last word (b[N-1]) // ------------------------------------------------------------------------- // Send polynomial a for (i = 0; i < DILITHIUM_N; i++) { coeff_t val = input1_vals[i]; local_stream.data = (ap_int<32>)val; local_stream.keep = -1; local_stream.strb = -1; local_stream.last = 0; // not last yet in_data.write(local_stream); } // Send polynomial b for (i = 0; i < DILITHIUM_N; i++) { coeff_t val = input2_vals[i]; local_stream.data = (ap_int<32>)val; local_stream.keep = -1; local_stream.strb = -1; local_stream.last = (i == DILITHIUM_N - 1) ? 1 : 0; in_data.write(local_stream); } // ------------------------------------------------------------------------- // Call DUT // ------------------------------------------------------------------------- poly_mult_dilithium(in_data, out_data); // ------------------------------------------------------------------------- // Read result from output AXI4-Stream // // The core outputs exactly DILITHIUM_N coefficients; TLAST is 1 on the // final beat. We read all N coefficients into actual_outputs[]. // ------------------------------------------------------------------------- for (i = 0; i < DILITHIUM_N; i++) { coeff_axis_t out_word = out_data.read(); actual_outputs[i] = (coeff_t)out_word.data; // Optional: you can sanity-check TLAST here: // if (i == DILITHIUM_N - 1 && out_word.last != 1) ... } // ------------------------------------------------------------------------- // Compare against golden output // ------------------------------------------------------------------------- int ret_val = 0; for (i = 0; i < DILITHIUM_N; i++) { if (output_vals[i] != actual_outputs[i]) { ret_val++; std::cout << "Mismatch at index " << i << ": got " << (long long)actual_outputs[i] << ", expected " << (long long)output_vals[i] << std::endl; break; // stop at first mismatch (like your Kyber testbench) } } if (ret_val == 0) { std::cout << "All " << DILITHIUM_N << " coefficients match golden output." << std::endl; } return ret_val; }