"""Pluggable linear solvers.
femorph-solver provides several backends for solving ``A x = b`` with a sparse SPD
or general sparse ``A``. The default is SciPy's built-in SuperLU; the
rest are optional and require extra dependencies listed in the class's
``install_hint``.
Usage::
from femorph_solver.solvers.linear import get_linear_solver
Solver = get_linear_solver("cholmod") # or "superlu", "umfpack", "pardiso", ...
s = Solver(K_csc)
x = s.solve(b)
``list_linear_solvers()`` reports every registered backend with an
``available`` flag so UIs / benchmarks can enumerate them.
"""
from __future__ import annotations
from femorph_solver._log import get_logger
from ._base import LinearSolverBase, SolverUnavailableError
from ._cg import CGSolver
from ._cholmod import CholmodSolver
from ._gmres import GMRESSolver
from ._mkl_pardiso import DirectMklPardisoSolver
from ._mumps import MUMPSSolver
from ._pardiso import PardisoSolver
from ._pyamg import PyAMGSolver
from ._superlu import SuperLUSolver
from ._umfpack import UMFPACKSolver
_REGISTRY: dict[str, type[LinearSolverBase]] = {
cls.name: cls
for cls in (
SuperLUSolver,
UMFPACKSolver,
CholmodSolver,
MUMPSSolver,
PardisoSolver,
DirectMklPardisoSolver,
CGSolver,
GMRESSolver,
PyAMGSolver,
)
}
#: DOF count above which Pardiso's threaded factorisation beats its
#: ~200 ms MKL-side setup overhead. On ``perf/bench_solvers.py`` with
#: ``OMP_NUM_THREADS=4`` (a sensible default for 4-core machines): at 32 940 DOFs
#: Pardiso takes 1.63 s vs CHOLMOD 1.03 s inside ARPACK; at 151 500 DOFs
#: Pardiso wins 3.4× (7.3 s vs 23.1 s). 50 000 DOFs is the point where
#: the two lines cross in the bench, so below it ``auto`` stays on the
#: lower-overhead CHOLMOD/SuperLU path.
_PARDISO_SIZE_THRESHOLD = 50_000
#: Auto preference when the problem is large enough for Pardiso.
#: ``mkl_direct`` is the bypass-pypardiso ctypes wrapper — on HEX8
#: modal benches it delivers ~2× lower modal_solve wall-time than
#: ``pardiso`` (per-solve drops from ~120 ms to ~30 ms at 128×128×2)
#: by skipping pypardiso's per-call ``_check_A``, ``_check_b``,
#: ``astype(int32)+1`` churn, and advertising 0-based CSR to MKL.
#: Peak RSS grows ~10 % at that size; the wall win dominates, so
#: ``mkl_direct`` leads the SPD chain when available. MUMPS slots in
#: below CHOLMOD but above UMFPACK because its multifrontal LDLᵀ is
#: well-tuned for large stiffness matrices and its license is
#: redistributable even when MKL Pardiso isn't.
_AUTO_SPD_LARGE = ("mkl_direct", "pardiso", "cholmod", "mumps", "umfpack", "superlu")
_AUTO_GENERAL_LARGE = ("mkl_direct", "pardiso", "mumps", "umfpack", "superlu")
#: Auto preference for smaller problems where Pardiso's setup overhead
#: dominates the wall time — CHOLMOD beats Pardiso ~1.6× under 50k DOFs
#: on the bench, so we stay on the lightweight path. MUMPS has similar
#: setup overhead to Pardiso, so it also loses to CHOLMOD/UMFPACK on
#: small systems.
_AUTO_SPD_SMALL = ("cholmod", "umfpack", "mkl_direct", "pardiso", "mumps", "superlu")
_AUTO_GENERAL_SMALL = ("umfpack", "mkl_direct", "pardiso", "mumps", "superlu")
#: Fire the "pardiso would have won here but isn't installed" warning at
#: most once per process — the call site is hot (every modal/static
#: solve) and users only need to see the install hint once.
_PARDISO_MISSING_WARNED = False
def _warn_pardiso_missing_if_beneficial(n_dof: int | None) -> None:
"""Emit a one-shot install hint when Pardiso would have been picked."""
global _PARDISO_MISSING_WARNED
if _PARDISO_MISSING_WARNED:
return
if n_dof is None or n_dof < _PARDISO_SIZE_THRESHOLD:
return
if PardisoSolver.available():
return
_PARDISO_MISSING_WARNED = True
# This is a runtime-diagnostic soft signal (Pardiso would be
# faster, but we're falling back gracefully) — route it through
# the femorph-solver logger at WARNING rather than
# ``warnings.warn``. Users pick up the hint via
# ``femorph_solver.open_logger()`` without enabling the Python
# warning filter globally.
get_logger(__name__).warning(
"pardiso not installed; at %s DOFs it is typically 3x-5x faster "
"than SuperLU/CHOLMOD on Intel hardware. Install with "
"`pip install femorph_solver[pardiso]` to enable it in the auto "
"solver chain. This message is emitted once per process.",
f"{n_dof:,}",
)
[docs]
def select_default_linear_solver(
*,
spd: bool = True,
n_dof: int | None = None,
) -> str:
"""Return the best available backend name for the requested workload.
Size-aware: when ``n_dof >= 50_000`` the preference is
``pardiso → cholmod → umfpack → superlu`` because Pardiso's threaded
factor dominates. Below the threshold (or when ``n_dof`` is not
known) Pardiso is demoted because its fixed setup cost erases its
factor-time lead on small systems. ``superlu`` is always available
as the final fallback.
If Pardiso would have been chosen by size but is not installed, a
one-shot :class:`UserWarning` points at the install hint.
"""
pardiso_big_enough = n_dof is not None and n_dof >= _PARDISO_SIZE_THRESHOLD
if pardiso_big_enough:
_warn_pardiso_missing_if_beneficial(n_dof)
if spd:
order = _AUTO_SPD_LARGE if pardiso_big_enough else _AUTO_SPD_SMALL
else:
order = _AUTO_GENERAL_LARGE if pardiso_big_enough else _AUTO_GENERAL_SMALL
for name in order:
cls = _REGISTRY.get(name)
if cls is not None and cls.available():
return name
return "superlu"
[docs]
def get_linear_solver(
name: str,
*,
n_dof: int | None = None,
) -> type[LinearSolverBase]:
"""Look up a registered linear-solver class by name.
Passing ``"auto"`` resolves via :func:`select_default_linear_solver`;
pass ``n_dof`` so the auto picker can prefer Pardiso on large
problems where its parallel factorisation dominates, and skip it on
small ones where its fixed overhead erases the gain.
Raises :class:`KeyError` if the name is unknown, or
:class:`SolverUnavailableError` if the backend is registered but its
optional dependency is not installed.
"""
if name == "auto":
name = select_default_linear_solver(spd=True, n_dof=n_dof)
try:
cls = _REGISTRY[name]
except KeyError as exc:
known = ", ".join(sorted(_REGISTRY))
raise KeyError(f"unknown linear solver {name!r}; registered: {known}") from exc
if not cls.available():
raise SolverUnavailableError(
f"linear solver {name!r} is registered but unavailable — {cls.install_hint}"
)
return cls
[docs]
def list_linear_solvers() -> list[dict[str, object]]:
"""Return a list of ``{name, available, kind, spd_only, install_hint}`` rows."""
out: list[dict[str, object]] = []
for name, cls in _REGISTRY.items():
out.append(
{
"name": name,
"available": cls.available(),
"kind": cls.kind,
"spd_only": cls.spd_only,
"install_hint": cls.install_hint,
}
)
return out
__all__ = [
"CGSolver",
"CholmodSolver",
"GMRESSolver",
"LinearSolverBase",
"MUMPSSolver",
"PardisoSolver",
"PyAMGSolver",
"SolverUnavailableError",
"SuperLUSolver",
"UMFPACKSolver",
"get_linear_solver",
"list_linear_solvers",
"select_default_linear_solver",
]