diff --git a/ntt_learning/visuals.py b/ntt_learning/visuals.py index 76f3edb..e6212e4 100644 --- a/ntt_learning/visuals.py +++ b/ntt_learning/visuals.py @@ -153,6 +153,25 @@ def _svg_text(x: float, y: float, text: str, *, size: int = 14, weight: str = "4 return f'{escape(text)}' +def _svg_badge( + x: float, + y: float, + text: str, + *, + fill: str = "#ffffff", + stroke: str = "#cbd2d9", + text_fill: str = SVG_INK, +) -> str: + width = max(66.0, 10.0 + len(text) * 7.2) + height = 24.0 + left = x - width / 2 + top = y - height + 6 + return f""" + + {escape(text)} + """ + + def _svg_multiline_text( x: float, y: float, @@ -202,6 +221,12 @@ def _svg_canvas_open(width: int, height: int) -> str: ) +def _wrap_label_geometry(source_x: float, target_x: float, band_top: float, band_bottom: float) -> tuple[float, float]: + label_x = (source_x + target_x) / 2 + label_y = (band_top + band_bottom) / 2 + return label_x, label_y + + def _html_token(label: str, value: str, *, fill: str, border: str, text: str = SVG_INK) -> str: return f"""
') - parts.append(f'') + parts.append(f'') + parts.append(f'') parts.append(f""" @@ -498,8 +526,8 @@ def _wrap_compare_frame_svg(coefficients: Sequence[int], n: int, step: int) -> s """) - parts.append(_svg_text((source_x + target_x) / 2, cyclic_y - 18, cyclic_label, size=12, fill=SVG_BLUE, anchor="middle")) - parts.append(_svg_text((source_x + target_x) / 2, neg_y - 18, neg_label, size=12, fill=SVG_ACCENT, anchor="middle")) + parts.append(_svg_badge(cyclic_label_x, cyclic_label_y, cyclic_label, fill="#f4fbff", stroke=SVG_BLUE, text_fill=SVG_BLUE)) + parts.append(_svg_badge(neg_label_x, neg_label_y, neg_label, fill="#fff0ec", stroke=SVG_ACCENT, text_fill=SVG_ACCENT)) parts.append("") return "".join(parts) diff --git a/tests/test_visuals.py b/tests/test_visuals.py index 6a7299f..e693f35 100644 --- a/tests/test_visuals.py +++ b/tests/test_visuals.py @@ -14,6 +14,8 @@ os.environ.setdefault("MPLCONFIGDIR", str(MPLCONFIGDIR)) from ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, find_psi from ntt_learning.visuals import ( _convolution_frame_html, + _wrap_compare_frame_svg, + _wrap_label_geometry, butterfly_story_player, direct_ntt_player, schoolbook_diagonal_player, @@ -98,6 +100,37 @@ class VisualUxTests(unittest.TestCase): self.assertNotEqual(before, frame_html.value) self.assertIn("Frame 2 of", caption_html.value) + def test_wraparound_callouts_do_not_share_title_band(self) -> None: + cell = 58 + top_y = 118 + cyclic_y = 252 + neg_y = 348 + arrow_start_y = top_y + 48 + current_index = 6 + slot = current_index % 4 + source_x = 72 + current_index * cell + target_x = 72 + slot * cell + + cyclic_label_x, cyclic_label_y = _wrap_label_geometry(source_x, target_x, arrow_start_y + 10, cyclic_y - 26) + neg_label_x, neg_label_y = _wrap_label_geometry(source_x, target_x, cyclic_y + 56, neg_y - 26) + + self.assertLess(cyclic_label_y, cyclic_y - 28) + self.assertLess(neg_label_y, neg_y - 28) + self.assertGreater(cyclic_label_y, top_y + 24) + self.assertGreater(neg_label_y, cyclic_y + 24) + self.assertNotAlmostEqual(cyclic_label_y, cyclic_y - 18, delta=1) + self.assertNotAlmostEqual(neg_label_y, neg_y - 18, delta=1) + self.assertGreater(cyclic_label_x, 0) + self.assertGreater(neg_label_x, 0) + + def test_wraparound_frame_uses_badge_callouts_for_wrap_labels(self) -> None: + frame = _wrap_compare_frame_svg([5, 16, 34, 60, 61, 52, 32], n=4, step=6) + + self.assertIn('fill="#f4fbff"', frame) + self.assertIn('fill="#fff0ec"', frame) + self.assertIn("+ wrap 1", frame) + self.assertIn("- wrap 1", frame) + if __name__ == "__main__": unittest.main()