mirror of
https://github.com/saymrwulf/NTT-learning.git
synced 2026-05-14 20:47:53 +00:00
Replace jittery wraparound animation
This commit is contained in:
parent
5f41121fd5
commit
934586e422
2 changed files with 189 additions and 105 deletions
|
|
@ -114,6 +114,8 @@ def _player_widget(
|
||||||
background: #fffdf8;
|
background: #fffdf8;
|
||||||
border-top: 1px solid #e8dcc9;
|
border-top: 1px solid #e8dcc9;
|
||||||
padding: 12px 16px;
|
padding: 12px 16px;
|
||||||
|
min-height: 62px;
|
||||||
|
box-sizing: border-box;
|
||||||
font-family: 'Avenir Next', 'Trebuchet MS', sans-serif;
|
font-family: 'Avenir Next', 'Trebuchet MS', sans-serif;
|
||||||
color: #243b53;
|
color: #243b53;
|
||||||
font-size: 13px;
|
font-size: 13px;
|
||||||
|
|
@ -153,25 +155,6 @@ def _svg_text(x: float, y: float, text: str, *, size: int = 14, weight: str = "4
|
||||||
return f'<text x="{x}" y="{y}" text-anchor="{anchor}" font-size="{size}" font-weight="{weight}" font-family="Avenir Next, Trebuchet MS, sans-serif" fill="{fill}">{escape(text)}</text>'
|
return f'<text x="{x}" y="{y}" text-anchor="{anchor}" font-size="{size}" font-weight="{weight}" font-family="Avenir Next, Trebuchet MS, sans-serif" fill="{fill}">{escape(text)}</text>'
|
||||||
|
|
||||||
|
|
||||||
def _svg_badge(
|
|
||||||
x: float,
|
|
||||||
y: float,
|
|
||||||
text: str,
|
|
||||||
*,
|
|
||||||
fill: str = "#ffffff",
|
|
||||||
stroke: str = "#cbd2d9",
|
|
||||||
text_fill: str = SVG_INK,
|
|
||||||
) -> str:
|
|
||||||
width = max(66.0, 10.0 + len(text) * 7.2)
|
|
||||||
height = 24.0
|
|
||||||
left = x - width / 2
|
|
||||||
top = y - height + 6
|
|
||||||
return f"""
|
|
||||||
<rect x="{left}" y="{top}" width="{width}" height="{height}" rx="12" fill="{fill}" stroke="{stroke}" stroke-width="1.4"></rect>
|
|
||||||
<text x="{x}" y="{y}" text-anchor="middle" font-size="11" font-weight="700" font-family="Avenir Next, Trebuchet MS, sans-serif" fill="{text_fill}">{escape(text)}</text>
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def _svg_multiline_text(
|
def _svg_multiline_text(
|
||||||
x: float,
|
x: float,
|
||||||
y: float,
|
y: float,
|
||||||
|
|
@ -221,12 +204,6 @@ def _svg_canvas_open(width: int, height: int) -> str:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _wrap_label_geometry(source_x: float, target_x: float, band_top: float, band_bottom: float) -> tuple[float, float]:
|
|
||||||
label_x = (source_x + target_x) / 2
|
|
||||||
label_y = (band_top + band_bottom) / 2
|
|
||||||
return label_x, label_y
|
|
||||||
|
|
||||||
|
|
||||||
def _html_token(label: str, value: str, *, fill: str, border: str, text: str = SVG_INK) -> str:
|
def _html_token(label: str, value: str, *, fill: str, border: str, text: str = SVG_INK) -> str:
|
||||||
return f"""
|
return f"""
|
||||||
<div style="
|
<div style="
|
||||||
|
|
@ -242,6 +219,8 @@ def _html_token(label: str, value: str, *, fill: str, border: str, text: str = S
|
||||||
justify-content: center;
|
justify-content: center;
|
||||||
gap: 3px;
|
gap: 3px;
|
||||||
text-align: center;
|
text-align: center;
|
||||||
|
line-height: 1.12;
|
||||||
|
overflow-wrap: anywhere;
|
||||||
">
|
">
|
||||||
<div style="font-size: 11px; font-weight: 700; letter-spacing: 0.03em; text-transform: uppercase;">{escape(label)}</div>
|
<div style="font-size: 11px; font-weight: 700; letter-spacing: 0.03em; text-transform: uppercase;">{escape(label)}</div>
|
||||||
<div style="font-size: 19px; font-weight: 800; font-family: Menlo, monospace;">{escape(value)}</div>
|
<div style="font-size: 19px; font-weight: 800; font-family: Menlo, monospace;">{escape(value)}</div>
|
||||||
|
|
@ -473,68 +452,187 @@ def schoolbook_diagonal_player(left: Sequence[int], right: Sequence[int]) -> wid
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _wrap_compare_frame_svg(coefficients: Sequence[int], n: int, step: int) -> str:
|
def _wrap_row_html(
|
||||||
|
label: str,
|
||||||
|
cards: Sequence[str],
|
||||||
|
*,
|
||||||
|
background: str,
|
||||||
|
border: str,
|
||||||
|
subtitle: str,
|
||||||
|
) -> str:
|
||||||
|
columns = "150px " + " ".join(["minmax(66px, 1fr)"] * len(cards))
|
||||||
|
return f"""
|
||||||
|
<div style="
|
||||||
|
padding: 14px;
|
||||||
|
border-radius: 18px;
|
||||||
|
background: {background};
|
||||||
|
border: 1px solid {border};
|
||||||
|
">
|
||||||
|
<div style="font-size:16px; font-weight:800; margin-bottom:6px; color:{SVG_INK};">{escape(label)}</div>
|
||||||
|
<div style="font-size:13px; color:#486581; margin-bottom:10px;">{escape(subtitle)}</div>
|
||||||
|
<div style="overflow-x:auto; padding-bottom:4px;">
|
||||||
|
<div style="
|
||||||
|
display:grid;
|
||||||
|
grid-template-columns:{columns};
|
||||||
|
gap:8px;
|
||||||
|
align-items:stretch;
|
||||||
|
min-width:max-content;
|
||||||
|
">
|
||||||
|
<div style="
|
||||||
|
padding:10px 12px;
|
||||||
|
border-radius:14px;
|
||||||
|
background:rgba(255,255,255,0.72);
|
||||||
|
border:1px solid rgba(16,42,67,0.1);
|
||||||
|
color:#334e68;
|
||||||
|
font-size:13px;
|
||||||
|
font-weight:700;
|
||||||
|
line-height:1.35;
|
||||||
|
">
|
||||||
|
{escape(label)}
|
||||||
|
</div>
|
||||||
|
{"".join(cards)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _wrap_compare_frame_html(coefficients: Sequence[int], n: int, step: int) -> str:
|
||||||
cyclic_rows = wraparound_contributions(coefficients, n=n, negacyclic=False)
|
cyclic_rows = wraparound_contributions(coefficients, n=n, negacyclic=False)
|
||||||
negacyclic_rows = wraparound_contributions(coefficients, n=n, negacyclic=True)
|
negacyclic_rows = wraparound_contributions(coefficients, n=n, negacyclic=True)
|
||||||
flat = [(index, coefficient) for index, coefficient in enumerate(coefficients)]
|
flat = [(index, coefficient) for index, coefficient in enumerate(coefficients)]
|
||||||
current_index, current_value = flat[step]
|
current_index, current_value = flat[step]
|
||||||
cell = 58
|
|
||||||
width = max(980, len(coefficients) * cell + 200)
|
|
||||||
height = 450
|
|
||||||
top_y = 118
|
|
||||||
cyclic_y = 252
|
|
||||||
neg_y = 348
|
|
||||||
|
|
||||||
parts = [
|
|
||||||
_svg_canvas_open(width, height),
|
|
||||||
f'<rect x="18" y="18" width="{width - 36}" height="{height - 36}" rx="18" fill="{SVG_PANEL}" stroke="#e8dcc9" stroke-width="2"></rect>',
|
|
||||||
_svg_text(34, 48, "Wraparound Comparison Player", size=22, weight="800"),
|
|
||||||
_svg_text(34, 74, f"Current source term: x^{current_index} with coefficient {current_value}", size=13, fill="#486581"),
|
|
||||||
_svg_text(34, top_y - 18, "Raw convolution tail", size=14, weight="700"),
|
|
||||||
_svg_text(34, cyclic_y - 18, "Cyclic fold into x^n - 1", size=14, weight="700", fill=SVG_BLUE),
|
|
||||||
_svg_text(34, neg_y - 18, "Negacyclic fold into x^n + 1", size=14, weight="700", fill=SVG_ACCENT),
|
|
||||||
]
|
|
||||||
|
|
||||||
for index, coefficient in enumerate(coefficients):
|
|
||||||
fill = SVG_HILITE if index == current_index else "#f1f5f9" if index > current_index else "#d8f3dc"
|
|
||||||
stroke = SVG_ACCENT if index == current_index else "#cbd2d9" if index > current_index else SVG_GOOD
|
|
||||||
parts.append(_svg_box(48 + index * cell, top_y, 48, 48, f"x^{index}", str(coefficient), fill=fill, stroke=stroke))
|
|
||||||
|
|
||||||
for slot, row in enumerate(cyclic_rows):
|
|
||||||
parts.append(_svg_box(48 + slot * cell, cyclic_y, 48, 48, f"slot {slot}", str(row["total"]), fill="#e0fbfc", stroke=SVG_BLUE))
|
|
||||||
for slot, row in enumerate(negacyclic_rows):
|
|
||||||
parts.append(_svg_box(48 + slot * cell, neg_y, 48, 48, f"slot {slot}", str(row["total"]), fill="#ffe8d6", stroke=SVG_ACCENT))
|
|
||||||
|
|
||||||
wraps, slot = divmod(current_index, n)
|
wraps, slot = divmod(current_index, n)
|
||||||
cyclic_label = f"+ wrap {wraps}"
|
cyclic_label = f"+ wrap {wraps}"
|
||||||
neg_label = ("-" if wraps % 2 else "+") + f" wrap {wraps}"
|
neg_label = ("-" if wraps % 2 else "+") + f" wrap {wraps}"
|
||||||
source_x = 72 + current_index * cell
|
|
||||||
target_x = 72 + slot * cell
|
|
||||||
arrow_start_y = top_y + 48
|
|
||||||
cyclic_label_x, cyclic_label_y = _wrap_label_geometry(source_x, target_x, arrow_start_y + 10, cyclic_y - 26)
|
|
||||||
neg_label_x, neg_label_y = _wrap_label_geometry(source_x, target_x, cyclic_y + 56, neg_y - 26)
|
|
||||||
|
|
||||||
parts.append(f'<line x1="{source_x}" y1="{arrow_start_y}" x2="{target_x}" y2="{cyclic_y}" stroke="{SVG_BLUE}" stroke-width="4" marker-end="url(#arrow-blue)"></line>')
|
raw_cards = []
|
||||||
parts.append(f'<line x1="{source_x}" y1="{arrow_start_y}" x2="{target_x}" y2="{neg_y}" stroke="{SVG_ACCENT}" stroke-width="4" marker-end="url(#arrow-red)"></line>')
|
for index, coefficient in enumerate(coefficients):
|
||||||
parts.append(f"""
|
active = index == current_index
|
||||||
<defs>
|
done = index < current_index
|
||||||
<marker id="arrow-blue" markerWidth="10" markerHeight="10" refX="7" refY="3" orient="auto">
|
fill = SVG_HILITE if active else "#d8f3dc" if done else "#f1f5f9"
|
||||||
<polygon points="0 0, 8 3, 0 6" fill="{SVG_BLUE}"></polygon>
|
border = SVG_ACCENT if active else SVG_GOOD if done else "#cbd2d9"
|
||||||
</marker>
|
raw_cards.append(_html_token(f"x^{index}", str(coefficient), fill=fill, border=border))
|
||||||
<marker id="arrow-red" markerWidth="10" markerHeight="10" refX="7" refY="3" orient="auto">
|
|
||||||
<polygon points="0 0, 8 3, 0 6" fill="{SVG_ACCENT}"></polygon>
|
cyclic_cards = []
|
||||||
</marker>
|
for row in cyclic_rows:
|
||||||
</defs>
|
active = row["slot"] == slot
|
||||||
""")
|
fill = "#cfeffd" if active else "#eaf7fc"
|
||||||
parts.append(_svg_badge(cyclic_label_x, cyclic_label_y, cyclic_label, fill="#f4fbff", stroke=SVG_BLUE, text_fill=SVG_BLUE))
|
border = SVG_BLUE if active else "#9dc7d8"
|
||||||
parts.append(_svg_badge(neg_label_x, neg_label_y, neg_label, fill="#fff0ec", stroke=SVG_ACCENT, text_fill=SVG_ACCENT))
|
cyclic_cards.append(_html_token(f"slot {row['slot']}", str(row["total"]), fill=fill, border=border))
|
||||||
parts.append("</svg>")
|
|
||||||
return "".join(parts)
|
negacyclic_cards = []
|
||||||
|
for row in negacyclic_rows:
|
||||||
|
active = row["slot"] == slot
|
||||||
|
fill = "#ffe1d6" if active else "#fff0ea"
|
||||||
|
border = SVG_ACCENT if active else "#e6a98d"
|
||||||
|
negacyclic_cards.append(_html_token(f"slot {row['slot']}", str(row["total"]), fill=fill, border=border))
|
||||||
|
|
||||||
|
flow_cards = f"""
|
||||||
|
<div style="
|
||||||
|
display:grid;
|
||||||
|
grid-template-columns: repeat(auto-fit, minmax(240px, 1fr));
|
||||||
|
gap: 12px;
|
||||||
|
">
|
||||||
|
<div style="
|
||||||
|
padding:14px;
|
||||||
|
border-radius:18px;
|
||||||
|
background:#eef8ff;
|
||||||
|
border:1px solid #b8dbef;
|
||||||
|
">
|
||||||
|
<div style="font-size:16px; font-weight:800; color:{SVG_BLUE}; margin-bottom:8px;">Cyclic move</div>
|
||||||
|
<div style="font-size:14px; color:{SVG_INK}; line-height:1.45;">
|
||||||
|
x^{current_index} -> slot {slot}<br>
|
||||||
|
coefficient {current_value}<br>
|
||||||
|
sign {cyclic_label}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div style="
|
||||||
|
padding:14px;
|
||||||
|
border-radius:18px;
|
||||||
|
background:#fff1eb;
|
||||||
|
border:1px solid #efb7a7;
|
||||||
|
">
|
||||||
|
<div style="font-size:16px; font-weight:800; color:{SVG_ACCENT}; margin-bottom:8px;">Negacyclic move</div>
|
||||||
|
<div style="font-size:14px; color:{SVG_INK}; line-height:1.45;">
|
||||||
|
x^{current_index} -> slot {slot}<br>
|
||||||
|
coefficient {current_value}<br>
|
||||||
|
sign {neg_label}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
"""
|
||||||
|
|
||||||
|
return f"""
|
||||||
|
<div style="
|
||||||
|
width:100%;
|
||||||
|
max-width:100%;
|
||||||
|
box-sizing:border-box;
|
||||||
|
display:grid;
|
||||||
|
gap:14px;
|
||||||
|
font-family:'Avenir Next', 'Trebuchet MS', sans-serif;
|
||||||
|
color:{SVG_INK};
|
||||||
|
">
|
||||||
|
<div style="
|
||||||
|
display:flex;
|
||||||
|
flex-wrap:wrap;
|
||||||
|
gap:10px;
|
||||||
|
align-items:center;
|
||||||
|
">
|
||||||
|
<div style="
|
||||||
|
padding:10px 12px;
|
||||||
|
border-radius:14px;
|
||||||
|
background:#fff3cd;
|
||||||
|
border:2px solid #f2c94c;
|
||||||
|
font-size:15px;
|
||||||
|
font-weight:800;
|
||||||
|
">
|
||||||
|
Active source term: x^{current_index} = {current_value}
|
||||||
|
</div>
|
||||||
|
<div style="
|
||||||
|
padding:10px 12px;
|
||||||
|
border-radius:14px;
|
||||||
|
background:#eef6ff;
|
||||||
|
border:1px solid #bcd4f6;
|
||||||
|
font-size:13px;
|
||||||
|
color:#334e68;
|
||||||
|
">
|
||||||
|
Track where this one term lands in x^n - 1 and x^n + 1.
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{_wrap_row_html(
|
||||||
|
"Raw convolution tail",
|
||||||
|
raw_cards,
|
||||||
|
background=SVG_PANEL,
|
||||||
|
border="#e8dcc9",
|
||||||
|
subtitle="Highlighted card = the exact term moving right now.",
|
||||||
|
)}
|
||||||
|
|
||||||
|
{flow_cards}
|
||||||
|
|
||||||
|
{_wrap_row_html(
|
||||||
|
"Cyclic fold into x^n - 1",
|
||||||
|
cyclic_cards,
|
||||||
|
background="#f5fbff",
|
||||||
|
border="#c7e0ec",
|
||||||
|
subtitle=f"x^{current_index} lands in slot {slot} with sign {cyclic_label}.",
|
||||||
|
)}
|
||||||
|
|
||||||
|
{_wrap_row_html(
|
||||||
|
"Negacyclic fold into x^n + 1",
|
||||||
|
negacyclic_cards,
|
||||||
|
background="#fff6f1",
|
||||||
|
border="#efc7b8",
|
||||||
|
subtitle=f"x^{current_index} lands in slot {slot} with sign {neg_label}.",
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def wraparound_comparison_player(coefficients: Sequence[int], n: int) -> widgets.Widget:
|
def wraparound_comparison_player(coefficients: Sequence[int], n: int) -> widgets.Widget:
|
||||||
"""Interactive comparison of cyclic and negacyclic folding, one source term at a time."""
|
"""Interactive comparison of cyclic and negacyclic folding, one source term at a time."""
|
||||||
frames = [_wrap_compare_frame_svg(coefficients, n, step) for step in range(len(coefficients))]
|
frames = [_wrap_compare_frame_html(coefficients, n, step) for step in range(len(coefficients))]
|
||||||
captions = []
|
captions = []
|
||||||
for index, coefficient in enumerate(coefficients):
|
for index, coefficient in enumerate(coefficients):
|
||||||
wraps, slot = divmod(index, n)
|
wraps, slot = divmod(index, n)
|
||||||
|
|
|
||||||
|
|
@ -14,8 +14,7 @@ os.environ.setdefault("MPLCONFIGDIR", str(MPLCONFIGDIR))
|
||||||
from ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, find_psi
|
from ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, find_psi
|
||||||
from ntt_learning.visuals import (
|
from ntt_learning.visuals import (
|
||||||
_convolution_frame_html,
|
_convolution_frame_html,
|
||||||
_wrap_compare_frame_svg,
|
_wrap_compare_frame_html,
|
||||||
_wrap_label_geometry,
|
|
||||||
butterfly_story_player,
|
butterfly_story_player,
|
||||||
direct_ntt_player,
|
direct_ntt_player,
|
||||||
schoolbook_diagonal_player,
|
schoolbook_diagonal_player,
|
||||||
|
|
@ -70,7 +69,6 @@ class VisualUxTests(unittest.TestCase):
|
||||||
psi = find_psi(order=4, modulus=17)
|
psi = find_psi(order=4, modulus=17)
|
||||||
trace = fast_ntt_psi_ct_trace([1, 2, 3, 4], modulus=17, psi=psi)
|
trace = fast_ntt_psi_ct_trace([1, 2, 3, 4], modulus=17, psi=psi)
|
||||||
players = [
|
players = [
|
||||||
wraparound_comparison_player([3, 0, 2, 1, 5, 4, 6], n=4),
|
|
||||||
direct_ntt_player([1, 2, 3, 4], modulus=17, psi=psi),
|
direct_ntt_player([1, 2, 3, 4], modulus=17, psi=psi),
|
||||||
butterfly_story_player(trace),
|
butterfly_story_player(trace),
|
||||||
]
|
]
|
||||||
|
|
@ -100,36 +98,24 @@ class VisualUxTests(unittest.TestCase):
|
||||||
self.assertNotEqual(before, frame_html.value)
|
self.assertNotEqual(before, frame_html.value)
|
||||||
self.assertIn("Frame 2 of", caption_html.value)
|
self.assertIn("Frame 2 of", caption_html.value)
|
||||||
|
|
||||||
def test_wraparound_callouts_do_not_share_title_band(self) -> None:
|
def test_wraparound_frame_uses_stable_html_bands(self) -> None:
|
||||||
cell = 58
|
frame = _wrap_compare_frame_html([5, 16, 34, 60, 61, 52, 32], n=4, step=6)
|
||||||
top_y = 118
|
|
||||||
cyclic_y = 252
|
|
||||||
neg_y = 348
|
|
||||||
arrow_start_y = top_y + 48
|
|
||||||
current_index = 6
|
|
||||||
slot = current_index % 4
|
|
||||||
source_x = 72 + current_index * cell
|
|
||||||
target_x = 72 + slot * cell
|
|
||||||
|
|
||||||
cyclic_label_x, cyclic_label_y = _wrap_label_geometry(source_x, target_x, arrow_start_y + 10, cyclic_y - 26)
|
self.assertNotIn("<svg", frame)
|
||||||
neg_label_x, neg_label_y = _wrap_label_geometry(source_x, target_x, cyclic_y + 56, neg_y - 26)
|
self.assertIn("Active source term: x^6 = 32", frame)
|
||||||
|
self.assertIn("Cyclic move", frame)
|
||||||
self.assertLess(cyclic_label_y, cyclic_y - 28)
|
self.assertIn("Negacyclic move", frame)
|
||||||
self.assertLess(neg_label_y, neg_y - 28)
|
self.assertIn("Cyclic fold into x^n - 1", frame)
|
||||||
self.assertGreater(cyclic_label_y, top_y + 24)
|
self.assertIn("Negacyclic fold into x^n + 1", frame)
|
||||||
self.assertGreater(neg_label_y, cyclic_y + 24)
|
|
||||||
self.assertNotAlmostEqual(cyclic_label_y, cyclic_y - 18, delta=1)
|
|
||||||
self.assertNotAlmostEqual(neg_label_y, neg_y - 18, delta=1)
|
|
||||||
self.assertGreater(cyclic_label_x, 0)
|
|
||||||
self.assertGreater(neg_label_x, 0)
|
|
||||||
|
|
||||||
def test_wraparound_frame_uses_badge_callouts_for_wrap_labels(self) -> None:
|
|
||||||
frame = _wrap_compare_frame_svg([5, 16, 34, 60, 61, 52, 32], n=4, step=6)
|
|
||||||
|
|
||||||
self.assertIn('fill="#f4fbff"', frame)
|
|
||||||
self.assertIn('fill="#fff0ec"', frame)
|
|
||||||
self.assertIn("+ wrap 1", frame)
|
self.assertIn("+ wrap 1", frame)
|
||||||
self.assertIn("- wrap 1", frame)
|
self.assertIn("- wrap 1", frame)
|
||||||
|
self.assertIn("display:grid", frame)
|
||||||
|
|
||||||
|
def test_player_caption_has_fixed_min_height(self) -> None:
|
||||||
|
player = wraparound_comparison_player([3, 0, 2, 1, 5, 4, 6], n=4)
|
||||||
|
_, _, caption_html = player_parts(player)
|
||||||
|
|
||||||
|
self.assertIn("min-height: 62px", caption_html.value)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue