diff --git a/test/b/test_b_jax.py b/test/b/test_b_jax.py index 2e274f3..0e27443 100644 --- a/test/b/test_b_jax.py +++ b/test/b/test_b_jax.py @@ -92,7 +92,7 @@ def test_eval(self): c = self.c f = self.f y = f.eval(c) - precalculated = np.asarray( + y_precalculated = np.asarray( [ [ [19.8694, 19.7848, 20.3956], @@ -111,8 +111,8 @@ def test_eval(self): ], ] ) - self.assertEqual((3, 3, 3), y.shape) - self.assertTrue(np.allclose(y, precalculated)) + self.assertEqual(y_precalculated.shape, y.shape) + self.assertTrue(np.allclose(y, y_precalculated)) g = f.jac(c) self.assertEqual(y.shape + c.shape, g.shape) @@ -179,9 +179,9 @@ def test_bernstein_poly(self): ) f = BernsteinPoly(c) y = f.eval(c, x) - precalculated = np.asarray([19.8694, 32.0761, 19.6774]) - self.assertEqual((3,), y.shape) - self.assertTrue(jnp.allclose(y, precalculated)) + y_precalculated = np.asarray([19.8694, 32.0761, 19.6774]) + self.assertEqual(y_precalculated.shape, y.shape) + self.assertTrue(np.allclose(y, y_precalculated)) g = f.jac_p(c, x) self.assertEqual((3,) + d, g.shape) @@ -192,20 +192,20 @@ def test_bernstein_poly(self): self.assertTrue(np.all(g > 0.0)) def test_from_lookup_table(self): - k = 5 - x = np.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) - y = np.square(x) + 2.0 * x + 3.0 - - f = BernsteinPoly.from_lookup_table((k,), (x,), y, non_negative=True) - c = f.prior() - self.assertEqual((k + 1,), c.shape) - self.assertAlmostEqual(3.0, c[0]) - self.assertAlmostEqual(3.4, c[1]) - self.assertAlmostEqual(3.9, c[2]) - self.assertAlmostEqual(4.5, c[3]) - self.assertAlmostEqual(5.2, c[4]) - self.assertAlmostEqual(6.0, c[5]) - self.assertTrue(jnp.allclose(f.eval(c, x), y)) + k = (3, 4, 2) + d = tuple([k_ + 1 for k_ in k]) + c = np.arange(np.prod(np.asarray(d))).reshape(d) + 1.0 + x = ( + np.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]), + np.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]), + np.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]), + ) + y = BernsteinGrid(x).eval(c) + + f = BernsteinPoly.from_lookup_table(k, x, y) + b = f.prior() + self.assertEqual(c.shape, b.shape) + self.assertTrue(np.allclose(b, c)) class BSolveTest(unittest.TestCase): @@ -217,11 +217,12 @@ class BSolveTest(unittest.TestCase): def test_b_solve_0_2(self): r"""Fit :math:`B_{0,2}(x)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = jnp.square(1.0 - x) c = b_solve((k,), (x,), y, non_negative=True) self.assertEqual((k + 1,), c.shape) + self.assertFalse(np.any(c < 0.0)) self.assertAlmostEqual(1.0, c[0].item()) self.assertAlmostEqual(0.0, c[1].item()) self.assertAlmostEqual(0.0, c[2].item()) @@ -229,11 +230,12 @@ def test_b_solve_0_2(self): def test_b_solve_1_2(self): r"""Fit :math:`B_{1,2}(x)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = 2.0 * x * (1.0 - x) c = b_solve((k,), (x,), y, non_negative=True) self.assertEqual((k + 1,), c.shape) + self.assertFalse(np.any(c < 0.0)) self.assertAlmostEqual(0.0, c[0].item()) self.assertAlmostEqual(1.0, c[1].item()) self.assertAlmostEqual(0.0, c[2].item()) @@ -241,11 +243,12 @@ def test_b_solve_1_2(self): def test_b_solve_2_2(self): r"""Fit :math:`B_{2,2}(x)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = jnp.square(x) c = b_solve((k,), (x,), y, non_negative=True) self.assertEqual((k + 1,), c.shape) + self.assertFalse(np.any(c < 0.0)) self.assertAlmostEqual(0.0, c[0].item()) self.assertAlmostEqual(0.0, c[1].item()) self.assertAlmostEqual(1.0, c[2].item()) @@ -253,13 +256,14 @@ def test_b_solve_2_2(self): def test_b_solve_0_0_2_2(self): r"""Fit :math:`B_{(0,0),(2,2)}(x_0, x_1)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = jnp.square(1.0 - x[jnp.newaxis, :]) * jnp.square( 1.0 - x[:, jnp.newaxis] ) c = b_solve((k, k), (x, x), y, non_negative=True) self.assertEqual((k + 1, k + 1), c.shape) + self.assertFalse(np.any(c < 0.0)) self.assertAlmostEqual(1.0, c[0, 0].item()) self.assertAlmostEqual(0.0, c[0, 1].item()) self.assertAlmostEqual(0.0, c[0, 2].item()) @@ -273,7 +277,7 @@ def test_b_solve_0_0_2_2(self): def test_b_solve_1_0_2_2(self): r"""Fit :math:`B_{(1,0),(2,2)}(x_0, x_1)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = ( 2.0 * x[:, jnp.newaxis] @@ -283,6 +287,7 @@ def test_b_solve_1_0_2_2(self): c = b_solve((k, k), (x, x), y, non_negative=True) self.assertEqual((k + 1, k + 1), c.shape) + self.assertFalse(np.any(c < 0.0)) self.assertAlmostEqual(0.0, c[0, 0].item()) self.assertAlmostEqual(0.0, c[0, 1].item()) self.assertAlmostEqual(0.0, c[0, 2].item()) @@ -296,13 +301,14 @@ def test_b_solve_1_0_2_2(self): def test_b_solve_2_0_2_2(self): r"""Fit :math:`B_{(2,0),(2,2)}(x_0, x_1)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = jnp.square(x[:, jnp.newaxis]) * jnp.square( 1.0 - x[jnp.newaxis, :] ) c = b_solve((k, k), (x, x), y, non_negative=True) self.assertEqual((k + 1, k + 1), c.shape) + self.assertFalse(np.any(c < 0.0)) self.assertAlmostEqual(0.0, c[0, 0].item()) self.assertAlmostEqual(0.0, c[0, 1].item()) self.assertAlmostEqual(0.0, c[0, 2].item()) @@ -316,7 +322,7 @@ def test_b_solve_2_0_2_2(self): def test_b_solve_0_1_2_2(self): r"""Fit :math:`B_{(0,1),(2,2)}(x_0, x_1)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = ( 2.0 * jnp.square(1.0 - x[:, jnp.newaxis]) @@ -326,6 +332,7 @@ def test_b_solve_0_1_2_2(self): c = b_solve((k, k), (x, x), y, non_negative=True) self.assertEqual((k + 1, k + 1), c.shape) + self.assertFalse(np.any(c < 0.0)) self.assertAlmostEqual(0.0, c[0, 0].item()) self.assertAlmostEqual(1.0, c[0, 1].item()) self.assertAlmostEqual(0.0, c[0, 2].item()) @@ -339,7 +346,7 @@ def test_b_solve_0_1_2_2(self): def test_b_solve_1_1_2_2(self): r"""Fit :math:`B_{(1,1),(2,2)}(x_0, x_1)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = ( 4.0 * x[:, jnp.newaxis] @@ -350,6 +357,7 @@ def test_b_solve_1_1_2_2(self): c = b_solve((k, k), (x, x), y, non_negative=True) self.assertEqual((k + 1, k + 1), c.shape) + self.assertFalse(np.any(c < 0.0)) self.assertAlmostEqual(0.0, c[0, 0].item()) self.assertAlmostEqual(0.0, c[0, 1].item()) self.assertAlmostEqual(0.0, c[0, 2].item()) @@ -363,7 +371,7 @@ def test_b_solve_1_1_2_2(self): def test_b_solve_2_1_2_2(self): r"""Fit :math:`B_{(2,1),(2,2)}(x_0, x_1)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = ( 2.0 * jnp.square(x[:, jnp.newaxis]) @@ -372,6 +380,7 @@ def test_b_solve_2_1_2_2(self): ) c = b_solve((k, k), (x, x), y, non_negative=True) + self.assertFalse(np.any(c < 0.0)) self.assertEqual((k + 1, k + 1), c.shape) self.assertAlmostEqual(0.0, c[0, 0].item()) self.assertAlmostEqual(0.0, c[0, 1].item()) @@ -386,12 +395,13 @@ def test_b_solve_2_1_2_2(self): def test_b_solve_0_2_2_2(self): r"""Fit :math:`B_{(0,2),(2,2)}(x_0, x_1)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = jnp.square(1.0 - x[:, jnp.newaxis]) * jnp.square( x[jnp.newaxis, :] ) c = b_solve((k, k), (x, x), y, non_negative=True) + self.assertFalse(np.any(c < 0.0)) self.assertEqual((k + 1, k + 1), c.shape) self.assertAlmostEqual(0.0, c[0, 0].item()) self.assertAlmostEqual(0.0, c[0, 1].item()) @@ -406,7 +416,7 @@ def test_b_solve_0_2_2_2(self): def test_b_solve_1_2_2_2(self): r"""Fit :math:`B_{(1,2),(2,2)}(x_0, x_1)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = ( 2.0 * x[:, jnp.newaxis] @@ -416,6 +426,7 @@ def test_b_solve_1_2_2_2(self): c = b_solve((k, k), (x, x), y, non_negative=True) self.assertEqual((k + 1, k + 1), c.shape) + self.assertFalse(np.any(c < 0.0)) self.assertAlmostEqual(0.0, c[0, 0].item()) self.assertAlmostEqual(0.0, c[0, 1].item()) self.assertAlmostEqual(0.0, c[0, 2].item()) @@ -429,11 +440,12 @@ def test_b_solve_1_2_2_2(self): def test_b_solve_2_2_2_2(self): r"""Fit :math:`B_{(2,2),(2,2)}(x_0, x_1)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = jnp.square(x[:, jnp.newaxis]) * jnp.square(x[jnp.newaxis, :]) c = b_solve((k, k), (x, x), y, non_negative=True) self.assertEqual((k + 1, k + 1), c.shape) + self.assertFalse(np.any(c < 0.0)) self.assertAlmostEqual(0.0, c[0, 0].item()) self.assertAlmostEqual(0.0, c[0, 1].item()) self.assertAlmostEqual(0.0, c[0, 2].item()) diff --git a/uncertaintyx/b/jax.py b/uncertaintyx/b/jax.py index 4ca8f38..05660d5 100644 --- a/uncertaintyx/b/jax.py +++ b/uncertaintyx/b/jax.py @@ -5,8 +5,8 @@ from typing import Self import jax -import jax.lax.linalg as jla import jax.numpy as jnp +import jax.numpy.linalg as jli import numpy as np import optax import optimistix @@ -186,44 +186,35 @@ def b_solve( N = len(k) # noqa: N806 bases = [b_basis(k[i], x[i]) for i in range(N)] - facts = [jla.qr(B.T, full_matrices=False) for B in bases] # noqa: N806 - Q = [_[0] for _ in facts] # noqa: N806 - R = [_[1] for _ in facts] # noqa: N806 + grams = [jnp.dot(B, B.T) for B in bases] # noqa: N806 - # compute the right hand side of the triangular equation + # compute the right hand side of the normal equation rhs = y for i in range(N): - rhs = jnp.tensordot(rhs, Q[i], axes=(0, 0)) - # solve the triangular equation + B = bases[i] # noqa: N806 + rhs = jnp.tensordot(rhs, B, axes=(0, 1)) + # solve the normal equation c_unconstrained = rhs - if N > 1: - for i in range(N): - solve = jax.vmap( - lambda a, b: jla.triangular_solve(a, b, left_side=True), - in_axes=(None, i), - out_axes=i, - ) - c_unconstrained = solve(R[i], c_unconstrained) - else: - c_unconstrained = jla.triangular_solve( - R[0], c_unconstrained, left_side=True + for i in range(N): + G = grams[i] # noqa: N806 + c_unconstrained = jnp.tensordot( + c_unconstrained, jli.pinv(G), axes=(0, 1) ) def hvp(c: Array): """The Hessian-vector product.""" res = c for i in range(N): - res = jnp.tensordot(res, R[i], axes=(0, 1)) - for i in range(N): - res = jnp.tensordot(res, R[i], axes=(0, 0)) + G = grams[i] # noqa: N806 + res = jnp.tensordot(res, G, axes=(0, 1)) return res - def nnls(c: Array, rhs: Array): + def nnls(c: Array): """ Non-negative least-squares solver. - Applies a positive transformation and an L-BFGS - optimizer to ensure non-negativity. + Applies a positive transformation and an L-BFGS optimizer + to ensure non-negativity. """ def forward(u: Array) -> Array: @@ -250,10 +241,6 @@ def make_minimizer(): optax.lbfgs(), atol=atol, rtol=rtol, norm=optimistix.max_norm ) - # compute the right hand side of the normal equation - for i in range(N): - rhs = jnp.tensordot(rhs, R[i], axes=(0, 0)) - u = inverse(jnp.abs(c) + jnp.finfo(c.dtype).eps) optimum = optimistix.minimise( misfit, make_minimizer(), u, max_steps=max_steps, throw=False @@ -264,13 +251,7 @@ def make_minimizer(): nnls_needed = jnp.logical_and( non_negative, jnp.any(c_unconstrained < 0.0) ) - return jax.lax.cond( - nnls_needed, - nnls, - lambda c, _: c, - c_unconstrained, - rhs, - ) + return jax.lax.cond(nnls_needed, nnls, lambda c: c, c_unconstrained) def _lower_bounds( @@ -327,9 +308,10 @@ def __init__( :param a: The lower bounds of the grid coordinates. :param b: The upper bounds of the grid coordinates. """ + N = len(x) # noqa: : N806 a = _lower_bounds(a, x) b = _upper_bounds(b, x) - x = tuple(jnp.asarray((x_ - a) / (b - a)) for x_ in x) + x_ = tuple(jnp.asarray((x[i] - a[i]) / (b[i] - a[i])) for i in range(N)) def f(c: Array) -> Array: r""" @@ -420,9 +402,10 @@ def from_lookup_table( :param rtol: The relative tolerance for terminating the solver. :param max_steps: The maximum number of steps the solver can take. """ + N = len(k) # noqa: : N806 a = _lower_bounds(a, x) b = _upper_bounds(b, x) - x_ = tuple(jnp.asarray((x_ - a) / (b - a)) for x_ in x) + x_ = tuple(jnp.asarray((x[i] - a[i]) / (b[i] - a[i])) for i in range(N)) y_ = jnp.asarray(y) c_ = b_solve( k,