Skip to content

desdeo_brb.brb

The main model class: BRBModel.

brb

Main BRB model class with an sklearn-compatible interface.

Provides the BRBModel class which supports fitting, predicting, and inspecting a Belief Rule-Based inference system.

BRBModel

BRBModel(precedent_referential_values: list[ndarray], consequent_referential_values: ndarray, rule_base: RuleBase | None = None, utility_fn: Callable[[ndarray], ndarray] | None = None, initial_rule_fn: Callable[[ndarray], float] | None = None, backend: str = 'numpy')

A trainable Belief Rule-Based inference model.

Implements an sklearn-compatible interface (fit, predict, score, get_params, set_params) for building and using BRB systems.

Parameters:

Name Type Description Default
precedent_referential_values list[ndarray]

List of 1D sorted arrays, one per attribute.

required
consequent_referential_values ndarray

1D sorted array of consequent values.

required
rule_base RuleBase | None

Optional pre-configured RuleBase. If None, a default one is constructed from the referential values.

None
utility_fn Callable[[ndarray], ndarray] | None

Optional utility function applied to consequent values before computing the scalar output.

None
initial_rule_fn Callable[[ndarray], float] | None

Optional callable mapping a 1D array of antecedent values to a scalar. Used to compute initial belief degrees when rule_base is None.

None
Source code in src/desdeo_brb/brb.py
def __init__(
    self,
    precedent_referential_values: list[np.ndarray],
    consequent_referential_values: np.ndarray,
    rule_base: RuleBase | None = None,
    utility_fn: Callable[[np.ndarray], np.ndarray] | None = None,
    initial_rule_fn: Callable[[np.ndarray], float] | None = None,
    backend: str = "numpy",
) -> None:
    if backend not in ("numpy", "jax"):
        raise ValueError(f"backend must be 'numpy' or 'jax', got {backend!r}")
    if backend == "jax":
        from desdeo_brb.jax_backend import JAX_AVAILABLE

        if not JAX_AVAILABLE:
            raise ImportError("Install JAX: pip install desdeo-brb[jax]")
    self._backend = backend

    self._precedent_referential_values = [
        np.asarray(rv, dtype=float) for rv in precedent_referential_values
    ]
    self._consequent_referential_values = np.asarray(consequent_referential_values, dtype=float)
    self._utility_fn = utility_fn
    self._ref_value_lengths = [len(rv) for rv in self._precedent_referential_values]

    if rule_base is not None:
        self.rule_base = rule_base
    else:
        self.rule_base = self._build_default_rule_base(initial_rule_fn)

predict

predict(X: ndarray) -> InferenceResult

Run the full inference pipeline on input data.

Parameters:

Name Type Description Default
X ndarray

Input array of shape (n_samples, n_attributes).

required

Returns:

Name Type Description
An InferenceResult

class:InferenceResult with all intermediate and final values.

Source code in src/desdeo_brb/brb.py
def predict(self, X: np.ndarray) -> InferenceResult:
    """Run the full inference pipeline on input data.

    Args:
        X: Input array of shape ``(n_samples, n_attributes)``.

    Returns:
        An :class:`InferenceResult` with all intermediate and final values.
    """
    if self._backend == "jax":
        return self._predict_jax(X)
    return self._predict_numpy(X)

predict_values

predict_values(X: ndarray) -> ndarray

Convenience method returning only the scalar outputs.

Parameters:

Name Type Description Default
X ndarray

Input array of shape (n_samples, n_attributes).

required

Returns:

Type Description
ndarray

1-D array of shape (n_samples,).

Source code in src/desdeo_brb/brb.py
def predict_values(self, X: np.ndarray) -> np.ndarray:
    """Convenience method returning only the scalar outputs.

    Args:
        X: Input array of shape ``(n_samples, n_attributes)``.

    Returns:
        1-D array of shape ``(n_samples,)``.
    """
    return self.predict(X).output

explain

