Source code for femorph_solver.solvers.linear._base

"""Common base class for pluggable linear solvers."""

from __future__ import annotations

from typing import ClassVar

import numpy as np
import scipy.sparse as sp


[docs] class SolverUnavailableError(RuntimeError): """Raised when a registered solver's optional dependency is missing."""
[docs] class LinearSolverBase: """Protocol every linear-solver backend implements. Concrete subclasses set the class attributes (name, kind, spd_only, install_hint), override :meth:`available`, and implement :meth:`_factor` and :meth:`solve`. """ #: Short identifier used by ``get_linear_solver("...")``. name: ClassVar[str] = "" #: ``"direct"`` (factor + forward/back solve) or ``"iterative"``. kind: ClassVar[str] = "direct" #: ``True`` if the solver requires a symmetric positive-definite matrix. spd_only: ClassVar[bool] = False #: User-facing hint printed when the backend is unavailable. install_hint: ClassVar[str] = "" def __init__( self, A: sp.spmatrix | sp.sparray, *, assume_spd: bool = False, mixed_precision: bool | None = False, ) -> None: """``mixed_precision``: * ``False`` (default) — factor and solve in double precision. * ``True`` — factor in single precision and refine in double (for backends that support it; no-op elsewhere). Halves the factor footprint when the backend honours the request. * ``None`` — let the backend choose the default based on the problem (``assume_spd`` + size heuristic); the :class:`PardisoSolver` is currently the only backend that opts into MP here, and only when the input is large enough for the memory saving to matter. """ self._n = A.shape[0] self._assume_spd = assume_spd self._mixed_precision = mixed_precision self._factor(A)
[docs] @staticmethod def available() -> bool: """Return ``True`` if the solver can be constructed on this install.""" return True
def _factor(self, A: sp.spmatrix | sp.sparray) -> None: """Precompute any factorization / preconditioner. Subclasses override. Called once by :meth:`__init__`. """ raise NotImplementedError
[docs] def solve(self, b: np.ndarray) -> np.ndarray: """Solve ``A x = b`` for a single RHS.""" raise NotImplementedError