Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 44 additions & 32 deletions test/b/test_b_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_eval(self):
c = self.c
f = self.f
y = f.eval(c)
precalculated = np.asarray(
y_precalculated = np.asarray(
[
[
[19.8694, 19.7848, 20.3956],
Expand All @@ -111,8 +111,8 @@ def test_eval(self):
],
]
)
self.assertEqual((3, 3, 3), y.shape)
self.assertTrue(np.allclose(y, precalculated))
self.assertEqual(y_precalculated.shape, y.shape)
self.assertTrue(np.allclose(y, y_precalculated))

g = f.jac(c)
self.assertEqual(y.shape + c.shape, g.shape)
Expand Down Expand Up @@ -179,9 +179,9 @@ def test_bernstein_poly(self):
)
f = BernsteinPoly(c)
y = f.eval(c, x)
precalculated = np.asarray([19.8694, 32.0761, 19.6774])
self.assertEqual((3,), y.shape)
self.assertTrue(jnp.allclose(y, precalculated))
y_precalculated = np.asarray([19.8694, 32.0761, 19.6774])
self.assertEqual(y_precalculated.shape, y.shape)
self.assertTrue(np.allclose(y, y_precalculated))

g = f.jac_p(c, x)
self.assertEqual((3,) + d, g.shape)
Expand All @@ -192,20 +192,20 @@ def test_bernstein_poly(self):
self.assertTrue(np.all(g > 0.0))

def test_from_lookup_table(self):
k = 5
x = np.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
y = np.square(x) + 2.0 * x + 3.0

f = BernsteinPoly.from_lookup_table((k,), (x,), y, non_negative=True)
c = f.prior()
self.assertEqual((k + 1,), c.shape)
self.assertAlmostEqual(3.0, c[0])
self.assertAlmostEqual(3.4, c[1])
self.assertAlmostEqual(3.9, c[2])
self.assertAlmostEqual(4.5, c[3])
self.assertAlmostEqual(5.2, c[4])
self.assertAlmostEqual(6.0, c[5])
self.assertTrue(jnp.allclose(f.eval(c, x), y))
k = (3, 4, 2)
d = tuple([k_ + 1 for k_ in k])
c = np.arange(np.prod(np.asarray(d))).reshape(d) + 1.0
x = (
np.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]),
np.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]),
np.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]),
)
y = BernsteinGrid(x).eval(c)

f = BernsteinPoly.from_lookup_table(k, x, y)
b = f.prior()
self.assertEqual(c.shape, b.shape)
self.assertTrue(np.allclose(b, c))


class BSolveTest(unittest.TestCase):
Expand All @@ -217,49 +217,53 @@ class BSolveTest(unittest.TestCase):
def test_b_solve_0_2(self):
r"""Fit :math:`B_{0,2}(x)`."""
k = 2
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
y = jnp.square(1.0 - x)

c = b_solve((k,), (x,), y, non_negative=True)
self.assertEqual((k + 1,), c.shape)
self.assertFalse(np.any(c < 0.0))
self.assertAlmostEqual(1.0, c[0].item())
self.assertAlmostEqual(0.0, c[1].item())
self.assertAlmostEqual(0.0, c[2].item())

def test_b_solve_1_2(self):
r"""Fit :math:`B_{1,2}(x)`."""
k = 2
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
y = 2.0 * x * (1.0 - x)

c = b_solve((k,), (x,), y, non_negative=True)
self.assertEqual((k + 1,), c.shape)
self.assertFalse(np.any(c < 0.0))
self.assertAlmostEqual(0.0, c[0].item())
self.assertAlmostEqual(1.0, c[1].item())
self.assertAlmostEqual(0.0, c[2].item())

def test_b_solve_2_2(self):
r"""Fit :math:`B_{2,2}(x)`."""
k = 2
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
y = jnp.square(x)

c = b_solve((k,), (x,), y, non_negative=True)
self.assertEqual((k + 1,), c.shape)
self.assertFalse(np.any(c < 0.0))
self.assertAlmostEqual(0.0, c[0].item())
self.assertAlmostEqual(0.0, c[1].item())
self.assertAlmostEqual(1.0, c[2].item())

def test_b_solve_0_0_2_2(self):
r"""Fit :math:`B_{(0,0),(2,2)}(x_0, x_1)`."""
k = 2
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
y = jnp.square(1.0 - x[jnp.newaxis, :]) * jnp.square(
1.0 - x[:, jnp.newaxis]
)