explain(X: ndarray, sample_idx: int = 0, top_k: int = 3, attribute_names: list[str] | None = None, consequent_name: str | None = None, threshold: float = 0.01) -> str

Predict on X and return a human-readable explanation.

Convenience wrapper that calls predict(X) and then InferenceResult.explain() with this model's rule base.

Parameters:

Name Type Description Default
X ndarray

Input array of shape (n_samples, n_attributes).

required
sample_idx int

Which sample in the batch to explain.

0
top_k int

Number of top-activated rules to show.

3
attribute_names list[str] | None

Display names for each attribute.

None
consequent_name str | None

Display name for the consequent.

None
threshold float

Minimum weight/belief to display.

0.01
Source code in src/desdeo_brb/brb.py
def explain(
    self,
    X: np.ndarray,
    sample_idx: int = 0,
    top_k: int = 3,
    attribute_names: list[str] | None = None,
    consequent_name: str | None = None,
    threshold: float = 0.01,
) -> str:
    """Predict on *X* and return a human-readable explanation.

    Convenience wrapper that calls ``predict(X)`` and then
    ``InferenceResult.explain()`` with this model's rule base.

    Args:
        X: Input array of shape ``(n_samples, n_attributes)``.
        sample_idx: Which sample in the batch to explain.
        top_k: Number of top-activated rules to show.
        attribute_names: Display names for each attribute.
        consequent_name: Display name for the consequent.
        threshold: Minimum weight/belief to display.
    """
    result = self.predict(X)
    return result.explain(
        sample_idx=sample_idx,
        top_k=top_k,
        rule_base=self.rule_base,
        attribute_names=attribute_names,
        consequent_name=consequent_name,
        threshold=threshold,
    )

fit

fit(X: ndarray, y: ndarray, fix_endpoints: bool = True, fix_endpoint_beliefs: bool = False, normalize_rule_weights: bool = True, method: str | None = None, optimizer_options: dict | None = None, n_restarts: int = 1, verbose: bool = False, **minimize_kwargs: Any) -> BRBModel

Train the model by minimizing MSE.

For the NumPy backend, supported methods are "SLSQP" (default) and "trust-constr". For the JAX backend, the only supported method is "L-BFGS-B" (default), which uses exact jax.grad gradients.

Parameters:

Name Type Description Default
X ndarray

Training inputs, shape (n_samples, n_attributes).

required
y ndarray

Target values, shape (n_samples,).

required
fix_endpoints bool

