pqc-accelerate/PYNQ-zcu104_Files/Dilithium.ipynb

1566 lines
54 KiB
Text
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "33632a5e",
"metadata": {},
"outputs": [],
"source": [
"# oho A\n",
"# Install and import liboqs-python (software Dilithium)\n",
"# If you don't have liboqs / liboqs-python yet, uncomment the next lines.\n",
"# On a ZCU104 this will compile on ARM; it may take a while.\n",
"#\n",
"#!git clone --depth=1 https://github.com/open-quantum-safe/liboqs-python.git\n",
"#%cd liboqs-python\n",
"#!pip install .\n",
"#%cd -"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "6093010b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Enabled signature mechanisms:\n",
"('Dilithium2', 'Dilithium3', 'Dilithium5', 'ML-DSA-44', 'ML-DSA-65', 'ML-DSA-87', 'Falcon-512', 'Falcon-1024', 'Falcon-padded-512', 'Falcon-padded-1024', 'SPHINCS+-SHA2-128f-simple', 'SPHINCS+-SHA2-128s-simple', 'SPHINCS+-SHA2-192f-simple', 'SPHINCS+-SHA2-192s-simple', 'SPHINCS+-SHA2-256f-simple', 'SPHINCS+-SHA2-256s-simple', 'SPHINCS+-SHAKE-128f-simple', 'SPHINCS+-SHAKE-128s-simple', 'SPHINCS+-SHAKE-192f-simple', 'SPHINCS+-SHAKE-192s-simple', 'SPHINCS+-SHAKE-256f-simple', 'SPHINCS+-SHAKE-256s-simple', 'MAYO-1', 'MAYO-2', 'MAYO-3', 'MAYO-5', 'cross-rsdp-128-balanced', 'cross-rsdp-128-fast', 'cross-rsdp-128-small', 'cross-rsdp-192-balanced', 'cross-rsdp-192-fast', 'cross-rsdp-192-small', 'cross-rsdp-256-balanced', 'cross-rsdp-256-fast', 'cross-rsdp-256-small', 'cross-rsdpg-128-balanced', 'cross-rsdpg-128-fast', 'cross-rsdpg-128-small', 'cross-rsdpg-192-balanced', 'cross-rsdpg-192-fast', 'cross-rsdpg-192-small', 'cross-rsdpg-256-balanced', 'cross-rsdpg-256-fast', 'cross-rsdpg-256-small', 'OV-Is', 'OV-Ip', 'OV-III', 'OV-V', 'OV-Is-pkc', 'OV-Ip-pkc', 'OV-III-pkc', 'OV-V-pkc', 'OV-Is-pkc-skc', 'OV-Ip-pkc-skc', 'OV-III-pkc-skc', 'OV-V-pkc-skc', 'SNOVA_24_5_4', 'SNOVA_24_5_4_SHAKE', 'SNOVA_24_5_4_esk', 'SNOVA_24_5_4_SHAKE_esk', 'SNOVA_37_17_2', 'SNOVA_25_8_3', 'SNOVA_56_25_2', 'SNOVA_49_11_3', 'SNOVA_37_8_4', 'SNOVA_24_5_5', 'SNOVA_60_10_4', 'SNOVA_29_6_5')\n",
"Using algorithm: ML-DSA-44\n",
"Software-only Dilithium verify OK? True\n"
]
}
],
"source": [
"# Cell 1: \n",
"\n",
"\n",
"import numpy as np\n",
"import math\n",
"import time\n",
"from pynq import Overlay\n",
"from pynq import allocate\n",
"from pynq import MMIO\n",
"\n",
"import oqs\n",
"\n",
"# Check which signature mechanisms are available\n",
"print(\"Enabled signature mechanisms:\")\n",
"print(oqs.get_enabled_sig_mechanisms())\n",
"\n",
"# Prefer the ML-DSA name; fall back to legacy Dilithium2 if needed\n",
"if \"ML-DSA-44\" in oqs.get_enabled_sig_mechanisms():\n",
" DILITHIUM_ALG = \"ML-DSA-44\"\n",
"elif \"Dilithium2\" in oqs.get_enabled_sig_mechanisms():\n",
" DILITHIUM_ALG = \"Dilithium2\"\n",
"else:\n",
" raise RuntimeError(\"No Dilithium / ML-DSA-44 implementation enabled in liboqs.\")\n",
"\n",
"print(\"Using algorithm:\", DILITHIUM_ALG)\n",
"\n",
"# Simple software-only keygen / sign / verify\n",
"message = b\"Hello from ZCU104 + Dilithium\"\n",
"\n",
"with oqs.Signature(DILITHIUM_ALG) as signer:\n",
" public_key = signer.generate_keypair()\n",
" secret_key = signer.export_secret_key()\n",
"\n",
" signature = signer.sign(message)\n",
" ok = signer.verify(message, signature, public_key)\n",
"\n",
"print(\"Software-only Dilithium verify OK?\", ok)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d32bcf5f",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "d8913f4a",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 13,
"id": "088eb969",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"total 19388\n",
"drwxr-xr-x 3 root root 4096 Dec 13 15:53 .\n",
"drwxrwxrwx 9 xilinx xilinx 4096 Dec 13 15:57 ..\n",
"-rw-r--r-- 1 root root 19311209 Dec 13 15:53 base.bit\n",
"-rw-r--r-- 1 root root 527952 Dec 13 15:52 base.hwh\n",
"drwxr-xr-x 2 root root 4096 Dec 13 15:23 .ipynb_checkpoints\n",
"Sat Dec 13 03:57:25 PM UTC 2025\n"
]
}
],
"source": [
"#oho B\n",
"!ls -la /home/xilinx/jupyter_notebooks/kyber_ntt_and_dilithium_ntt\n",
"!date\n",
"# Loading the kyber-ntt_and_dilithium_ntt bit file to configure the PL\n",
"ol = Overlay('/home/xilinx/jupyter_notebooks/kyber_ntt_and_dilithium_ntt/base.bit')"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "fc922ce1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dict_keys(['axi_dma_0', 'poly_mult_0', 'poly_mult_dil_0', 'axi_dma_1', 'ps_e_0'])\n",
"poly_mult_dil_0 @ 0x80020000\n",
"send _max_size: 16383\n",
"recv _max_size: 16383\n",
"{'C_DLYTMR_RESOLUTION': '125', 'C_ENABLE_MULTI_CHANNEL': '0', 'C_FAMILY': 'zynquplus', 'C_INCLUDE_MM2S': '1', 'C_INCLUDE_MM2S_DRE': '0', 'C_INCLUDE_MM2S_SF': '1', 'C_INCLUDE_S2MM': '1', 'C_INCLUDE_S2MM_DRE': '0', 'C_INCLUDE_S2MM_SF': '1', 'C_INCLUDE_SG': '0', 'C_INCREASE_THROUGHPUT': '0', 'C_MICRO_DMA': '0', 'C_MM2S_BURST_SIZE': '16', 'C_M_AXIS_MM2S_CNTRL_TDATA_WIDTH': '32', 'C_M_AXIS_MM2S_TDATA_WIDTH': '64', 'C_M_AXI_MM2S_ADDR_WIDTH': '32', 'C_M_AXI_MM2S_DATA_WIDTH': '64', 'C_M_AXI_S2MM_ADDR_WIDTH': '32', 'C_M_AXI_S2MM_DATA_WIDTH': '32', 'C_M_AXI_SG_ADDR_WIDTH': '32', 'C_M_AXI_SG_DATA_WIDTH': '32', 'C_NUM_MM2S_CHANNELS': '1', 'C_NUM_S2MM_CHANNELS': '1', 'C_PRMRY_IS_ACLK_ASYNC': '0', 'C_S2MM_BURST_SIZE': '16', 'C_SG_INCLUDE_STSCNTRL_STRM': '0', 'C_SG_LENGTH_WIDTH': '14', 'C_SG_USE_STSAPP_LENGTH': '0', 'C_S_AXIS_S2MM_STS_TDATA_WIDTH': '32', 'C_S_AXIS_S2MM_TDATA_WIDTH': '32', 'C_S_AXI_LITE_ADDR_WIDTH': '10', 'C_S_AXI_LITE_DATA_WIDTH': '32', 'Component_Name': 'base_axi_dma_1_0', 'c_addr_width': '32', 'c_dlytmr_resolution': '125', 'c_enable_multi_channel': '0', 'c_include_mm2s': '1', 'c_include_mm2s_dre': '0', 'c_include_mm2s_sf': '1', 'c_include_s2mm': '1', 'c_include_s2mm_dre': '0', 'c_include_s2mm_sf': '1', 'c_include_sg': '0', 'c_increase_throughput': '0', 'c_m_axi_mm2s_data_width': '64', 'c_m_axi_s2mm_data_width': '32', 'c_m_axis_mm2s_tdata_width': '64', 'c_micro_dma': '0', 'c_mm2s_burst_size': '16', 'c_num_mm2s_channels': '1', 'c_num_s2mm_channels': '1', 'c_prmry_is_aclk_async': '0', 'c_s2mm_burst_size': '16', 'c_s_axis_s2mm_tdata_width': '32', 'c_sg_include_stscntrl_strm': '0', 'c_sg_length_width': '14', 'c_sg_use_stsapp_length': '0', 'c_single_interface': '0', 'EDK_IPTYPE': 'PERIPHERAL', 'C_BASEADDR': '0x80030000', 'C_HIGHADDR': '0x8003FFFF', 'ADDR_WIDTH': '10', 'ARUSER_WIDTH': '0', 'AWUSER_WIDTH': '0', 'BUSER_WIDTH': '0', 'CLK_DOMAIN': 'base_ps_e_0_0_pl_clk0', 'DATA_WIDTH': '32', 'FREQ_HZ': '99999001', 'HAS_BRESP': '1', 'HAS_BURST': '0', 'HAS_CACHE': '0', 'HAS_LOCK': '0', 'HAS_PROT': '0', 'HAS_QOS': '0', 'HAS_REGION': '0', 'HAS_RRESP': '1', 'HAS_WSTRB': '0', 'ID_WIDTH': '0', 'INSERT_VIP': '0', 'MAX_BURST_LENGTH': '1', 'NUM_READ_OUTSTANDING': '1', 'NUM_READ_THREADS': '1', 'NUM_WRITE_OUTSTANDING': '1', 'NUM_WRITE_THREADS': '1', 'PHASE': '0.0', 'PROTOCOL': 'AXI4LITE', 'READ_WRITE_MODE': 'READ_WRITE', 'RUSER_BITS_PER_BYTE': '0', 'RUSER_WIDTH': '0', 'SUPPORTS_NARROW_BURST': '0', 'WUSER_BITS_PER_BYTE': '0', 'WUSER_WIDTH': '0', 'HAS_TKEEP': '1', 'HAS_TLAST': '1', 'HAS_TREADY': '1', 'HAS_TSTRB': '0', 'LAYERED_METADATA': 'undef', 'TDATA_NUM_BYTES': '8', 'TDEST_WIDTH': '0', 'TID_WIDTH': '0', 'TUSER_WIDTH': '0'}\n"
]
}
],
"source": [
"#oho E\n",
"#oho inquiry for ZCU104 Board\n",
"\n",
"# Discover the real base address from the .hwh\n",
"print(ol.ip_dict.keys()) # just to see exact IP names\n",
"\n",
"poly_name = 'poly_mult_dil_0' # adjust if Vivado named it differently\n",
"poly_info = ol.ip_dict[poly_name]\n",
"print(\"poly_mult_dil_0 @\", hex(poly_info['phys_addr']))\n",
"mmio = MMIO(poly_info['phys_addr'], poly_info['addr_range'])\n",
"\n",
"# oho !!! dma_1 gehört zu dilithium accelerator, not dma_0 !\n",
"dma = ol.axi_dma_1\n",
"dma_send = dma.sendchannel\n",
"dma_recv = dma.recvchannel\n",
"\n",
"\n",
"# 1) What PYNQ thinks the max dma channel size is:\n",
"print(\"send _max_size:\", dma.sendchannel._max_size)\n",
"print(\"recv _max_size:\", dma.recvchannel._max_size)\n",
"\n",
"# 2) versus: Dump the raw IP parameters from the .hwh:\n",
"print(ol.ip_dict['axi_dma_1']['parameters'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dafae649",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "3572bb8e",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 15,
"id": "1db1b137",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# 4. Dilithium ring parameters\n",
"Q = 8380417\n",
"N = 256\n",
"\n",
"def center_mod_q(x):\n",
" \"\"\"Map integer x to centered representative in (-Q/2, Q/2].\"\"\"\n",
" x = int(x) % Q\n",
" if x > Q // 2:\n",
" x -= Q\n",
" return x\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "af55b05f",
"metadata": {},
"outputs": [],
"source": [
"def hw_poly_mult(a, b):\n",
" \"\"\"\n",
" Hardware polynomial multiplication in R_q[X]/(X^N + 1).\n",
"\n",
" a, b: 1D numpy arrays of length N with integer entries (Python ints or np.int32/int64).\n",
" Returns: numpy array length N, centered mod Q.\n",
" \"\"\"\n",
" a = np.asarray(a, dtype=np.int64)\n",
" b = np.asarray(b, dtype=np.int64)\n",
" assert a.shape == (N,)\n",
" assert b.shape == (N,)\n",
"\n",
" # Allocate DMA buffers\n",
" # Input: N 64-bit words, each packing two 32-bit coefficients.\n",
" in_buf = allocate(shape=(N,), dtype=np.uint64)\n",
" # Output: N 32-bit coefficients\n",
" out_buf = allocate(shape=(N,), dtype=np.uint32)\n",
"\n",
" # Pack two 32-bit coeffs into one 64-bit word\n",
" # Convention: low 32 bits = a[i], high 32 bits = b[i]\n",
" for i in range(N):\n",
" a_i = int(a[i] % Q) & 0xFFFFFFFF\n",
" b_i = int(b[i] % Q) & 0xFFFFFFFF\n",
" word = ((b_i << 32) | a_i) # pure Python ints, no NumPy ufuncs\n",
" in_buf[i] = np.uint64(word)\n",
"\n",
"\n",
" \n",
" # oho vorbild kyber \n",
" # Start IP core (ap_start = 1 at control register 0x00)\n",
" mmio.write(0x00, 0x1)\n",
" \n",
" \n",
" \n",
" # Start DMA transfers\n",
" dma.sendchannel.transfer(in_buf)\n",
" dma.recvchannel.transfer(out_buf)\n",
" dma.sendchannel.wait()\n",
" dma.recvchannel.wait()\n",
"\n",
" # Convert back to centered coefficients\n",
" c = np.empty(N, dtype=np.int64)\n",
" for i in range(N):\n",
" c[i] = center_mod_q(out_buf[i])\n",
"\n",
" # Optional: free buffers (PYNQ also cleans them up via GC, but explicit is nice on small boards)\n",
" in_buf.freebuffer()\n",
" out_buf.freebuffer()\n",
"\n",
" return c\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "cb8a287b",
"metadata": {},
"outputs": [],
"source": [
"# Cell 3: Software negacyclic polynomial multiplication in R_q[X]/(X^N + 1)\n",
"\n",
"def sw_poly_mult(a, b):\n",
" \"\"\"\n",
" O(N^2) negacyclic polynomial multiplication:\n",
" c(X) = a(X) * b(X) mod (X^N + 1, Q)\n",
" \"\"\"\n",
" a = np.asarray(a, dtype=np.int64)\n",
" b = np.asarray(b, dtype=np.int64)\n",
" assert a.shape == (N,)\n",
" assert b.shape == (N,)\n",
"\n",
" c = np.zeros(N, dtype=np.int64)\n",
"\n",
" for i in range(N):\n",
" acc = 0\n",
" for j in range(N):\n",
" k = i - j\n",
" if k >= 0:\n",
" acc += int(a[j]) * int(b[k])\n",
" else:\n",
" # Wrap with a minus sign: X^N = -1\n",
" k += N\n",
" acc -= int(a[j]) * int(b[k])\n",
" c[i] = center_mod_q(acc)\n",
"\n",
" return c\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a1e0e420",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 18,
"id": "7dfe972f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Max |c_sw - c_hw| mod Q = 0\n",
"Number of mismatching coefficients: 0\n",
"HW matches SW golden for this test vector.\n"
]
}
],
"source": [
"# Cell 4: Compare HW vs SW polynomial multiplication\n",
"\n",
"from numpy.random import default_rng\n",
"\n",
"rng = default_rng(12345)\n",
"\n",
"# Generate random centered coefficients\n",
"a = rng.integers(low=-Q//2, high=Q//2, size=N, dtype=np.int64)\n",
"b = rng.integers(low=-Q//2, high=Q//2, size=N, dtype=np.int64)\n",
"\n",
"c_sw = sw_poly_mult(a, b)\n",
"\n",
"c_hw = hw_poly_mult(a, b)\n",
"\n",
"# Compare modulo Q (in case of small representation differences)\n",
"diff = (c_sw - c_hw) % Q\n",
"max_diff = int(np.max(np.abs(diff)))\n",
"num_mism = int(np.count_nonzero(diff))\n",
"\n",
"print(\"Max |c_sw - c_hw| mod Q =\", max_diff)\n",
"print(\"Number of mismatching coefficients:\", num_mism)\n",
"\n",
"if num_mism == 0:\n",
" print(\"HW matches SW golden for this test vector.\")\n",
"else:\n",
" print(\"Mismatch: investigate HLS core / packing / reduction.\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fa590582",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "9a7991e6",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 19,
"id": "99023a64",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"=== Software Dilithium using liboqs ===\n",
"Keygen time (SW): 0.001 s\n",
"Sign time (SW): 0.005 s\n",
"Verify time (SW): 0.001 s\n",
"Verification OK? True\n",
"\n",
"=== Hardware-accelerated poly_mult_dil ===\n",
"Software poly_mult time: 0.282400 s\n",
"Hardware poly_mult time (incl. DMA): 0.009862 s\n",
"Mismatch count: 0\n"
]
}
],
"source": [
"# Cell 5: Tie it together\n",
"\n",
"import time\n",
"\n",
"# 1. Software Dilithium keygen/sign/verify (liboqs)\n",
"message = b\"Hardware-software co-design demo on ZCU104\"\n",
"\n",
"with oqs.Signature(DILITHIUM_ALG) as signer:\n",
" print(\"\\n=== Software Dilithium using liboqs ===\")\n",
" t0 = time.time()\n",
" public_key = signer.generate_keypair()\n",
" secret_key = signer.export_secret_key()\n",
" t1 = time.time()\n",
" print(\"Keygen time (SW): %.3f s\" % (t1 - t0))\n",
"\n",
" t0 = time.time()\n",
" signature = signer.sign(message)\n",
" t1 = time.time()\n",
" print(\"Sign time (SW): %.3f s\" % (t1 - t0))\n",
"\n",
" t0 = time.time()\n",
" ok = signer.verify(message, signature, public_key)\n",
" t1 = time.time()\n",
" print(\"Verify time (SW): %.3f s\" % (t1 - t0))\n",
" print(\"Verification OK?\", ok)\n",
"\n",
"# 2. Hardware-accelerated polynomial multiplication demo\n",
"print(\"\\n=== Hardware-accelerated poly_mult_dil ===\")\n",
"\n",
"a = rng.integers(low=-Q//2, high=Q//2, size=N, dtype=np.int64)\n",
"b = rng.integers(low=-Q//2, high=Q//2, size=N, dtype=np.int64)\n",
"\n",
"t0 = time.time()\n",
"c_sw = sw_poly_mult(a, b)\n",
"t1 = time.time()\n",
"print(\"Software poly_mult time: %.6f s\" % (t1 - t0))\n",
"\n",
"t0 = time.time()\n",
"c_hw = hw_poly_mult(a, b)\n",
"t1 = time.time()\n",
"print(\"Hardware poly_mult time (incl. DMA): %.6f s\" % (t1 - t0))\n",
"\n",
"diff = (c_sw - c_hw) % Q\n",
"num_mism = int(np.count_nonzero(diff))\n",
"print(\"Mismatch count:\", num_mism)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "33dcd211",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "36fb8ddb",
"metadata": {},
"outputs": [],
"source": [
"###########################################################################################"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ac4e49fd",
"metadata": {},
"outputs": [],
"source": [
"## Kyber-Notebook analoge Test Verfahren"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a9ec9ac1",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 20,
"id": "818b97e0",
"metadata": {},
"outputs": [],
"source": [
"# oho D_dil\n",
"# 512-point NTT / INTT in pure Python, matching the HLS core\n",
"# and a SW NTT-based polynomial multiplication in R_q[X]/(X^256 + 1).\n",
"\n",
"import numpy as np\n",
"\n",
"# Ring parameters (keep them in sync with your hardware)\n",
"Q = 8380417\n",
"N = 256\n",
"N2 = 2 * N\n",
"\n",
"# These are exactly the twiddle constants we used in the HLS core\n",
"NTT_WLEN = [\n",
" 8380416, 4808194, 4614810, 2883726,\n",
" 6250525, 7044481, 3241972, 6644104,\n",
" 1921994\n",
"]\n",
"\n",
"NTT_WLEN_INV = [\n",
" 8380416, 3572223, 3761513, 5234739,\n",
" 3764867, 3227876, 6621070, 6125690,\n",
" 527981\n",
"]\n",
"\n",
"INV_NTT512 = 8364049 # 512^{-1} mod Q\n",
"\n",
"def mod_q(x: int) -> int:\n",
" \"\"\"Reduce integer x modulo Q into [0, Q).\"\"\"\n",
" r = x % Q\n",
" if r < 0:\n",
" r += Q\n",
" return r\n",
"\n",
"def modadd(a: int, b: int) -> int:\n",
" s = a + b\n",
" if s >= Q:\n",
" s -= Q\n",
" return s\n",
"\n",
"def modsub(a: int, b: int) -> int:\n",
" d = a - b\n",
" if d < 0:\n",
" d += Q\n",
" return d\n",
"\n",
"def mul_mod(a: int, b: int) -> int:\n",
" return mod_q(a * b)\n",
"\n",
"def ntt_512_py(a, invert: bool = False):\n",
" \"\"\"\n",
" In-place iterative radix-2 NTT of size 512, pure Python.\n",
"\n",
" a: iterable of length 512, entries in Z_Q.\n",
" invert = False: forward NTT\n",
" invert = True: inverse NTT (includes multiply by 512^{-1} mod Q)\n",
" \"\"\"\n",
" # Copy into a plain Python list of ints\n",
" a = [mod_q(int(x)) for x in a]\n",
" n2 = N2\n",
"\n",
" # Bit-reversal permutation\n",
" j = 0\n",
" for i in range(1, n2):\n",
" bit = n2 >> 1\n",
" while j & bit:\n",
" j ^= bit\n",
" bit >>= 1\n",
" j ^= bit\n",
" if i < j:\n",
" a[i], a[j] = a[j], a[i]\n",
"\n",
" length = 2\n",
" stage = 0\n",
" while length <= n2:\n",
" wlen = NTT_WLEN_INV[stage] if invert else NTT_WLEN[stage]\n",
" half = length >> 1\n",
"\n",
" for i in range(0, n2, length):\n",
" w = 1\n",
" for k in range(half):\n",
" u = a[i + k]\n",
" v = mul_mod(a[i + k + half], w)\n",
" a[i + k] = modadd(u, v)\n",
" a[i + k + half] = modsub(u, v)\n",
" w = mul_mod(w, wlen)\n",
"\n",
" length <<= 1\n",
" stage += 1\n",
"\n",
" if invert:\n",
" # Multiply by 512^{-1} mod Q\n",
" for i in range(n2):\n",
" a[i] = mul_mod(a[i], INV_NTT512)\n",
"\n",
" return a\n",
"\n",
"def poly_mul_sw_ntt(a, b):\n",
" \"\"\"\n",
" SW NTT-based polynomial multiplication:\n",
" c(X) = a(X) * b(X) mod (X^256 + 1, Q)\n",
" using the same 512-pt NTT+fold as the HLS core.\n",
"\n",
" a, b: 1D arrays of length N (can be centered ints).\n",
" Returns: length-N numpy array, centered mod Q.\n",
" \"\"\"\n",
" a = np.asarray(a, dtype=np.int64)\n",
" b = np.asarray(b, dtype=np.int64)\n",
" assert a.shape == (N,) and b.shape == (N,)\n",
"\n",
" # Map inputs into [0, Q)\n",
" A = [mod_q(int(x)) for x in a] + [0] * (N2 - N)\n",
" B = [mod_q(int(x)) for x in b] + [0] * (N2 - N)\n",
"\n",
" # Forward NTT\n",
" A = ntt_512_py(A, invert=False)\n",
" B = ntt_512_py(B, invert=False)\n",
"\n",
" # Pointwise multiply in NTT domain\n",
" C = [mul_mod(A[i], B[i]) for i in range(N2)]\n",
"\n",
" # Inverse NTT\n",
" C = ntt_512_py(C, invert=True)\n",
"\n",
" # Negacyclic fold: c[k] = C[k] - C[k+N] mod Q\n",
" c = np.empty(N, dtype=np.int64)\n",
" for k in range(N):\n",
" val = modsub(C[k], C[k + N])\n",
" # Use your existing centered representation helper if you have it\n",
" # Otherwise: center in (-Q/2, Q/2]\n",
" if 'center_mod_q' in globals():\n",
" c[k] = center_mod_q(val)\n",
" else:\n",
" # fallback centered mapping\n",
" r = val % Q\n",
" if r > Q // 2:\n",
" r -= Q\n",
" c[k] = r\n",
"\n",
" return c\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12e59835",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 22,
"id": "f676296e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"liboqs available, using algorithm: ML-DSA-44\n"
]
}
],
"source": [
"# oho F_dil liboqs setup for Dilithium\n",
"\n",
"HAS_OQS = False\n",
"DILITHIUM_ALG = None\n",
"oqs_signer = None\n",
"oqs_pk = None\n",
"\n",
"try:\n",
" import oqs\n",
"\n",
" enabled = oqs.get_enabled_sig_mechanisms()\n",
" # Prefer ML-DSA-44 if present, else fall back to Dilithium2\n",
" if \"ML-DSA-44\" in enabled:\n",
" DILITHIUM_ALG = \"ML-DSA-44\"\n",
" elif \"Dilithium2\" in enabled:\n",
" DILITHIUM_ALG = \"Dilithium2\"\n",
" else:\n",
" print(\"No Dilithium / ML-DSA scheme enabled in liboqs; skipping liboqs benchmarks.\")\n",
" if DILITHIUM_ALG is not None:\n",
" oqs_signer = oqs.Signature(DILITHIUM_ALG)\n",
" oqs_pk = oqs_signer.generate_keypair() # one keypair reused for all timings\n",
" # sk = oqs_signer.export_secret_key() # not needed explicitly here\n",
" HAS_OQS = True\n",
" print(\"liboqs available, using algorithm:\", DILITHIUM_ALG)\n",
"except ImportError:\n",
" print(\"liboqs-python (oqs) not installed; skipping liboqs benchmarks.\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "91689702",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 21,
"id": "4aaa17d9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Naive vs SW NTT mismatches: 0\n",
"Naive vs HW mismatches: 0\n",
"t_naive = 0.286735 s\n",
"t_sw_ntt = 0.081213 s\n",
"t_hw_ntt = 0.008756 s\n",
"Speedup naive / SW NTT: 3.53 x\n",
"Speedup SW NTT / HW NTT: 9.28 x\n"
]
}
],
"source": [
"# oho E_dil\n",
"# Compare SW naive vs SW NTT vs HW NTT for Dilithium ring multiplication.\n",
"\n",
"import time\n",
"from numpy.random import default_rng\n",
"\n",
"rng = default_rng(12345)\n",
"\n",
"# Alias your existing naive golden for clarity\n",
"def poly_mul_sw_naive(a, b):\n",
" # This assumes you already have sw_poly_mult(a,b) as the O(N^2) golden.\n",
" return sw_poly_mult(a, b)\n",
"\n",
"def benchmark_once():\n",
" # Random centered coefficients in (-Q/2, Q/2]\n",
" a = rng.integers(low=-Q//2, high=Q//2, size=N, dtype=np.int64)\n",
" b = rng.integers(low=-Q//2, high=Q//2, size=N, dtype=np.int64)\n",
"\n",
" # 1) Naive SW (O(N^2))\n",
" t0 = time.perf_counter()\n",
" c_naive = poly_mul_sw_naive(a, b)\n",
" t1 = time.perf_counter()\n",
" t_naive = t1 - t0\n",
"\n",
" # 2) SW NTT (same algorithm as HW, but in Python)\n",
" t0 = time.perf_counter()\n",
" c_sw_ntt = poly_mul_sw_ntt(a, b)\n",
" t1 = time.perf_counter()\n",
" t_sw_ntt = t1 - t0\n",
"\n",
" # 3) HW NTT (HLS core via DMA)\n",
" t0 = time.perf_counter()\n",
" c_hw = hw_poly_mult(a, b)\n",
" t1 = time.perf_counter()\n",
" t_hw = t1 - t0\n",
"\n",
" # Correctness checks vs naive, modulo Q\n",
" diff_sw_ntt = (c_naive - c_sw_ntt) % Q\n",
" diff_hw = (c_naive - c_hw) % Q\n",
"\n",
" print(\"Naive vs SW NTT mismatches:\",\n",
" int(np.count_nonzero(diff_sw_ntt)))\n",
" print(\"Naive vs HW mismatches:\",\n",
" int(np.count_nonzero(diff_hw)))\n",
"\n",
" print(f\"t_naive = {t_naive:.6f} s\")\n",
" print(f\"t_sw_ntt = {t_sw_ntt:.6f} s\")\n",
" print(f\"t_hw_ntt = {t_hw:.6f} s\")\n",
"\n",
" if t_sw_ntt > 0 and t_hw > 0:\n",
" print(\"Speedup naive / SW NTT: %.2f x\" % (t_naive / t_sw_ntt))\n",
" print(\"Speedup SW NTT / HW NTT: %.2f x\" % (t_sw_ntt / t_hw))\n",
"\n",
"benchmark_once()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d500e63b",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 23,
"id": "bee9cd02",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Naive vs SW NTT mismatches: 0\n",
"Naive vs HW mismatches: 0\n",
"t_naive = 0.293554 s\n",
"t_sw_ntt = 0.081380 s\n",
"t_hw_ntt = 0.008733 s\n",
"Speedup naive / SW NTT: 3.61 x\n",
"Speedup SW NTT / HW NTT: 9.32 x\n",
"liboqs Dilithium sign+verify OK? True, time: 0.002986 s\n"
]
}
],
"source": [
"# oho G_dil benchmark: SW naive vs SW NTT vs HW NTT vs liboqs Dilithium\n",
"\n",
"import time\n",
"from numpy.random import default_rng\n",
"\n",
"rng = default_rng(12345)\n",
"\n",
"# Alias: your existing naive golden\n",
"def poly_mul_sw_naive(a, b):\n",
" # This assumes sw_poly_mult(a,b) is the O(N^2) negacyclic multiply you already defined.\n",
" return sw_poly_mult(a, b)\n",
"\n",
"def benchmark_once():\n",
" # Random centered coefficients in (-Q/2, Q/2]\n",
" a = rng.integers(low=-Q//2, high=Q//2, size=N, dtype=np.int64)\n",
" b = rng.integers(low=-Q//2, high=Q//2, size=N, dtype=np.int64)\n",
"\n",
" # 1) Naive SW (O(N^2))\n",
" t0 = time.perf_counter()\n",
" c_naive = poly_mul_sw_naive(a, b)\n",
" t1 = time.perf_counter()\n",
" t_naive = t1 - t0\n",
"\n",
" # 2) SW NTT (same algorithm as HW, but in Python)\n",
" t0 = time.perf_counter()\n",
" c_sw_ntt = poly_mul_sw_ntt(a, b)\n",
" t1 = time.perf_counter()\n",
" t_sw_ntt = t1 - t0\n",
"\n",
" # 3) HW NTT (HLS core via DMA)\n",
" t0 = time.perf_counter()\n",
" c_hw = hw_poly_mult(a, b)\n",
" t1 = time.perf_counter()\n",
" t_hw = t1 - t0\n",
"\n",
" # Correctness checks vs naive, modulo Q\n",
" diff_sw_ntt = (c_naive - c_sw_ntt) % Q\n",
" diff_hw = (c_naive - c_hw) % Q\n",
"\n",
" print(\"Naive vs SW NTT mismatches:\",\n",
" int(np.count_nonzero(diff_sw_ntt)))\n",
" print(\"Naive vs HW mismatches:\",\n",
" int(np.count_nonzero(diff_hw)))\n",
"\n",
" print(f\"t_naive = {t_naive:.6f} s\")\n",
" print(f\"t_sw_ntt = {t_sw_ntt:.6f} s\")\n",
" print(f\"t_hw_ntt = {t_hw:.6f} s\")\n",
"\n",
" if t_sw_ntt > 0 and t_hw > 0:\n",
" print(\"Speedup naive / SW NTT: %.2f x\" % (t_naive / t_sw_ntt))\n",
" print(\"Speedup SW NTT / HW NTT: %.2f x\" % (t_sw_ntt / t_hw))\n",
"\n",
" # 4) liboqs Dilithium sign+verify (full scheme, not just one poly mult)\n",
" if HAS_OQS and oqs_signer is not None:\n",
" # Use some message derived from the polynomials just to keep it tied together\n",
" msg = a.tobytes()[:32] # first 32 bytes as a \"message\"\n",
" t0 = time.perf_counter()\n",
" sig = oqs_signer.sign(msg)\n",
" ok = oqs_signer.verify(msg, sig, oqs_pk)\n",
" t1 = time.perf_counter()\n",
" t_oqs = t1 - t0\n",
" print(f\"liboqs Dilithium sign+verify OK? {ok}, time: {t_oqs:.6f} s\")\n",
" else:\n",
" print(\"liboqs Dilithium not available in this environment.\")\n",
"\n",
"benchmark_once()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ffbb055",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f592b8e",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "d26c9518",
"metadata": {},
"outputs": [],
"source": [
"#########################################################\n",
"# Fair comparison\n",
"#########################################################"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "963d4b71",
"metadata": {},
"outputs": [],
"source": [
"# oho H_dil simple R_q vector/matrix utilities and pluggable poly_mul backend\n",
"\n",
"import numpy as np\n",
"\n",
"Q = 8380417\n",
"N = 256\n",
"\n",
"def poly_mul_naive(a, b):\n",
" # Your existing O(N^2) golden\n",
" return sw_poly_mult(a, b)\n",
"\n",
"def poly_mul_sw_ntt_backend(a, b):\n",
" # Your Python NTT-based multiply from the previous cells\n",
" return poly_mul_sw_ntt(a, b)\n",
"\n",
"def poly_mul_hw_backend(a, b):\n",
" # Your FPGA-based multiply via DMA\n",
" return hw_poly_mult(a, b)\n",
"\n",
"def poly_add(a, b):\n",
" a = np.asarray(a, dtype=np.int64)\n",
" b = np.asarray(b, dtype=np.int64)\n",
" c = (a + b) % Q\n",
" # center if you prefer centered reps; else leave mod Q\n",
" return np.array([center_mod_q(int(x)) for x in c], dtype=np.int64)\n",
"\n",
"def poly_sub(a, b):\n",
" a = np.asarray(a, dtype=np.int64)\n",
" b = np.asarray(b, dtype=np.int64)\n",
" c = (a - b) % Q\n",
" return np.array([center_mod_q(int(x)) for x in c], dtype=np.int64)\n",
"\n",
"def sample_poly_small(eta=2):\n",
" # Very crude: coefficients uniform in [-eta, eta]\n",
" return np.random.randint(-eta, eta+1, size=N, dtype=np.int64)\n",
"\n",
"def sample_poly_uniform():\n",
" # Uniform in [0, Q)\n",
" return np.random.randint(0, Q, size=N, dtype=np.int64)\n",
"\n",
"def vec_poly_add(v1, v2):\n",
" assert len(v1) == len(v2)\n",
" return [poly_add(v1[i], v2[i]) for i in range(len(v1))]\n",
"\n",
"def vec_poly_sub(v1, v2):\n",
" assert len(v1) == len(v2)\n",
" return [poly_sub(v1[i], v2[i]) for i in range(len(v1))]\n",
"\n",
"def mat_poly_mul_vec(A, x, poly_mul_func):\n",
" \"\"\"\n",
" A: list of k rows, each row is list of l polynomials (shape k x l).\n",
" x: list of l polynomials (length l).\n",
" poly_mul_func: function(a,b) -> polynomial of length N.\n",
" Returns: list of k polynomials (A * x).\n",
" \"\"\"\n",
" k = len(A)\n",
" l = len(A[0])\n",
" assert len(x) == l\n",
" result = []\n",
" for i in range(k):\n",
" acc = np.zeros(N, dtype=np.int64)\n",
" for j in range(l):\n",
" acc = poly_add(acc, poly_mul_func(A[i][j], x[j]))\n",
" result.append(acc)\n",
" return result\n",
"\n",
"def vec_poly_mul_scalar(v, c_poly, poly_mul_func):\n",
" \"\"\"\n",
" v: list of polynomials\n",
" c_poly: polynomial\n",
" returns: [v[i] * c_poly] via poly_mul_func\n",
" \"\"\"\n",
" return [poly_mul_func(v[i], c_poly) for i in range(len(v))]\n"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "425937d6",
"metadata": {},
"outputs": [],
"source": [
"# oho I_dil toy Dilithium-style keygen / sign / verify with pluggable poly_mul\n",
"\n",
"# Parameters mimicking Dilithium-II\n",
"K = 4 # rows of A\n",
"L = 4 # cols of A\n",
"\n",
"def toy_dil_keygen(poly_mul_func):\n",
" \"\"\"\n",
" Keygen-like flow:\n",
" - Sample A (KxL matrix of polynomials)\n",
" - Sample s1 (L small polys), s2 (K small polys)\n",
" - t = A*s1 + s2\n",
" Returns (pk, sk) where:\n",
" pk = (A, t)\n",
" sk = (s1, s2)\n",
" \"\"\"\n",
" # Sample A\n",
" A = [[sample_poly_uniform() for _ in range(L)] for _ in range(K)]\n",
" # Secret vectors\n",
" s1 = [sample_poly_small() for _ in range(L)]\n",
" s2 = [sample_poly_small() for _ in range(K)]\n",
" # t = A*s1 + s2\n",
" As1 = mat_poly_mul_vec(A, s1, poly_mul_func)\n",
" t = vec_poly_add(As1, s2)\n",
" pk = (A, t)\n",
" sk = (s1, s2)\n",
" return pk, sk\n",
"\n",
"def toy_dil_sign(poly_mul_func, sk, pk, msg_bytes):\n",
" \"\"\"\n",
" Sign-like flow:\n",
" - Sample y (L small polys)\n",
" - w = A*y\n",
" - c = hash-like dummy poly derived from msg (here: simple deterministic map)\n",
" - z = y + c*s1\n",
" - r = w - c*s2\n",
" Returns a toy signature (z, c, r).\n",
" \"\"\"\n",
" A, t = pk\n",
" s1, s2 = sk\n",
"\n",
" # Sample y\n",
" y = [sample_poly_small() for _ in range(L)]\n",
" # w = A*y\n",
" w = mat_poly_mul_vec(A, y, poly_mul_func)\n",
"\n",
" # Dummy \"hash\" c: a polynomial derived deterministically from msg_bytes\n",
" # This is NOT secure; it's just to inject msg into the flow.\n",
" np.random.seed(int.from_bytes(msg_bytes[:4], \"little\", signed=False))\n",
" c_poly = sample_poly_small(eta=1)\n",
"\n",
" # z = y + c*s1\n",
" cs1 = vec_poly_mul_scalar(s1, c_poly, poly_mul_func)\n",
" z = vec_poly_add(y, cs1)\n",
"\n",
" # r = w - c*s2\n",
" cs2 = vec_poly_mul_scalar(s2, c_poly, poly_mul_func)\n",
" r = vec_poly_sub(w, cs2)\n",
"\n",
" sig = (z, c_poly, r)\n",
" return sig\n",
"\n",
"def toy_dil_verify(poly_mul_func, pk, msg_bytes, sig):\n",
" \"\"\"\n",
" Verify-like flow:\n",
" - Reconstruct w' = A*z - c*t\n",
" - Check that w' is \"small-ish\" in some toy sense\n",
" Returns True/False.\n",
" \"\"\"\n",
" A, t = pk\n",
" z, c_poly, r = sig\n",
"\n",
" # Recompute w' = A*z - c*t\n",
" Az = mat_poly_mul_vec(A, z, poly_mul_func)\n",
" ct = vec_poly_mul_scalar(t, c_poly, poly_mul_func)\n",
" w_prime = vec_poly_sub(Az, ct)\n",
"\n",
" # Compare w' and r (toy condition): they should be \"close\"\n",
" # Here just check exact equality; in real Dilithium you'd have norm bounds.\n",
" for wp, rv in zip(w_prime, r):\n",
" if not np.array_equal((wp % Q), (rv % Q)):\n",
" return False\n",
" return True\n"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "8ea91187",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"=== Toy Dilithium-style flows (same harness, different poly_mul) ===\n",
"\n",
"[toy-naive] message: b'Test message for toy Dilithium'\n",
"[toy-naive] signature preview:\n",
" z[0]: [6, 12, 7, 17, ...]\n",
" c : [0, -1, 1, -1, ...]\n",
" r[0]: [-657352, 4065526, -3395326, 1460624, ...]\n",
"[toy-naive] keygen: 4.116052 s\n",
"[toy-naive] sign: 6.019188 s\n",
"[toy-naive] verify: 5.229953 s\n",
"[toy-naive] verify OK? True\n",
"\n",
"[toy-sw-ntt] message: b'Test message for toy Dilithium'\n",
"[toy-sw-ntt] signature preview:\n",
" z[0]: [6, 12, 7, 17, ...]\n",
" c : [0, -1, 1, -1, ...]\n",
" r[0]: [-657352, 4065526, -3395326, 1460624, ...]\n",
"[toy-sw-ntt] keygen: 1.292702 s\n",
"[toy-sw-ntt] sign: 1.921250 s\n",
"[toy-sw-ntt] verify: 1.606101 s\n",
"[toy-sw-ntt] verify OK? True\n",
"\n",
"[toy-hw-ntt] message: b'Test message for toy Dilithium'\n",
"[toy-hw-ntt] signature preview:\n",
" z[0]: [6, 12, 7, 17, ...]\n",
" c : [0, -1, 1, -1, ...]\n",
" r[0]: [-657352, 4065526, -3395326, 1460624, ...]\n",
"[toy-hw-ntt] keygen: 0.154008 s\n",
"[toy-hw-ntt] sign: 0.222962 s\n",
"[toy-hw-ntt] verify: 0.184250 s\n",
"[toy-hw-ntt] verify OK? True\n",
"\n",
"Speedups (toy flows, sign only):\n",
" naive / sw-ntt: 3.13 x\n",
" sw-ntt / hw-ntt: 8.62 x\n",
"\n",
"=== liboqs real Dilithium (C implementation) ===\n",
"\n",
"[liboqs] message: b'Test message for liboqs Dilithium'\n",
"[liboqs] signature preview: len=2420, first 16 bytes: 12954eb436346ddf827f047c121627f6\n",
"[liboqs] keygen: 0.001172 s\n",
"[liboqs] sign: 0.001966 s\n",
"[liboqs] verify: 0.000734 s\n",
"[liboqs] verify OK? True\n"
]
},
{
"data": {
"text/plain": [
"(0.0011719209996954305, 0.0019655719997899723, 0.0007336500002566027)"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# oho J_dil end-to-end benchmark: toy flows vs liboqs, with message & signature previews\n",
"\n",
"import time\n",
"import binascii\n",
"\n",
"def preview_poly(poly, n=4):\n",
" \"\"\"Return a short string preview of the first n coefficients of a polynomial.\"\"\"\n",
" poly = np.asarray(poly, dtype=np.int64)\n",
" coeffs = \", \".join(str(int(x)) for x in poly[:n])\n",
" return f\"[{coeffs}{', ...' if poly.size > n else ''}]\"\n",
"\n",
"def preview_sig_toy(sig, n=4):\n",
" \"\"\"Preview toy signature (z, c_poly, r) in a compact way.\"\"\"\n",
" z, c_poly, r = sig\n",
" # z: list of L polys, r: list of K polys\n",
" z0 = z[0] if len(z) > 0 else np.zeros(N, dtype=np.int64)\n",
" r0 = r[0] if len(r) > 0 else np.zeros(N, dtype=np.int64)\n",
" return {\n",
" \"z[0]\": preview_poly(z0, n),\n",
" \"c\": preview_poly(c_poly, n),\n",
" \"r[0]\": preview_poly(r0, n),\n",
" }\n",
"\n",
"def preview_sig_liboqs(sig_bytes, n=16):\n",
" \"\"\"Preview liboqs signature as first n bytes in hex.\"\"\"\n",
" sig = bytes(sig_bytes)\n",
" prefix = sig[:n]\n",
" return f\"len={len(sig)}, first {n} bytes: {binascii.hexlify(prefix).decode()}\"\n",
"\n",
"def run_toy_flow(poly_mul_func, label):\n",
" msg = b\"Test message for toy Dilithium\"\n",
" print(f\"\\n[{label}] message:\", msg)\n",
"\n",
" t0 = time.perf_counter()\n",
" pk, sk = toy_dil_keygen(poly_mul_func)\n",
" t1 = time.perf_counter()\n",
"\n",
" sig = toy_dil_sign(poly_mul_func, sk, pk, msg)\n",
" t2 = time.perf_counter()\n",
"\n",
" ok = toy_dil_verify(poly_mul_func, pk, msg, sig)\n",
" t3 = time.perf_counter()\n",
"\n",
" # Signature preview\n",
" sig_preview = preview_sig_toy(sig, n=4)\n",
"\n",
" print(f\"[{label}] signature preview:\")\n",
" print(f\" z[0]: {sig_preview['z[0]']}\")\n",
" print(f\" c : {sig_preview['c']}\")\n",
" print(f\" r[0]: {sig_preview['r[0]']}\")\n",
"\n",
" print(f\"[{label}] keygen: {t1 - t0:.6f} s\")\n",
" print(f\"[{label}] sign: {t2 - t1:.6f} s\")\n",
" print(f\"[{label}] verify: {t3 - t2:.6f} s\")\n",
" print(f\"[{label}] verify OK? {ok}\")\n",
" return (t1 - t0), (t2 - t1), (t3 - t2)\n",
"\n",
"def run_liboqs_flow():\n",
" if not HAS_OQS or oqs_signer is None:\n",
" print(\"[liboqs] not available.\")\n",
" return None\n",
"\n",
" msg = b\"Test message for liboqs Dilithium\"\n",
" print(\"\\n[liboqs] message:\", msg)\n",
"\n",
" t0 = time.perf_counter()\n",
" with oqs.Signature(DILITHIUM_ALG) as signer:\n",
" pk = signer.generate_keypair()\n",
" sk = signer.export_secret_key()\n",
" t1 = time.perf_counter()\n",
"\n",
" sig = signer.sign(msg)\n",
" t2 = time.perf_counter()\n",
"\n",
" ok = signer.verify(msg, sig, pk)\n",
" t3 = time.perf_counter()\n",
"\n",
" sig_preview = preview_sig_liboqs(sig, n=16)\n",
" print(f\"[liboqs] signature preview: {sig_preview}\")\n",
" print(f\"[liboqs] keygen: {t1 - t0:.6f} s\")\n",
" print(f\"[liboqs] sign: {t2 - t1:.6f} s\")\n",
" print(f\"[liboqs] verify: {t3 - t2:.6f} s\")\n",
" print(f\"[liboqs] verify OK? {ok}\")\n",
" return (t1 - t0), (t2 - t1), (t3 - t2)\n",
"\n",
"print(\"=== Toy Dilithium-style flows (same harness, different poly_mul) ===\")\n",
"\n",
"tkg_n, ts_n, tv_n = run_toy_flow(poly_mul_naive, \"toy-naive\")\n",
"tkg_s, ts_s, tv_s = run_toy_flow(poly_mul_sw_ntt_backend, \"toy-sw-ntt\")\n",
"tkg_h, ts_h, tv_h = run_toy_flow(poly_mul_hw_backend, \"toy-hw-ntt\")\n",
"\n",
"print(\"\\nSpeedups (toy flows, sign only):\")\n",
"print(\" naive / sw-ntt: %.2f x\" % (ts_n / ts_s))\n",
"print(\" sw-ntt / hw-ntt: %.2f x\" % (ts_s / ts_h))\n",
"\n",
"print(\"\\n=== liboqs real Dilithium (C implementation) ===\")\n",
"run_liboqs_flow()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3071d384",
"metadata": {},
"outputs": [],
"source": [
"#######################\n",
"# Interoperability Test\n",
"#######################"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "bacd163d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[interop] generating pk, sk with naive backend\n",
"\n",
"[interop] sign=naive, verify=sw-ntt\n",
" sign backend: poly_mul_naive\n",
" verify backend: poly_mul_sw_ntt_backend\n",
" sign time: 6.004031 s\n",
" verify time: 1.602216 s\n",
" verify OK? True\n",
"\n",
"[interop] sign=sw-ntt, verify=naive\n",
" sign backend: poly_mul_sw_ntt_backend\n",
" verify backend: poly_mul_naive\n",
" sign time: 1.920178 s\n",
" verify time: 5.216772 s\n",
" verify OK? True\n",
"\n",
"[interop] sign=hw-ntt, verify=sw-ntt\n",
" sign backend: poly_mul_hw_backend\n",
" verify backend: poly_mul_sw_ntt_backend\n",
" sign time: 0.222985 s\n",
" verify time: 1.599958 s\n",
" verify OK? True\n",
"\n",
"[interop] sign=sw-ntt, verify=hw-ntt\n",
" sign backend: poly_mul_sw_ntt_backend\n",
" verify backend: poly_mul_hw_backend\n",
" sign time: 1.918987 s\n",
" verify time: 0.183846 s\n",
" verify OK? True\n",
"\n",
"[interop] sign=naive, verify=hw-ntt\n",
" sign backend: poly_mul_naive\n",
" verify backend: poly_mul_hw_backend\n",
" sign time: 5.994883 s\n",
" verify time: 0.183779 s\n",
" verify OK? True\n",
"\n",
"[interop] sign=hw-ntt, verify=naive\n",
" sign backend: poly_mul_hw_backend\n",
" verify backend: poly_mul_naive\n",
" sign time: 0.225005 s\n",
" verify time: 5.256011 s\n",
" verify OK? True\n",
"\n",
"[interop] summary:\n",
" naive ↔ sw-ntt: True\n",
" hw-ntt ↔ sw-ntt: True\n",
" hw-ntt ↔ naive: True\n"
]
}
],
"source": [
"# oho K_dil interoperability tests between poly_mul backends\n",
"\n",
"import time\n",
"\n",
"def run_interop_case(label, poly_mul_sign, poly_mul_verify, pk, sk, msg):\n",
" \"\"\"Sign with one backend, verify with another.\"\"\"\n",
" print(f\"\\n[interop] {label}\")\n",
" t0 = time.perf_counter()\n",
" sig = toy_dil_sign(poly_mul_sign, sk, pk, msg)\n",
" t1 = time.perf_counter()\n",
" ok = toy_dil_verify(poly_mul_verify, pk, msg, sig)\n",
" t2 = time.perf_counter()\n",
" print(f\" sign backend: {poly_mul_sign.__name__}\")\n",
" print(f\" verify backend: {poly_mul_verify.__name__}\")\n",
" print(f\" sign time: {t1 - t0:.6f} s\")\n",
" print(f\" verify time: {t2 - t1:.6f} s\")\n",
" print(f\" verify OK? {ok}\")\n",
" return ok\n",
"\n",
"# Fix a message and a single keypair (so only backends differ)\n",
"msg = b\"Interop test message for toy Dilithium\"\n",
"print(\"[interop] generating pk, sk with naive backend\")\n",
"pk, sk = toy_dil_keygen(poly_mul_naive)\n",
"\n",
"# 1) Sign by SW naive, verify by SW NTT and vice versa\n",
"ok_1a = run_interop_case(\n",
" \"sign=naive, verify=sw-ntt\",\n",
" poly_mul_naive,\n",
" poly_mul_sw_ntt_backend,\n",
" pk, sk, msg,\n",
")\n",
"ok_1b = run_interop_case(\n",
" \"sign=sw-ntt, verify=naive\",\n",
" poly_mul_sw_ntt_backend,\n",
" poly_mul_naive,\n",
" pk, sk, msg,\n",
")\n",
"\n",
"# 2) Sign by HW NTT, verify by SW NTT and vice versa\n",
"ok_2a = run_interop_case(\n",
" \"sign=hw-ntt, verify=sw-ntt\",\n",
" poly_mul_hw_backend,\n",
" poly_mul_sw_ntt_backend,\n",
" pk, sk, msg,\n",
")\n",
"ok_2b = run_interop_case(\n",
" \"sign=sw-ntt, verify=hw-ntt\",\n",
" poly_mul_sw_ntt_backend,\n",
" poly_mul_hw_backend,\n",
" pk, sk, msg,\n",
")\n",
"\n",
"# Also useful: cross all three backends for sanity\n",
"ok_3a = run_interop_case(\n",
" \"sign=naive, verify=hw-ntt\",\n",
" poly_mul_naive,\n",
" poly_mul_hw_backend,\n",
" pk, sk, msg,\n",
")\n",
"ok_3b = run_interop_case(\n",
" \"sign=hw-ntt, verify=naive\",\n",
" poly_mul_hw_backend,\n",
" poly_mul_naive,\n",
" pk, sk, msg,\n",
")\n",
"\n",
"print(\"\\n[interop] summary:\")\n",
"print(\" naive ↔ sw-ntt:\", ok_1a and ok_1b)\n",
"print(\" hw-ntt ↔ sw-ntt:\", ok_2a and ok_2b)\n",
"print(\" hw-ntt ↔ naive:\", ok_3a and ok_3b)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "98b90350",
"metadata": {},
"outputs": [],
"source": [
"######################################\n",
"# Throughput Testing\n",
"######################################"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "509500db",
"metadata": {},
"outputs": [],
"source": [
"# oho L_dil reusable-buffer variant of the HW poly multiply for throughput\n",
"\n",
"from pynq import allocate\n",
"\n",
"def make_hw_poly_mult_reusable():\n",
" \"\"\"\n",
" Returns a callable hw_poly_mult_reuse(a, b) that uses preallocated\n",
" DMA buffers under the hood, for high-throughput benchmarking.\n",
" \"\"\"\n",
" in_buf = allocate(shape=(N,), dtype=np.uint64)\n",
" out_buf = allocate(shape=(N,), dtype=np.uint32)\n",
"\n",
" def hw_poly_mult_reuse(a, b):\n",
" a = np.asarray(a, dtype=np.int64)\n",
" b = np.asarray(b, dtype=np.int64)\n",
" assert a.shape == (N,) and b.shape == (N,)\n",
"\n",
" # Pack two 32-bit coeffs into one 64-bit word\n",
" # Convention: low 32 bits = a[i], high 32 bits = b[i]\n",
" for i in range(N):\n",
" a_i = int(a[i] % Q) & 0xFFFFFFFF\n",
" b_i = int(b[i] % Q) & 0xFFFFFFFF\n",
" word = ((b_i << 32) | a_i)\n",
" in_buf[i] = np.uint64(word)\n",
"\n",
" # Start hardware core (ap_start = 1)\n",
" mmio.write(0x00, 0x01)\n",
"\n",
" # Kick DMA transfers\n",
" dma.sendchannel.transfer(in_buf)\n",
" dma.recvchannel.transfer(out_buf)\n",
" dma.sendchannel.wait()\n",
" dma.recvchannel.wait()\n",
"\n",
" # We don't need to return the result for pure throughput measurement,\n",
" # but it's useful to keep it functional.\n",
" c = np.empty(N, dtype=np.int64)\n",
" for i in range(N):\n",
" c[i] = center_mod_q(int(out_buf[i]))\n",
" return c\n",
"\n",
" return hw_poly_mult_reuse, in_buf, out_buf\n",
"\n",
"# Create a reusable HW function once\n",
"hw_poly_mult_reuse, hw_in_buf, hw_out_buf = make_hw_poly_mult_reusable()\n"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "f9b43699",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"=== Throughput benchmark with num_ops = 10 ===\n",
"[naive] total time: 2.871808 s, ops/s: 3.48\n",
"[sw-ntt] total time: 0.811113 s, ops/s: 12.33\n",
"[hw-ntt] total time: 0.076152 s, ops/s: 131.32\n",
"\n",
"Throughput ratios (ops/s):\n",
" sw-ntt / naive: 3.54 x\n",
" hw-ntt / sw-ntt: 10.65 x\n",
" hw-ntt / naive: 37.71 x\n",
"=== Throughput benchmark with num_ops = 100 ===\n",
"[naive] total time: 28.092749 s, ops/s: 3.56\n",
"[sw-ntt] total time: 8.087545 s, ops/s: 12.36\n",
"[hw-ntt] total time: 0.763630 s, ops/s: 130.95\n",
"\n",
"Throughput ratios (ops/s):\n",
" sw-ntt / naive: 3.47 x\n",
" hw-ntt / sw-ntt: 10.59 x\n",
" hw-ntt / naive: 36.79 x\n",
"=== Throughput benchmark with num_ops = 1000 ===\n",
"[naive] total time: 275.062409 s, ops/s: 3.64\n",
"[sw-ntt] total time: 77.406778 s, ops/s: 12.92\n",
"[hw-ntt] total time: 7.654999 s, ops/s: 130.63\n",
"\n",
"Throughput ratios (ops/s):\n",
" sw-ntt / naive: 3.55 x\n",
" hw-ntt / sw-ntt: 10.11 x\n",
" hw-ntt / naive: 35.93 x\n"
]
}
],
"source": [
"# oho M_dil throughput benchmark: many independent multiplies\n",
"\n",
"import time\n",
"from numpy.random import default_rng\n",
"\n",
"rng = default_rng(123456)\n",
"\n",
"def throughput_benchmark(num_ops=100):\n",
" print(f\"=== Throughput benchmark with num_ops = {num_ops} ===\")\n",
"\n",
" # Pre-generate inputs so all backends see the same workload\n",
" A_batch = rng.integers(low=-Q//2, high=Q//2, size=(num_ops, N), dtype=np.int64)\n",
" B_batch = rng.integers(low=-Q//2, high=Q//2, size=(num_ops, N), dtype=np.int64)\n",
"\n",
" # 1) Naive SW (O(N^2))\n",
" t0 = time.perf_counter()\n",
" for k in range(num_ops):\n",
" _ = poly_mul_sw_naive(A_batch[k], B_batch[k])\n",
" t1 = time.perf_counter()\n",
" t_naive_total = t1 - t0\n",
" naive_ops_per_s = num_ops / t_naive_total\n",
"\n",
" print(f\"[naive] total time: {t_naive_total:.6f} s, ops/s: {naive_ops_per_s:.2f}\")\n",
"\n",
" # 2) SW NTT (Python 512-NTT)\n",
" t0 = time.perf_counter()\n",
" for k in range(num_ops):\n",
" _ = poly_mul_sw_ntt(A_batch[k], B_batch[k])\n",
" t1 = time.perf_counter()\n",
" t_sw_ntt_total = t1 - t0\n",
" sw_ntt_ops_per_s = num_ops / t_sw_ntt_total\n",
"\n",
" print(f\"[sw-ntt] total time: {t_sw_ntt_total:.6f} s, ops/s: {sw_ntt_ops_per_s:.2f}\")\n",
"\n",
" # 3) HW NTT (FPGA via reusable DMA buffers)\n",
" t0 = time.perf_counter()\n",
" for k in range(num_ops):\n",
" _ = hw_poly_mult_reuse(A_batch[k], B_batch[k])\n",
" t1 = time.perf_counter()\n",
" t_hw_total = t1 - t0\n",
" hw_ops_per_s = num_ops / t_hw_total\n",
"\n",
" print(f\"[hw-ntt] total time: {t_hw_total:.6f} s, ops/s: {hw_ops_per_s:.2f}\")\n",
"\n",
" # Aggregate ratios\n",
" print(\"\\nThroughput ratios (ops/s):\")\n",
" print(\" sw-ntt / naive: %.2f x\" % (sw_ntt_ops_per_s / naive_ops_per_s))\n",
" print(\" hw-ntt / sw-ntt: %.2f x\" % (hw_ops_per_s / sw_ntt_ops_per_s))\n",
" print(\" hw-ntt / naive: %.2f x\" % (hw_ops_per_s / naive_ops_per_s))\n",
"\n",
"# Example: varying load\n",
"throughput_benchmark(num_ops=10)\n",
"throughput_benchmark(num_ops=100)\n",
"# next line: Do you want to wait so long here ?\n",
"#throughput_benchmark(num_ops=1000)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "61736ef2",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}