mirror of
https://github.com/saymrwulf/NTT-learning.git
synced 2026-06-04 23:39:43 +00:00
1226 lines
46 KiB
Python
1226 lines
46 KiB
Python
"""Blunt visual helpers for the NTT notebooks."""
|
||
|
||
from __future__ import annotations
|
||
|
||
from html import escape
|
||
from typing import Sequence
|
||
|
||
import ipywidgets as widgets
|
||
import matplotlib
|
||
|
||
matplotlib.use("Agg")
|
||
import matplotlib.pyplot as plt
|
||
from IPython.display import clear_output, display
|
||
|
||
from .toy_ntt import (
|
||
TransformStage,
|
||
TransformTrace,
|
||
base_multiply_pair,
|
||
bit_reversed_order,
|
||
convolution_contributions,
|
||
forward_ntt_psi,
|
||
inverse_ntt_psi,
|
||
ntt_psi_exponent_grid,
|
||
ntt_psi_matrix,
|
||
pairwise_product_grid,
|
||
pointwise_multiply,
|
||
stage_pairings,
|
||
wraparound_contributions,
|
||
)
|
||
|
||
SVG_BG = "#f7f3ea"
|
||
SVG_PANEL = "#fffdf8"
|
||
SVG_INK = "#1f2933"
|
||
SVG_SOFT = "#d8e2dc"
|
||
SVG_ACCENT = "#ef476f"
|
||
SVG_HILITE = "#ffd166"
|
||
SVG_GOOD = "#06d6a0"
|
||
SVG_WARN = "#f08a5d"
|
||
SVG_BLUE = "#118ab2"
|
||
|
||
|
||
def _widget_chrome(title: str, subtitle: str, view: widgets.Widget) -> widgets.Widget:
|
||
header = widgets.HTML(
|
||
f"""
|
||
<div style="
|
||
background: linear-gradient(135deg, #f7ede2 0%, #f5cac3 45%, #84a59d 100%);
|
||
color: #102a43;
|
||
padding: 14px 16px;
|
||
border-radius: 14px 14px 0 0;
|
||
font-family: 'Avenir Next', 'Trebuchet MS', sans-serif;
|
||
box-shadow: inset 0 0 0 1px rgba(16,42,67,0.12);
|
||
">
|
||
<div style="font-size: 18px; font-weight: 800;">{escape(title)}</div>
|
||
<div style="font-size: 12px; margin-top: 4px;">{escape(subtitle)}</div>
|
||
</div>
|
||
"""
|
||
)
|
||
box = widgets.VBox(
|
||
[header, view],
|
||
layout=widgets.Layout(border="1px solid #d9d9d9", overflow="hidden"),
|
||
)
|
||
return box
|
||
|
||
|
||
def _player_widget(
|
||
*,
|
||
title: str,
|
||
subtitle: str,
|
||
frames: Sequence[str],
|
||
captions: Sequence[str],
|
||
width: str = "100%",
|
||
) -> widgets.Widget:
|
||
if not frames:
|
||
raise ValueError("player widget requires at least one frame")
|
||
|
||
play = widgets.Play(value=0, min=0, max=len(frames) - 1, step=1, interval=1100, description="Play")
|
||
slider = widgets.IntSlider(
|
||
value=0,
|
||
min=0,
|
||
max=len(frames) - 1,
|
||
step=1,
|
||
description="Step",
|
||
continuous_update=False,
|
||
layout=widgets.Layout(width="420px"),
|
||
)
|
||
widgets.jslink((play, "value"), (slider, "value"))
|
||
|
||
frame_html = widgets.HTML(layout=widgets.Layout(width=width))
|
||
caption_html = widgets.HTML(layout=widgets.Layout(width=width))
|
||
|
||
def render(index: int) -> None:
|
||
frame_html.value = frames[index]
|
||
caption_html.value = f"""
|
||
<div style="
|
||
background: #fffdf8;
|
||
border-top: 1px solid #e8dcc9;
|
||
padding: 12px 16px;
|
||
font-family: 'Avenir Next', 'Trebuchet MS', sans-serif;
|
||
color: #243b53;
|
||
font-size: 13px;
|
||
line-height: 1.45;
|
||
">
|
||
<strong>Frame {index + 1} of {len(frames)}</strong><br>
|
||
{escape(captions[index])}
|
||
</div>
|
||
"""
|
||
|
||
slider.observe(lambda change: render(change["new"]), names="value")
|
||
render(0)
|
||
|
||
controls = widgets.HBox(
|
||
[play, slider],
|
||
layout=widgets.Layout(padding="10px 14px", align_items="center"),
|
||
)
|
||
return _widget_chrome(title, subtitle, widgets.VBox([controls, frame_html, caption_html]))
|
||
|
||
|
||
def _svg_box(x: float, y: float, w: float, h: float, label: str, value: str, *, fill: str, stroke: str, stroke_width: float = 2.0) -> str:
|
||
return f"""
|
||
<rect x="{x}" y="{y}" width="{w}" height="{h}" rx="10" fill="{fill}" stroke="{stroke}" stroke-width="{stroke_width}"></rect>
|
||
<text x="{x + w / 2}" y="{y + 18}" text-anchor="middle" font-size="11" font-family="Menlo, monospace" fill="{SVG_INK}">{escape(label)}</text>
|
||
<text x="{x + w / 2}" y="{y + h / 2 + 10}" text-anchor="middle" font-size="22" font-weight="700" font-family="Menlo, monospace" fill="{SVG_INK}">{escape(value)}</text>
|
||
"""
|
||
|
||
|
||
def _svg_text(x: float, y: float, text: str, *, size: int = 14, weight: str = "400", fill: str = SVG_INK, anchor: str = "start") -> str:
|
||
return f'<text x="{x}" y="{y}" text-anchor="{anchor}" font-size="{size}" font-weight="{weight}" font-family="Avenir Next, Trebuchet MS, sans-serif" fill="{fill}">{escape(text)}</text>'
|
||
|
||
|
||
def _convolution_frame_svg(
|
||
left: Sequence[int],
|
||
right: Sequence[int],
|
||
diagonal_index: int,
|
||
*,
|
||
title: str,
|
||
) -> str:
|
||
rows = convolution_contributions(left, right)
|
||
current = rows[diagonal_index]
|
||
grid = pairwise_product_grid(left, right)
|
||
cell = 68
|
||
left_margin = 78
|
||
top_margin = 78
|
||
bottom_y = top_margin + len(left) * cell + 70
|
||
width = left_margin + len(right) * cell + 120
|
||
height = bottom_y + 110
|
||
|
||
parts = [
|
||
f'<svg viewBox="0 0 {width} {height}" width="100%" style="background:{SVG_BG}; border-radius:0 0 14px 14px;">',
|
||
f'<rect x="18" y="18" width="{width - 36}" height="{height - 36}" rx="18" fill="{SVG_PANEL}" stroke="#e8dcc9" stroke-width="2"></rect>',
|
||
_svg_text(34, 48, title, size=22, weight="800"),
|
||
_svg_text(34, 68, f"Diagonal {diagonal_index}: every highlighted cell lands in output coefficient y{diagonal_index}", size=13, fill="#486581"),
|
||
]
|
||
|
||
for row in range(len(left)):
|
||
parts.append(_svg_text(54, top_margin + row * cell + 42, f"a{row}", size=12, anchor="middle"))
|
||
for col in range(len(right)):
|
||
parts.append(_svg_text(left_margin + col * cell + 28, 64, f"b{col}", size=12, anchor="middle"))
|
||
|
||
for row, row_values in enumerate(grid):
|
||
for col, value in enumerate(row_values):
|
||
active = row + col == diagonal_index
|
||
done = row + col < diagonal_index
|
||
fill = SVG_HILITE if active else "#e6fffa" if done else "#f8f9fa"
|
||
stroke = SVG_ACCENT if active else SVG_GOOD if done else "#cbd2d9"
|
||
parts.append(
|
||
_svg_box(
|
||
left_margin + col * cell,
|
||
top_margin + row * cell,
|
||
56,
|
||
56,
|
||
f"a{row}·b{col}",
|
||
str(value),
|
||
fill=fill,
|
||
stroke=stroke,
|
||
stroke_width=3.2 if active else 1.8,
|
||
)
|
||
)
|
||
|
||
for row in rows:
|
||
output_index = int(row["output_index"])
|
||
fill = SVG_HILITE if output_index == diagonal_index else "#d9f0ff" if output_index < diagonal_index else "#f1f5f9"
|
||
stroke = SVG_ACCENT if output_index == diagonal_index else SVG_BLUE if output_index < diagonal_index else "#cbd2d9"
|
||
parts.append(
|
||
_svg_box(
|
||
left_margin + output_index * cell * 0.92,
|
||
bottom_y,
|
||
60,
|
||
56,
|
||
f"y{output_index}",
|
||
str(row["total"]),
|
||
fill=fill,
|
||
stroke=stroke,
|
||
stroke_width=3.2 if output_index == diagonal_index else 1.8,
|
||
)
|
||
)
|
||
|
||
terms = [f"{term['left_value']}×{term['right_value']}={term['product']}" for term in current["terms"]]
|
||
equation = " + ".join(terms) if terms else "0"
|
||
parts.append(_svg_text(34, height - 34, f"Current diagonal sum: {equation} = {current['total']}", size=15, weight="700"))
|
||
parts.append("</svg>")
|
||
return "".join(parts)
|
||
|
||
|
||
def schoolbook_diagonal_player(left: Sequence[int], right: Sequence[int]) -> widgets.Widget:
|
||
"""Interactive diagonal-by-diagonal walkthrough of schoolbook multiplication."""
|
||
rows = convolution_contributions(left, right)
|
||
frames = [
|
||
_convolution_frame_svg(left, right, diagonal_index=index, title="Schoolbook Multiplication As A Moving Diagonal")
|
||
for index in range(len(rows))
|
||
]
|
||
captions = [
|
||
f"Only the highlighted products contribute to y{row['output_index']}. Watch the diagonal sweep across the grid instead of imagining the sum in your head."
|
||
for row in rows
|
||
]
|
||
return _player_widget(
|
||
title="Convolution Diagonal Player",
|
||
subtitle="Press play. The highlighted diagonal is the coefficient being formed right now.",
|
||
frames=frames,
|
||
captions=captions,
|
||
)
|
||
|
||
|
||
def _wrap_compare_frame_svg(coefficients: Sequence[int], n: int, step: int) -> str:
|
||
cyclic_rows = wraparound_contributions(coefficients, n=n, negacyclic=False)
|
||
negacyclic_rows = wraparound_contributions(coefficients, n=n, negacyclic=True)
|
||
flat = [(index, coefficient) for index, coefficient in enumerate(coefficients)]
|
||
current_index, current_value = flat[step]
|
||
cell = 58
|
||
width = max(980, len(coefficients) * cell + 200)
|
||
height = 390
|
||
top_y = 78
|
||
cyclic_y = 210
|
||
neg_y = 300
|
||
|
||
parts = [
|
||
f'<svg viewBox="0 0 {width} {height}" width="100%" style="background:{SVG_BG}; border-radius:0 0 14px 14px;">',
|
||
f'<rect x="18" y="18" width="{width - 36}" height="{height - 36}" rx="18" fill="{SVG_PANEL}" stroke="#e8dcc9" stroke-width="2"></rect>',
|
||
_svg_text(34, 48, "Wraparound Comparison Player", size=22, weight="800"),
|
||
_svg_text(34, 68, f"Current source term: x^{current_index} with coefficient {current_value}", size=13, fill="#486581"),
|
||
_svg_text(34, top_y - 14, "Raw convolution tail", size=14, weight="700"),
|
||
_svg_text(34, cyclic_y - 14, "Cyclic fold into x^n - 1", size=14, weight="700", fill=SVG_BLUE),
|
||
_svg_text(34, neg_y - 14, "Negacyclic fold into x^n + 1", size=14, weight="700", fill=SVG_ACCENT),
|
||
]
|
||
|
||
for index, coefficient in enumerate(coefficients):
|
||
fill = SVG_HILITE if index == current_index else "#f1f5f9" if index > current_index else "#d8f3dc"
|
||
stroke = SVG_ACCENT if index == current_index else "#cbd2d9" if index > current_index else SVG_GOOD
|
||
parts.append(_svg_box(48 + index * cell, top_y, 48, 48, f"x^{index}", str(coefficient), fill=fill, stroke=stroke))
|
||
|
||
for slot, row in enumerate(cyclic_rows):
|
||
parts.append(_svg_box(48 + slot * cell, cyclic_y, 48, 48, f"slot {slot}", str(row["total"]), fill="#e0fbfc", stroke=SVG_BLUE))
|
||
for slot, row in enumerate(negacyclic_rows):
|
||
parts.append(_svg_box(48 + slot * cell, neg_y, 48, 48, f"slot {slot}", str(row["total"]), fill="#ffe8d6", stroke=SVG_ACCENT))
|
||
|
||
wraps, slot = divmod(current_index, n)
|
||
cyclic_label = f"+ wrap {wraps}"
|
||
neg_label = ("-" if wraps % 2 else "+") + f" wrap {wraps}"
|
||
source_x = 72 + current_index * cell
|
||
target_x = 72 + slot * cell
|
||
|
||
parts.append(f'<line x1="{source_x}" y1="{top_y + 48}" x2="{target_x}" y2="{cyclic_y}" stroke="{SVG_BLUE}" stroke-width="4" marker-end="url(#arrow-blue)"></line>')
|
||
parts.append(f'<line x1="{source_x}" y1="{top_y + 48}" x2="{target_x}" y2="{neg_y}" stroke="{SVG_ACCENT}" stroke-width="4" marker-end="url(#arrow-red)"></line>')
|
||
parts.append(f"""
|
||
<defs>
|
||
<marker id="arrow-blue" markerWidth="10" markerHeight="10" refX="7" refY="3" orient="auto">
|
||
<polygon points="0 0, 8 3, 0 6" fill="{SVG_BLUE}"></polygon>
|
||
</marker>
|
||
<marker id="arrow-red" markerWidth="10" markerHeight="10" refX="7" refY="3" orient="auto">
|
||
<polygon points="0 0, 8 3, 0 6" fill="{SVG_ACCENT}"></polygon>
|
||
</marker>
|
||
</defs>
|
||
""")
|
||
parts.append(_svg_text((source_x + target_x) / 2, cyclic_y - 18, cyclic_label, size=12, fill=SVG_BLUE, anchor="middle"))
|
||
parts.append(_svg_text((source_x + target_x) / 2, neg_y - 18, neg_label, size=12, fill=SVG_ACCENT, anchor="middle"))
|
||
parts.append("</svg>")
|
||
return "".join(parts)
|
||
|
||
|
||
def wraparound_comparison_player(coefficients: Sequence[int], n: int) -> widgets.Widget:
|
||
"""Interactive comparison of cyclic and negacyclic folding, one source term at a time."""
|
||
frames = [_wrap_compare_frame_svg(coefficients, n, step) for step in range(len(coefficients))]
|
||
captions = []
|
||
for index, coefficient in enumerate(coefficients):
|
||
wraps, slot = divmod(index, n)
|
||
neg_sign = "-" if wraps % 2 else "+"
|
||
captions.append(
|
||
f"x^{index} with coefficient {coefficient} lands in slot {slot}. Cyclic folding keeps a + sign; negacyclic folding uses {neg_sign} after {wraps} wrap(s)."
|
||
)
|
||
return _player_widget(
|
||
title="Wraparound Step Player",
|
||
subtitle="Play one raw coefficient at a time and compare x^n-1 with x^n+1 side by side.",
|
||
frames=frames,
|
||
captions=captions,
|
||
)
|
||
|
||
|
||
def _direct_ntt_frame_svg(values: Sequence[int], modulus: int, psi: int, output_index: int, input_index: int) -> str:
|
||
exponents = ntt_psi_exponent_grid(len(values))
|
||
matrix = ntt_psi_matrix(len(values), modulus, psi)
|
||
cell = 62
|
||
width = 1040
|
||
height = 470
|
||
grid_x = 260
|
||
grid_y = 84
|
||
boxes_y = 360
|
||
contributions = []
|
||
for row, value in enumerate(values):
|
||
exponent = exponents[output_index][row]
|
||
factor = matrix[output_index][row]
|
||
product = (value * factor) % modulus
|
||
contributions.append((exponent, factor, product))
|
||
|
||
partial = sum(product for _, _, product in contributions[: input_index + 1]) % modulus
|
||
final = sum(product for _, _, product in contributions) % modulus
|
||
|
||
parts = [
|
||
f'<svg viewBox="0 0 {width} {height}" width="100%" style="background:{SVG_BG}; border-radius:0 0 14px 14px;">',
|
||
f'<rect x="18" y="18" width="{width - 36}" height="{height - 36}" rx="18" fill="{SVG_PANEL}" stroke="#e8dcc9" stroke-width="2"></rect>',
|
||
_svg_text(34, 48, "Direct NTTψ Contribution Player", size=22, weight="800"),
|
||
_svg_text(34, 68, f"Building output slot j={output_index}, currently consuming input i={input_index}", size=13, fill="#486581"),
|
||
_svg_text(34, 108, "signal", size=14, weight="700"),
|
||
_svg_text(grid_x, 68, "matrix cell = psi^(2ij + i)", size=14, weight="700"),
|
||
]
|
||
|
||
for index, value in enumerate(values):
|
||
fill = SVG_HILITE if index == input_index else "#d8f3dc" if index < input_index else "#f1f5f9"
|
||
stroke = SVG_ACCENT if index == input_index else SVG_GOOD if index < input_index else "#cbd2d9"
|
||
parts.append(_svg_box(36, 128 + index * 60, 92, 48, f"a{index}", str(value), fill=fill, stroke=stroke))
|
||
|
||
for row in range(len(values)):
|
||
parts.append(_svg_text(grid_x - 18, grid_y + row * cell + 34, f"i={row}", size=12, anchor="end"))
|
||
for col in range(len(values)):
|
||
parts.append(_svg_text(grid_x + col * cell + 28, grid_y - 12, f"j={col}", size=12, anchor="middle"))
|
||
|
||
for col in range(len(values)):
|
||
for row in range(len(values)):
|
||
active = col == output_index and row == input_index
|
||
done = col == output_index and row < input_index
|
||
fill = SVG_HILITE if active else "#d8f3dc" if done else "#f1f5f9"
|
||
stroke = SVG_ACCENT if active else SVG_GOOD if done else "#cbd2d9"
|
||
parts.append(
|
||
_svg_box(
|
||
grid_x + col * cell,
|
||
grid_y + row * cell,
|
||
56,
|
||
56,
|
||
f"{exponents[col][row]}",
|
||
str(matrix[col][row]),
|
||
fill=fill,
|
||
stroke=stroke,
|
||
stroke_width=3.0 if active else 1.6,
|
||
)
|
||
)
|
||
|
||
for index, (_, factor, product) in enumerate(contributions):
|
||
fill = SVG_HILITE if index == input_index else "#d8f3dc" if index < input_index else "#f1f5f9"
|
||
stroke = SVG_ACCENT if index == input_index else SVG_GOOD if index < input_index else "#cbd2d9"
|
||
parts.append(_svg_box(520 + index * 90, boxes_y, 74, 54, f"a{index}·w", str(product), fill=fill, stroke=stroke))
|
||
parts.append(_svg_text(556 + index * 90, boxes_y - 10, f"factor {factor}", size=11, anchor="middle"))
|
||
|
||
parts.append(_svg_text(34, height - 60, f"Current partial sum for j={output_index}: {partial} mod {modulus}", size=15, weight="700"))
|
||
parts.append(_svg_text(34, height - 34, f"Completed output slot y{output_index}: {final} mod {modulus}", size=15, weight="700", fill=SVG_ACCENT))
|
||
parts.append("</svg>")
|
||
return "".join(parts)
|
||
|
||
|
||
def direct_ntt_player(values: Sequence[int], modulus: int, psi: int) -> widgets.Widget:
|
||
"""Interactive walkthrough of the direct NTT_psi matrix multiplication."""
|
||
frames = []
|
||
captions = []
|
||
for output_index in range(len(values)):
|
||
for input_index in range(len(values)):
|
||
exponent = 2 * input_index * output_index + input_index
|
||
factor = pow(psi, exponent, modulus)
|
||
product = (values[input_index] * factor) % modulus
|
||
frames.append(_direct_ntt_frame_svg(values, modulus, psi, output_index, input_index))
|
||
captions.append(
|
||
f"Output slot j={output_index}: multiply a{input_index}={values[input_index]} by psi^{exponent}={factor} to contribute {product} mod {modulus}."
|
||
)
|
||
return _player_widget(
|
||
title="Direct NTTψ Player",
|
||
subtitle="Press play and watch the transform build one contribution at a time.",
|
||
frames=frames,
|
||
captions=captions,
|
||
)
|
||
|
||
|
||
def _butterfly_story_frame_svg(trace: TransformTrace, stage_index: int, pair_index: int) -> str:
|
||
stage = trace.stages[stage_index]
|
||
left, right = stage.pairings[pair_index]
|
||
zeta = stage.zetas[pair_index]
|
||
width = max(820, len(stage.input_values) * 120)
|
||
height = 340
|
||
input_y = 110
|
||
output_y = 240
|
||
spacing = 92
|
||
start_x = 60
|
||
parts = [
|
||
f'<svg viewBox="0 0 {width} {height}" width="100%" style="background:{SVG_BG}; border-radius:0 0 14px 14px;">',
|
||
f'<rect x="18" y="18" width="{width - 36}" height="{height - 36}" rx="18" fill="{SVG_PANEL}" stroke="#e8dcc9" stroke-width="2"></rect>',
|
||
_svg_text(34, 48, f"{trace.algorithm.upper()} Butterfly Player", size=22, weight="800"),
|
||
_svg_text(34, 68, f"Stage {stage.stage_index}, active pair ({left}, {right}), zeta={zeta}", size=13, fill="#486581"),
|
||
]
|
||
|
||
for index, value in enumerate(stage.input_values):
|
||
active = index in (left, right)
|
||
fill = SVG_HILITE if active else "#f1f5f9"
|
||
stroke = SVG_ACCENT if active else "#cbd2d9"
|
||
parts.append(_svg_box(start_x + index * spacing, input_y, 64, 52, f"in {index}", str(value), fill=fill, stroke=stroke, stroke_width=3.0 if active else 1.8))
|
||
|
||
for index, value in enumerate(stage.output_values):
|
||
active = index in (left, right)
|
||
fill = "#d8f3dc" if active else "#f1f5f9"
|
||
stroke = SVG_GOOD if active else "#cbd2d9"
|
||
parts.append(_svg_box(start_x + index * spacing, output_y, 64, 52, f"out {index}", str(value), fill=fill, stroke=stroke, stroke_width=3.0 if active else 1.8))
|
||
|
||
for index in range(len(stage.input_values)):
|
||
x = start_x + index * spacing + 32
|
||
color = SVG_ACCENT if index in (left, right) else "#d9d9d9"
|
||
width_line = 4 if index in (left, right) else 1.8
|
||
parts.append(f'<line x1="{x}" y1="{input_y + 52}" x2="{x}" y2="{output_y}" stroke="{color}" stroke-width="{width_line}"></line>')
|
||
|
||
x_left = start_x + left * spacing + 32
|
||
x_right = start_x + right * spacing + 32
|
||
parts.append(f'<line x1="{x_left}" y1="{input_y + 62}" x2="{x_right}" y2="{input_y + 62}" stroke="{SVG_ACCENT}" stroke-width="4"></line>')
|
||
parts.append(_svg_text((x_left + x_right) / 2, 196, f"pair ({left}, {right})", size=12, weight="700", fill=SVG_ACCENT, anchor="middle"))
|
||
parts.append(_svg_text((x_left + x_right) / 2, 214, f"inputs -> outputs = ({stage.input_values[left]}, {stage.input_values[right]}) -> ({stage.output_values[left]}, {stage.output_values[right]})", size=11, anchor="middle"))
|
||
parts.append(_svg_text((x_left + x_right) / 2, 232, stage.note, size=11, anchor="middle", fill="#486581"))
|
||
parts.append("</svg>")
|
||
return "".join(parts)
|
||
|
||
|
||
def butterfly_story_player(trace: TransformTrace) -> widgets.Widget:
|
||
"""Interactive pair-by-pair walkthrough of a butterfly trace."""
|
||
frames = []
|
||
captions = []
|
||
for stage_index, stage in enumerate(trace.stages):
|
||
for pair_index, pair in enumerate(stage.pairings):
|
||
left, right = pair
|
||
frames.append(_butterfly_story_frame_svg(trace, stage_index, pair_index))
|
||
captions.append(
|
||
f"Stage {stage.stage_index}: the active pair is {pair}. Watch only these two wires. Everything else is parked until its own butterfly fires."
|
||
)
|
||
return _player_widget(
|
||
title="Butterfly Pair Player",
|
||
subtitle="Play pair by pair. This is the local machine the learner needs to internalize.",
|
||
frames=frames,
|
||
captions=captions,
|
||
)
|
||
|
||
|
||
def _value_colors(values: Sequence[int]) -> list[str]:
|
||
colors = []
|
||
for value in values:
|
||
if value < 0:
|
||
colors.append("#f08a5d")
|
||
elif value == 0:
|
||
colors.append("#d9d9d9")
|
||
else:
|
||
colors.append("#7ad3a8")
|
||
return colors
|
||
|
||
|
||
def _draw_value_row(ax, values: Sequence[int], y: float, prefix: str) -> None:
|
||
colors = _value_colors(values)
|
||
for index, (value, color) in enumerate(zip(values, colors)):
|
||
ax.text(
|
||
index,
|
||
y,
|
||
f"{prefix}{index}\n{value}",
|
||
ha="center",
|
||
va="center",
|
||
fontsize=10,
|
||
family="monospace",
|
||
bbox={
|
||
"boxstyle": "round,pad=0.35",
|
||
"facecolor": color,
|
||
"edgecolor": "#222222",
|
||
"linewidth": 1.2,
|
||
},
|
||
)
|
||
|
||
|
||
def _annotate_grid(ax, grid: Sequence[Sequence[int]]) -> None:
|
||
for row, row_values in enumerate(grid):
|
||
for column, value in enumerate(row_values):
|
||
ax.text(
|
||
column,
|
||
row,
|
||
str(value),
|
||
ha="center",
|
||
va="center",
|
||
color="#101010",
|
||
fontsize=10,
|
||
family="monospace",
|
||
)
|
||
|
||
|
||
def plot_integer_grid(
|
||
grid: Sequence[Sequence[int]],
|
||
*,
|
||
title: str,
|
||
x_label: str,
|
||
y_label: str,
|
||
cmap: str = "YlGnBu",
|
||
):
|
||
"""Plot a heatmap with the exact integer values written in every cell."""
|
||
if not grid or not grid[0]:
|
||
raise ValueError("plot_integer_grid requires a non-empty rectangular grid")
|
||
|
||
fig, ax = plt.subplots(figsize=(max(6, len(grid[0]) * 1.2), max(4, len(grid) * 0.85)))
|
||
ax.imshow(grid, cmap=cmap, aspect="auto")
|
||
ax.set_title(title, fontsize=14, fontweight="bold")
|
||
ax.set_xlabel(x_label)
|
||
ax.set_ylabel(y_label)
|
||
ax.set_xticks(range(len(grid[0])))
|
||
ax.set_yticks(range(len(grid)))
|
||
_annotate_grid(ax, grid)
|
||
fig.tight_layout()
|
||
return fig
|
||
|
||
|
||
def plot_convolution_grid(
|
||
left: Sequence[int], right: Sequence[int], title: str = "Schoolbook Product Grid"
|
||
):
|
||
"""Plot the full schoolbook multiplication table and the diagonal sums."""
|
||
grid = pairwise_product_grid(left, right)
|
||
diagonal_sums = []
|
||
for diagonal in range(len(left) + len(right) - 1):
|
||
total = 0
|
||
for row in range(len(left)):
|
||
column = diagonal - row
|
||
if 0 <= column < len(right):
|
||
total += grid[row][column]
|
||
diagonal_sums.append(total)
|
||
|
||
fig, axes = plt.subplots(2, 1, figsize=(max(7, len(right) * 1.2), 6), height_ratios=[3, 1])
|
||
heatmap_ax, sum_ax = axes
|
||
|
||
heatmap_ax.imshow(grid, cmap="YlGnBu", aspect="auto")
|
||
heatmap_ax.set_title(title, fontsize=14, fontweight="bold")
|
||
heatmap_ax.set_xlabel("right coefficient index")
|
||
heatmap_ax.set_ylabel("left coefficient index")
|
||
heatmap_ax.set_xticks(range(len(right)))
|
||
heatmap_ax.set_yticks(range(len(left)))
|
||
|
||
_annotate_grid(heatmap_ax, grid)
|
||
|
||
sum_ax.axis("off")
|
||
sum_ax.set_title("Diagonal Sums = Convolution Coefficients", fontsize=12, fontweight="bold", pad=8)
|
||
for index, value in enumerate(diagonal_sums):
|
||
sum_ax.text(
|
||
index,
|
||
0,
|
||
f"y{index}\n{value}",
|
||
ha="center",
|
||
va="center",
|
||
fontsize=10,
|
||
family="monospace",
|
||
bbox={
|
||
"boxstyle": "round,pad=0.35",
|
||
"facecolor": "#f4f1de",
|
||
"edgecolor": "#222222",
|
||
"linewidth": 1.0,
|
||
},
|
||
)
|
||
sum_ax.set_xlim(-0.5, len(diagonal_sums) - 0.5)
|
||
sum_ax.set_ylim(-1, 1)
|
||
fig.tight_layout()
|
||
return fig
|
||
|
||
|
||
def plot_ntt_psi_exponent_heatmap(length: int, title: str = "NTT_psi Exponent Grid"):
|
||
"""Plot the exponent pattern 2ij + i used by the direct negative-wrapped NTT."""
|
||
return plot_integer_grid(
|
||
ntt_psi_exponent_grid(length),
|
||
title=title,
|
||
x_label="output index j",
|
||
y_label="input index i",
|
||
cmap="YlOrRd",
|
||
)
|
||
|
||
|
||
def plot_ntt_psi_matrix_heatmap(length: int, modulus: int, psi: int, title: str = "NTT_psi Matrix Values"):
|
||
"""Plot the concrete direct transform matrix over Z_q."""
|
||
return plot_integer_grid(
|
||
ntt_psi_matrix(length, modulus, psi),
|
||
title=title,
|
||
x_label="output index j",
|
||
y_label="input index i",
|
||
cmap="PuBuGn",
|
||
)
|
||
|
||
|
||
def plot_wraparound(
|
||
coefficients: Sequence[int],
|
||
n: int,
|
||
*,
|
||
negacyclic: bool = True,
|
||
title: str | None = None,
|
||
):
|
||
"""Plot how the tail wraps back into degree < n."""
|
||
rows = wraparound_contributions(coefficients, n=n, negacyclic=negacyclic)
|
||
if title is None:
|
||
title = "Negacyclic Folding" if negacyclic else "Cyclic Folding"
|
||
|
||
fig, ax = plt.subplots(figsize=(max(8, len(coefficients) * 1.1), 5.5))
|
||
ax.set_title(title, fontsize=14, fontweight="bold")
|
||
ax.axis("off")
|
||
|
||
top_y = 2.4
|
||
bottom_y = 0.4
|
||
_draw_value_row(ax, coefficients, top_y, "x^")
|
||
reduced_values = [row["total"] for row in rows]
|
||
_draw_value_row(ax, reduced_values, bottom_y, "slot ")
|
||
|
||
for slot, row in enumerate(rows):
|
||
for contribution in row["contributions"]:
|
||
source_index = contribution["source_index"]
|
||
color = "#d1495b" if contribution["sign"] < 0 else "#2a9d8f"
|
||
label = "-" if contribution["sign"] < 0 else "+"
|
||
ax.annotate(
|
||
"",
|
||
xy=(slot, bottom_y + 0.3),
|
||
xytext=(source_index, top_y - 0.25),
|
||
arrowprops={"arrowstyle": "->", "color": color, "linewidth": 2.0},
|
||
)
|
||
mid_x = (slot + source_index) / 2
|
||
mid_y = (top_y + bottom_y) / 2 + 0.25
|
||
ax.text(
|
||
mid_x,
|
||
mid_y,
|
||
f"{label} wrap {contribution['wraps']}",
|
||
ha="center",
|
||
va="center",
|
||
fontsize=9,
|
||
color=color,
|
||
family="monospace",
|
||
)
|
||
|
||
ax.set_xlim(-0.8, max(len(coefficients), n) - 0.2)
|
||
ax.set_ylim(-0.4, 3.2)
|
||
fig.tight_layout()
|
||
return fig
|
||
|
||
|
||
def plot_vector_comparison(
|
||
left: Sequence[int],
|
||
right: Sequence[int],
|
||
*,
|
||
left_label: str = "left",
|
||
right_label: str = "right",
|
||
title: str = "Vector Comparison",
|
||
):
|
||
"""Plot two vectors slot-by-slot with explicit differences."""
|
||
if len(left) != len(right):
|
||
raise ValueError("plot_vector_comparison requires equal-length vectors")
|
||
|
||
differences = [int(right_value - left_value) for left_value, right_value in zip(left, right)]
|
||
fig, axes = plt.subplots(3, 1, figsize=(max(8, len(left) * 1.3), 7), height_ratios=[1, 1, 1])
|
||
labels = [left_label, right_label, "delta"]
|
||
rows = [left, right, differences]
|
||
row_colors = ["#edf6f9", "#fff3b0", "#f5cac3"]
|
||
|
||
for ax, label, values, row_color in zip(axes, labels, rows, row_colors):
|
||
ax.axis("off")
|
||
ax.set_title(label, fontsize=12, fontweight="bold", pad=6)
|
||
for index, value in enumerate(values):
|
||
ax.text(
|
||
index,
|
||
0,
|
||
f"{index}\n{value}",
|
||
ha="center",
|
||
va="center",
|
||
fontsize=10,
|
||
family="monospace",
|
||
bbox={
|
||
"boxstyle": "round,pad=0.32",
|
||
"facecolor": row_color if label != "delta" else _value_colors([value])[0],
|
||
"edgecolor": "#222222",
|
||
"linewidth": 1.1,
|
||
},
|
||
)
|
||
ax.set_xlim(-0.8, len(values) - 0.2)
|
||
ax.set_ylim(-0.8, 0.8)
|
||
|
||
fig.suptitle(title, fontsize=14, fontweight="bold")
|
||
fig.tight_layout()
|
||
return fig
|
||
|
||
|
||
def plot_bit_reversal_mapping(length: int, title: str = "Normal Order To Bit-Reversed Order"):
|
||
"""Plot the bit-reversal permutation as explicit wires."""
|
||
if length <= 0 or length & (length - 1):
|
||
raise ValueError("plot_bit_reversal_mapping requires a power-of-two length")
|
||
|
||
from .toy_ntt import bit_reversed_indices
|
||
|
||
permutation = bit_reversed_indices(length)
|
||
width = length.bit_length() - 1
|
||
|
||
fig, ax = plt.subplots(figsize=(8, max(4, length * 0.65)))
|
||
ax.set_title(title, fontsize=14, fontweight="bold")
|
||
ax.axis("off")
|
||
|
||
for index, target in enumerate(permutation):
|
||
ax.text(
|
||
0,
|
||
-index,
|
||
f"{index:>2} | {index:0{width}b}",
|
||
ha="center",
|
||
va="center",
|
||
family="monospace",
|
||
bbox={"boxstyle": "round,pad=0.25", "facecolor": "#edf6f9", "edgecolor": "#264653"},
|
||
)
|
||
ax.text(
|
||
4,
|
||
-target,
|
||
f"{target:>2} | {target:0{width}b}",
|
||
ha="center",
|
||
va="center",
|
||
family="monospace",
|
||
bbox={"boxstyle": "round,pad=0.25", "facecolor": "#fff3b0", "edgecolor": "#9c6644"},
|
||
)
|
||
ax.plot([0.6, 3.4], [-index, -target], color="#7f5539", linewidth=2.2, alpha=0.9)
|
||
|
||
ax.text(0, 1, "NO", ha="center", va="center", fontsize=12, fontweight="bold")
|
||
ax.text(4, 1, "BO", ha="center", va="center", fontsize=12, fontweight="bold")
|
||
ax.set_xlim(-1.2, 5.2)
|
||
ax.set_ylim(-length + 0.2, 1.8)
|
||
fig.tight_layout()
|
||
return fig
|
||
|
||
|
||
def plot_butterfly_network(trace: TransformTrace, title: str | None = None):
|
||
"""Plot the whole staged network with pair links visible at each stage."""
|
||
if title is None:
|
||
title = f"{trace.algorithm.upper()} Butterfly Network"
|
||
|
||
columns = [trace.input_values] + [stage.output_values for stage in trace.stages]
|
||
fig, ax = plt.subplots(figsize=(max(10, len(columns) * 2.4), max(5, len(trace.input_values) * 0.8)))
|
||
ax.set_title(title, fontsize=14, fontweight="bold")
|
||
ax.axis("off")
|
||
|
||
x_positions = [index * 2.4 for index in range(len(columns))]
|
||
|
||
for column_index, (x, values) in enumerate(zip(x_positions, columns)):
|
||
for row_index, value in enumerate(values):
|
||
ax.text(
|
||
x,
|
||
-row_index,
|
||
str(value),
|
||
ha="center",
|
||
va="center",
|
||
fontsize=10,
|
||
family="monospace",
|
||
bbox={
|
||
"boxstyle": "round,pad=0.26",
|
||
"facecolor": _value_colors([value])[0],
|
||
"edgecolor": "#222222",
|
||
"linewidth": 1.0,
|
||
},
|
||
)
|
||
|
||
label = "input" if column_index == 0 else f"s{column_index}"
|
||
ax.text(x, 1.1, label, ha="center", va="center", fontsize=11, fontweight="bold")
|
||
|
||
if column_index == 0:
|
||
continue
|
||
|
||
previous_x = x_positions[column_index - 1]
|
||
stage = trace.stages[column_index - 1]
|
||
colors = ["#264653", "#2a9d8f", "#e76f51", "#8d99ae", "#c1121f", "#3a86ff"]
|
||
|
||
for index in range(len(values)):
|
||
ax.plot([previous_x + 0.35, x - 0.35], [-index, -index], color="#b0b0b0", linewidth=0.9, alpha=0.65)
|
||
|
||
for pair_index, ((left, right), zeta) in enumerate(zip(stage.pairings, stage.zetas)):
|
||
color = colors[pair_index % len(colors)]
|
||
x_mid = (previous_x + x) / 2
|
||
ax.plot([previous_x + 0.35, x - 0.35], [-left, -left], color=color, linewidth=2.2)
|
||
ax.plot([previous_x + 0.35, x - 0.35], [-right, -right], color=color, linewidth=2.2)
|
||
ax.plot([x_mid, x_mid], [-left, -right], color=color, linewidth=2.6, alpha=0.95)
|
||
ax.text(
|
||
x_mid,
|
||
-((left + right) / 2),
|
||
f"zeta={zeta}",
|
||
ha="center",
|
||
va="center",
|
||
fontsize=8,
|
||
family="monospace",
|
||
bbox={
|
||
"boxstyle": "round,pad=0.18",
|
||
"facecolor": "#ffffff",
|
||
"edgecolor": color,
|
||
"linewidth": 1.0,
|
||
},
|
||
)
|
||
|
||
ax.set_xlim(-1.1, x_positions[-1] + 1.1)
|
||
ax.set_ylim(-len(trace.input_values) + 0.2, 1.8)
|
||
fig.tight_layout()
|
||
return fig
|
||
|
||
|
||
def plot_stage_pairing_map(
|
||
length: int,
|
||
block_size: int,
|
||
*,
|
||
title: str | None = None,
|
||
):
|
||
"""Plot which indices talk to each other in one butterfly stage."""
|
||
if title is None:
|
||
title = f"Stage Pairing Map (n={length}, block={block_size})"
|
||
|
||
pairs = stage_pairings(length, block_size)
|
||
fig, ax = plt.subplots(figsize=(10, max(4.2, length * 0.55)))
|
||
ax.set_title(title, fontsize=14, fontweight="bold")
|
||
ax.axis("off")
|
||
|
||
for index in range(length):
|
||
ax.text(
|
||
0,
|
||
-index,
|
||
f"{index}",
|
||
ha="center",
|
||
va="center",
|
||
fontsize=10,
|
||
family="monospace",
|
||
bbox={"boxstyle": "round,pad=0.25", "facecolor": "#edf6f9", "edgecolor": "#264653"},
|
||
)
|
||
|
||
colors = ["#264653", "#2a9d8f", "#e76f51", "#8d99ae", "#c1121f", "#3a86ff"]
|
||
for pair_index, (left, right) in enumerate(pairs):
|
||
color = colors[pair_index % len(colors)]
|
||
ax.plot([0.6, 4.2], [-left, -right], color=color, linewidth=2.4)
|
||
ax.text(
|
||
5.4,
|
||
-((left + right) / 2),
|
||
f"{left} <-> {right}",
|
||
ha="center",
|
||
va="center",
|
||
fontsize=9,
|
||
family="monospace",
|
||
bbox={"boxstyle": "round,pad=0.22", "facecolor": "#ffffff", "edgecolor": color},
|
||
)
|
||
|
||
ax.text(0, 1.0, "indices", ha="center", va="center", fontsize=11, fontweight="bold")
|
||
ax.text(5.4, 1.0, "pairs", ha="center", va="center", fontsize=11, fontweight="bold")
|
||
ax.set_xlim(-1.0, 6.8)
|
||
ax.set_ylim(-length + 0.2, 1.8)
|
||
fig.tight_layout()
|
||
return fig
|
||
|
||
|
||
def plot_stage_schedule(length: int, title: str | None = None):
|
||
"""Plot the full stage schedule for a power-of-two transform length."""
|
||
if length <= 0 or length & (length - 1):
|
||
raise ValueError("plot_stage_schedule requires a power-of-two length")
|
||
if title is None:
|
||
title = f"Butterfly Stage Schedule For n={length}"
|
||
|
||
stages = []
|
||
block_size = 2
|
||
stage_index = 1
|
||
while block_size <= length:
|
||
stages.append(
|
||
{
|
||
"stage": stage_index,
|
||
"block_size": block_size,
|
||
"pair_distance": block_size // 2,
|
||
"pair_count": len(stage_pairings(length, block_size)),
|
||
}
|
||
)
|
||
block_size *= 2
|
||
stage_index += 1
|
||
|
||
fig, ax = plt.subplots(figsize=(11, max(4.5, len(stages) * 0.95)))
|
||
ax.set_title(title, fontsize=14, fontweight="bold")
|
||
ax.axis("off")
|
||
|
||
headers = ["stage", "block", "distance", "pairs", "what happens"]
|
||
x_positions = [0, 2.0, 4.0, 6.0, 9.2]
|
||
for x, header in zip(x_positions, headers):
|
||
ax.text(x, 1.0, header, ha="center", va="center", fontsize=11, fontweight="bold")
|
||
|
||
for row_index, row in enumerate(stages):
|
||
y = -row_index
|
||
explanation = f"indices {row['pair_distance']} apart talk inside blocks of {row['block_size']}"
|
||
values = [
|
||
str(row["stage"]),
|
||
str(row["block_size"]),
|
||
str(row["pair_distance"]),
|
||
str(row["pair_count"]),
|
||
explanation,
|
||
]
|
||
for x, value in zip(x_positions, values):
|
||
ax.text(
|
||
x,
|
||
y,
|
||
value,
|
||
ha="center",
|
||
va="center",
|
||
fontsize=10,
|
||
family="monospace",
|
||
bbox={
|
||
"boxstyle": "round,pad=0.24",
|
||
"facecolor": "#edf6f9" if x < 8 else "#ffffff",
|
||
"edgecolor": "#264653",
|
||
"linewidth": 1.0,
|
||
},
|
||
)
|
||
|
||
ax.set_xlim(-1.0, 12.3)
|
||
ax.set_ylim(-len(stages) + 0.2, 1.7)
|
||
fig.tight_layout()
|
||
return fig
|
||
|
||
|
||
def plot_stage(stage: TransformStage, title: str | None = None):
|
||
"""Plot one explicit butterfly stage with input and output rows."""
|
||
if title is None:
|
||
title = f"{stage.algorithm.upper()} Stage {stage.stage_index}"
|
||
|
||
fig, ax = plt.subplots(figsize=(max(8, len(stage.input_values) * 1.35), 5.8))
|
||
ax.set_title(title, fontsize=14, fontweight="bold")
|
||
ax.axis("off")
|
||
|
||
input_y = 2.6
|
||
output_y = 0.5
|
||
_draw_value_row(ax, stage.input_values, input_y, "i")
|
||
_draw_value_row(ax, stage.output_values, output_y, "o")
|
||
|
||
colors = ["#264653", "#2a9d8f", "#e76f51", "#8d99ae", "#c1121f", "#3a86ff"]
|
||
for pair_index, ((left, right), zeta) in enumerate(zip(stage.pairings, stage.zetas)):
|
||
color = colors[pair_index % len(colors)]
|
||
center_x = (left + right) / 2
|
||
ax.plot([left, right], [input_y - 0.45, input_y - 0.45], color=color, linewidth=2.5)
|
||
ax.plot([left, left], [input_y - 0.45, output_y + 0.55], color=color, linewidth=1.5, alpha=0.85)
|
||
ax.plot([right, right], [input_y - 0.45, output_y + 0.55], color=color, linewidth=1.5, alpha=0.85)
|
||
ax.text(
|
||
center_x,
|
||
1.55,
|
||
f"pair {left}-{right}\nzeta={zeta}",
|
||
ha="center",
|
||
va="center",
|
||
fontsize=10,
|
||
family="monospace",
|
||
bbox={
|
||
"boxstyle": "round,pad=0.35",
|
||
"facecolor": "#ffffff",
|
||
"edgecolor": color,
|
||
"linewidth": 1.4,
|
||
},
|
||
)
|
||
|
||
ax.text(
|
||
len(stage.input_values) / 2 - 0.5,
|
||
-0.05,
|
||
stage.note,
|
||
ha="center",
|
||
va="center",
|
||
fontsize=10,
|
||
color="#333333",
|
||
)
|
||
ax.set_xlim(-0.8, len(stage.input_values) - 0.2)
|
||
ax.set_ylim(-0.5, 3.3)
|
||
fig.tight_layout()
|
||
return fig
|
||
|
||
|
||
def plot_transform_pipeline(
|
||
left: Sequence[int],
|
||
right: Sequence[int],
|
||
*,
|
||
modulus: int,
|
||
psi: int,
|
||
title: str = "Transform-Domain Multiply Pipeline",
|
||
):
|
||
"""Plot the end-to-end direct NTT_psi multiply pipeline."""
|
||
left_hat = forward_ntt_psi(left, modulus, psi)
|
||
right_hat = forward_ntt_psi(right, modulus, psi)
|
||
product_hat = pointwise_multiply(left_hat, right_hat, modulus)
|
||
recovered = inverse_ntt_psi(product_hat, modulus, psi)
|
||
|
||
lanes = [
|
||
("left", list(left)),
|
||
("right", list(right)),
|
||
("left_hat", left_hat),
|
||
("right_hat", right_hat),
|
||
("pointwise", product_hat),
|
||
("inverse", recovered),
|
||
]
|
||
|
||
fig, ax = plt.subplots(figsize=(max(10, len(lanes) * 2.15), max(4.6, len(left) * 0.85)))
|
||
ax.set_title(title, fontsize=14, fontweight="bold")
|
||
ax.axis("off")
|
||
|
||
for lane_index, (label, values) in enumerate(lanes):
|
||
x = lane_index * 2.2
|
||
for row_index, value in enumerate(values):
|
||
ax.text(
|
||
x,
|
||
-row_index,
|
||
str(value),
|
||
ha="center",
|
||
va="center",
|
||
fontsize=10,
|
||
family="monospace",
|
||
bbox={
|
||
"boxstyle": "round,pad=0.24",
|
||
"facecolor": _value_colors([value])[0],
|
||
"edgecolor": "#222222",
|
||
"linewidth": 1.0,
|
||
},
|
||
)
|
||
ax.text(x, 1.1, label, ha="center", va="center", fontsize=10, fontweight="bold")
|
||
if lane_index < len(lanes) - 1:
|
||
ax.annotate(
|
||
"",
|
||
xy=(x + 1.5, -len(values) / 2 + 0.4),
|
||
xytext=(x + 0.6, -len(values) / 2 + 0.4),
|
||
arrowprops={"arrowstyle": "->", "color": "#6c757d", "linewidth": 1.8},
|
||
)
|
||
|
||
ax.text(4.4, 1.6, "NTT_psi", ha="center", va="center", fontsize=10, family="monospace")
|
||
ax.text(8.8, 1.6, "slotwise *", ha="center", va="center", fontsize=10, family="monospace")
|
||
ax.text(11.0, 1.6, "INTT_psi", ha="center", va="center", fontsize=10, family="monospace")
|
||
ax.set_xlim(-1.0, (len(lanes) - 1) * 2.2 + 1.1)
|
||
ax.set_ylim(-len(left) + 0.2, 2.0)
|
||
fig.tight_layout()
|
||
return fig
|
||
|
||
|
||
def plot_base_multiply_pair_diagram(
|
||
left: Sequence[int],
|
||
right: Sequence[int],
|
||
*,
|
||
zeta: int,
|
||
modulus: int,
|
||
title: str = "Base Multiplication On A Degree-1 Pair",
|
||
):
|
||
"""Plot the two-term base multiplication block used in Kyber-style explanations."""
|
||
if len(left) != 2 or len(right) != 2:
|
||
raise ValueError("plot_base_multiply_pair_diagram expects two 2-entry vectors")
|
||
|
||
result = base_multiply_pair(left, right, zeta, modulus)
|
||
fig, ax = plt.subplots(figsize=(9, 4.6))
|
||
ax.set_title(title, fontsize=14, fontweight="bold")
|
||
ax.axis("off")
|
||
|
||
left_x = 0
|
||
right_x = 2.8
|
||
out_x = 6.8
|
||
ys = [0.9, -0.7]
|
||
|
||
for x, label, values, facecolor in [
|
||
(left_x, "left", left, "#edf6f9"),
|
||
(right_x, "right", right, "#fff3b0"),
|
||
(out_x, "out", result, "#d8f3dc"),
|
||
]:
|
||
ax.text(x, 1.8, label, ha="center", va="center", fontsize=12, fontweight="bold")
|
||
for index, (y, value) in enumerate(zip(ys, values)):
|
||
ax.text(
|
||
x,
|
||
y,
|
||
f"{label}[{index}] = {value}",
|
||
ha="center",
|
||
va="center",
|
||
fontsize=10,
|
||
family="monospace",
|
||
bbox={
|
||
"boxstyle": "round,pad=0.3",
|
||
"facecolor": facecolor,
|
||
"edgecolor": "#222222",
|
||
"linewidth": 1.0,
|
||
},
|
||
)
|
||
|
||
ax.annotate("", xy=(left_x + 0.9, 0.1), xytext=(out_x - 1.0, 0.9), arrowprops={"arrowstyle": "->", "linewidth": 2.0, "color": "#355070"})
|
||
ax.annotate("", xy=(left_x + 0.9, -0.7), xytext=(out_x - 1.0, -0.7), arrowprops={"arrowstyle": "->", "linewidth": 2.0, "color": "#355070"})
|
||
ax.annotate("", xy=(right_x + 0.9, 0.1), xytext=(out_x - 1.0, 0.9), arrowprops={"arrowstyle": "->", "linewidth": 2.0, "color": "#6d597a"})
|
||
ax.annotate("", xy=(right_x + 0.9, -0.7), xytext=(out_x - 1.0, -0.7), arrowprops={"arrowstyle": "->", "linewidth": 2.0, "color": "#6d597a"})
|
||
|
||
ax.text(
|
||
4.8,
|
||
1.05,
|
||
f"c0 = a0*b0 + zeta*a1*b1 mod {modulus}\n= {result[0]}",
|
||
ha="center",
|
||
va="center",
|
||
fontsize=10,
|
||
family="monospace",
|
||
bbox={"boxstyle": "round,pad=0.32", "facecolor": "#ffffff", "edgecolor": "#355070"},
|
||
)
|
||
ax.text(
|
||
4.8,
|
||
-1.0,
|
||
f"c1 = a0*b1 + a1*b0 mod {modulus}\n= {result[1]}",
|
||
ha="center",
|
||
va="center",
|
||
fontsize=10,
|
||
family="monospace",
|
||
bbox={"boxstyle": "round,pad=0.32", "facecolor": "#ffffff", "edgecolor": "#6d597a"},
|
||
)
|
||
ax.text(4.8, 0.0, f"zeta = {zeta}", ha="center", va="center", fontsize=10, family="monospace")
|
||
ax.set_xlim(-1.0, 8.2)
|
||
ax.set_ylim(-2.0, 2.2)
|
||
fig.tight_layout()
|
||
return fig
|
||
|
||
|
||
def plot_root_order_comparison(samples: Sequence[tuple[int, int]], title: str = "Root Existence Check"):
|
||
"""Plot which moduli allow n-th and 2n-th root stories."""
|
||
fig, ax = plt.subplots(figsize=(10, max(4.5, len(samples) * 0.75)))
|
||
ax.set_title(title, fontsize=14, fontweight="bold")
|
||
ax.axis("off")
|
||
|
||
headers = ["n", "q", "n | q-1", "2n | q-1"]
|
||
x_positions = [0, 2, 4.2, 6.8]
|
||
for x, header in zip(x_positions, headers):
|
||
ax.text(x, 1.2, header, ha="center", va="center", fontsize=11, fontweight="bold")
|
||
|
||
for row_index, (n, q) in enumerate(samples):
|
||
y = -row_index
|
||
statuses = [str(n), str(q), "yes" if (q - 1) % n == 0 else "no", "yes" if (q - 1) % (2 * n) == 0 else "no"]
|
||
for x, value in zip(x_positions, statuses):
|
||
facecolor = "#d8f3dc" if value == "yes" else "#f5cac3" if value == "no" else "#edf6f9"
|
||
ax.text(
|
||
x,
|
||
y,
|
||
value,
|
||
ha="center",
|
||
va="center",
|
||
fontsize=10,
|
||
family="monospace",
|
||
bbox={"boxstyle": "round,pad=0.25", "facecolor": facecolor, "edgecolor": "#222222"},
|
||
)
|
||
|
||
ax.set_xlim(-1.0, 8.0)
|
||
ax.set_ylim(-len(samples) + 0.2, 1.8)
|
||
fig.tight_layout()
|
||
return fig
|
||
|
||
|
||
def plot_trace_overview(trace: TransformTrace, title: str | None = None):
|
||
"""Plot every stage output as a column of values."""
|
||
if title is None:
|
||
title = f"{trace.algorithm.upper()} Trace Overview"
|
||
|
||
columns = [trace.input_values] + [stage.output_values for stage in trace.stages]
|
||
fig, ax = plt.subplots(figsize=(max(9, len(columns) * 2.0), max(4.5, len(trace.input_values) * 0.7)))
|
||
ax.set_title(title, fontsize=14, fontweight="bold")
|
||
ax.axis("off")
|
||
|
||
for column_index, values in enumerate(columns):
|
||
x = column_index * 2.0
|
||
for row_index, value in enumerate(values):
|
||
ax.text(
|
||
x,
|
||
-row_index,
|
||
str(value),
|
||
ha="center",
|
||
va="center",
|
||
fontsize=10,
|
||
family="monospace",
|
||
bbox={
|
||
"boxstyle": "round,pad=0.25",
|
||
"facecolor": _value_colors([value])[0],
|
||
"edgecolor": "#222222",
|
||
"linewidth": 1.0,
|
||
},
|
||
)
|
||
if column_index == 0:
|
||
label = "input"
|
||
else:
|
||
label = f"stage {column_index}"
|
||
ax.text(x, 1, label, ha="center", va="center", fontsize=11, fontweight="bold")
|
||
|
||
ax.set_xlim(-1.0, (len(columns) - 1) * 2.0 + 1.0)
|
||
ax.set_ylim(-len(trace.input_values) + 0.2, 1.8)
|
||
fig.tight_layout()
|
||
return fig
|
||
|
||
|
||
def interactive_trace(trace: TransformTrace, title: str | None = None):
|
||
"""Return a slider-based stage explorer for a transform trace."""
|
||
if title is None:
|
||
title = f"{trace.algorithm.upper()} Stage Explorer"
|
||
|
||
slider = widgets.IntSlider(
|
||
value=1,
|
||
min=1,
|
||
max=max(1, len(trace.stages)),
|
||
step=1,
|
||
description="Stage",
|
||
continuous_update=False,
|
||
)
|
||
output = widgets.Output()
|
||
|
||
def render(stage_index: int) -> None:
|
||
with output:
|
||
clear_output(wait=True)
|
||
stage = trace.stages[stage_index - 1]
|
||
fig = plot_stage(stage, title=f"{title} | Stage {stage_index}")
|
||
display(fig)
|
||
plt.close(fig)
|
||
rows = []
|
||
for pair, zeta in zip(stage.pairings, stage.zetas):
|
||
left, right = pair
|
||
rows.append(
|
||
f"pair {pair}: inputs=({stage.input_values[left]}, {stage.input_values[right]}) "
|
||
f"-> outputs=({stage.output_values[left]}, {stage.output_values[right]}) | zeta={zeta}"
|
||
)
|
||
print("\n".join(rows))
|
||
|
||
slider.observe(lambda change: render(change["new"]), names="value")
|
||
render(slider.value)
|
||
widget = widgets.VBox([widgets.HTML(f"<h4>{title}</h4>"), slider, output])
|
||
return widget
|
||
|
||
|
||
def show_trace(trace: TransformTrace, title: str | None = None):
|
||
"""Display the interactive trace widget immediately."""
|
||
widget = interactive_trace(trace, title=title)
|
||
display(widget)
|
||
return widget
|