If True, fix the first and last precedent referential values (endpoints of each attribute's range).

True
fix_endpoint_beliefs bool

If True, also fix the belief degrees for rules at the boundary referential values during training. Use this when the initial beliefs at the domain boundaries are known to be correct (e.g., from initial_rule_fn or verified expert knowledge) to prevent the optimizer from distorting endpoint predictions.

False
normalize_rule_weights bool

If True (default), constrain rule weights to sum to 1 during optimization. If False, only bound each rule weight individually to [0, 1]; the optimizer may pick any scaling and the final stored weights are renormalized. Removing the sum constraint can give SLSQP / trust-constr a less coupled search landscape.

True
method str | None

scipy.optimize.minimize method. NumPy backend supports "SLSQP" and "trust-constr"; JAX backend supports "L-BFGS-B". If None, uses the backend default.

None
optimizer_options dict | None

Options dict passed to scipy.optimize.minimize as the options argument. Merged with sensible per-method defaults; user values override the defaults.

None
n_restarts int

Number of optimization runs (default 1). When > 1, the first run uses the unperturbed initial parameters and subsequent runs perturb the initial parameters with seeded random noise. The final model is the best of all runs as measured by training MSE. Multi-start is critical for escaping bad local minima.

1
verbose bool

If True, print optimizer progress.

False
**minimize_kwargs Any

Extra keyword arguments forwarded to scipy.optimize.minimize.

{}

Returns:

Type Description
BRBModel

self

Source code in src/desdeo_brb/brb.py
def fit(
    self,
    X: np.ndarray,
    y: np.ndarray,
    fix_endpoints: bool = True,
    fix_endpoint_beliefs: bool = False,
    normalize_rule_weights: bool = True,
    method: str | None = None,
    optimizer_options: dict | None = None,
    n_restarts: int = 1,
    verbose: bool = False,
    **minimize_kwargs: Any,
) -> "BRBModel":
    """Train the model by minimizing MSE.

    For the NumPy backend, supported methods are ``"SLSQP"`` (default)
    and ``"trust-constr"``. For the JAX backend, the only supported
    method is ``"L-BFGS-B"`` (default), which uses exact ``jax.grad``
    gradients.

    Args:
        X: Training inputs, shape ``(n_samples, n_attributes)``.
        y: Target values, shape ``(n_samples,)``.
        fix_endpoints: If ``True``, fix the first and last precedent
            referential values (endpoints of each attribute's range).
        fix_endpoint_beliefs: If ``True``, also fix the belief degrees
            for rules at the boundary referential values during training.
            Use this when the initial beliefs at the domain boundaries
            are known to be correct (e.g., from ``initial_rule_fn`` or
            verified expert knowledge) to prevent the optimizer from
            distorting endpoint predictions.
        normalize_rule_weights: If ``True`` (default), constrain rule
            weights to sum to 1 during optimization. If ``False``, only
            bound each rule weight individually to [0, 1]; the optimizer
            may pick any scaling and the final stored weights are
            renormalized. Removing the sum constraint can give SLSQP /
            trust-constr a less coupled search landscape.
        method: scipy.optimize.minimize method. NumPy backend supports
            ``"SLSQP"`` and ``"trust-constr"``; JAX backend supports
            ``"L-BFGS-B"``. If ``None``, uses the backend default.
        optimizer_options: Options dict passed to ``scipy.optimize.minimize``
            as the ``options`` argument. Merged with sensible per-method
            defaults; user values override the defaults.
        n_restarts: Number of optimization runs (default 1). When > 1,
            the first run uses the unperturbed initial parameters and
            subsequent runs perturb the initial parameters with seeded
            random noise. The final model is the best of all runs as
            measured by training MSE. Multi-start is critical for
            escaping bad local minima.
        verbose: If ``True``, print optimizer progress.
        **minimize_kwargs: Extra keyword arguments forwarded to
            ``scipy.optimize.minimize``.

    Returns:
        self
    """
    # Validate method against backend
    if self._backend == "numpy":
        if method is None:
            method = "SLSQP"
        valid_numpy = ("SLSQP", "trust-constr", "ipopt", "DE", "DE+SLSQP")
        if method not in valid_numpy:
            raise ValueError(f"NumPy backend supports methods {valid_numpy}, got {method!r}")
        if method == "ipopt":
            try:
                from desdeo_brb.pyomo_backend import PYOMO_AVAILABLE
            except ImportError:
                PYOMO_AVAILABLE = False
            if not PYOMO_AVAILABLE:
                raise ImportError(
                    "Install Pyomo for IPOPT support: pip install desdeo-brb[pyomo]"
                )
    else:  # jax
        if method is None:
            method = "L-BFGS-B"
        if method != "L-BFGS-B":
            raise ValueError(f"JAX backend supports method='L-BFGS-B' only, got {method!r}")

    if n_restarts < 1:
        raise ValueError(f"n_restarts must be >= 1, got {n_restarts}")

    def _run_one(verbose_inner: bool = False) -> None:
        if self._backend == "jax":
            self._fit_jax(
                X,
                y,
                fix_endpoints,
                fix_endpoint_beliefs,
                method,
                optimizer_options,
                verbose=verbose_inner,
                normalize_rule_weights=normalize_rule_weights,
                **minimize_kwargs,
            )
        elif method == "ipopt":
            self._fit_pyomo(
                X,
                y,
                fix_endpoints,
                fix_endpoint_beliefs,
                normalize_rule_weights,
                optimizer_options,
                verbose_inner,
            )
        elif method == "DE":
            self._fit_de(
                X,
                y,
                fix_endpoints,
                fix_endpoint_beliefs,
                normalize_rule_weights,
                optimizer_options,
                verbose_inner,
            )
        elif method == "DE+SLSQP":
            self._fit_de_slsqp(
                X,
                y,
                fix_endpoints,
                fix_endpoint_beliefs,
                normalize_rule_weights,
                optimizer_options,
                verbose_inner,
            )
        else:
            self._fit_numpy(
                X,
                y,
                fix_endpoints,
                fix_endpoint_beliefs,
                method,
                optimizer_options,
                verbose=verbose_inner,
                normalize_rule_weights=normalize_rule_weights,
                **minimize_kwargs,
            )

    if n_restarts == 1:
        # Single run, surface optimizer verbosity through the inner call
        _run_one(verbose_inner=verbose)
        return self

    # Multi-start: snapshot the initial parameters, run multiple times,
    # keep the result with the lowest training MSE.
    initial_flat = self._flatten_params()

    best_mse = float("inf")
    best_rule_base: RuleBase | None = None
    best_restart_index = 0

    for restart in range(n_restarts):
        if restart == 0:
            self.rule_base = self._unflatten_params(initial_flat)
        else:
            rng = np.random.default_rng(restart)
            perturbed_flat = self._perturb_params(
                initial_flat, rng, fix_endpoints, normalize_rule_weights
            )
            self.rule_base = self._unflatten_params(perturbed_flat)

        _run_one(verbose_inner=verbose)

        y_pred = self.predict_values(X)
        mse = float(np.mean((y - y_pred) ** 2))
        improved = mse < best_mse
        if improved:
            best_mse = mse
            best_rule_base = self.rule_base
            best_restart_index = restart + 1

        if verbose:
            marker = " (best)" if improved else ""
            print(f"Restart {restart + 1}/{n_restarts}: MSE = {mse:.5f}{marker}")

    if best_rule_base is not None:
        self.rule_base = best_rule_base

    if verbose:
        print(f"Best result: restart {best_restart_index}, MSE = {best_mse:.5f}")

    return self

update_from_pyomo

update_from_pyomo(pyomo_model) -> None

Extract solved parameter values from a Pyomo model and update the rule base.

Reads the variable values for belief degrees, rule weights, attribute weights, and referential values from the Pyomo model and assembles them into a fresh RuleBase (with validation). Solver-tolerance violations are projected back onto the constraint surface (rows renormalized to sum to 1, attribute weights clipped to be non-negative, referential values sorted).

This method is the inverse of :func:build_pyomo_brb_model for the parameter-extraction direction. Users who want to optimize a custom Pyomo objective on top of the BRB structure can call::

from desdeo_brb.pyomo_backend import build_pyomo_brb_model
import pyomo.environ as pyo

m = build_pyomo_brb_model(brb, X, y)
m.del_component(m.obj)
m.obj = pyo.Objective(expr=my_custom_loss(m), sense=pyo.minimize)
pyo.SolverFactory("ipopt").solve(m)
brb.update_from_pyomo(m)
Source code in src/desdeo_brb/brb.py
def update_from_pyomo(self, pyomo_model) -> None:
    """Extract solved parameter values from a Pyomo model and update the rule base.

    Reads the variable values for belief degrees, rule weights,
    attribute weights, and referential values from the Pyomo model
    and assembles them into a fresh ``RuleBase`` (with validation).
    Solver-tolerance violations are projected back onto the constraint
    surface (rows renormalized to sum to 1, attribute weights clipped
    to be non-negative, referential values sorted).

    This method is the inverse of :func:`build_pyomo_brb_model` for
    the parameter-extraction direction. Users who want to optimize a
    custom Pyomo objective on top of the BRB structure can call::

        from desdeo_brb.pyomo_backend import build_pyomo_brb_model
        import pyomo.environ as pyo

        m = build_pyomo_brb_model(brb, X, y)
        m.del_component(m.obj)
        m.obj = pyo.Objective(expr=my_custom_loss(m), sense=pyo.minimize)
        pyo.SolverFactory("ipopt").solve(m)
        brb.update_from_pyomo(m)
    """
    try:
        import pyomo.environ as pyo
    except ImportError as exc:  # pragma: no cover
        raise ImportError("Install Pyomo: pip install desdeo-brb[pyomo]") from exc

    n_rules = pyomo_model._brb_n_rules
    n_consequents = pyomo_model._brb_n_consequents
    n_attributes = pyomo_model._brb_n_attributes
    ref_value_lengths = pyomo_model._brb_ref_value_lengths
    rule_antecedent_indices = pyomo_model._brb_rule_antecedent_indices
    consequent_rv = pyomo_model._brb_consequent_referential_values

    # Extract belief degrees and clamp + renormalize per row
    belief_degrees = np.zeros((n_rules, n_consequents))
    for k in range(n_rules):
        for n in range(n_consequents):
            belief_degrees[k, n] = float(pyo.value(pyomo_model.beta[k, n]))
    belief_degrees = np.clip(belief_degrees, 0.0, 1.0)
    row_sums = belief_degrees.sum(axis=1, keepdims=True)
    row_sums = np.where(row_sums > 0, row_sums, 1.0)
    belief_degrees = belief_degrees / row_sums

    # Extract rule weights and renormalize
    rule_weights = np.array([float(pyo.value(pyomo_model.theta[k])) for k in range(n_rules)])
    rule_weights = np.clip(rule_weights, 0.0, 1.0)
    rw_sum = rule_weights.sum()
    if rw_sum > 0:
        rule_weights = rule_weights / rw_sum
    else:
        rule_weights = np.full(n_rules, 1.0 / n_rules)

    # Extract attribute weights and clip to non-negative
    attribute_weights = np.zeros((n_rules, n_attributes))
    for k in range(n_rules):
        for i in range(n_attributes):
            attribute_weights[k, i] = float(pyo.value(pyomo_model.delta[k, i]))
    attribute_weights = np.clip(attribute_weights, 0.0, None)

    # Extract referential values and sort each attribute's values
    precedent_referential_values: list[np.ndarray] = []
    for i in range(n_attributes):
        length = int(ref_value_lengths[i])
        rv = np.array([float(pyo.value(pyomo_model.A[i, j])) for j in range(length)])
        rv = np.sort(rv)
        precedent_referential_values.append(rv)

    new_rule_base = RuleBase(
        precedent_referential_values=precedent_referential_values,
        consequent_referential_values=np.asarray(consequent_rv),
        belief_degrees=belief_degrees,
        rule_weights=rule_weights,
        attribute_weights=attribute_weights,
        rule_antecedent_indices=np.asarray(rule_antecedent_indices),
    )

    self.rule_base = new_rule_base
    # Keep the model's cached referential value lengths in sync
    self._precedent_referential_values = precedent_referential_values
    self._ref_value_lengths = [len(rv) for rv in precedent_referential_values]

fit_custom

fit_custom(loss_fn: Callable[[BRBModel], float], fix_endpoints: bool = True, fix_endpoint_beliefs: bool = False, normalize_rule_weights: bool = True, method: str = 'SLSQP', optimizer_options: dict | None = None, n_restarts: int = 1, constraints: list[dict] | None = None, verbose: bool = False, **minimize_kwargs: Any) -> BRBModel

Train using a user-supplied loss function.

The loss function receives the model instance (with updated parameters) and must return a scalar loss value. The model's parameters are updated internally before each call so the user can simply call model.predict_values() inside the loss.

Optimization uses scipy with finite differences regardless of the model's backend, since the user's loss function is opaque to JAX. The structural BRB constraints (belief degree row sums, rule weight sum, attribute weight bounds, referential value ordering) are always enforced; users may pass additional constraints.

Parameters:

Name Type Description Default
loss_fn Callable[[BRBModel], float]

Callable (model) -> float returning the scalar loss.

required
fix_endpoints bool

If True, fix the first and last precedent referential values for each attribute.

True
fix_endpoint_beliefs bool

If True, fix the belief degrees of rules at the boundary referential values.

False
normalize_rule_weights bool

If True, constrain rule weights to sum to 1 during optimization.

True
method str

scipy optimizer to use. Supported: "SLSQP" (default) and "trust-constr".

'SLSQP'
optimizer_options dict | None

Options dict passed to scipy.optimize.minimize, merged with sensible per-method defaults.

None
n_restarts int

Number of optimization runs from perturbed initial points. The best result by loss_fn value is kept.

1
constraints list[dict] | None

Additional constraints to add on top of the BRB structural constraints. For SLSQP, list of dicts with "type" / "fun" keys; for trust-constr, list of LinearConstraint / NonlinearConstraint objects.

None
verbose bool

If True, print per-restart loss.

False
**minimize_kwargs Any

Extra keyword arguments forwarded to scipy.optimize.minimize.

{}

Returns:

Type Description
BRBModel

self

Source code in src/desdeo_brb/brb.py
def fit_custom(
    self,
    loss_fn: Callable[["BRBModel"], float],
    fix_endpoints: bool = True,
    fix_endpoint_beliefs: bool = False,
    normalize_rule_weights: bool = True,
    method: str = "SLSQP",
    optimizer_options: dict | None = None,
    n_restarts: int = 1,
    constraints: list[dict] | None = None,
    verbose: bool = False,
    **minimize_kwargs: Any,
) -> "BRBModel":
    """Train using a user-supplied loss function.

    The loss function receives the model instance (with updated
    parameters) and must return a scalar loss value. The model's
    parameters are updated internally before each call so the user
    can simply call ``model.predict_values()`` inside the loss.

    Optimization uses scipy with finite differences regardless of the
    model's backend, since the user's loss function is opaque to JAX.
    The structural BRB constraints (belief degree row sums, rule weight
    sum, attribute weight bounds, referential value ordering) are
    always enforced; users may pass additional ``constraints``.

    Args:
        loss_fn: Callable ``(model) -> float`` returning the scalar loss.
        fix_endpoints: If ``True``, fix the first and last precedent
            referential values for each attribute.
        fix_endpoint_beliefs: If ``True``, fix the belief degrees of
            rules at the boundary referential values.
        normalize_rule_weights: If ``True``, constrain rule weights to
            sum to 1 during optimization.
        method: scipy optimizer to use. Supported: ``"SLSQP"`` (default)
            and ``"trust-constr"``.
        optimizer_options: Options dict passed to ``scipy.optimize.minimize``,
            merged with sensible per-method defaults.
        n_restarts: Number of optimization runs from perturbed initial
            points. The best result by ``loss_fn`` value is kept.
        constraints: Additional constraints to add on top of the BRB
            structural constraints. For SLSQP, list of dicts with
            ``"type"`` / ``"fun"`` keys; for trust-constr, list of
            ``LinearConstraint`` / ``NonlinearConstraint`` objects.
        verbose: If ``True``, print per-restart loss.
        **minimize_kwargs: Extra keyword arguments forwarded to
            ``scipy.optimize.minimize``.

    Returns:
        self
    """
    if method not in ("SLSQP", "trust-constr"):
        raise ValueError(
            f"fit_custom supports method='SLSQP' or 'trust-constr', got {method!r}"
        )
    if n_restarts < 1:
        raise ValueError(f"n_restarts must be >= 1, got {n_restarts}")

    def _run_one() -> None:
        self._fit_custom_inner(
            loss_fn,
            fix_endpoints,
            fix_endpoint_beliefs,
            normalize_rule_weights,
            method,
            optimizer_options,
            constraints,
            **minimize_kwargs,
        )

    if n_restarts == 1:
        _run_one()
        return self

    # Multi-start: snapshot initial parameters, run multiple times,
    # keep the rule base with the lowest loss_fn value.
    initial_flat = self._flatten_params()
    best_loss = float("inf")
    best_rule_base: RuleBase | None = None
    best_restart_index = 0

    for restart in range(n_restarts):
        if restart == 0:
            self.rule_base = self._unflatten_params(initial_flat)
        else:
            rng = np.random.default_rng(restart)
            perturbed_flat = self._perturb_params(
                initial_flat, rng, fix_endpoints, normalize_rule_weights
            )
            self.rule_base = self._unflatten_params(perturbed_flat)

        _run_one()

        loss = float(loss_fn(self))
        improved = loss < best_loss
        if improved:
            best_loss = loss
            best_rule_base = self.rule_base
            best_restart_index = restart + 1

        if verbose:
            marker = " (best)" if improved else ""
            print(f"Restart {restart + 1}/{n_restarts}: loss = {loss:.5f}{marker}")

    if best_rule_base is not None:
        self.rule_base = best_rule_base

    if verbose:
        print(f"Best result: restart {best_restart_index}, loss = {best_loss:.5f}")

    return self

get_params

get_params(deep: bool = True) -> dict[str, Any]

Get model parameters (sklearn-compatible).

Parameters:

Name Type Description Default
deep bool

If True, return nested parameters.

True

Returns:

Type Description
dict[str, Any]

Dictionary of parameter names to values.

Source code in src/desdeo_brb/brb.py
def get_params(self, deep: bool = True) -> dict[str, Any]:
    """Get model parameters (sklearn-compatible).

    Args:
        deep: If ``True``, return nested parameters.

    Returns:
        Dictionary of parameter names to values.
    """
    params: dict[str, Any] = {
        "precedent_referential_values": self._precedent_referential_values,
        "consequent_referential_values": self._consequent_referential_values,
        "utility_fn": self._utility_fn,
        "backend": self._backend,
    }
    if deep:
        params["rule_base"] = self.rule_base
    return params

set_params

set_params(**params: Any) -> BRBModel

Set model parameters (sklearn-compatible).

Parameters:

Name Type Description Default
**params Any

Parameter names and values.

{}

Returns:

Type Description
BRBModel

self

Source code in src/desdeo_brb/brb.py
def set_params(self, **params: Any) -> "BRBModel":
    """Set model parameters (sklearn-compatible).

    Args:
        **params: Parameter names and values.

    Returns:
        self
    """
    if "precedent_referential_values" in params:
        self._precedent_referential_values = [
            np.asarray(rv, dtype=float) for rv in params["precedent_referential_values"]
        ]
        self._ref_value_lengths = [len(rv) for rv in self._precedent_referential_values]
    if "consequent_referential_values" in params:
        self._consequent_referential_values = np.asarray(
            params["consequent_referential_values"], dtype=float
        )
    if "utility_fn" in params:
        self._utility_fn = params["utility_fn"]
    if "rule_base" in params:
        self.rule_base = params["rule_base"]
    if "backend" in params:
        backend = params["backend"]
        if backend not in ("numpy", "jax"):
            raise ValueError(f"backend must be 'numpy' or 'jax', got {backend!r}")
        self._backend = backend
    return self

score

score(X: ndarray, y: ndarray) -> float

Return negative MSE (sklearn convention: higher is better).

Parameters:

Name Type Description Default
X ndarray

Input array, shape (n_samples, n_attributes).

required
y ndarray

True target values, shape (n_samples,).

required

Returns:

Type Description
float

Negative mean squared error.

Source code in src/desdeo_brb/brb.py
def score(self, X: np.ndarray, y: np.ndarray) -> float:
    """Return negative MSE (sklearn convention: higher is better).

    Args:
        X: Input array, shape ``(n_samples, n_attributes)``.
        y: True target values, shape ``(n_samples,)``.

    Returns:
        Negative mean squared error.
    """
    y_pred = self.predict_values(X)
    return -float(np.mean((y - y_pred) ** 2))