"""
.. _ref_solid185_static_cantilever_example:

HEX8 — cantilever plate static analysis
===========================================

Static analysis of a cantilever plate under a distributed tip load.
This example walks through the full
:meth:`femorph_solver.Model.solve` → :class:`StaticResult` →
:func:`femorph_solver.io.static_result_to_grid` → pyvista rendering loop and
checks static equilibrium via :attr:`StaticResult.reaction`.

Euler–Bernoulli beam theory is included as a back-of-envelope
reference; HEX8 is a first-order hex with full 2 × 2 × 2 Gauss
integration, which exhibits well-known shear locking in thin-bending
problems unless many elements are used through the thickness.  The
difference between the two is a feature of the element, not a bug —
swap in ``HEX20`` (quadratic) to recover EB to a few percent with
the same mesh.
"""

from __future__ import annotations

import numpy as np
import pyvista as pv
from vtkmodules.util.vtkConstants import VTK_HEXAHEDRON

import femorph_solver
from femorph_solver import ELEMENTS

# %%
# Geometry + material
# -------------------
# Steel cantilever, 1 m × 0.1 m × 0.05 m, meshed 40 × 4 × 4 hex
# (640 HEX8 elements, 1 025 nodes).
E = 2.0e11  # Pa
NU = 0.30
RHO = 7850.0
LX, LY, LZ = 1.0, 0.1, 0.05
NX, NY, NZ = 40, 4, 4
F_TIP = -5.0e3  # N (downward)

xs = np.linspace(0.0, LX, NX + 1)
ys = np.linspace(0.0, LY, NY + 1)
zs = np.linspace(0.0, LZ, NZ + 1)
xx, yy, zz = np.meshgrid(xs, ys, zs, indexing="ij")
points = np.stack([xx.ravel(), yy.ravel(), zz.ravel()], axis=1)


def _node_idx(i: int, j: int, k: int) -> int:
    """0-based VTK point index for the structured mesh."""
    return (i * (NY + 1) + j) * (NZ + 1) + k


cells_flat: list[int] = []
for i in range(NX):
    for j in range(NY):
        for k in range(NZ):
            cells_flat.extend(
                [
                    8,
                    _node_idx(i, j, k),
                    _node_idx(i + 1, j, k),
                    _node_idx(i + 1, j + 1, k),
                    _node_idx(i, j + 1, k),
                    _node_idx(i, j, k + 1),
                    _node_idx(i + 1, j, k + 1),
                    _node_idx(i + 1, j + 1, k + 1),
                    _node_idx(i, j + 1, k + 1),
                ]
            )

n_cells = NX * NY * NZ
cell_types = np.full(n_cells, VTK_HEXAHEDRON, dtype=np.uint8)
grid = pv.UnstructuredGrid(np.asarray(cells_flat, dtype=np.int64), cell_types, points)

# %%
# Build the model
# ---------------
m = femorph_solver.Model.from_grid(grid)
m.assign(ELEMENTS.HEX8, material={"EX": E, "PRXY": NU, "DENS": RHO})

node_nums = np.asarray(m.grid.point_data["ansys_node_num"])
pts = np.asarray(m.grid.points)

# Clamp the ``x = 0`` face in all 3 DOFs.
x0_mask = pts[:, 0] < 1e-9
x0_nodes = node_nums[x0_mask].tolist()
m.fix(nodes=x0_nodes, dof="UX")
m.fix(nodes=x0_nodes, dof="UY")
m.fix(nodes=x0_nodes, dof="UZ")

# Distributed downward tip load.
tip_mask = pts[:, 0] > LX - 1e-9
tip_nodes = node_nums[tip_mask].tolist()
fz_each = F_TIP / len(tip_nodes)
for nn in tip_nodes:
    m.apply_force(int(nn), fz=fz_each)

# %%
# Solve + reaction check
# ----------------------
# :meth:`Model.solve` returns a :class:`StaticResult` with
# ``displacement``, ``reaction``, and ``free_mask``.  Reactions are
# nonzero only at constrained DOFs; summing ``FZ`` at the clamp must
# equal ``-F_TIP`` to machine precision for a well-posed static solve.
res = m.solve_static()

dof = m.dof_map()
fz_clamp = 0.0
for nn in x0_nodes:
    rows = np.where((dof[:, 0] == nn) & (dof[:, 1] == 2))[0]
    for r in rows:
        fz_clamp += float(res.reaction[r])
print(f"Σ FZ reaction at clamp = {fz_clamp:.4e} N   (expected {-F_TIP:.4e})")

# %%
# Tip deflection vs Euler–Bernoulli
# ---------------------------------
# :math:`\\delta_\\mathrm{EB} = F L^3 / (3 E I)` with
# :math:`I = b h^3 / 12` is the slender-beam estimate.  With 4 elements
# through the thickness, HEX8's shear locking gives a few percent
# error — swap in ``HEX20`` (see :ref:`ref_solid186_example`) to
# remove it entirely.
I_y = LY * LZ**3 / 12.0
delta_eb = F_TIP * LX**3 / (3.0 * E * I_y)

grid = femorph_solver.io.static_result_to_grid(m, res)
tip_mask = grid.points[:, 0] > LX - 1e-9
w_tip_femorph_solver = grid.point_data["displacement"][tip_mask, 2].min()
print(f"Euler-Bernoulli tip deflection = {delta_eb:.3e} m")
print(f"femorph-solver tip deflection (min UZ)   = {w_tip_femorph_solver:.3e} m")
print(
    f"relative error                  = {abs(w_tip_femorph_solver - delta_eb) / abs(delta_eb):.2%}"
)

# %%
# Render the deformed plate, coloured by displacement magnitude
# -------------------------------------------------------------
warped = grid.warp_by_vector("displacement", factor=20.0)

plotter = pv.Plotter(off_screen=True)
plotter.add_mesh(
    m.grid,
    style="wireframe",
    color="gray",
    opacity=0.35,
    label="undeformed",
)
plotter.add_mesh(
    warped,
    scalars="displacement_magnitude",
    show_edges=True,
    cmap="viridis",
    scalar_bar_args={"title": "|u| [m]"},
    label="deformed ×20",
)
plotter.add_legend()
plotter.add_axes()
plotter.camera_position = [(2.4, -1.6, 1.0), (0.5, 0.05, 0.0), (0.0, 0.0, 1.0)]
plotter.show()
