Source code for qtdraw.util.util_axis

"""
Utility for axis actor and label, unit cell, and view vector.

This module contains utility for axis, etc.
"""

import vtk
import pyvista as pv
import numpy as np
from math import floor, ceil

from qtdraw.core.pyvista_widget_setting import widget_detail as detail
from qtdraw.core.pyvista_widget_setting import CHOP, DIGIT
from qtdraw.widget.color_palette import all_colors
from qtdraw.util.util import text_to_list


# ==================================================
def _create_label_axes_actor(A, label, size, bold, italic, color, scale):
    """
    Create label only orientation axes actor.

    Args:
        A (numpy.ndarray): (a1, a2, a3) unit vectors, 4x4 [float].
        label (str): axes labels.
        size (int): font size.
        bold (bool): bold face ?
        italic (bool): italic ?
        color (list): axes colors, [[float]], RGB in unit of [0,1].
        scale (float): zoom factor.

    Returns:
        - (pyvista.AxesActor) -- axes actor.
    """
    label = text_to_list(label.replace(" ", ""))

    # font size is zoomed manually.
    size = int(size * scale)

    # set transform for non-orthogonal axes.
    transform = vtk.vtkTransform()
    transform.SetMatrix(A.ravel().tolist())

    # create axes actor with label only.
    lbl = pv.AxesActor()
    lbl.SetShaftTypeToCylinder()
    lbl.SetTipTypeToCone()
    lbl.SetCylinderRadius(0.0)
    lbl.SetConeRadius(0.0)
    lbl.SetUserTransform(transform)

    # set font properties.
    lbl.SetXAxisLabelText(label[0])
    lbl.SetYAxisLabelText(label[1])
    lbl.SetZAxisLabelText(label[2])

    x_p = lbl.GetXAxisCaptionActor2D()
    y_p = lbl.GetYAxisCaptionActor2D()
    z_p = lbl.GetZAxisCaptionActor2D()
    for i, c in zip([x_p, y_p, z_p], color):
        i.GetPositionCoordinate().SetCoordinateSystemToWorld()
        i.GetTextActor().SetTextScaleModeToViewport()
        if bold:
            i.GetCaptionTextProperty().BoldOn()
        else:
            i.GetCaptionTextProperty().BoldOff()
        if italic:
            i.GetCaptionTextProperty().ItalicOn()
        else:
            i.GetCaptionTextProperty().ItalicOff()
        i.GetCaptionTextProperty().SetFontSize(size)
        i.GetCaptionTextProperty().SetColor(*c)

    return lbl


# ==================================================
def _create_axes_actor(
    A,
    label,
    label_size,
    label_bold,
    label_italic,
    label_color,
    scale,
    shaft_color,
    sphere_color,
    shaft_radius,
    tip_radius,
    tip_length,
    tip_resolution,
    sphere_radius,
    theta_phi_resolution,
):
    """
    Create custom axes actor.

    Args:
        A (numpy.ndarray): (a1, a2, a3) unit vectors, 4x4 [float].
        label (str): axes labels.
        label_size (int): font size.
        label_bold (bool): bold face ?
        label_italic (bool): italic ?
        label_color (list): axes label color names, [str].
        scale (float): zoom factor.
        shaft_color (list): axes color names, [str].
        sphere_color (str): center color name.
        shaft_radius (float): axes cylinder radius.
        tip_radius (float): axes tip radius.
        tip_length (float): axes tip length.
        tip_resolution (int): axes tip resolution.
        sphere_radius (float): axes sphere radius.
        theta_phi_resolution (list): axes sphere theta, phi resolution, [int].

    Returns:
        - (vtk.vtkPropAssembly) -- custom axes actor.
    """
    # convert from color name to RGB float.
    shaft_color = [(np.array(all_colors[c][1]) / 255) for c in shaft_color]
    sphere_color = np.array(all_colors[sphere_color][1]) / 255
    # convert from color name to RGB.
    label_color = [all_colors[c][1] for c in label_color]

    # create axes.
    assembly = vtk.vtkPropAssembly()
    for d, c in zip(A[0:3, 0:3].T, shaft_color):
        # axes arrows.
        g = pv.Arrow(
            direction=d,
            shaft_radius=shaft_radius,
            tip_radius=tip_radius,
            tip_length=tip_length,
            tip_resolution=tip_resolution,
        )
        actor = pv.Actor(mapper=pv.DataSetMapper(g))
        actor.GetProperty().SetColor(c)
        assembly.AddPart(actor)

        # dummy axes to keep rotation center as origin.
        g = pv.Sphere(radius=0.0, center=-np.array(d))
        actor = pv.Actor(mapper=pv.DataSetMapper(g))
        assembly.AddPart(actor)

    # center sphere (theta, phi are used differently).
    phi, theta = theta_phi_resolution
    g0 = pv.Sphere(radius=sphere_radius, theta_resolution=theta, phi_resolution=phi)
    actor = pv.Actor(mapper=pv.DataSetMapper(g0))
    actor.GetProperty().SetColor(sphere_color)
    assembly.AddPart(actor)

    # add axes label.
    if label is not None:
        lbl = _create_label_axes_actor(
            A,
            label=label,
            size=label_size,
            bold=label_bold,
            italic=label_italic,
            color=label_color,
            scale=scale,
        )
        assembly.AddPart(lbl)

    return assembly


