r"""
QUAD4 plane reference geometry — bilinear shape functions + 2×2 Gauss
=====================================================================

The 4-node bilinear plane element maps to the natural
coordinates :math:`(\xi, \eta) \in [-1, 1]^2`.  Two
translational DOFs per node; 8 DOFs per element.

Shape functions (bilinear Lagrange on the reference square):

.. math::

    N_i(\xi, \eta) = \tfrac{1}{4}(1 + \xi_i\xi)(1 + \eta_i\eta),
    \qquad i = 1, \ldots, 4,

with :math:`(\xi_i, \eta_i) \in \{-1, +1\}^2` at the corresponding
corners.  Stiffness and consistent mass use 2×2 Gauss-Legendre
(4 points), exact for the bilinear :math:`B^T D B` and
:math:`N^T N` integrands.

Plane stress vs plane strain selected by the ``_PLANE_MODE``
material flag (Cook 2002 §3.5–§3.6).

References
----------
* Zienkiewicz, O. C. and Taylor, R. L. (2013) *The Finite Element
  Method*, 7th ed., §6 + App. G.
* Cook, R. D., Malkus, D. S., Plesha, M. E., Witt, R. J. (2002)
  *Concepts and Applications of Finite Element Analysis*, 4th
  ed., Wiley, §5, §11.3.
* Hughes, T. J. R. (2000) *The Finite Element Method*, Dover, §3.1.

Implementation: :class:`femorph_solver.elements.quad4_plane.Quad4Plane`.
"""

from __future__ import annotations

import numpy as np
import pyvista as pv

# %%
# Reference quad
# --------------

corners = np.array([[-1.0, -1.0, 0.0], [+1.0, -1.0, 0.0], [+1.0, +1.0, 0.0], [-1.0, +1.0, 0.0]])
cells = np.hstack([[4], np.arange(4, dtype=np.int64)])
cell_types = np.array([pv.CellType.QUAD], dtype=np.uint8)
ref_quad = pv.UnstructuredGrid(cells, cell_types, corners)

# %%
# 2×2 Gauss-Legendre on the reference square

g = 1.0 / np.sqrt(3.0)
gauss = np.array([[a, b, 0.0] for a in (-g, +g) for b in (-g, +g)])

# %%
# Render

plotter = pv.Plotter(off_screen=True, window_size=(560, 480))
plotter.add_mesh(
    ref_quad, style="wireframe", color="black", line_width=2.5, label="reference QUAD4"
)
plotter.add_points(
    corners, render_points_as_spheres=True, point_size=18, color="black", label="corner nodes (4)"
)
plotter.add_points(
    gauss, render_points_as_spheres=True, point_size=16, color="#d62728", label="2×2 Gauss"
)
plotter.view_xy()
plotter.camera.zoom(1.4)
plotter.add_legend(face=None, size=(0.22, 0.10), bcolor="white")
plotter.show()

# %%
# Verify partition of unity at the Gauss points
# ---------------------------------------------

xi_i = np.array([-1.0, +1.0, +1.0, -1.0])
eta_i = np.array([-1.0, -1.0, +1.0, +1.0])
sums = []
for q in gauss:
    s = sum(0.25 * (1 + xi_i[i] * q[0]) * (1 + eta_i[i] * q[1]) for i in range(4))
    sums.append(s)
sums = np.array(sums)
np.testing.assert_allclose(sums, 1.0, atol=1e-14)
print(f"4 Gauss points; Σ N_i(ξ_q, η_q) = {sums.round(15)} — partition of unity verified.")
