desdeo_brb.jax_backend¶
JAX-based inference and training. Requires the jax optional dependency
(pip install desdeo-brb[jax]).
jax_backend
¶
JAX backend for the BRB inference pipeline.
Provides JIT-compiled versions of the inference functions and a differentiable end-to-end inference function for gradient-based training.
Because jax.jit traces code and requires static array shapes, the
varying-length referential value arrays are padded into fixed-size 2D
arrays (see :func:~desdeo_brb.utils.pad_referential_values). Unused
entries are filled with np.inf and masked during computation.
The rv_lengths parameter is passed as a tuple of Python ints (not a
JAX array) and declared as a static argument for JIT, so that the lengths
are available as concrete values at trace time.
input_transform_jax
¶
Transform raw inputs into belief distributions (JAX version).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X
|
ndarray
|
Input array of shape |
required |
padded_rv
|
ndarray
|
Padded referential values of shape
|
required |
rv_lengths
|
tuple[int, ...]
|
Tuple of ints with the actual number of referential values per attribute. Must be concrete (not a traced array). |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
3D array of shape |
ndarray
|
with belief degrees. Padded positions are zero. |
Source code in src/desdeo_brb/jax_backend.py
compute_activation_weights_jax
¶
compute_activation_weights_jax(alphas: ndarray, rule_antecedent_indices: ndarray, thetas: ndarray, deltas: ndarray) -> ndarray
Compute activation weights (JAX version).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
alphas
|
ndarray
|
3-D array from :func: |
required |
rule_antecedent_indices
|
ndarray
|
Integer array |
required |
thetas
|
ndarray
|
Rule weights |
required |
deltas
|
ndarray
|
Attribute weights |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
2-D array of shape |
Source code in src/desdeo_brb/jax_backend.py
compute_combined_belief_degrees_jax
¶
Combine belief degrees using the ER algorithm (JAX version).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bre_matrix
|
ndarray
|
Shape |
required |
weights
|
ndarray
|
Shape |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Shape |
Source code in src/desdeo_brb/jax_backend.py
compute_output_jax
¶
Compute scalar outputs (JAX version, identity utility only).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
belief_degrees
|
ndarray
|
Shape |
required |
consequents
|
ndarray
|
Shape |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Shape |
Source code in src/desdeo_brb/jax_backend.py
full_inference_jax
¶
full_inference_jax(flat_params: ndarray, X: ndarray, consequent_rv: ndarray, rule_antecedent_indices: ndarray, n_rules: int, n_consequents: int, n_attributes: int, rv_lengths: tuple[int, ...]) -> ndarray
End-to-end differentiable inference from flat parameters to outputs.
This is a pure function suitable for jax.jit and jax.grad.
It unflattens the parameter vector, runs all inference steps, and
returns scalar outputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
flat_params
|
ndarray
|
1-D parameter vector (same layout as
:meth: |
required |
X
|
ndarray
|
Input array |
required |
consequent_rv
|
ndarray
|
Consequent referential values |
required |
rule_antecedent_indices
|
ndarray
|
Integer array |
required |
n_rules
|
int
|
Number of rules (static). |
required |
n_consequents
|
int
|
Number of consequent values (static). |
required |
n_attributes
|
int
|
Number of attributes (static). |
required |
rv_lengths
|
tuple[int, ...]
|
Tuple of Python ints with referential value lengths per attribute (static — required for JIT tracing). |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
1D array of shape |
Source code in src/desdeo_brb/jax_backend.py
full_inference_jax_unconstrained
¶
full_inference_jax_unconstrained(flat_params: ndarray, X: ndarray, consequent_rv: ndarray, rule_antecedent_indices: ndarray, n_rules: int, n_consequents: int, n_attributes: int, rv_lengths: tuple[int, ...], normalize_rule_weights: bool = True) -> ndarray
End-to-end inference from unconstrained parameters.
Wraps :func:full_inference_jax with differentiable reparameterization:
softmax for belief degree rows, softmax or sigmoid for rule weights,
softplus for attribute weights, and sort for referential values. This
allows L-BFGS-B (box bounds only) to optimize without explicit equality
constraints.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
Same as
|
func: |
required | |
normalize_rule_weights
|
bool
|
If True, apply softmax to rule weights (constraining them to the simplex). If False, apply sigmoid (each weight independently in [0, 1]). |
True
|