Source code for femorph_solver.solvers.linear._base
"""Common base class for pluggable linear solvers."""
from __future__ import annotations
from typing import ClassVar
import numpy as np
import scipy.sparse as sp
[docs]
class SolverUnavailableError(RuntimeError):
"""Raised when a registered solver's optional dependency is missing."""
[docs]
class LinearSolverBase:
"""Protocol every linear-solver backend implements.
Concrete subclasses set the class attributes (name, kind, spd_only,
install_hint), override :meth:`available`, and implement
:meth:`_factor` and :meth:`solve`.
"""
#: Short identifier used by ``get_linear_solver("...")``.
name: ClassVar[str] = ""
#: ``"direct"`` (factor + forward/back solve) or ``"iterative"``.
kind: ClassVar[str] = "direct"
#: ``True`` if the solver requires a symmetric positive-definite matrix.
spd_only: ClassVar[bool] = False
#: User-facing hint printed when the backend is unavailable.
install_hint: ClassVar[str] = ""
def __init__(
self,
A: sp.spmatrix | sp.sparray,
*,
assume_spd: bool = False,
mixed_precision: bool | None = False,
) -> None:
"""``mixed_precision``:
* ``False`` (default) — factor and solve in double precision.
* ``True`` — factor in single precision and refine in double
(for backends that support it; no-op elsewhere). Halves the
factor footprint when the backend honours the request.
* ``None`` — let the backend choose the default based on the
problem (``assume_spd`` + size heuristic); the
:class:`PardisoSolver` is currently the only backend that
opts into MP here, and only when the input is large enough
for the memory saving to matter.
"""
self._n = A.shape[0]
self._assume_spd = assume_spd
self._mixed_precision = mixed_precision
self._factor(A)
[docs]
@staticmethod
def available() -> bool:
"""Return ``True`` if the solver can be constructed on this install."""
return True
def _factor(self, A: sp.spmatrix | sp.sparray) -> None:
"""Precompute any factorization / preconditioner.
Subclasses override. Called once by :meth:`__init__`.
"""
raise NotImplementedError
[docs]
def solve(self, b: np.ndarray) -> np.ndarray:
"""Solve ``A x = b`` for a single RHS."""
raise NotImplementedError