r"""
Simply-supported plate under uniform pressure (Navier series)
=============================================================

Classical thin-plate verification: a simply-supported square
plate of side :math:`a` and thickness :math:`h` carries a
uniform transverse pressure :math:`q`.  The Navier double-sine
expansion (Timoshenko & Woinowsky-Krieger 1959 §30 eqs. 115–117)
gives the centre deflection

.. math::

    w(x, y) = \frac{16\, q}{\pi^{6} D}
    \sum_{m \,\text{odd}} \sum_{n \,\text{odd}}
    \frac{\sin(m \pi x / a)\, \sin(n \pi y / a)}
         {m\, n\, \bigl(m^{2}/a^{2} + n^{2}/a^{2}\bigr)^{2}},
    \qquad
    D = \frac{E\, h^{3}}{12\, (1 - \nu^{2})}.

Evaluated at :math:`(a/2, a/2)` with summation truncated at the
first ~25 odd-modes per direction the series converges to

.. math::

    w_\mathrm{max} \;\approx\; 0.00406\, \frac{q\, a^{4}}{D}
    \qquad \text{(Timoshenko §31 Table 8).}

Implementation
--------------

A :math:`30 \times 30 \times 2`-element HEX8 mesh on the unit
square plate, with :math:`L/h = 50` (default geometry).
:class:`~femorph_solver.elements.hex8.Hex8` is invoked in
**enhanced-strain** mode (Simo–Rifai 1990) so the bending
response of the thin-plate slab is recovered without
shear-locking — the locked plain-Gauss form lands at ~50 %
error on this mesh; EAS closes that gap to ~5 %.

Symmetry / SS BCs:

* :math:`u_z = 0` along all four edges, full thickness — the
  standard solid-element analogue of a Kirchhoff simple
  support.
* In-plane rigid-body translation pinned at one corner;
  in-plane :math:`u_y` pinned at the opposite-x corner to kill
  the residual rotation mode.

Pressure is lumped onto the top-face nodes by uniform area
weighting (the lumping error is :math:`O(1/n_x^{2})`, well below
the plate-theory error on this mesh).

References
----------

* Timoshenko, S. P. and Woinowsky-Krieger, S. (1959) *Theory
  of Plates and Shells*, 2nd ed., McGraw-Hill, §30 (Navier
  series), §31 (centre-deflection tables).
* Cook, R. D., Malkus, D. S., Plesha, M. E., Witt, R. J. (2002)
  *Concepts and Applications of Finite Element Analysis*, 4th
  ed., Wiley, §12.5.
* Simo, J. C. and Rifai, M. S. (1990) "A class of mixed
  assumed-strain methods and the method of incompatible
  modes," *IJNME* 29 (8), 1595–1638 (HEX8 enhanced strain).
"""

from __future__ import annotations

import math

import numpy as np
import pyvista as pv

import femorph_solver
from femorph_solver import ELEMENTS

# %%
# Problem data
# ------------

E = 2.0e11  # Pa (steel)
NU = 0.3
RHO = 7850.0  # kg/m^3
a = 1.0  # plate edge length [m]
h = 0.02  # thickness [m]; L/h = 50
q = 1.0e5  # uniform pressure [Pa]

NX = NY = 30
NZ = 2

D = E * h**3 / (12.0 * (1.0 - NU**2))

# %%
# Closed-form Navier series for the centre deflection
# ---------------------------------------------------


