"""
.. _ref_link180_example:

TRUSS2 — axial truss under end load
====================================

Single TRUSS2 spar fixed at node 1 and pulled along the x-axis at
node 2. The tip displacement is compared to the closed-form rod
solution ``u = P L / (E A)`` and the axial stress is plotted on the
deformed mesh.
"""

from __future__ import annotations

import numpy as np
import pyvista as pv
from vtkmodules.util.vtkConstants import VTK_LINE

import femorph_solver
from femorph_solver import ELEMENTS

# %%
# Problem data
# ------------
# Steel rod, 1 m long, 100 mm² cross-section, pulled with a 1 kN tip load.
E = 2.1e11  # Pa
A = 1.0e-4  # m² (100 mm²)
L = 1.0  # m
P = 1.0e3  # N (tensile)

# %%
# Build the model
# ---------------
# ``femorph_solver.Model.assign`` collapses the legacy ``et / mp / r``
# trio into one call.  A TRUSS2 has three translational DOFs per node,
# so we only fix UX at node 2's support and zero-out the transverse
# DOFs to kill the transverse rigid modes.
points = np.array([[0.0, 0.0, 0.0], [L, 0.0, 0.0]], dtype=np.float64)
cells = np.array([2, 0, 1], dtype=np.int64)
cell_types = np.array([VTK_LINE], dtype=np.uint8)
grid = pv.UnstructuredGrid(cells, cell_types, points)

m = femorph_solver.Model.from_grid(grid)
m.assign(
    ELEMENTS.TRUSS2,
    material={"EX": E, "DENS": 7850.0},
    real=(A,),
)

m.fix(nodes=[1], dof="ALL")  # clamp all translational DOFs at node 1
m.fix(nodes=[2], dof="UY")  # suppress transverse rigid-body motion
m.fix(nodes=[2], dof="UZ")
m.apply_force(2, fx=P)

# %%
# Static solve + analytical comparison
# ------------------------------------
# The rod equation gives ``u_tip = P L / (E A)``. With a single TRUSS2
# this is exact (linear shape functions are sufficient for a prismatic
# bar under end load).
res = m.solve_static()

dof = m.dof_map()
tip_ux_row = np.where((dof[:, 0] == 2) & (dof[:, 1] == 0))[0][0]
u_tip = res.displacement[tip_ux_row]
u_expected = P * L / (E * A)

print(f"TRUSS2 tip UX        = {u_tip:.6e} m")
print(f"Analytical PL/(EA)    = {u_expected:.6e} m")
assert np.isclose(u_tip, u_expected, rtol=1e-10)

# %%
# Post-processing: axial stress
# -----------------------------
# For a single TRUSS2 the axial stress is uniform: ``σ = E · (Δu / L)``
# where ``Δu`` is the elongation. Carry it as cell data on the mesh for
# plotting.
sigma_axial = E * (u_tip / L)
print(f"Axial stress          = {sigma_axial:.3e} Pa (= P/A = {P / A:.3e})")

grid = m.grid.copy()
displacement = np.zeros((grid.n_points, 3), dtype=np.float64)
for i, node_num in enumerate(grid.point_data["ansys_node_num"]):
    rows = np.where(dof[:, 0] == int(node_num))[0]
    for r in rows:
        displacement[i, int(dof[r, 1])] = res.displacement[r]
grid.point_data["displacement"] = displacement
grid.cell_data["sigma_axial"] = np.array([sigma_axial])

# %%
# Plot the deformed truss coloured by axial stress
# ------------------------------------------------
warped = grid.warp_by_vector("displacement", factor=1.0e5)
plotter = pv.Plotter(off_screen=True)
plotter.add_mesh(
    grid,
    style="wireframe",
    color="gray",
    line_width=3,
    label="undeformed",
)
plotter.add_mesh(
    warped,
    scalars="sigma_axial",
    line_width=6,
    scalar_bar_args={"title": "axial stress [Pa]"},
    label="deformed (×1e5)",
)
plotter.add_legend()
plotter.add_axes()
plotter.show()