# ==================================================
def _create_axes_actor_full(
    pv_widget,
    A,
    label_color,
    shaft_color,
    shaft_radius,
    shaft_resolution,
    tip_radius,
    tip_length,
    tip_resolution,
):
    """
    Create custom axes actor (crossed axes).

    Args:
        pv_widget (PyVistaWidget): pyvista widget.
        A (numpy.ndarray): (a1, a2, a3) unit vectors, 4x4 [float].
        label_color (list): axes label color names, [str].
        shaft_color (list): axes color names, [str].
        shaft_radius (float): axes cylinder radius.
        shaft_resolution (int): axes shaft resolution.
        tip_radius (float): axes tip radius.
        tip_length (float): axes tip length.
        tip_resolution (int): axes tip resolution.
    """
    shaft_color = [(np.array(all_colors[c][1]) / 255) for c in shaft_color]
    label_color = [all_colors[c][1] for c in label_color]

    length = 2.5
    offset = -0.5 * length
    width = shaft_radius * 2.5

    for no, (d, c) in enumerate(zip(A[0:3, 0:3].T, shaft_color)):
        d = np.asarray(d) / np.linalg.norm(d)

        obj = pv.Arrow(
            start=offset * d,
            direction=d,
            scale=length,
            shaft_radius=shaft_radius * width / length,
            tip_radius=1.2 * tip_radius * width / length,
            tip_length=tip_length * 0.15,
            shaft_resolution=shaft_resolution,
            tip_resolution=tip_resolution,
        )

        pv_widget.add_mesh(mesh=obj, smooth_shading=True, color=c, name=f"axes_arrow_{no}")

    return None


