NTT-learning/ntt_learning/toy_ntt.py

616 lines
21 KiB
Python
Raw Normal View History

"""Small, inspectable helpers for the NTT learning course."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, Sequence
def _apply_mod(values: Sequence[int], modulus: int | None) -> list[int]:
if modulus is None:
return [int(value) for value in values]
return [int(value) % modulus for value in values]
def schoolbook_convolution(
left: Sequence[int], right: Sequence[int], modulus: int | None = None
) -> list[int]:
"""Multiply two coefficient arrays directly."""
if not left or not right:
return []
result = [0] * (len(left) + len(right) - 1)
for left_index, left_value in enumerate(left):
for right_index, right_value in enumerate(right):
result[left_index + right_index] += left_value * right_value
return _apply_mod(result, modulus)
def pairwise_product_grid(left: Sequence[int], right: Sequence[int]) -> list[list[int]]:
"""Return the full schoolbook multiplication grid."""
return [[int(left_value * right_value) for right_value in right] for left_value in left]
def convolution_contributions(left: Sequence[int], right: Sequence[int]) -> list[dict[str, object]]:
"""Return each output index and the contributing schoolbook products."""
raw = schoolbook_convolution(left, right)
rows: list[dict[str, object]] = []
for output_index, total in enumerate(raw):
terms = []
for left_index, left_value in enumerate(left):
right_index = output_index - left_index
if right_index < 0 or right_index >= len(right):
continue
right_value = right[right_index]
terms.append(
{
"left_index": left_index,
"right_index": right_index,
"left_value": int(left_value),
"right_value": int(right_value),
"product": int(left_value * right_value),
}
)
rows.append({"output_index": output_index, "terms": terms, "total": int(total)})
return rows
def negacyclic_reduce(
coefficients: Sequence[int], n: int, modulus: int | None = None
) -> list[int]:
"""Fold a polynomial back into the ring Z_q[x] / (x^n + 1)."""
if n <= 0:
raise ValueError("n must be positive")
reduced = [0] * n
for index, coefficient in enumerate(coefficients):
wraps, slot = divmod(index, n)
reduced[slot] += coefficient if wraps % 2 == 0 else -coefficient
return _apply_mod(reduced, modulus)
def wraparound_contributions(
coefficients: Sequence[int], n: int, negacyclic: bool = True
) -> list[dict[str, object]]:
"""Describe how high-degree terms fold back into degree < n."""
if n <= 0:
raise ValueError("n must be positive")
reduced = [0] * n
rows: list[list[dict[str, int]]] = [[] for _ in range(n)]
for index, coefficient in enumerate(coefficients):
wraps, slot = divmod(index, n)
sign = -1 if negacyclic and wraps % 2 else 1
signed_value = int(sign * coefficient)
reduced[slot] += signed_value
rows[slot].append(
{
"source_index": index,
"wraps": wraps,
"sign": sign,
"value": int(coefficient),
"signed_value": signed_value,
}
)
return [
{"slot": slot, "contributions": rows[slot], "total": int(reduced[slot])}
for slot in range(n)
]
def negacyclic_multiply(
left: Sequence[int], right: Sequence[int], n: int, modulus: int | None = None
) -> list[int]:
"""Multiply then fold back into degree < n with sign flips on wraparound."""
if len(left) != n or len(right) != n:
raise ValueError("left and right must both have length n")
return negacyclic_reduce(schoolbook_convolution(left, right), n=n, modulus=modulus)
def mod_inverse(value: int, modulus: int) -> int:
"""Return the multiplicative inverse modulo ``modulus``."""
return pow(value, -1, modulus)
def pointwise_multiply(
left: Sequence[int], right: Sequence[int], modulus: int | None = None
) -> list[int]:
"""Multiply transform-domain entries slot by slot."""
if len(left) != len(right):
raise ValueError("left and right must have the same length")
products = [int(a * b) for a, b in zip(left, right)]
return _apply_mod(products, modulus)
def scale_values(values: Sequence[int], factor: int, modulus: int | None = None) -> list[int]:
"""Multiply all values by a scalar."""
scaled = [int(factor * value) for value in values]
return _apply_mod(scaled, modulus)
def base_multiply_pair(
left: Sequence[int], right: Sequence[int], zeta: int, modulus: int
) -> list[int]:
"""Multiply two degree-1 polynomials in a quadratic factor ring."""
if len(left) != 2 or len(right) != 2:
raise ValueError("base_multiply_pair expects two 2-term coefficient vectors")
a0, a1 = (int(left[0]), int(left[1]))
b0, b1 = (int(right[0]), int(right[1]))
return [
(a0 * b0 + zeta * a1 * b1) % modulus,
(a0 * b1 + a1 * b0) % modulus,
]
def _proper_divisors(order: int) -> list[int]:
return [candidate for candidate in range(1, order) if order % candidate == 0]
def find_primitive_root(order: int, modulus: int) -> int:
"""Find a primitive ``order``-th root of unity in Z_modulus."""
if order <= 0:
raise ValueError("order must be positive")
if (modulus - 1) % order != 0:
raise ValueError("order must divide modulus - 1")
for candidate in range(2, modulus):
if pow(candidate, order, modulus) != 1:
continue
if any(pow(candidate, divisor, modulus) == 1 for divisor in _proper_divisors(order)):
continue
return candidate
raise ValueError(f"no primitive {order}-th root exists modulo {modulus}")
def find_psi(order: int, modulus: int) -> int:
"""Find a primitive ``2 * order``-th root for negative-wrapped NTT."""
full_order = 2 * order
root = find_primitive_root(full_order, modulus)
if pow(root, order, modulus) != (modulus - 1) % modulus:
raise ValueError(f"primitive {full_order}-th root does not satisfy psi^order = -1")
return root
def forward_ntt(values: Sequence[int], modulus: int, omega: int) -> list[int]:
"""Definition-first positive-wrapped NTT using the matrix viewpoint."""
n = len(values)
if n == 0:
return []
spectrum = []
for row in range(n):
total = 0
for column, value in enumerate(values):
total += value * pow(omega, row * column, modulus)
spectrum.append(total % modulus)
return spectrum
def inverse_ntt(values: Sequence[int], modulus: int, omega: int) -> list[int]:
"""Inverse transform paired with ``forward_ntt``."""
n = len(values)
if n == 0:
return []
omega_inverse = mod_inverse(omega, modulus)
n_inverse = mod_inverse(n, modulus)
recovered = []
for row in range(n):
total = 0
for column, value in enumerate(values):
total += value * pow(omega_inverse, row * column, modulus)
recovered.append((total * n_inverse) % modulus)
return recovered
def ntt_psi_exponent_grid(length: int) -> list[list[int]]:
"""Return the exponent grid used by direct negative-wrapped NTT."""
return [[2 * row * column + row for row in range(length)] for column in range(length)]
def ntt_psi_matrix(length: int, modulus: int, psi: int) -> list[list[int]]:
"""Return the direct negative-wrapped transform matrix."""
return [
[pow(psi, 2 * row * column + row, modulus) for row in range(length)]
for column in range(length)
]
def forward_ntt_psi(values: Sequence[int], modulus: int, psi: int) -> list[int]:
"""Definition-first negative-wrapped NTT using a 2n-th root ``psi``."""
n = len(values)
if n == 0:
return []
spectrum = []
for column in range(n):
total = 0
for row, value in enumerate(values):
total += value * pow(psi, 2 * row * column + row, modulus)
spectrum.append(total % modulus)
return spectrum
def inverse_ntt_psi(values: Sequence[int], modulus: int, psi: int) -> list[int]:
"""Inverse of ``forward_ntt_psi``."""
n = len(values)
if n == 0:
return []
n_inverse = mod_inverse(n, modulus)
recovered = []
for row in range(n):
total = 0
for column, value in enumerate(values):
total += value * pow(psi, -(2 * row * column + row), modulus)
recovered.append((n_inverse * total) % modulus)
return recovered
def ct_butterfly_pair(top: int, bottom: int, zeta: int, modulus: int) -> tuple[int, int]:
"""One Cooley-Tukey style butterfly on a pair."""
twiddled = (zeta * bottom) % modulus
return (top + twiddled) % modulus, (top - twiddled) % modulus
def gs_butterfly_pair(top: int, bottom: int, zeta: int, modulus: int) -> tuple[int, int]:
"""One Gentleman-Sande style butterfly on a pair."""
summed = (top + bottom) % modulus
diff = (top - bottom) % modulus
return summed, (zeta * diff) % modulus
def stage_pairings(length: int, block_size: int) -> list[tuple[int, int]]:
"""Return the index pairings touched by a single butterfly stage."""
if length <= 0:
raise ValueError("length must be positive")
if block_size <= 0 or block_size % 2 != 0:
raise ValueError("block_size must be a positive even integer")
if length % block_size != 0:
raise ValueError("block_size must divide length")
half = block_size // 2
pairs: list[tuple[int, int]] = []
for start in range(0, length, block_size):
for offset in range(half):
pairs.append((start + offset, start + offset + half))
return pairs
def _expand_zetas(zetas: int | Iterable[int], pair_count: int) -> list[int]:
if isinstance(zetas, int):
return [zetas] * pair_count
expanded = [int(zeta) for zeta in zetas]
if len(expanded) != pair_count:
raise ValueError(f"expected {pair_count} zetas, received {len(expanded)}")
return expanded
@dataclass(frozen=True)
class ButterflyAction:
pair: tuple[int, int]
zeta: int
inputs: tuple[int, int]
outputs: tuple[int, int]
@dataclass(frozen=True)
class TransformStage:
"""One visualizable stage in a fast NTT / iNTT trace."""
algorithm: str
stage_index: int
input_values: tuple[int, ...]
output_values: tuple[int, ...]
pairings: tuple[tuple[int, int], ...]
zetas: tuple[int, ...]
note: str
@dataclass(frozen=True)
class TransformTrace:
"""A full fast-transform trace suitable for notebook visualization."""
algorithm: str
modulus: int
root: int
input_values: tuple[int, ...]
stages: tuple[TransformStage, ...]
raw_output: tuple[int, ...]
normal_order_output: tuple[int, ...]
scaled_output: tuple[int, ...] | None = None
def apply_ct_stage(
values: Sequence[int], block_size: int, zetas: int | Iterable[int], modulus: int
) -> tuple[list[int], list[ButterflyAction]]:
"""Apply a toy Cooley-Tukey stage and return per-pair trace data."""
source = list(values)
updated = list(values)
pairs = stage_pairings(len(source), block_size)
zeta_values = _expand_zetas(zetas, len(pairs))
actions: list[ButterflyAction] = []
for (left, right), zeta in zip(pairs, zeta_values):
outputs = ct_butterfly_pair(source[left], source[right], zeta, modulus)
updated[left], updated[right] = outputs
actions.append(
ButterflyAction(
pair=(left, right),
zeta=zeta % modulus,
inputs=(source[left], source[right]),
outputs=outputs,
)
)
return updated, actions
def apply_gs_stage(
values: Sequence[int], block_size: int, zetas: int | Iterable[int], modulus: int
) -> tuple[list[int], list[ButterflyAction]]:
"""Apply a toy Gentleman-Sande stage and return per-pair trace data."""
source = list(values)
updated = list(values)
pairs = stage_pairings(len(source), block_size)
zeta_values = _expand_zetas(zetas, len(pairs))
actions: list[ButterflyAction] = []
for (left, right), zeta in zip(pairs, zeta_values):
outputs = gs_butterfly_pair(source[left], source[right], zeta, modulus)
updated[left], updated[right] = outputs
actions.append(
ButterflyAction(
pair=(left, right),
zeta=zeta % modulus,
inputs=(source[left], source[right]),
outputs=outputs,
)
)
return updated, actions
def action_rows(actions: Sequence[ButterflyAction]) -> list[dict[str, object]]:
"""Return action traces in a notebook-friendly shape."""
return [
{
"pair": action.pair,
"zeta": action.zeta,
"inputs": action.inputs,
"outputs": action.outputs,
}
for action in actions
]
def stage_rows(stage: TransformStage) -> list[dict[str, object]]:
"""Return rows describing the pair operations of one trace stage."""
rows = []
for pair, zeta in zip(stage.pairings, stage.zetas):
left, right = pair
rows.append(
{
"pair": pair,
"zeta": zeta,
"inputs": (stage.input_values[left], stage.input_values[right]),
"outputs": (stage.output_values[left], stage.output_values[right]),
}
)
return rows
def bit_reverse(index: int, width: int) -> int:
"""Reverse ``width`` bits from ``index``."""
if width < 0:
raise ValueError("width must be non-negative")
reversed_bits = 0
for _ in range(width):
reversed_bits = (reversed_bits << 1) | (index & 1)
index >>= 1
return reversed_bits
def bit_reversed_indices(length: int) -> list[int]:
"""Return the bit-reversal permutation for ``length``."""
if length <= 0:
raise ValueError("length must be positive")
if length & (length - 1):
raise ValueError("bit_reversed_indices requires a power-of-two length")
width = length.bit_length() - 1
return [bit_reverse(index, width) for index in range(length)]
def bit_reversed_order(values: Sequence[int]) -> list[int]:
"""Return the array reordered by bit-reversed indices."""
length = len(values)
if length == 0:
return []
permutation = bit_reversed_indices(length)
return [values[index] for index in permutation]
def _interleave(left: Sequence[int], right: Sequence[int]) -> list[int]:
result: list[int] = []
for left_value, right_value in zip(left, right):
result.extend([int(left_value), int(right_value)])
return result
def _renumber_stages(stages: Sequence[TransformStage]) -> tuple[TransformStage, ...]:
return tuple(
TransformStage(
algorithm=stage.algorithm,
stage_index=index + 1,
input_values=stage.input_values,
output_values=stage.output_values,
pairings=stage.pairings,
zetas=stage.zetas,
note=stage.note,
)
for index, stage in enumerate(stages)
)
def _merge_child_stages(
left_stages: Sequence[TransformStage], right_stages: Sequence[TransformStage]
) -> list[TransformStage]:
if len(left_stages) != len(right_stages):
raise ValueError("expected the same number of stages on both recursive branches")
merged: list[TransformStage] = []
for left_stage, right_stage in zip(left_stages, right_stages):
merged.append(
TransformStage(
algorithm=left_stage.algorithm,
stage_index=0,
input_values=tuple(_interleave(left_stage.input_values, right_stage.input_values)),
output_values=tuple(_interleave(left_stage.output_values, right_stage.output_values)),
pairings=tuple((2 * a, 2 * b) for a, b in left_stage.pairings)
+ tuple((2 * a + 1, 2 * b + 1) for a, b in right_stage.pairings),
zetas=left_stage.zetas + right_stage.zetas,
note=left_stage.note,
)
)
return merged
def _fast_ntt_psi_ct_recursive(
values: Sequence[int], modulus: int, psi: int
) -> tuple[tuple[int, ...], list[TransformStage]]:
size = len(values)
if size == 1:
return (int(values[0]),), []
even_output, even_stages = _fast_ntt_psi_ct_recursive(values[0::2], modulus, pow(psi, 2, modulus))
odd_output, odd_stages = _fast_ntt_psi_ct_recursive(values[1::2], modulus, pow(psi, 2, modulus))
merged_stages = _merge_child_stages(even_stages, odd_stages)
stage_input = tuple(_interleave(even_output, odd_output))
output = list(stage_input)
pairings = []
zetas = []
for column in range(size // 2):
left = 2 * column
right = 2 * column + 1
zeta = pow(psi, 2 * column + 1, modulus)
output[left], output[right] = ct_butterfly_pair(stage_input[left], stage_input[right], zeta, modulus)
pairings.append((left, right))
zetas.append(zeta)
merged_stages.append(
TransformStage(
algorithm="ct",
stage_index=0,
input_values=stage_input,
output_values=tuple(output),
pairings=tuple(pairings),
zetas=tuple(zetas),
note="Interleave recursive sub-transforms, then apply adjacent CT butterflies.",
)
)
return tuple(output), merged_stages
def fast_ntt_psi_ct_trace(values: Sequence[int], modulus: int, psi: int) -> TransformTrace:
"""Return a stage-by-stage CT fast-NTT trace with BO output and NO comparison."""
if not values:
return TransformTrace("ct", modulus, psi, (), (), (), (), None)
if len(values) & (len(values) - 1):
raise ValueError("fast_ntt_psi_ct_trace requires a power-of-two length")
raw_output, stages = _fast_ntt_psi_ct_recursive(values, modulus, psi)
return TransformTrace(
algorithm="ct",
modulus=modulus,
root=psi,
input_values=tuple(int(value) for value in values),
stages=_renumber_stages(stages),
raw_output=raw_output,
normal_order_output=tuple(bit_reversed_order(raw_output)),
scaled_output=None,
)
def fast_ntt_psi_ct(values: Sequence[int], modulus: int, psi: int) -> list[int]:
"""Compute the BO output of the CT fast negative-wrapped NTT."""
return list(fast_ntt_psi_ct_trace(values, modulus, psi).raw_output)
def _fast_intt_psi_gs_recursive(
values: Sequence[int], modulus: int, psi: int
) -> tuple[tuple[int, ...], list[TransformStage]]:
size = len(values)
if size == 1:
return (int(values[0]),), []
stage_input = tuple(int(value) for value in values)
stage_output = list(stage_input)
pairings = []
zetas = []
for column in range(size // 2):
left = 2 * column
right = 2 * column + 1
zeta = mod_inverse(pow(psi, 2 * column + 1, modulus), modulus)
stage_output[left], stage_output[right] = gs_butterfly_pair(
stage_input[left], stage_input[right], zeta, modulus
)
pairings.append((left, right))
zetas.append(zeta)
even_output, even_stages = _fast_intt_psi_gs_recursive(
stage_output[0::2], modulus, pow(psi, 2, modulus)
)
odd_output, odd_stages = _fast_intt_psi_gs_recursive(
stage_output[1::2], modulus, pow(psi, 2, modulus)
)
return (
tuple(_interleave(even_output, odd_output)),
[
TransformStage(
algorithm="gs",
stage_index=0,
input_values=stage_input,
output_values=tuple(stage_output),
pairings=tuple(pairings),
zetas=tuple(zetas),
note="Apply adjacent GS butterflies, then recurse on even and odd branches.",
),
*_merge_child_stages(even_stages, odd_stages),
],
)
def fast_intt_psi_gs_trace(values: Sequence[int], modulus: int, psi: int) -> TransformTrace:
"""Return a stage-by-stage GS fast-iNTT trace from BO input to NO output."""
if not values:
return TransformTrace("gs", modulus, psi, (), (), (), (), ())
if len(values) & (len(values) - 1):
raise ValueError("fast_intt_psi_gs_trace requires a power-of-two length")
raw_output, stages = _fast_intt_psi_gs_recursive(values, modulus, psi)
n_inverse = mod_inverse(len(values), modulus)
scaled_output = tuple((n_inverse * value) % modulus for value in raw_output)
return TransformTrace(
algorithm="gs",
modulus=modulus,
root=psi,
input_values=tuple(int(value) for value in values),
stages=_renumber_stages(stages),
raw_output=raw_output,
normal_order_output=raw_output,
scaled_output=scaled_output,
)
def fast_intt_psi_gs(values: Sequence[int], modulus: int, psi: int) -> list[int]:
"""Compute the NO output of the GS fast negative-wrapped inverse NTT."""
trace = fast_intt_psi_gs_trace(values, modulus, psi)
return list(trace.scaled_output or ())