"""
Plotting mode shapes
====================

A recipe for rendering the lowest few modes of a structure as a grid
of sub-plots.  The key ingredients are
:func:`femorph_solver.io.modal_result_to_grid` (which scatters every
mode onto the mesh as a pyvista point-data array) and
:meth:`pyvista.DataSet.warp_by_vector` (which deforms the mesh by the
mode shape scaled to a fixed fraction of the bounding box).
"""

from __future__ import annotations

import numpy as np
import pyvista as pv

import femorph_solver
from femorph_solver import ELEMENTS

# %%
# Build a quick cantilever plate + solve
# --------------------------------------
LX, LY, LZ = 1.0, 1.0, 0.01
grid = pv.StructuredGrid(
    *np.meshgrid(
        np.linspace(0.0, LX, 21),
        np.linspace(0.0, LY, 21),
        np.linspace(0.0, LZ, 3),
        indexing="ij",
    )
).cast_to_unstructured_grid()

m = femorph_solver.Model.from_grid(grid)
m.assign(ELEMENTS.HEX8, material={"EX": 2.0e11, "PRXY": 0.30, "DENS": 7850.0})

pts = np.asarray(grid.points)
node_nums = np.asarray(grid.point_data["ansys_node_num"])
m.fix(nodes=node_nums[pts[:, 0] < 1e-9].tolist(), dof="ALL")

res = m.solve_modal(n_modes=6)

# %%
# Attach every mode to the grid
# -----------------------------
grid_plot = femorph_solver.io.modal_result_to_grid(m, res)

# %%
# Render modes 1..6 in a 2 × 3 viewport grid
# ------------------------------------------
plotter = pv.Plotter(shape=(2, 3), off_screen=True, window_size=(1200, 600))
for idx in range(6):
    row, col = divmod(idx, 3)
    plotter.subplot(row, col)
    phi_k = grid_plot.point_data[f"mode_{idx + 1}_disp"]
    factor = 0.1 / (np.max(np.abs(phi_k)) + 1e-30)
    plotter.add_mesh(grid_plot, style="wireframe", color="gray", opacity=0.3)
    plotter.add_mesh(
        grid_plot.warp_by_vector(f"mode_{idx + 1}_disp", factor=factor),
        scalars=f"mode_{idx + 1}_magnitude",
        cmap="viridis",
        show_scalar_bar=False,
    )
    plotter.add_text(
        f"mode {idx + 1}: {res.frequency[idx]:.1f} Hz",
        position="upper_edge",
        font_size=10,
    )
plotter.show()
