r"""
TET10 shape-function slices on the unit tetrahedron
====================================================

Renders the ten quadratic-tetrahedron shape functions
:math:`N_i^c(L)` (4 corners) and :math:`N_{ij}^m(L)` (6
mid-edges) sliced through the :math:`\xi_3 = 0` face of the
unit tetrahedron.  In volume coordinates :math:`L_i` (where
:math:`L_0 = 1 - L_1 - L_2 - L_3`) the basis is

.. math::

    N_i^c = L_i\, (2 L_i - 1), \qquad i = 0, \ldots, 3,

.. math::

    N_{ij}^m = 4\, L_i\, L_j, \qquad
    (i, j) \in \{(0,1), (1,2), (2,0), (0,3), (1,3), (2,3)\}.

Each corner shape function vanishes at every other corner and
every mid-edge node *not* on its incident edges; each
mid-edge function equals 1 at its own mid-edge node and 0
everywhere else.  The figure below visualises this property
on a 2D cross-section.

References
----------

* Cook, R. D., Malkus, D. S., Plesha, M. E., Witt, R. J.
  (2002) *Concepts and Applications of Finite Element
  Analysis*, 4th ed., Wiley, Table 6.5-1.
* Hughes, T. J. R. (2000) *The Finite Element Method —
  Linear Static and Dynamic Finite Element Analysis*, Dover,
  §3.8.
* Zienkiewicz, O. C., Taylor, R. L. (2013) *The Finite
  Element Method*, 7th ed., §8.8.2.

Implementation: :class:`femorph_solver.elements.tet10.Tet10`.
"""

from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np

# %%
# Unit-tet corner volume coordinates.

corner_coords = np.eye(4)  # L_0 = (1,0,0,0); L_1 = (0,1,0,0); ...
edge_pairs = [(0, 1), (1, 2), (2, 0), (0, 3), (1, 3), (2, 3)]


def n_corner(L, i):
    """Quadratic corner shape function L_i (2 L_i - 1)."""
    return L[..., i] * (2 * L[..., i] - 1)


def n_mid_edge(L, i, j):
    """Quadratic mid-edge shape function 4 L_i L_j."""
    return 4 * L[..., i] * L[..., j]


# %%
# Sample a triangular slice through the xi_3 = 0 face
#
# That face is the triangle with corners (L_0, L_1, L_2, L_3) =
# (1,0,0,0), (0,1,0,0), (0,0,1,0).  Sample (L_1, L_2) on a
# triangular grid; L_0 = 1 - L_1 - L_2 fills in.

n = 31
L1, L2 = np.meshgrid(np.linspace(0, 1, n), np.linspace(0, 1, n), indexing="ij")
mask = L1 + L2 <= 1
L0 = 1.0 - L1 - L2
L3 = np.zeros_like(L0)
# Stack into shape (n, n, 4); set values outside the triangle to NaN
# so they don't render.
L = np.stack([L0, L1, L2, L3], axis=-1)
L[~mask] = np.nan

# %%
# Render — 4×3 grid: rows are corner shape functions (3) +
# the three mid-edges that lie on the xi_3=0 face (3 more);
# the six mid-edges off the face (those that include node 3)
# vanish on this slice and aren't worth plotting.

face_face_pairs = [(0, 1), (1, 2), (2, 0)]
n_corners_on_face = 3
fig, axes = plt.subplots(2, 3, figsize=(12, 7), constrained_layout=True)
fig.suptitle(
    "TET10 shape functions on the $\\xi_3 = 0$ face slice",
    fontsize=13,
)

for col in range(3):
    ax = axes[0, col]
    N = n_corner(L, col)
    img = ax.contourf(L1, L2, N, levels=21, cmap="plasma")
    ax.set_aspect("equal")
    ax.set_title(f"$N_{{{col}}}^c = L_{{{col}}} (2 L_{{{col}}} - 1)$", fontsize=10)
    ax.set_xlabel(r"$L_1$")
    ax.set_ylabel(r"$L_2$")
    fig.colorbar(img, ax=ax, shrink=0.85)

for col, (i, j) in enumerate(face_face_pairs):
    ax = axes[1, col]
    N = n_mid_edge(L, i, j)
    img = ax.contourf(L1, L2, N, levels=21, cmap="plasma")
    ax.set_aspect("equal")
    ax.set_title(f"$N_{{{i}{j}}}^m = 4 L_{{{i}}} L_{{{j}}}$", fontsize=10)
    ax.set_xlabel(r"$L_1$")
    ax.set_ylabel(r"$L_2$")
    fig.colorbar(img, ax=ax, shrink=0.85)

fig.show()

# %%
# Verify partition of unity at every node of the tetrahedron
#
# All 10 nodes — 4 corners + 6 mid-edges — collected into a
# single (10, 4) volume-coordinate array.  At every one,
# sum_i N_i^c + sum_{ij} N_{ij}^m must equal 1.

mid_coords = np.array([0.5 * (corner_coords[i] + corner_coords[j]) for i, j in edge_pairs])
all_node_coords = np.vstack([corner_coords, mid_coords])  # (10, 4)
total = sum(n_corner(all_node_coords, i) for i in range(4)) + sum(
    n_mid_edge(all_node_coords, i, j) for i, j in edge_pairs
)
np.testing.assert_allclose(total, 1.0, atol=1e-14)
print("OK — TET10 partition-of-unity holds at every reference-element node.")

# %%
# Verify Kronecker-delta interpolation at every node
#
# - corner i: N_i^c(corner_j) = delta_{ij}; all six mid-edge
#   functions vanish at every corner.
# - mid-edge (i, j): N_{ij}^m(mid_kl) = delta_{(i,j),(k,l)};
#   all four corner functions vanish at every mid-edge.

# Stack everything for one big check: the (10, 10) basis matrix
# evaluated at every reference-element node should be identity.
n_corner_at_all = np.stack([n_corner(all_node_coords, i) for i in range(4)], axis=-1)
n_mid_at_all = np.stack([n_mid_edge(all_node_coords, i, j) for i, j in edge_pairs], axis=-1)
basis = np.hstack([n_corner_at_all, n_mid_at_all])  # (10, 10)
np.testing.assert_allclose(basis, np.eye(10), atol=1e-14)
print("OK — TET10 Kronecker-delta interpolation verified at every node.")
