"""
.. _ref_solid185_strain_example:

HEX8 — elastic-strain post-processing
=========================================

Solve a HEX8 flat plate under uniaxial tension and recover the full
6-component elastic-strain tensor on the mesh with
:meth:`femorph_solver.result.StaticResult.elastic_strain` — element-nodal
strain averaged onto every grid point.

``result.elastic_strain(model=m)`` returns the nodal-averaged Voigt
strain ``(n_points, 6)`` (the canonical post-processing output);
:meth:`~femorph_solver.result.StaticResult.elastic_strain_per_element`
returns the per-element dict ``{elem_num: (n_nodes_in_elem, 6)}`` for
verification workflows that need raw element-by-element values.  Strain
is computed at each element's own nodes as
:math:`\\varepsilon(\\xi_\\text{node}) = B(\\xi_\\text{node})\\cdot u_e`
— no RST round-trip, no disk write.
"""

from __future__ import annotations

import numpy as np
import pyvista as pv
from vtkmodules.util.vtkConstants import VTK_HEXAHEDRON

import femorph_solver
from femorph_solver import ELEMENTS

# %%
# Problem setup
# -------------
# A 1 m × 0.4 m × 0.05 m steel plate meshed as a 20 × 8 × 1 HEX8
# brick (160 elements). The ``x = 0`` face is held in ``UX`` (symmetry),
# a single pin at the origin kills the ``UY`` / ``UZ`` rigid-body modes,
# and the ``x = LX`` face is pulled by a total force ``F`` split over
# its corner nodes.
E = 2.1e11  # Pa
NU = 0.30
RHO = 7850.0
LX, LY, LZ = 1.0, 0.4, 0.05
NX, NY, NZ = 20, 8, 1
F_TOTAL = 1.0e5  # N

xs = np.linspace(0.0, LX, NX + 1)
ys = np.linspace(0.0, LY, NY + 1)
zs = np.linspace(0.0, LZ, NZ + 1)
xx, yy, zz = np.meshgrid(xs, ys, zs, indexing="ij")
points = np.stack([xx.ravel(), yy.ravel(), zz.ravel()], axis=1)


# Hex connectivity in VTK_HEXAHEDRON order (0-based VTK indices).
def _node_idx(i: int, j: int, k: int) -> int:
    return (i * (NY + 1) + j) * (NZ + 1) + k


cells_flat: list[int] = []
for i in range(NX):
    for j in range(NY):
        for k in range(NZ):
            cells_flat.extend(
                [
                    8,
                    _node_idx(i, j, k),
                    _node_idx(i + 1, j, k),
                    _node_idx(i + 1, j + 1, k),
                    _node_idx(i, j + 1, k),
                    _node_idx(i, j, k + 1),
                    _node_idx(i + 1, j, k + 1),
                    _node_idx(i + 1, j + 1, k + 1),
                    _node_idx(i, j + 1, k + 1),
                ]
            )

n_cells = NX * NY * NZ
cell_types = np.full(n_cells, VTK_HEXAHEDRON, dtype=np.uint8)
grid = pv.UnstructuredGrid(np.asarray(cells_flat, dtype=np.int64), cell_types, points)

# %%
# Build the femorph-solver model
# ------------------------------
m = femorph_solver.Model.from_grid(grid)
m.assign(ELEMENTS.HEX8, material={"EX": E, "PRXY": NU, "DENS": RHO})

node_nums = np.asarray(m.grid.point_data["ansys_node_num"])
pts = np.asarray(m.grid.points)

# Symmetry BC: x=0 face clamped in UX; single pin at the origin in UY/UZ.
x0_nodes = node_nums[pts[:, 0] < 1e-9].tolist()
m.fix(nodes=x0_nodes, dof="UX")
origin_nodes = node_nums[(pts[:, 0] < 1e-9) & (pts[:, 1] < 1e-9) & (pts[:, 2] < 1e-9)].tolist()
m.fix(nodes=origin_nodes, dof="UY")
m.fix(nodes=origin_nodes, dof="UZ")

# Traction on x=LX face: split F_TOTAL over its nodes.
x_end_nodes = node_nums[pts[:, 0] > LX - 1e-9].tolist()
fx_each = F_TOTAL / len(x_end_nodes)
for nn in x_end_nodes:
    m.apply_force(int(nn), fx=fx_each)

# %%
# Static solve
# ------------
res = m.solve_static()

# %%
# Recover elastic strain
# ----------------------
# Default call returns nodal-averaged strain of shape ``(n_points, 6)``:
# columns are ``[εxx, εyy, εzz, γxy, γyz, γxz]`` with *engineering*
# shears (canonical Voigt strain-recovery output).
eps = res.elastic_strain(model=m)
print(f"eps shape: {eps.shape}")

# Analytical: uniform σxx = F_TOTAL / (LY · LZ), εxx = σ / E,
# εyy = εzz = -ν · εxx.
sigma_xx = F_TOTAL / (LY * LZ)
eps_xx_expected = sigma_xx / E
eps_yy_expected = -NU * eps_xx_expected
print(f"εxx expected = {eps_xx_expected:.3e}")
print(f"εxx recovered (mean over nodes) = {eps[:, 0].mean():.3e}")
print(f"εyy recovered (mean)            = {eps[:, 1].mean():.3e}")
print(f"εyy analytical                  = {eps_yy_expected:.3e}")

# %%
# Per-element arrays keyed by element number — useful when you want to
# see jumps at element boundaries or compute element-wise strain norms.
per_elem = res.elastic_strain_per_element(model=m)
first_elem = next(iter(per_elem))
print(
    f"per-element dict has {len(per_elem)} elements; "
    f"first key = {first_elem}, "
    f"strain block shape = {per_elem[first_elem].shape}"
)

# %%
# Visualise εxx on the deformed mesh
# ----------------------------------
# :func:`femorph_solver.io.static_result_to_grid` scatters the per-node
# displacement onto ``(n_points, 3)`` UX/UY/UZ point data in one call —
# no hand-rolled dof-map loop required.  ``elastic_strain`` is already
# indexed in grid-point order so εxx drops straight onto the grid.
grid = femorph_solver.io.static_result_to_grid(m, res)
grid.point_data["eps_xx"] = eps[:, 0]

warped = grid.warp_by_vector("displacement", factor=200.0)

plotter = pv.Plotter(off_screen=True)
plotter.add_mesh(
    grid,
    style="wireframe",
    color="gray",
    line_width=1,
    opacity=0.4,
    label="undeformed",
)
plotter.add_mesh(
    warped,
    scalars="eps_xx",
    show_edges=True,
    cmap="viridis",
    scalar_bar_args={"title": "εxx"},
    label="εxx (deformed ×200)",
)
plotter.add_legend()
plotter.add_axes()
plotter.view_xy()
plotter.show()
