Replace early notebook plots with teaching players

This commit is contained in:
saymrwulf 2026-04-16 11:25:10 +02:00
parent 74700b8457
commit f152567e06
9 changed files with 462 additions and 24 deletions

View file

@ -48,7 +48,7 @@
}
},
"outputs": [],
"source": "# MANDATORY | difficulty 3 | Prediction Check For n=4\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace\nfrom ntt_learning.visuals import interactive_trace\n\ntrace = fast_ntt_psi_ct_trace([1, 2, 3, 4], 7681, 1925)\ndisplay(interactive_trace(trace, title=\"Check your n=4 prediction\"))\n"
"source": "# MANDATORY | difficulty 3 | Prediction Check For n=4\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace\nfrom ntt_learning.visuals import butterfly_story_player, interactive_trace\n\ntrace = fast_ntt_psi_ct_trace([1, 2, 3, 4], 7681, 1925)\ndisplay(butterfly_story_player(trace))\ndisplay(interactive_trace(trace, title=\"Check your n=4 prediction\"))\n"
},
{
"cell_type": "markdown",

View file

@ -62,7 +62,7 @@
}
},
"outputs": [],
"source": "# MANDATORY | difficulty 3 | Trace The Exact n=4 Paper Example\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, forward_ntt_psi\nfrom ntt_learning.visuals import interactive_trace, plot_butterfly_network, plot_trace_overview\n\nsignal = [1, 2, 3, 4]\nmodulus = 7681\npsi = 1925\ntrace = fast_ntt_psi_ct_trace(signal, modulus, psi)\n\nprint(\"raw CT output (BO):\", trace.raw_output)\nprint(\"bit-reversed back to NO:\", trace.normal_order_output)\nprint(\"direct NTT_psi:\", forward_ntt_psi(signal, modulus, psi))\ndisplay(plot_trace_overview(trace, title=\"CT overview for [1,2,3,4]\"))\ndisplay(plot_butterfly_network(trace, title=\"Full CT network for [1,2,3,4]\"))\ndisplay(interactive_trace(trace, title=\"CT forward trace\"))\n"
"source": "# MANDATORY | difficulty 3 | Trace The Exact n=4 Paper Example\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, forward_ntt_psi\nfrom ntt_learning.visuals import butterfly_story_player, interactive_trace, plot_butterfly_network, plot_trace_overview\n\nsignal = [1, 2, 3, 4]\nmodulus = 7681\npsi = 1925\ntrace = fast_ntt_psi_ct_trace(signal, modulus, psi)\n\nprint(\"raw CT output (BO):\", trace.raw_output)\nprint(\"bit-reversed back to NO:\", trace.normal_order_output)\nprint(\"direct NTT_psi:\", forward_ntt_psi(signal, modulus, psi))\ndisplay(butterfly_story_player(trace))\ndisplay(plot_trace_overview(trace, title=\"CT overview for [1,2,3,4]\"))\ndisplay(plot_butterfly_network(trace, title=\"Full CT network for [1,2,3,4]\"))\ndisplay(interactive_trace(trace, title=\"CT forward trace\"))\n"
},
{
"cell_type": "markdown",
@ -102,7 +102,7 @@
}
},
"outputs": [],
"source": "# MANDATORY | difficulty 3 | Go One Stage Deeper With n=8\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, find_psi\nfrom ntt_learning.visuals import interactive_trace, plot_butterfly_network, plot_trace_overview\n\nsignal = [0, 1, 2, 3, 4, 5, 6, 7]\nmodulus = 97\npsi = find_psi(8, modulus)\ntrace = fast_ntt_psi_ct_trace(signal, modulus, psi)\n\nprint(\"psi:\", psi)\nprint(\"BO output:\", trace.raw_output)\nprint(\"NO output:\", trace.normal_order_output)\ndisplay(plot_trace_overview(trace, title=\"Three CT stages for n=8\"))\ndisplay(plot_butterfly_network(trace, title=\"Full CT network for n=8\"))\ndisplay(interactive_trace(trace, title=\"n=8 CT stage explorer\"))\n"
"source": "# MANDATORY | difficulty 3 | Go One Stage Deeper With n=8\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, find_psi\nfrom ntt_learning.visuals import butterfly_story_player, interactive_trace, plot_butterfly_network, plot_trace_overview\n\nsignal = [0, 1, 2, 3, 4, 5, 6, 7]\nmodulus = 97\npsi = find_psi(8, modulus)\ntrace = fast_ntt_psi_ct_trace(signal, modulus, psi)\n\nprint(\"psi:\", psi)\nprint(\"BO output:\", trace.raw_output)\nprint(\"NO output:\", trace.normal_order_output)\ndisplay(butterfly_story_player(trace))\ndisplay(plot_trace_overview(trace, title=\"Three CT stages for n=8\"))\ndisplay(plot_butterfly_network(trace, title=\"Full CT network for n=8\"))\ndisplay(interactive_trace(trace, title=\"n=8 CT stage explorer\"))\n"
},
{
"cell_type": "code",

View file

@ -48,7 +48,7 @@
}
},
"outputs": [],
"source": "# MANDATORY | difficulty 2 | Run The Prediction Check\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution\nfrom ntt_learning.visuals import plot_convolution_grid, plot_wraparound\n\nleft = [3, 0, 2, 1]\nright = [1, 4, 0, 2]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nprint(\"negacyclic result:\", negacyclic_multiply(left, right, n=4))\ndisplay(plot_convolution_grid(left, right, title=\"Prediction check grid\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Prediction check fold\"))\n"
"source": "# MANDATORY | difficulty 2 | Run The Prediction Check\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution\nfrom ntt_learning.visuals import plot_convolution_grid, plot_wraparound, schoolbook_diagonal_player, wraparound_comparison_player\n\nleft = [3, 0, 2, 1]\nright = [1, 4, 0, 2]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nprint(\"negacyclic result:\", negacyclic_multiply(left, right, n=4))\ndisplay(schoolbook_diagonal_player(left, right))\ndisplay(wraparound_comparison_player(raw, n=4))\ndisplay(plot_convolution_grid(left, right, title=\"Prediction check grid\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Prediction check fold\"))\n"
},
{
"cell_type": "markdown",
@ -74,7 +74,7 @@
}
},
"outputs": [],
"source": "# MANDATORY | difficulty 2 | A Second Visual Drill\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import schoolbook_convolution\nfrom ntt_learning.visuals import plot_wraparound\n\nraw = schoolbook_convolution([2, 5, 0, 1], [1, 0, 3, 2])\nprint(\"raw convolution:\", raw)\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Trace one tail coefficient by eye\"))\n"
"source": "# MANDATORY | difficulty 2 | A Second Visual Drill\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import schoolbook_convolution\nfrom ntt_learning.visuals import plot_wraparound, wraparound_comparison_player\n\nraw = schoolbook_convolution([2, 5, 0, 1], [1, 0, 3, 2])\nprint(\"raw convolution:\", raw)\ndisplay(wraparound_comparison_player(raw, n=4))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Trace one tail coefficient by eye\"))\n"
},
{
"cell_type": "markdown",

View file

@ -44,11 +44,11 @@
"role": "mandatory",
"difficulty": 2,
"kind": "demo",
"title": "See The Product Grid And The Diagonal Sums"
"title": "Play The Diagonal Sweep"
}
},
"outputs": [],
"source": "# MANDATORY | difficulty 2 | See The Product Grid And The Diagonal Sums\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import convolution_contributions, schoolbook_convolution\nfrom ntt_learning.visuals import plot_convolution_grid\n\nleft = [1, 2, 3, 4]\nright = [5, 6, 7, 8]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nfor row in convolution_contributions(left, right):\n print(row)\n\nfig = plot_convolution_grid(left, right, title=\"Schoolbook products for [1,2,3,4] * [5,6,7,8]\")\ndisplay(fig)\n"
"source": "# MANDATORY | difficulty 2 | Play The Diagonal Sweep\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import convolution_contributions, schoolbook_convolution\nfrom ntt_learning.visuals import plot_convolution_grid, schoolbook_diagonal_player\n\nleft = [1, 2, 3, 4]\nright = [5, 6, 7, 8]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nfor row in convolution_contributions(left, right):\n print(row)\n\ndisplay(schoolbook_diagonal_player(left, right))\nfig = plot_convolution_grid(left, right, title=\"Schoolbook products for [1,2,3,4] * [5,6,7,8]\")\ndisplay(fig)\n"
},
{
"cell_type": "markdown",
@ -70,11 +70,11 @@
"role": "mandatory",
"difficulty": 2,
"kind": "demo",
"title": "See Cyclic And Negacyclic Folding Side By Side"
"title": "Play The Wraparound Step By Step"
}
},
"outputs": [],
"source": "# MANDATORY | difficulty 2 | See Cyclic And Negacyclic Folding Side By Side\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution, wraparound_contributions\nfrom ntt_learning.visuals import plot_wraparound\n\nleft = [1, 2, 3, 4]\nright = [5, 6, 7, 8]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nprint(\"negacyclic in x^4 + 1:\", negacyclic_multiply(left, right, n=4))\nprint(\"cyclic folding rows:\")\nfor row in wraparound_contributions(raw, n=4, negacyclic=False):\n print(row)\nprint(\"negacyclic folding rows:\")\nfor row in wraparound_contributions(raw, n=4, negacyclic=True):\n print(row)\n\ndisplay(plot_wraparound(raw, n=4, negacyclic=False, title=\"Cyclic folding into x^4 - 1\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Negacyclic folding into x^4 + 1\"))\n"
"source": "# MANDATORY | difficulty 2 | Play The Wraparound Step By Step\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution, wraparound_contributions\nfrom ntt_learning.visuals import plot_wraparound, wraparound_comparison_player\n\nleft = [1, 2, 3, 4]\nright = [5, 6, 7, 8]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nprint(\"negacyclic in x^4 + 1:\", negacyclic_multiply(left, right, n=4))\nprint(\"cyclic folding rows:\")\nfor row in wraparound_contributions(raw, n=4, negacyclic=False):\n print(row)\nprint(\"negacyclic folding rows:\")\nfor row in wraparound_contributions(raw, n=4, negacyclic=True):\n print(row)\n\ndisplay(wraparound_comparison_player(raw, n=4))\ndisplay(plot_wraparound(raw, n=4, negacyclic=False, title=\"Cyclic folding into x^4 - 1\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Negacyclic folding into x^4 + 1\"))\n"
},
{
"cell_type": "markdown",

View file

@ -48,7 +48,7 @@
}
},
"outputs": [],
"source": "# MANDATORY | difficulty 3 | Compare The Two Fold Rules\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import schoolbook_convolution\nfrom ntt_learning.visuals import plot_wraparound\n\nraw = schoolbook_convolution([1, 2, 3, 4], [5, 6, 7, 8])\nprint(\"raw convolution:\", raw)\ndisplay(plot_wraparound(raw, n=4, negacyclic=False, title=\"Positive wrap into x^4 - 1\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Negative wrap into x^4 + 1\"))\n"
"source": "# MANDATORY | difficulty 3 | Compare The Two Fold Rules\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import schoolbook_convolution\nfrom ntt_learning.visuals import plot_wraparound, wraparound_comparison_player\n\nraw = schoolbook_convolution([1, 2, 3, 4], [5, 6, 7, 8])\nprint(\"raw convolution:\", raw)\ndisplay(wraparound_comparison_player(raw, n=4))\ndisplay(plot_wraparound(raw, n=4, negacyclic=False, title=\"Positive wrap into x^4 - 1\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Negative wrap into x^4 + 1\"))\n"
},
{
"cell_type": "markdown",

View file

@ -62,7 +62,7 @@
}
},
"outputs": [],
"source": "# MANDATORY | difficulty 3 | Interactive Signal Explorer\n\nimport ipywidgets as widgets\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi\n\nmodulus = 17\npsi = find_psi(4, modulus)\n\ndef preview(a0=1, a1=2, a2=3, a3=4):\n signal = [a0, a1, a2, a3]\n spectrum = forward_ntt_psi(signal, modulus, psi)\n print(\"signal:\", signal)\n print(\"spectrum:\", spectrum)\n print(\"inverse:\", inverse_ntt_psi(spectrum, modulus, psi))\n\ndisplay(\n widgets.interact(\n preview,\n a0=(0, 16),\n a1=(0, 16),\n a2=(0, 16),\n a3=(0, 16),\n )\n)\n"
"source": "# MANDATORY | difficulty 3 | Interactive Signal Explorer\n\nimport ipywidgets as widgets\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi\nfrom ntt_learning.visuals import direct_ntt_player\n\nmodulus = 17\npsi = find_psi(4, modulus)\n\ndef preview(a0=1, a1=2, a2=3, a3=4):\n signal = [a0, a1, a2, a3]\n spectrum = forward_ntt_psi(signal, modulus, psi)\n print(\"signal:\", signal)\n print(\"spectrum:\", spectrum)\n print(\"inverse:\", inverse_ntt_psi(spectrum, modulus, psi))\n display(direct_ntt_player(signal, modulus, psi))\n\ndisplay(\n widgets.interact(\n preview,\n a0=(0, 16),\n a1=(0, 16),\n a2=(0, 16),\n a3=(0, 16),\n )\n)\n"
},
{
"cell_type": "markdown",

View file

@ -48,7 +48,7 @@
}
},
"outputs": [],
"source": "# MANDATORY | difficulty 2 | Inspect \u03c9, \u03c8, And The Direct Transform Matrix\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import find_primitive_root, find_psi, ntt_psi_exponent_grid, ntt_psi_matrix\nfrom ntt_learning.visuals import plot_ntt_psi_exponent_heatmap, plot_ntt_psi_matrix_heatmap\n\nmodulus = 17\nn = 4\nomega = find_primitive_root(n, modulus)\npsi = find_psi(n, modulus)\n\nprint(\"omega:\", omega)\nprint(\"psi:\", psi)\nprint(\"exponent grid:\")\nfor row in ntt_psi_exponent_grid(n):\n print(row)\nprint(\"NTT_psi matrix:\")\nfor row in ntt_psi_matrix(n, modulus, psi):\n print(row)\n\ndisplay(plot_ntt_psi_exponent_heatmap(n, title=\"Exponents 2ij + i for n=4\"))\ndisplay(plot_ntt_psi_matrix_heatmap(n, modulus, psi, title=\"Concrete NTT_psi matrix in Z_17\"))\n"
"source": "# MANDATORY | difficulty 2 | Inspect \u03c9, \u03c8, And The Direct Transform Matrix\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import find_primitive_root, find_psi, ntt_psi_exponent_grid, ntt_psi_matrix\nfrom ntt_learning.visuals import direct_ntt_player, plot_ntt_psi_exponent_heatmap, plot_ntt_psi_matrix_heatmap\n\nmodulus = 17\nn = 4\nomega = find_primitive_root(n, modulus)\npsi = find_psi(n, modulus)\n\nprint(\"omega:\", omega)\nprint(\"psi:\", psi)\nprint(\"exponent grid:\")\nfor row in ntt_psi_exponent_grid(n):\n print(row)\nprint(\"NTT_psi matrix:\")\nfor row in ntt_psi_matrix(n, modulus, psi):\n print(row)\n\ndisplay(direct_ntt_player([1, 2, 3, 4], modulus, psi))\ndisplay(plot_ntt_psi_exponent_heatmap(n, title=\"Exponents 2ij + i for n=4\"))\ndisplay(plot_ntt_psi_matrix_heatmap(n, modulus, psi, title=\"Concrete NTT_psi matrix in Z_17\"))\n"
},
{
"cell_type": "markdown",
@ -74,7 +74,7 @@
}
},
"outputs": [],
"source": "# MANDATORY | difficulty 2 | Run A Direct NTT\u03c8 / INTT\u03c8 Round Trip\n\nfrom ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi\n\nsignal = [1, 2, 3, 4]\nmodulus = 17\npsi = find_psi(len(signal), modulus)\nspectrum = forward_ntt_psi(signal, modulus, psi)\n\nprint(\"signal:\", signal)\nprint(\"spectrum:\", spectrum)\nprint(\"inverse recovery:\", inverse_ntt_psi(spectrum, modulus, psi))\n"
"source": "# MANDATORY | difficulty 2 | Run A Direct NTT\u03c8 / INTT\u03c8 Round Trip\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi\nfrom ntt_learning.visuals import direct_ntt_player\n\nsignal = [1, 2, 3, 4]\nmodulus = 17\npsi = find_psi(len(signal), modulus)\nspectrum = forward_ntt_psi(signal, modulus, psi)\n\nprint(\"signal:\", signal)\nprint(\"spectrum:\", spectrum)\nprint(\"inverse recovery:\", inverse_ntt_psi(spectrum, modulus, psi))\ndisplay(direct_ntt_player(signal, modulus, psi))\n"
},
{
"cell_type": "code",

View file

@ -2,6 +2,7 @@
from __future__ import annotations
from html import escape
from typing import Sequence
import ipywidgets as widgets
@ -16,6 +17,7 @@ from .toy_ntt import (
TransformTrace,
base_multiply_pair,
bit_reversed_order,
convolution_contributions,
forward_ntt_psi,
inverse_ntt_psi,
ntt_psi_exponent_grid,
@ -26,6 +28,426 @@ from .toy_ntt import (
wraparound_contributions,
)
SVG_BG = "#f7f3ea"
SVG_PANEL = "#fffdf8"
SVG_INK = "#1f2933"
SVG_SOFT = "#d8e2dc"
SVG_ACCENT = "#ef476f"
SVG_HILITE = "#ffd166"
SVG_GOOD = "#06d6a0"
SVG_WARN = "#f08a5d"
SVG_BLUE = "#118ab2"
def _widget_chrome(title: str, subtitle: str, view: widgets.Widget) -> widgets.Widget:
header = widgets.HTML(
f"""
<div style="
background: linear-gradient(135deg, #f7ede2 0%, #f5cac3 45%, #84a59d 100%);
color: #102a43;
padding: 14px 16px;
border-radius: 14px 14px 0 0;
font-family: 'Avenir Next', 'Trebuchet MS', sans-serif;
box-shadow: inset 0 0 0 1px rgba(16,42,67,0.12);
">
<div style="font-size: 18px; font-weight: 800;">{escape(title)}</div>
<div style="font-size: 12px; margin-top: 4px;">{escape(subtitle)}</div>
</div>
"""
)
box = widgets.VBox(
[header, view],
layout=widgets.Layout(border="1px solid #d9d9d9", overflow="hidden"),
)
return box
def _player_widget(
*,
title: str,
subtitle: str,
frames: Sequence[str],
captions: Sequence[str],
width: str = "100%",
) -> widgets.Widget:
if not frames:
raise ValueError("player widget requires at least one frame")
play = widgets.Play(value=0, min=0, max=len(frames) - 1, step=1, interval=1100, description="Play")
slider = widgets.IntSlider(
value=0,
min=0,
max=len(frames) - 1,
step=1,
description="Step",
continuous_update=False,
layout=widgets.Layout(width="420px"),
)
widgets.jslink((play, "value"), (slider, "value"))
frame_html = widgets.HTML(layout=widgets.Layout(width=width))
caption_html = widgets.HTML(layout=widgets.Layout(width=width))
def render(index: int) -> None:
frame_html.value = frames[index]
caption_html.value = f"""
<div style="
background: #fffdf8;
border-top: 1px solid #e8dcc9;
padding: 12px 16px;
font-family: 'Avenir Next', 'Trebuchet MS', sans-serif;
color: #243b53;
font-size: 13px;
line-height: 1.45;
">
<strong>Frame {index + 1} of {len(frames)}</strong><br>
{escape(captions[index])}
</div>
"""
slider.observe(lambda change: render(change["new"]), names="value")
render(0)
controls = widgets.HBox(
[play, slider],
layout=widgets.Layout(padding="10px 14px", align_items="center"),
)
return _widget_chrome(title, subtitle, widgets.VBox([controls, frame_html, caption_html]))
def _svg_box(x: float, y: float, w: float, h: float, label: str, value: str, *, fill: str, stroke: str, stroke_width: float = 2.0) -> str:
return f"""
<rect x="{x}" y="{y}" width="{w}" height="{h}" rx="10" fill="{fill}" stroke="{stroke}" stroke-width="{stroke_width}"></rect>
<text x="{x + w / 2}" y="{y + 18}" text-anchor="middle" font-size="11" font-family="Menlo, monospace" fill="{SVG_INK}">{escape(label)}</text>
<text x="{x + w / 2}" y="{y + h / 2 + 10}" text-anchor="middle" font-size="22" font-weight="700" font-family="Menlo, monospace" fill="{SVG_INK}">{escape(value)}</text>
"""
def _svg_text(x: float, y: float, text: str, *, size: int = 14, weight: str = "400", fill: str = SVG_INK, anchor: str = "start") -> str:
return f'<text x="{x}" y="{y}" text-anchor="{anchor}" font-size="{size}" font-weight="{weight}" font-family="Avenir Next, Trebuchet MS, sans-serif" fill="{fill}">{escape(text)}</text>'
def _convolution_frame_svg(
left: Sequence[int],
right: Sequence[int],
diagonal_index: int,
*,
title: str,
) -> str:
rows = convolution_contributions(left, right)
current = rows[diagonal_index]
grid = pairwise_product_grid(left, right)
cell = 68
left_margin = 78
top_margin = 78
bottom_y = top_margin + len(left) * cell + 70
width = left_margin + len(right) * cell + 120
height = bottom_y + 110
parts = [
f'<svg viewBox="0 0 {width} {height}" width="100%" style="background:{SVG_BG}; border-radius:0 0 14px 14px;">',
f'<rect x="18" y="18" width="{width - 36}" height="{height - 36}" rx="18" fill="{SVG_PANEL}" stroke="#e8dcc9" stroke-width="2"></rect>',
_svg_text(34, 48, title, size=22, weight="800"),
_svg_text(34, 68, f"Diagonal {diagonal_index}: every highlighted cell lands in output coefficient y{diagonal_index}", size=13, fill="#486581"),
]
for row in range(len(left)):
parts.append(_svg_text(54, top_margin + row * cell + 42, f"a{row}", size=12, anchor="middle"))
for col in range(len(right)):
parts.append(_svg_text(left_margin + col * cell + 28, 64, f"b{col}", size=12, anchor="middle"))
for row, row_values in enumerate(grid):
for col, value in enumerate(row_values):
active = row + col == diagonal_index
done = row + col < diagonal_index
fill = SVG_HILITE if active else "#e6fffa" if done else "#f8f9fa"
stroke = SVG_ACCENT if active else SVG_GOOD if done else "#cbd2d9"
parts.append(
_svg_box(
left_margin + col * cell,
top_margin + row * cell,
56,
56,
f"a{row}·b{col}",
str(value),
fill=fill,
stroke=stroke,
stroke_width=3.2 if active else 1.8,
)
)
for row in rows:
output_index = int(row["output_index"])
fill = SVG_HILITE if output_index == diagonal_index else "#d9f0ff" if output_index < diagonal_index else "#f1f5f9"
stroke = SVG_ACCENT if output_index == diagonal_index else SVG_BLUE if output_index < diagonal_index else "#cbd2d9"
parts.append(
_svg_box(
left_margin + output_index * cell * 0.92,
bottom_y,
60,
56,
f"y{output_index}",
str(row["total"]),
fill=fill,
stroke=stroke,
stroke_width=3.2 if output_index == diagonal_index else 1.8,
)
)
terms = [f"{term['left_value']}×{term['right_value']}={term['product']}" for term in current["terms"]]
equation = " + ".join(terms) if terms else "0"
parts.append(_svg_text(34, height - 34, f"Current diagonal sum: {equation} = {current['total']}", size=15, weight="700"))
parts.append("</svg>")
return "".join(parts)
def schoolbook_diagonal_player(left: Sequence[int], right: Sequence[int]) -> widgets.Widget:
"""Interactive diagonal-by-diagonal walkthrough of schoolbook multiplication."""
rows = convolution_contributions(left, right)
frames = [
_convolution_frame_svg(left, right, diagonal_index=index, title="Schoolbook Multiplication As A Moving Diagonal")
for index in range(len(rows))
]
captions = [
f"Only the highlighted products contribute to y{row['output_index']}. Watch the diagonal sweep across the grid instead of imagining the sum in your head."
for row in rows
]
return _player_widget(
title="Convolution Diagonal Player",
subtitle="Press play. The highlighted diagonal is the coefficient being formed right now.",
frames=frames,
captions=captions,
)
def _wrap_compare_frame_svg(coefficients: Sequence[int], n: int, step: int) -> str:
cyclic_rows = wraparound_contributions(coefficients, n=n, negacyclic=False)
negacyclic_rows = wraparound_contributions(coefficients, n=n, negacyclic=True)
flat = [(index, coefficient) for index, coefficient in enumerate(coefficients)]
current_index, current_value = flat[step]
cell = 58
width = max(980, len(coefficients) * cell + 200)
height = 390
top_y = 78
cyclic_y = 210
neg_y = 300
parts = [
f'<svg viewBox="0 0 {width} {height}" width="100%" style="background:{SVG_BG}; border-radius:0 0 14px 14px;">',
f'<rect x="18" y="18" width="{width - 36}" height="{height - 36}" rx="18" fill="{SVG_PANEL}" stroke="#e8dcc9" stroke-width="2"></rect>',
_svg_text(34, 48, "Wraparound Comparison Player", size=22, weight="800"),
_svg_text(34, 68, f"Current source term: x^{current_index} with coefficient {current_value}", size=13, fill="#486581"),
_svg_text(34, top_y - 14, "Raw convolution tail", size=14, weight="700"),
_svg_text(34, cyclic_y - 14, "Cyclic fold into x^n - 1", size=14, weight="700", fill=SVG_BLUE),
_svg_text(34, neg_y - 14, "Negacyclic fold into x^n + 1", size=14, weight="700", fill=SVG_ACCENT),
]
for index, coefficient in enumerate(coefficients):
fill = SVG_HILITE if index == current_index else "#f1f5f9" if index > current_index else "#d8f3dc"
stroke = SVG_ACCENT if index == current_index else "#cbd2d9" if index > current_index else SVG_GOOD
parts.append(_svg_box(48 + index * cell, top_y, 48, 48, f"x^{index}", str(coefficient), fill=fill, stroke=stroke))
for slot, row in enumerate(cyclic_rows):
parts.append(_svg_box(48 + slot * cell, cyclic_y, 48, 48, f"slot {slot}", str(row["total"]), fill="#e0fbfc", stroke=SVG_BLUE))
for slot, row in enumerate(negacyclic_rows):
parts.append(_svg_box(48 + slot * cell, neg_y, 48, 48, f"slot {slot}", str(row["total"]), fill="#ffe8d6", stroke=SVG_ACCENT))
wraps, slot = divmod(current_index, n)
cyclic_label = f"+ wrap {wraps}"
neg_label = ("-" if wraps % 2 else "+") + f" wrap {wraps}"
source_x = 72 + current_index * cell
target_x = 72 + slot * cell
parts.append(f'<line x1="{source_x}" y1="{top_y + 48}" x2="{target_x}" y2="{cyclic_y}" stroke="{SVG_BLUE}" stroke-width="4" marker-end="url(#arrow-blue)"></line>')
parts.append(f'<line x1="{source_x}" y1="{top_y + 48}" x2="{target_x}" y2="{neg_y}" stroke="{SVG_ACCENT}" stroke-width="4" marker-end="url(#arrow-red)"></line>')
parts.append(f"""
<defs>
<marker id="arrow-blue" markerWidth="10" markerHeight="10" refX="7" refY="3" orient="auto">
<polygon points="0 0, 8 3, 0 6" fill="{SVG_BLUE}"></polygon>
</marker>
<marker id="arrow-red" markerWidth="10" markerHeight="10" refX="7" refY="3" orient="auto">
<polygon points="0 0, 8 3, 0 6" fill="{SVG_ACCENT}"></polygon>
</marker>
</defs>
""")
parts.append(_svg_text((source_x + target_x) / 2, cyclic_y - 18, cyclic_label, size=12, fill=SVG_BLUE, anchor="middle"))
parts.append(_svg_text((source_x + target_x) / 2, neg_y - 18, neg_label, size=12, fill=SVG_ACCENT, anchor="middle"))
parts.append("</svg>")
return "".join(parts)
def wraparound_comparison_player(coefficients: Sequence[int], n: int) -> widgets.Widget:
"""Interactive comparison of cyclic and negacyclic folding, one source term at a time."""
frames = [_wrap_compare_frame_svg(coefficients, n, step) for step in range(len(coefficients))]
captions = []
for index, coefficient in enumerate(coefficients):
wraps, slot = divmod(index, n)
neg_sign = "-" if wraps % 2 else "+"
captions.append(
f"x^{index} with coefficient {coefficient} lands in slot {slot}. Cyclic folding keeps a + sign; negacyclic folding uses {neg_sign} after {wraps} wrap(s)."
)
return _player_widget(
title="Wraparound Step Player",
subtitle="Play one raw coefficient at a time and compare x^n-1 with x^n+1 side by side.",
frames=frames,
captions=captions,
)
def _direct_ntt_frame_svg(values: Sequence[int], modulus: int, psi: int, output_index: int, input_index: int) -> str:
exponents = ntt_psi_exponent_grid(len(values))
matrix = ntt_psi_matrix(len(values), modulus, psi)
cell = 62
width = 1040
height = 470
grid_x = 260
grid_y = 84
boxes_y = 360
contributions = []
for row, value in enumerate(values):
exponent = exponents[output_index][row]
factor = matrix[output_index][row]
product = (value * factor) % modulus
contributions.append((exponent, factor, product))
partial = sum(product for _, _, product in contributions[: input_index + 1]) % modulus
final = sum(product for _, _, product in contributions) % modulus
parts = [
f'<svg viewBox="0 0 {width} {height}" width="100%" style="background:{SVG_BG}; border-radius:0 0 14px 14px;">',
f'<rect x="18" y="18" width="{width - 36}" height="{height - 36}" rx="18" fill="{SVG_PANEL}" stroke="#e8dcc9" stroke-width="2"></rect>',
_svg_text(34, 48, "Direct NTTψ Contribution Player", size=22, weight="800"),
_svg_text(34, 68, f"Building output slot j={output_index}, currently consuming input i={input_index}", size=13, fill="#486581"),
_svg_text(34, 108, "signal", size=14, weight="700"),
_svg_text(grid_x, 68, "matrix cell = psi^(2ij + i)", size=14, weight="700"),
]
for index, value in enumerate(values):
fill = SVG_HILITE if index == input_index else "#d8f3dc" if index < input_index else "#f1f5f9"
stroke = SVG_ACCENT if index == input_index else SVG_GOOD if index < input_index else "#cbd2d9"
parts.append(_svg_box(36, 128 + index * 60, 92, 48, f"a{index}", str(value), fill=fill, stroke=stroke))
for row in range(len(values)):
parts.append(_svg_text(grid_x - 18, grid_y + row * cell + 34, f"i={row}", size=12, anchor="end"))
for col in range(len(values)):
parts.append(_svg_text(grid_x + col * cell + 28, grid_y - 12, f"j={col}", size=12, anchor="middle"))
for col in range(len(values)):
for row in range(len(values)):
active = col == output_index and row == input_index
done = col == output_index and row < input_index
fill = SVG_HILITE if active else "#d8f3dc" if done else "#f1f5f9"
stroke = SVG_ACCENT if active else SVG_GOOD if done else "#cbd2d9"
parts.append(
_svg_box(
grid_x + col * cell,
grid_y + row * cell,
56,
56,
f"{exponents[col][row]}",
str(matrix[col][row]),
fill=fill,
stroke=stroke,
stroke_width=3.0 if active else 1.6,
)
)
for index, (_, factor, product) in enumerate(contributions):
fill = SVG_HILITE if index == input_index else "#d8f3dc" if index < input_index else "#f1f5f9"
stroke = SVG_ACCENT if index == input_index else SVG_GOOD if index < input_index else "#cbd2d9"
parts.append(_svg_box(520 + index * 90, boxes_y, 74, 54, f"a{index}·w", str(product), fill=fill, stroke=stroke))
parts.append(_svg_text(556 + index * 90, boxes_y - 10, f"factor {factor}", size=11, anchor="middle"))
parts.append(_svg_text(34, height - 60, f"Current partial sum for j={output_index}: {partial} mod {modulus}", size=15, weight="700"))
parts.append(_svg_text(34, height - 34, f"Completed output slot y{output_index}: {final} mod {modulus}", size=15, weight="700", fill=SVG_ACCENT))
parts.append("</svg>")
return "".join(parts)
def direct_ntt_player(values: Sequence[int], modulus: int, psi: int) -> widgets.Widget:
"""Interactive walkthrough of the direct NTT_psi matrix multiplication."""
frames = []
captions = []
for output_index in range(len(values)):
for input_index in range(len(values)):
exponent = 2 * input_index * output_index + input_index
factor = pow(psi, exponent, modulus)
product = (values[input_index] * factor) % modulus
frames.append(_direct_ntt_frame_svg(values, modulus, psi, output_index, input_index))
captions.append(
f"Output slot j={output_index}: multiply a{input_index}={values[input_index]} by psi^{exponent}={factor} to contribute {product} mod {modulus}."
)
return _player_widget(
title="Direct NTTψ Player",
subtitle="Press play and watch the transform build one contribution at a time.",
frames=frames,
captions=captions,
)
def _butterfly_story_frame_svg(trace: TransformTrace, stage_index: int, pair_index: int) -> str:
stage = trace.stages[stage_index]
left, right = stage.pairings[pair_index]
zeta = stage.zetas[pair_index]
width = max(820, len(stage.input_values) * 120)
height = 340
input_y = 110
output_y = 240
spacing = 92
start_x = 60
parts = [
f'<svg viewBox="0 0 {width} {height}" width="100%" style="background:{SVG_BG}; border-radius:0 0 14px 14px;">',
f'<rect x="18" y="18" width="{width - 36}" height="{height - 36}" rx="18" fill="{SVG_PANEL}" stroke="#e8dcc9" stroke-width="2"></rect>',
_svg_text(34, 48, f"{trace.algorithm.upper()} Butterfly Player", size=22, weight="800"),
_svg_text(34, 68, f"Stage {stage.stage_index}, active pair ({left}, {right}), zeta={zeta}", size=13, fill="#486581"),
]
for index, value in enumerate(stage.input_values):
active = index in (left, right)
fill = SVG_HILITE if active else "#f1f5f9"
stroke = SVG_ACCENT if active else "#cbd2d9"
parts.append(_svg_box(start_x + index * spacing, input_y, 64, 52, f"in {index}", str(value), fill=fill, stroke=stroke, stroke_width=3.0 if active else 1.8))
for index, value in enumerate(stage.output_values):
active = index in (left, right)
fill = "#d8f3dc" if active else "#f1f5f9"
stroke = SVG_GOOD if active else "#cbd2d9"
parts.append(_svg_box(start_x + index * spacing, output_y, 64, 52, f"out {index}", str(value), fill=fill, stroke=stroke, stroke_width=3.0 if active else 1.8))
for index in range(len(stage.input_values)):
x = start_x + index * spacing + 32
color = SVG_ACCENT if index in (left, right) else "#d9d9d9"
width_line = 4 if index in (left, right) else 1.8
parts.append(f'<line x1="{x}" y1="{input_y + 52}" x2="{x}" y2="{output_y}" stroke="{color}" stroke-width="{width_line}"></line>')
x_left = start_x + left * spacing + 32
x_right = start_x + right * spacing + 32
parts.append(f'<line x1="{x_left}" y1="{input_y + 62}" x2="{x_right}" y2="{input_y + 62}" stroke="{SVG_ACCENT}" stroke-width="4"></line>')
parts.append(_svg_text((x_left + x_right) / 2, 196, f"pair ({left}, {right})", size=12, weight="700", fill=SVG_ACCENT, anchor="middle"))
parts.append(_svg_text((x_left + x_right) / 2, 214, f"inputs -> outputs = ({stage.input_values[left]}, {stage.input_values[right]}) -> ({stage.output_values[left]}, {stage.output_values[right]})", size=11, anchor="middle"))
parts.append(_svg_text((x_left + x_right) / 2, 232, stage.note, size=11, anchor="middle", fill="#486581"))
parts.append("</svg>")
return "".join(parts)
def butterfly_story_player(trace: TransformTrace) -> widgets.Widget:
"""Interactive pair-by-pair walkthrough of a butterfly trace."""
frames = []
captions = []
for stage_index, stage in enumerate(trace.stages):
for pair_index, pair in enumerate(stage.pairings):
left, right = pair
frames.append(_butterfly_story_frame_svg(trace, stage_index, pair_index))
captions.append(
f"Stage {stage.stage_index}: the active pair is {pair}. Watch only these two wires. Everything else is parked until its own butterfly fires."
)
return _player_widget(
title="Butterfly Pair Player",
subtitle="Play pair by pair. This is the local machine the learner needs to internalize.",
frames=frames,
captions=captions,
)
def _value_colors(values: Sequence[int]) -> list[str]:
colors = []

View file

@ -455,12 +455,12 @@ def build_bundle_01() -> None:
"mandatory",
2,
"demo",
"See The Product Grid And The Diagonal Sums",
"Play The Diagonal Sweep",
"""
from IPython.display import display
from ntt_learning.toy_ntt import convolution_contributions, schoolbook_convolution
from ntt_learning.visuals import plot_convolution_grid
from ntt_learning.visuals import plot_convolution_grid, schoolbook_diagonal_player
left = [1, 2, 3, 4]
right = [5, 6, 7, 8]
@ -470,6 +470,7 @@ def build_bundle_01() -> None:
for row in convolution_contributions(left, right):
print(row)
display(schoolbook_diagonal_player(left, right))
fig = plot_convolution_grid(left, right, title="Schoolbook products for [1,2,3,4] * [5,6,7,8]")
display(fig)
""",
@ -492,12 +493,12 @@ def build_bundle_01() -> None:
"mandatory",
2,
"demo",
"See Cyclic And Negacyclic Folding Side By Side",
"Play The Wraparound Step By Step",
"""
from IPython.display import display
from ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution, wraparound_contributions
from ntt_learning.visuals import plot_wraparound
from ntt_learning.visuals import plot_wraparound, wraparound_comparison_player
left = [1, 2, 3, 4]
right = [5, 6, 7, 8]
@ -512,6 +513,7 @@ def build_bundle_01() -> None:
for row in wraparound_contributions(raw, n=4, negacyclic=True):
print(row)
display(wraparound_comparison_player(raw, n=4))
display(plot_wraparound(raw, n=4, negacyclic=False, title="Cyclic folding into x^4 - 1"))
display(plot_wraparound(raw, n=4, negacyclic=True, title="Negacyclic folding into x^4 + 1"))
""",
@ -599,7 +601,7 @@ def build_bundle_01() -> None:
from IPython.display import display
from ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution
from ntt_learning.visuals import plot_convolution_grid, plot_wraparound
from ntt_learning.visuals import plot_convolution_grid, plot_wraparound, schoolbook_diagonal_player, wraparound_comparison_player
left = [3, 0, 2, 1]
right = [1, 4, 0, 2]
@ -607,6 +609,8 @@ def build_bundle_01() -> None:
print("raw convolution:", raw)
print("negacyclic result:", negacyclic_multiply(left, right, n=4))
display(schoolbook_diagonal_player(left, right))
display(wraparound_comparison_player(raw, n=4))
display(plot_convolution_grid(left, right, title="Prediction check grid"))
display(plot_wraparound(raw, n=4, negacyclic=True, title="Prediction check fold"))
""",
@ -636,10 +640,11 @@ def build_bundle_01() -> None:
from IPython.display import display
from ntt_learning.toy_ntt import schoolbook_convolution
from ntt_learning.visuals import plot_wraparound
from ntt_learning.visuals import plot_wraparound, wraparound_comparison_player
raw = schoolbook_convolution([2, 5, 0, 1], [1, 0, 3, 2])
print("raw convolution:", raw)
display(wraparound_comparison_player(raw, n=4))
display(plot_wraparound(raw, n=4, negacyclic=True, title="Trace one tail coefficient by eye"))
""",
),
@ -815,10 +820,11 @@ def build_bundle_01() -> None:
from IPython.display import display
from ntt_learning.toy_ntt import schoolbook_convolution
from ntt_learning.visuals import plot_wraparound
from ntt_learning.visuals import plot_wraparound, wraparound_comparison_player
raw = schoolbook_convolution([1, 2, 3, 4], [5, 6, 7, 8])
print("raw convolution:", raw)
display(wraparound_comparison_player(raw, n=4))
display(plot_wraparound(raw, n=4, negacyclic=False, title="Positive wrap into x^4 - 1"))
display(plot_wraparound(raw, n=4, negacyclic=True, title="Negative wrap into x^4 + 1"))
""",
@ -926,7 +932,7 @@ def build_bundle_02() -> None:
from IPython.display import display
from ntt_learning.toy_ntt import find_primitive_root, find_psi, ntt_psi_exponent_grid, ntt_psi_matrix
from ntt_learning.visuals import plot_ntt_psi_exponent_heatmap, plot_ntt_psi_matrix_heatmap
from ntt_learning.visuals import direct_ntt_player, plot_ntt_psi_exponent_heatmap, plot_ntt_psi_matrix_heatmap
modulus = 17
n = 4
@ -942,6 +948,7 @@ def build_bundle_02() -> None:
for row in ntt_psi_matrix(n, modulus, psi):
print(row)
display(direct_ntt_player([1, 2, 3, 4], modulus, psi))
display(plot_ntt_psi_exponent_heatmap(n, title="Exponents 2ij + i for n=4"))
display(plot_ntt_psi_matrix_heatmap(n, modulus, psi, title="Concrete NTT_psi matrix in Z_17"))
""",
@ -962,7 +969,10 @@ def build_bundle_02() -> None:
"demo",
"Run A Direct NTTψ / INTTψ Round Trip",
"""
from IPython.display import display
from ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi
from ntt_learning.visuals import direct_ntt_player
signal = [1, 2, 3, 4]
modulus = 17
@ -972,6 +982,7 @@ def build_bundle_02() -> None:
print("signal:", signal)
print("spectrum:", spectrum)
print("inverse recovery:", inverse_ntt_psi(spectrum, modulus, psi))
display(direct_ntt_player(signal, modulus, psi))
""",
),
code(
@ -1079,6 +1090,7 @@ def build_bundle_02() -> None:
from IPython.display import display
from ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi
from ntt_learning.visuals import direct_ntt_player
modulus = 17
psi = find_psi(4, modulus)
@ -1089,6 +1101,7 @@ def build_bundle_02() -> None:
print("signal:", signal)
print("spectrum:", spectrum)
print("inverse:", inverse_ntt_psi(spectrum, modulus, psi))
display(direct_ntt_player(signal, modulus, psi))
display(
widgets.interact(
@ -1449,7 +1462,7 @@ def build_bundle_03() -> None:
from IPython.display import display
from ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, forward_ntt_psi
from ntt_learning.visuals import interactive_trace, plot_butterfly_network, plot_trace_overview
from ntt_learning.visuals import butterfly_story_player, interactive_trace, plot_butterfly_network, plot_trace_overview
signal = [1, 2, 3, 4]
modulus = 7681
@ -1459,6 +1472,7 @@ def build_bundle_03() -> None:
print("raw CT output (BO):", trace.raw_output)
print("bit-reversed back to NO:", trace.normal_order_output)
print("direct NTT_psi:", forward_ntt_psi(signal, modulus, psi))
display(butterfly_story_player(trace))
display(plot_trace_overview(trace, title="CT overview for [1,2,3,4]"))
display(plot_butterfly_network(trace, title="Full CT network for [1,2,3,4]"))
display(interactive_trace(trace, title="CT forward trace"))
@ -1505,7 +1519,7 @@ def build_bundle_03() -> None:
from IPython.display import display
from ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, find_psi
from ntt_learning.visuals import interactive_trace, plot_butterfly_network, plot_trace_overview
from ntt_learning.visuals import butterfly_story_player, interactive_trace, plot_butterfly_network, plot_trace_overview
signal = [0, 1, 2, 3, 4, 5, 6, 7]
modulus = 97
@ -1515,6 +1529,7 @@ def build_bundle_03() -> None:
print("psi:", psi)
print("BO output:", trace.raw_output)
print("NO output:", trace.normal_order_output)
display(butterfly_story_player(trace))
display(plot_trace_overview(trace, title="Three CT stages for n=8"))
display(plot_butterfly_network(trace, title="Full CT network for n=8"))
display(interactive_trace(trace, title="n=8 CT stage explorer"))
@ -1604,9 +1619,10 @@ def build_bundle_03() -> None:
from IPython.display import display
from ntt_learning.toy_ntt import fast_ntt_psi_ct_trace
from ntt_learning.visuals import interactive_trace
from ntt_learning.visuals import butterfly_story_player, interactive_trace
trace = fast_ntt_psi_ct_trace([1, 2, 3, 4], 7681, 1925)
display(butterfly_story_player(trace))
display(interactive_trace(trace, title="Check your n=4 prediction"))
""",
),