def navier_w_max(q: float, a: float, D: float, n_terms: int = 25) -> float:
    """Sum the Navier double-sine series for the centre deflection."""
    total = 0.0
    for mm in range(1, 2 * n_terms + 1, 2):
        for nn in range(1, 2 * n_terms + 1, 2):
            sign = ((-1) ** ((mm - 1) // 2)) * ((-1) ** ((nn - 1) // 2))
            denom = mm * nn * (mm**2 / a**2 + nn**2 / a**2) ** 2
            total += sign / denom
    return (16.0 * q / (math.pi**6 * D)) * total


w_max_published = navier_w_max(q, a, D)
print(f"a = {a} m, h = {h} m, L/h = {a / h:.0f}")
print(f"E = {E / 1e9:.0f} GPa, ν = {NU}, D = {D:.3e} N m")
print(f"q = {q / 1e3:.1f} kPa")
print(f"w_max  (Navier sum, 25 odd modes per axis) = {w_max_published * 1e6:.3f} µm")
print(f"w_max  (Timoshenko Table 8 ≈ 0.00406 q a^4/D) = {0.00406 * q * a**4 / D * 1e6:.3f} µm")

# %%
# Build a 30 × 30 × 2 HEX8 mesh
# -----------------------------

xs = np.linspace(0.0, a, NX + 1)
ys = np.linspace(0.0, a, NY + 1)
zs = np.linspace(0.0, h, NZ + 1)
grid = pv.StructuredGrid(*np.meshgrid(xs, ys, zs, indexing="ij")).cast_to_unstructured_grid()
print(f"mesh: {grid.n_points} nodes, {grid.n_cells} HEX8 cells")

m = femorph_solver.Model.from_grid(grid)
m.assign(
    ELEMENTS.HEX8(integration="enhanced_strain"),
    material={"EX": E, "PRXY": NU, "DENS": RHO},
)

# %%
# Boundary conditions
# -------------------
#
# UZ pinned along all four edges (full thickness) → SS support;
# in-plane rigid-body translation pinned at one corner;
# residual UY rotation pinned at the diagonally-opposite corner.

pts = np.asarray(m.grid.points)
tol = 1e-9
edge = (
    (np.abs(pts[:, 0]) < tol)
    | (np.abs(pts[:, 0] - a) < tol)
    | (np.abs(pts[:, 1]) < tol)
    | (np.abs(pts[:, 1] - a) < tol)
)
m.fix(nodes=(np.where(edge)[0] + 1).tolist(), dof="UZ")

corner = int(np.where(np.linalg.norm(pts, axis=1) < tol)[0][0])
m.fix(nodes=[corner + 1], dof="UX")
m.fix(nodes=[corner + 1], dof="UY")

corner_x = int(
    np.where((np.abs(pts[:, 0] - a) < tol) & (np.abs(pts[:, 1]) < tol) & (np.abs(pts[:, 2]) < tol))[
        0
    ][0]
)
m.fix(nodes=[corner_x + 1], dof="UY")

# %%
# Lumped pressure on the top face
# -------------------------------

top = np.where(np.abs(pts[:, 2] - h) < tol)[0]
f_per_node = -q * (a * a) / len(top)
for n_idx in top:
    m.apply_force(int(n_idx + 1), fz=f_per_node)

# %%
# Static solve + centre-deflection extraction
# -------------------------------------------

res = m.solve_static()
u = np.asarray(res.displacement, dtype=np.float64).reshape(-1, 3)
centre = np.array([a / 2, a / 2, h / 2])
idx_centre = int(np.argmin(np.linalg.norm(pts - centre, axis=1)))
w_computed = -float(u[idx_centre, 2])

err_pct = (w_computed - w_max_published) / w_max_published * 100.0
print()
print(f"w_max computed (HEX8 EAS, 30×30×2) = {w_computed * 1e6:+.3f} µm")
print(f"w_max published (Navier)            = {w_max_published * 1e6:+.3f} µm")
print(f"relative error                      = {err_pct:+.2f} %")

# 10% tolerance — HEX8 EAS at L/h = 50 lands within ~5–6% on a 30×30×2
# mesh.  Refining to 60×60×2 closes this further; the validation suite
# studies the convergence ladder explicitly.
assert abs(err_pct) < 10.0, f"w_max deviation {err_pct:.2f}% exceeds 10%"

# %%
# Render the deflected plate, coloured by UZ
# ------------------------------------------

grid_render = m.grid.copy()
grid_render.point_data["displacement"] = u
grid_render.point_data["UZ"] = u[:, 2]

warp = grid_render.warp_by_vector("displacement", factor=2.0e2)

plotter = pv.Plotter(off_screen=True, window_size=(720, 480))
plotter.add_mesh(
    grid_render,
    style="wireframe",
    color="grey",
    opacity=0.35,
    line_width=1,
    label="undeformed",
)
plotter.add_mesh(
    warp,
    scalars="UZ",
    cmap="viridis",
    show_edges=False,
    scalar_bar_args={"title": "UZ [m]  (deformed ×200)"},
)
plotter.view_isometric()
plotter.camera.zoom(1.05)
plotter.show()
