"""
.. _ref_solid186_example:

HEX20 — uniaxial tension on a 20-node hex
============================================

Single HEX20 brick (a unit cube, 20 nodes including the mid-edge
nodes of the serendipity family) loaded in uniaxial tension.  The
consistent-load vector for a serendipity 8-node face has corner weight
:math:`-1/12` and mid-edge weight :math:`+1/3`; applying these produces
a uniform σxx field, so ``εxx = σ/E`` and ``εyy = εzz = −ν·εxx``
exactly at every node.
"""

from __future__ import annotations

import numpy as np
import pyvista as pv
from vtkmodules.util.vtkConstants import VTK_QUADRATIC_HEXAHEDRON

import femorph_solver
from femorph_solver import ELEMENTS

# %%
# Reference 20-node unit cube
# ---------------------------
# Corners 1-8 in VTK_QUADRATIC_HEXAHEDRON order, then mid-edge nodes 9-20 on the
# bottom, top, and vertical edges.
E = 2.1e11  # Pa
NU = 0.30
F_TOTAL = 4.2e4  # N

coords = np.array(
    [
        [0.0, 0.0, 0.0],
        [1.0, 0.0, 0.0],
        [1.0, 1.0, 0.0],
        [0.0, 1.0, 0.0],
        [0.0, 0.0, 1.0],
        [1.0, 0.0, 1.0],
        [1.0, 1.0, 1.0],
        [0.0, 1.0, 1.0],
        [0.5, 0.0, 0.0],
        [1.0, 0.5, 0.0],
        [0.5, 1.0, 0.0],
        [0.0, 0.5, 0.0],
        [0.5, 0.0, 1.0],
        [1.0, 0.5, 1.0],
        [0.5, 1.0, 1.0],
        [0.0, 0.5, 1.0],
        [0.0, 0.0, 0.5],
        [1.0, 0.0, 0.5],
        [1.0, 1.0, 0.5],
        [0.0, 1.0, 0.5],
    ],
    dtype=np.float64,
)

# %%
# Build the model
# ---------------
cells = np.concatenate([[20], np.arange(20, dtype=np.int64)])
cell_types = np.array([VTK_QUADRATIC_HEXAHEDRON], dtype=np.uint8)
grid = pv.UnstructuredGrid(cells, cell_types, coords)

m = femorph_solver.Model.from_grid(grid)
m.assign(ELEMENTS.HEX20, material={"EX": E, "PRXY": NU})

# Symmetry BC: x=0 face UX, y=0 face UY, z=0 face UZ.
x0 = [1, 4, 5, 8, 12, 16, 17, 20]
y0 = [1, 2, 5, 6, 9, 13, 17, 18]
z0 = [1, 2, 3, 4, 9, 10, 11, 12]
for nn in x0:
    m.fix(nodes=[nn], dof="UX", value=0.0)
for nn in y0:
    m.fix(nodes=[nn], dof="UY", value=0.0)
for nn in z0:
    m.fix(nodes=[nn], dof="UZ", value=0.0)

# Consistent-load on x=1 face: 4 corners × −F/12 + 4 mid-edges × +F/3
# (integrates to F_TOTAL exactly for a uniform traction).
for nn in (2, 3, 6, 7):
    m.apply_force(nn, fx=-F_TOTAL / 12.0)
for nn in (10, 14, 18, 19):
    m.apply_force(nn, fx=F_TOTAL / 3.0)

# %%
# Static solve and strain recovery
# --------------------------------
res = m.solve_static()
eps = res.elastic_strain(model=m)

eps_xx_expected = F_TOTAL / E
eps_yy_expected = -NU * eps_xx_expected
print(f"εxx expected = {eps_xx_expected:.3e}  |  recovered (mean) = {eps[:, 0].mean():.3e}")
print(f"εyy expected = {eps_yy_expected:.3e}  |  recovered (mean) = {eps[:, 1].mean():.3e}")

# %%
# Plot the deformed cube, coloured by εxx
# ---------------------------------------
grid = femorph_solver.io.static_result_to_grid(m, res)
grid.point_data["eps_xx"] = eps[:, 0]

plotter = pv.Plotter(off_screen=True)
plotter.add_mesh(grid, style="wireframe", color="gray", line_width=1, opacity=0.4)
plotter.add_mesh(
    grid.warp_by_vector("displacement", factor=2.0e5),
    scalars="eps_xx",
    show_edges=True,
    cmap="viridis",
    scalar_bar_args={"title": "εxx"},
)
plotter.add_axes()
plotter.show()
