Source code for qtdraw.widget.mathjax

"""
MathJaxSVG converter.

This module provides mathjax to SVG converter.
"""

import re
import hashlib
from pathlib import Path
import asyncio
import threading
from playwright.async_api import async_playwright
import xml.etree.ElementTree as ET

from qtdraw.core.qtdraw_info import __top_dir__
from qtdraw.widget.color_palette import all_colors

# ===============================
# Global constants.
_MATHJAX_PATH = str(Path(__top_dir__) / "qtdraw" / "mathjax" / "es5" / "tex-svg-full.js")
_HTML_TEMPLATE = """
<!DOCTYPE html>
<html>
<head>
    <meta charset="utf-8">
    <script>
        window.MathJax={{
            tex: {{ inlineMath: [['$','$'],['\\\\(','\\\\)']] }},
            svg: {{ fontCache: 'none' }}
        }};
    </script>
    <style>
        body {{
            margin: 0;
            display: flex;
            justify-content: center;
            align-items: center;
            font-size: 10pt;
        }}
    </style>
</head>
<body>
    <div id="math">{latex}</div>
</body>
</html>
"""


# ===============================
[docs] class MathJaxSVG: _SVG_NS = "http://www.w3.org/2000/svg" # =============================== def __init__(self, cache_dir=None, clear_cache=False): """ MathJax converter. Args: cache_dir (str, optional): cache directory. clear_cache (bool, optional): clear disk cache ? """ self._svg_cache = {} # memory cache. # disk cache. self._cache_dir = cache_dir or (Path.home() / ".qtdraw" / "svg_cache") self._cache_dir.mkdir(parents=True, exist_ok=True) # clear disk cache. if clear_cache: for f in self._cache_dir.glob("*.svg"): try: f.unlink() except: pass ET.register_namespace("", self._SVG_NS) # run event loop in independent thread. self._thread = threading.Thread(target=self._thread_main, daemon=True) self._thread.start() # wait for execution of playwright. self._ready = threading.Event() self._ready.wait() # =============================== event loop in thread. def _thread_main(self): self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) self._loop.create_task(self._async_init()) self._loop.run_forever() # =============================== async def _async_init(self): self._playwright = await async_playwright().start() self._browser = await self._playwright.chromium.launch(headless=True) self._ready.set() # complete execution. # ===============================
[docs] def convert(self, latex, color="black", size=10): """ Convert latex to SVG string. Args: latex (str): LaTeX code w/o $. color (str, optional): color name. size (int, optional): point. Returns: - (str) -- SVG string. - (tuple) -- width and height. """ return asyncio.run_coroutine_threadsafe(self._convert_async(latex, color, size), self._loop).result()
# =============================== implementaion for convert with async for Jupyter. async def _convert_async(self, latex, color, size): if latex in self._svg_cache: # use memory cache. svg_str = self._svg_cache[latex] else: cache_path = self._get_cache_path(latex) # use disk cache. if cache_path.exists(): svg_str = cache_path.read_text() else: # create SVG. page = await self._browser.new_page() html = _HTML_TEMPLATE.format(latex=latex) await page.set_content(html) await page.add_script_tag(path=_MATHJAX_PATH) await asyncio.sleep(0.05) svg_elem = await page.query_selector("mjx-container svg") if not svg_elem: await page.close() raise RuntimeError("Failed to get SVG element.") svg_str = await svg_elem.evaluate("el => el.outerHTML") await page.close() svg_str = self._flatten_svg_string(svg_str) self._svg_cache[latex] = svg_str # get scaled size. x, y, w, h = map(float, self._get_attribute(svg_str, "viewBox").split()) scale = size / 1000.0 wh = int(w * scale + 0.99999), int(h * scale + 0.99999) # set color. svg_str = self._replace_attribute(svg_str, "fill", f"{all_colors[color][0]}") return svg_str, wh # ===============================
[docs] def close(self): # write memory cache to disk cache. for latex, svg_str in self._svg_cache.items(): cache_path = self._get_cache_path(latex) if not cache_path.exists(): cache_path.write_text(svg_str) # close browser and playwright. asyncio.run_coroutine_threadsafe(self._async_close(), self._loop).result() self._loop.call_soon_threadsafe(self._loop.stop)
# =============================== async def _async_close(self): await self._browser.close() await self._playwright.stop() # =============================== def _get_cache_path(self, latex): hash_key = hashlib.sha256(f"{latex}".encode("utf-8")).hexdigest() return self._cache_dir / f"{hash_key}.svg" # =============================== @staticmethod def _get_attribute(svg_str, keyword): match = re.search(rf'{re.escape(keyword)}="([^"]+)"', svg_str) if match: return match.group(1) return None # =============================== @staticmethod def _replace_attribute(svg_str, keyword, value): if re.search(rf'{re.escape(keyword)}="[^"]+"', svg_str): return re.sub(rf'{re.escape(keyword)}="[^"]+"', f'{keyword}="{value}"', svg_str) else: return svg_str # =============================== @staticmethod def _flatten_svg_string(svg): if not svg: return "" root = ET.fromstring(svg) def unwrap_inner_svg(elem): for child in list(elem): if child.tag.endswith("svg"): for grand in list(child): elem.append(grand) elem.remove(child) else: unwrap_inner_svg(child) unwrap_inner_svg(root) # unify all "fill" to currentColor. for elem in root.iter(): if "fill" in elem.attrib and elem.attrib["fill"] != "none": elem.attrib["fill"] = "currentColor" return ET.tostring(root, encoding="unicode")