c = b_solve((k, k), (x, x), y, non_negative=True)
self.assertEqual((k + 1, k + 1), c.shape)
self.assertFalse(np.any(c < 0.0))
self.assertAlmostEqual(1.0, c[0, 0].item())
self.assertAlmostEqual(0.0, c[0, 1].item())
self.assertAlmostEqual(0.0, c[0, 2].item())
Expand All @@ -273,7 +277,7 @@ def test_b_solve_0_0_2_2(self):
def test_b_solve_1_0_2_2(self):
r"""Fit :math:`B_{(1,0),(2,2)}(x_0, x_1)`."""
k = 2
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
y = (
2.0
* x[:, jnp.newaxis]
Expand All @@ -283,6 +287,7 @@ def test_b_solve_1_0_2_2(self):

c = b_solve((k, k), (x, x), y, non_negative=True)
self.assertEqual((k + 1, k + 1), c.shape)
self.assertFalse(np.any(c < 0.0))
self.assertAlmostEqual(0.0, c[0, 0].item())
self.assertAlmostEqual(0.0, c[0, 1].item())
self.assertAlmostEqual(0.0, c[0, 2].item())
Expand All @@ -296,13 +301,14 @@ def test_b_solve_1_0_2_2(self):
def test_b_solve_2_0_2_2(self):
r"""Fit :math:`B_{(2,0),(2,2)}(x_0, x_1)`."""
k = 2
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
y = jnp.square(x[:, jnp.newaxis]) * jnp.square(
1.0 - x[jnp.newaxis, :]
)

c = b_solve((k, k), (x, x), y, non_negative=True)
self.assertEqual((k + 1, k + 1), c.shape)
self.assertFalse(np.any(c < 0.0))
self.assertAlmostEqual(0.0, c[0, 0].item())
self.assertAlmostEqual(0.0, c[0, 1].item())
self.assertAlmostEqual(0.0, c[0, 2].item())
Expand All @@ -316,7 +322,7 @@ def test_b_solve_2_0_2_2(self):
def test_b_solve_0_1_2_2(self):
r"""Fit :math:`B_{(0,1),(2,2)}(x_0, x_1)`."""
k = 2
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
y = (
2.0
* jnp.square(1.0 - x[:, jnp.newaxis])
Expand All @@ -326,6 +332,7 @@ def test_b_solve_0_1_2_2(self):

c = b_solve((k, k), (x, x), y, non_negative=True)
self.assertEqual((k + 1, k + 1), c.shape)
self.assertFalse(np.any(c < 0.0))
self.assertAlmostEqual(0.0, c[0, 0].item())
self.assertAlmostEqual(1.0, c[0, 1].item())
self.assertAlmostEqual(0.0, c[0, 2].item())
Expand All @@ -339,7 +346,7 @@ def test_b_solve_0_1_2_2(self):
def test_b_solve_1_1_2_2(self):
r"""Fit :math:`B_{(1,1),(2,2)}(x_0, x_1)`."""
k = 2
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
y = (
4.0
* x[:, jnp.newaxis]
Expand All @@ -350,6 +357,7 @@ def test_b_solve_1_1_2_2(self):

c = b_solve((k, k), (x, x), y, non_negative=True)
self.assertEqual((k + 1, k + 1), c.shape)
self.assertFalse(np.any(c < 0.0))
self.assertAlmostEqual(0.0, c[0, 0].item())
self.assertAlmostEqual(0.0, c[0, 1].item())
self.assertAlmostEqual(0.0, c[0, 2].item())
Expand All @@ -363,7 +371,7 @@ def test_b_solve_1_1_2_2(self):
def test_b_solve_2_1_2_2(self):
r"""Fit :math:`B_{(2,1),(2,2)}(x_0, x_1)`."""
k = 2
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
y = (
2.0
* jnp.square(x[:, jnp.newaxis])
Expand All @@ -372,6 +380,7 @@ def test_b_solve_2_1_2_2(self):
)

c = b_solve((k, k), (x, x), y, non_negative=True)
self.assertFalse(np.any(c < 0.0))
self.assertEqual((k + 1, k + 1), c.shape)
self.assertAlmostEqual(0.0, c[0, 0].item())
self.assertAlmostEqual(0.0, c[0, 1].item())
Expand All @@ -386,12 +395,13 @@ def test_b_solve_2_1_2_2(self):
def test_b_solve_0_2_2_2(self):
r"""Fit :math:`B_{(0,2),(2,2)}(x_0, x_1)`."""
k = 2
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
y = jnp.square(1.0 - x[:, jnp.newaxis]) * jnp.square(
x[jnp.newaxis, :]
)

c = b_solve((k, k), (x, x), y, non_negative=True)
self.assertFalse(np.any(c < 0.0))
self.assertEqual((k + 1, k + 1), c.shape)
self.assertAlmostEqual(0.0, c[0, 0].item())
self.assertAlmostEqual(0.0, c[0, 1].item())
Expand All @@ -406,7 +416,7 @@ def test_b_solve_0_2_2_2(self):
def test_b_solve_1_2_2_2(self):
r"""Fit :math:`B_{(1,2),(2,2)}(x_0, x_1)`."""
k = 2
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
y = (
2.0
* x[:, jnp.newaxis]
Expand All @@ -416,6 +426,7 @@ def test_b_solve_1_2_2_2(self):

c = b_solve((k, k), (x, x), y, non_negative=True)
self.assertEqual((k + 1, k + 1), c.shape)
self.assertFalse(np.any(c < 0.0))
self.assertAlmostEqual(0.0, c[0, 0].item())
self.assertAlmostEqual(0.0, c[0, 1].item())
self.assertAlmostEqual(0.0, c[0, 2].item())
Expand All @@ -429,11 +440,12 @@ def test_b_solve_1_2_2_2(self):
def test_b_solve_2_2_2_2(self):
r"""Fit :math:`B_{(2,2),(2,2)}(x_0, x_1)`."""
k = 2
x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
y = jnp.square(x[:, jnp.newaxis]) * jnp.square(x[jnp.newaxis, :])

c = b_solve((k, k), (x, x), y, non_negative=True)
self.assertEqual((k + 1, k + 1), c.shape)
self.assertFalse(np.any(c < 0.0))
self.assertAlmostEqual(0.0, c[0, 0].item())
self.assertAlmostEqual(0.0, c[0, 1].item())
self.assertAlmostEqual(0.0, c[0, 2].item())
Expand Down
57 changes: 20 additions & 37 deletions uncertaintyx/b/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import Self

import jax
import jax.lax.linalg as jla
import jax.numpy as jnp
import jax.numpy.linalg as jli
import numpy as np
import optax
import optimistix
Expand Down Expand Up @@ -186,44 +186,35 @@ def b_solve(

N = len(k) # noqa: N806
bases = [b_basis(k[i], x[i]) for i in range(N)]
facts = [jla.qr(B.T, full_matrices=False) for B in bases] # noqa: N806
Q = [_[0] for _ in facts] # noqa: N806
R = [_[1] for _ in facts] # noqa: N806
grams = [jnp.dot(B, B.T) for B in bases] # noqa: N806

# compute the right hand side of the triangular equation
# compute the right hand side of the normal equation
rhs = y
for i in range(N):
rhs = jnp.tensordot(rhs, Q[i], axes=(0, 0))
# solve the triangular equation
B = bases[i] # noqa: N806
rhs = jnp.tensordot(rhs, B, axes=(0, 1))
# solve the normal equation
c_unconstrained = rhs
if N > 1:
for i in range(N):
solve = jax.vmap(
lambda a, b: jla.triangular_solve(a, b, left_side=True),
in_axes=(None, i),
out_axes=i,
)
c_unconstrained = solve(R[i], c_unconstrained)
else:
c_unconstrained = jla.triangular_solve(
R[0], c_unconstrained, left_side=True
for i in range(N):
G = grams[i] # noqa: N806
c_unconstrained = jnp.tensordot(
c_unconstrained, jli.pinv(G), axes=(0, 1)
)

def hvp(c: Array):
"""The Hessian-vector product."""
res = c
for i in range(N):
res = jnp.tensordot(res, R[i], axes=(0, 1))
for i in range(N):
res = jnp.tensordot(res, R[i], axes=(0, 0))
G = grams[i] # noqa: N806
res = jnp.tensordot(res, G, axes=(0, 1))
return res

def nnls(c: Array, rhs: Array):
def nnls(c: Array):
"""
Non-negative least-squares solver.

Applies a positive transformation and an L-BFGS
optimizer to ensure non-negativity.
Applies a positive transformation and an L-BFGS optimizer
to ensure non-negativity.
"""

def forward(u: Array) -> Array:
Expand All @@ -250,10 +241,6 @@ def make_minimizer():
optax.lbfgs(), atol=atol, rtol=rtol, norm=optimistix.max_norm
)

# compute the right hand side of the normal equation
for i in range(N):
rhs = jnp.tensordot(rhs, R[i], axes=(0, 0))

u = inverse(jnp.abs(c) + jnp.finfo(c.dtype).eps)
optimum = optimistix.minimise(
misfit, make_minimizer(), u, max_steps=max_steps, throw=False
Expand All @@ -264,13 +251,7 @@ def make_minimizer():
nnls_needed = jnp.logical_and(
non_negative, jnp.any(c_unconstrained < 0.0)
)
return jax.lax.cond(
nnls_needed,
nnls,
lambda c, _: c,
c_unconstrained,
rhs,
)
return jax.lax.cond(nnls_needed, nnls, lambda c: c, c_unconstrained)


def _lower_bounds(
Expand Down Expand Up @@ -327,9 +308,10 @@ def __init__(
:param a: The lower bounds of the grid coordinates.
:param b: The upper bounds of the grid coordinates.
"""
N = len(x) # noqa: : N806
a = _lower_bounds(a, x)
b = _upper_bounds(b, x)
x = tuple(jnp.asarray((x_ - a) / (b - a)) for x_ in x)
x_ = tuple(jnp.asarray((x[i] - a[i]) / (b[i] - a[i])) for i in range(N))

def f(c: Array) -> Array:
r"""
Expand Down Expand Up @@ -420,9 +402,10 @@ def from_lookup_table(
:param rtol: The relative tolerance for terminating the solver.
:param max_steps: The maximum number of steps the solver can take.
"""
N = len(k) # noqa: : N806
a = _lower_bounds(a, x)
b = _upper_bounds(b, x)
x_ = tuple(jnp.asarray((x_ - a) / (b - a)) for x_ in x)
x_ = tuple(jnp.asarray((x[i] - a[i]) / (b[i] - a[i])) for i in range(N))
y_ = jnp.asarray(y)
c_ = b_solve(
k,
Expand Down
Loading