# ==================================================
[docs] def create_axes_widget( pv_widget, A, label="[x,y,z]", label_size=28, label_bold=True, label_italic=False, label_color=["black", "black", "black"], viewport=True, full=False, ): """ Create axes widget. Args: pv_widget (PyVistaWidget): pyvista widget. A (numpy.ndarray): (a1, a2, a3) unit vectors, 4x4 [float]. label (str, optional): axes labels. label_size (int, optional): font size. label_bold (bool, optional): bold face ? label_italic (bool, optional): italic ? label_color (list, optional): axes label color names, [str]. viewport (bool, optional): set viewport ? full (bool, optional): full axes ? Note: - if label is None, no label is used. """ scale = detail["scale"] pickable = detail["pickable"] shaft_color = detail["shaft_color"] sphere_color = detail["sphere_color"] shaft_radius = detail["shaft_radius"] shaft_resolution = detail["shaft_resolution"] tip_radius = detail["tip_radius"] tip_length = detail["tip_length"] tip_resolution = detail["tip_resolution"] sphere_radius = detail["sphere_radius"] theta_phi_resolution = detail["theta_phi_resolution"] if viewport: viewport = detail["viewport"] else: viewport = [0.0, 0.0, 1.0, 1.0] pickable = False # create axes actor. if full: _create_axes_actor_full( pv_widget, A, label_color=label_color, shaft_color=shaft_color, shaft_radius=shaft_radius, shaft_resolution=shaft_resolution, tip_radius=tip_radius, tip_length=tip_length, tip_resolution=tip_resolution, ) else: marker = _create_axes_actor( A, label=label, label_size=label_size, label_bold=label_bold, label_italic=label_italic, label_color=label_color, scale=scale, shaft_color=shaft_color, sphere_color=sphere_color, shaft_radius=shaft_radius, tip_radius=tip_radius, tip_length=tip_length, tip_resolution=tip_resolution, sphere_radius=sphere_radius, theta_phi_resolution=theta_phi_resolution, ) # create axes widget. pv_widget.add_orientation_widget(marker, interactive=pickable, viewport=viewport) pv_widget.renderer.axes_widget.SetZoom(scale)
# ==================================================
[docs] def create_unit_cell(A, origin, lower=None, dimensions=None): """ Create unit cell mesh. Args: A (numpy.ndarray): (a1, a2, a3) unit vectors, 4x4 [float]. origin (list or numpy.ndarray): origin, [float]. lower (list, optional): lower bound indices, [int]. dimensions (list, optional): repeat times, [int]. Returns: - (pyvista.PolyData) -- unit cel mesh. """ if lower is None: lower = [0, 0, 0] if dimensions is None: dimensions = [1, 1, 1] # signle box. pts = np.array( [ [0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0], [0, 0, 1], [1, 0, 1], [1, 1, 1], [0, 1, 1], ], dtype=np.float64, ) shift = np.array([origin] * len(pts), dtype=np.float64) pts = pts + shift lines = [5, 0, 1, 2, 3, 0, 5, 4, 5, 6, 7, 4, 2, 0, 4, 2, 1, 5, 2, 2, 6, 2, 3, 7] box = pv.PolyData(pts, lines=lines) box.transform(A, inplace=True) # repeated boxes. m = pv.ImageData(dimensions=dimensions, origin=lower).cast_to_unstructured_grid() m.transform(A, inplace=True) p = m.glyph(geom=box, factor=1.0, scale=False, orient=False) return p
# ==================================================
[docs] def create_cell_grid(ilower, dims): """ Create grid point. Parameters: Args: ilower (list): start cell. dims (list): range in each dim. Returns: - (list) -- grid point, [str]. """ # range. x = np.arange(ilower[0], ilower[0] + dims[0]) y = np.arange(ilower[1], ilower[1] + dims[1]) z = np.arange(ilower[2], ilower[2] + dims[2]) # mesh grid. X, Y, Z = np.meshgrid(x, y, z, indexing="ij") # transform grid to point. grid = np.vstack([X.ravel(), Y.ravel(), Z.ravel()]).T grid = [str(i).replace(" ", "") for i in grid.tolist()] return grid
# ==================================================
[docs] def get_lattice_vector(crystal, cell): """ Get lattice vector. Args: crystal (str): crystal. cell (dict): cell. Returns: - dict: cell. - list: A. """ a = float(cell["a"]) b = float(cell["b"]) c = float(cell["c"]) alpha = float(cell["alpha"]) beta = float(cell["beta"]) gamma = float(cell["gamma"]) if crystal == "monoclinic": alpha = 90.0 gamma = 90.0 elif crystal == "orthorhombic": alpha = 90.0 beta = 90.0 gamma = 90.0 elif crystal in ["trigonal", "hexagonal"]: alpha = 90.0 beta = 90.0 gamma = 120.0 b = a elif crystal == "tetragonal": alpha = 90.0 beta = 90.0 gamma = 90.0 b = a elif crystal == "cubic": alpha = 90.0 beta = 90.0 gamma = 90.0 b = a c = a ca = np.cos(alpha * np.pi / 180) cb = np.cos(beta * np.pi / 180) cc = np.cos(gamma * np.pi / 180) sc = np.sin(gamma * np.pi / 180) s = 1.0 - ca * ca - cb * cb - cc * cc + 2.0 * ca * cb * cc s = max(CHOP, np.sqrt(s)) a1 = np.array([a, 0, 0]).round(DIGIT).tolist() a2 = np.array([b * cc, b * sc, 0]).round(DIGIT).tolist() a3 = np.array([c * cb, c * (ca - cb * cc) / sc, c * s / sc]).round(DIGIT).tolist() A = np.eye(4) A[0:3, 0] = a1 A[0:3, 1] = a2 A[0:3, 2] = a3 A = A.round(DIGIT).tolist() cell = {"a": a, "b": b, "c": c, "alpha": alpha, "beta": beta, "gamma": gamma} return cell, A
# ==================================================
[docs] def get_repeat_range(lower, upper): """ Get repeart range. Args: lower (list): upper. upper (list): lower. Returns: - list: lower cell. - list: size of repeat. """ i1 = [floor(lower[0]), floor(lower[1]), floor(lower[2])] i2 = [ceil(upper[0] + CHOP), ceil(upper[1] + CHOP), ceil(upper[2] + CHOP)] dims = [i2[0] - i1[0], i2[1] - i1[1], i2[2] - i1[2]] for i in range(3): dims[i] = max(1, dims[i]) return i1, dims
# ==================================================
[docs] def get_outside_box(point, lower, upper): """ Get indices outside range. Args: point (numpy.ndarray): a set of points. lower (list): lower bound. upper (list): upper bound. Returns: - (numpy.ndarray) -- list of indices. """ xmin, ymin, zmin = lower xmax, ymax, zmax = upper in_x = (point[:, 0] >= xmin) & (point[:, 0] <= xmax) in_y = (point[:, 1] >= ymin) & (point[:, 1] <= ymax) in_z = (point[:, 2] >= zmin) & (point[:, 2] <= zmax) in_box = in_x & in_y & in_z outside = np.where(~in_box)[0] return outside
# ==================================================
[docs] def get_hkl_from_camera(camera, A): """ Get index from camera. Args: camera (Camera): camera. A (ndarray): A = [a1, a2, a3]. Returns: - (list) -- index (cannot determine hkl, return [0,0,0]). """ rounding_tol = 0.1 angle_tol = 1.0 A = np.array(A[0:3, 0:3]) vec = np.array(camera.position) - np.array(camera.focal_point) vec_norm = np.linalg.norm(vec) if vec_norm < 1e-10: return [0, 0, 0] unit_vec = vec / vec_norm try: hkl_raw = np.linalg.solve(A, vec) except np.linalg.LinAlgError: return [0, 0, 0] max_abs = np.abs(hkl_raw).max() if max_abs < 1e-10: return [0, 0, 0] scaled_base = hkl_raw / max_abs best_candidate = [0, 0, 0] min_err = float("inf") for s in range(1, 10): target = scaled_base * s candidate = np.round(target) if np.all(candidate == 0): continue if np.any(np.abs(candidate) > 9): continue candidate_vec = A @ candidate c_norm = np.linalg.norm(candidate_vec) if c_norm < 1e-10: continue unit_cand = candidate_vec / c_norm cos_theta = np.clip(np.dot(unit_vec, unit_cand), -1.0, 1.0) angle_err = np.degrees(np.arccos(cos_theta)) if angle_err > angle_tol: continue diff = np.abs(target - candidate).mean() if diff < rounding_tol: if angle_err < min_err: min_err = angle_err best_candidate = candidate.astype(int).tolist() return best_candidate
# ==================================================
[docs] def get_camera_params(hkl, A, camera=None, bounds=None): """ Get camera parameters. Args: hkl (list): index. A (ndarray): A = [a1, a2, a3]. camera (Camera, optional): current camera. bounds (ndarray, optional): render bounds. Returns: - (ndarray) -- position. - (ndarray) -- focal point. - (ndarray) -- view up. """ A = np.array(A[0:3, 0:3]) n = np.array(hkl) view = n[0] * A[:, 0] + n[1] * A[:, 1] + n[2] * A[:, 2] norm = np.linalg.norm(view) view = view / norm if np.allclose(view, [0, 1, 0]) or np.allclose(view, [0, 0, -1]): viewup = np.array([1, 0, 0]) elif np.allclose(view, [0, 0, 1]) or np.allclose(view, [-1, 0, 0]): viewup = np.array([0, 1, 0]) elif np.allclose(view, [1, 0, 0]) or np.allclose(view, [0, -1, 0]): viewup = np.array([0, 0, 1]) else: vz = np.sqrt(1.0 - view[2] * view[2]) vx = -view[2] * view[0] / vz vy = -view[2] * view[1] / vz viewup = np.array([vx, vy, vz], dtype=np.float64) if bounds is None: focal = np.sum(A, axis=1) / 2.0 else: focal = np.array([(bounds[0] + bounds[1]) / 2, (bounds[2] + bounds[3]) / 2, (bounds[4] + bounds[5]) / 2]) if camera is not None: distance = np.linalg.norm(np.array(camera.position) - np.array(camera.focal_point)) new_position = focal + (view * distance) else: new_position = view return new_position, focal, viewup