From 0dac141b984e32a363ded2516884088ed2043637 Mon Sep 17 00:00:00 2001 From: "R. Quast" Date: Fri, 22 May 2026 19:54:28 +0200 Subject: [PATCH 01/15] Refactor Bernstein polynomial tests for clarity Refactor tests for Bernstein polynomial evaluation and lookup table creation. Update expected values and assertions to match new data structures. --- test/b/test_b_jax.py | 58 +++++++++++++++++++++++++++++--------------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/test/b/test_b_jax.py b/test/b/test_b_jax.py index 2e274f3..c0c7427 100644 --- a/test/b/test_b_jax.py +++ b/test/b/test_b_jax.py @@ -92,7 +92,7 @@ def test_eval(self): c = self.c f = self.f y = f.eval(c) - precalculated = np.asarray( + y_precalculated = np.asarray( [ [ [19.8694, 19.7848, 20.3956], @@ -111,8 +111,8 @@ def test_eval(self): ], ] ) - self.assertEqual((3, 3, 3), y.shape) - self.assertTrue(np.allclose(y, precalculated)) + self.assertEqual(y_precalculated.shape, y.shape) + self.assertTrue(np.allclose(y, y_precalculated)) g = f.jac(c) self.assertEqual(y.shape + c.shape, g.shape) @@ -179,9 +179,9 @@ def test_bernstein_poly(self): ) f = BernsteinPoly(c) y = f.eval(c, x) - precalculated = np.asarray([19.8694, 32.0761, 19.6774]) - self.assertEqual((3,), y.shape) - self.assertTrue(jnp.allclose(y, precalculated)) + y_precalculated = np.asarray([19.8694, 32.0761, 19.6774]) + self.assertEqual(y_precalculated, y.shape) + self.assertTrue(jnp.allclose(y, y_precalculated)) g = f.jac_p(c, x) self.assertEqual((3,) + d, g.shape) @@ -192,21 +192,39 @@ def test_bernstein_poly(self): self.assertTrue(np.all(g > 0.0)) def test_from_lookup_table(self): - k = 5 - x = np.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) - y = np.square(x) + 2.0 * x + 3.0 - - f = BernsteinPoly.from_lookup_table((k,), (x,), y, non_negative=True) + k = (4, 3, 2) + d = tuple([k_ + 1 for k_ in k]) + x = ( + np.asarray([0.2718, 0.5772, 0.3141]), + np.asarray([0.5772, 0.3141, 0.2718]), + np.asarray([0.3141, 0.2718, 0.5772]), + ) + y = np.asarray( + [ + [ + [19.8694, 19.7848, 20.3956], + [17.5015, 17.4169, 18.0277], + [17.1208, 17.0362, 17.6470], + ], + [ + [34.5286, 34.4440, 35.0548], + [32.1607, 32.0761, 32.6869], + [31.7800, 31.6954, 32.3062], + ], + [ + [21.8998, 21.8152, 22.4260], + [19.5319, 19.4473, 20.0581], + [19.1512, 19.0666, 19.6774], + ], + ] + ) + + f = BernsteinPoly.from_lookup_table(k, x,, y, non_negative=True) c = f.prior() - self.assertEqual((k + 1,), c.shape) - self.assertAlmostEqual(3.0, c[0]) - self.assertAlmostEqual(3.4, c[1]) - self.assertAlmostEqual(3.9, c[2]) - self.assertAlmostEqual(4.5, c[3]) - self.assertAlmostEqual(5.2, c[4]) - self.assertAlmostEqual(6.0, c[5]) - self.assertTrue(jnp.allclose(f.eval(c, x), y)) - + c_expected = np.arange(np.prod(np.asarray(d))).reshape(d) + 1.0 + self.assertEqual(c_expected.shape, c.shape) + self.assertTrue(np.allclose(c, c_expected)) + class BSolveTest(unittest.TestCase): """ From 5c35dafeea3f807c63b9413c3febd63854a3be48 Mon Sep 17 00:00:00 2001 From: "R. Quast" Date: Fri, 22 May 2026 19:57:47 +0200 Subject: [PATCH 02/15] Fix syntax error in from_lookup_table call --- test/b/test_b_jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/b/test_b_jax.py b/test/b/test_b_jax.py index c0c7427..0ae78a9 100644 --- a/test/b/test_b_jax.py +++ b/test/b/test_b_jax.py @@ -219,7 +219,7 @@ def test_from_lookup_table(self): ] ) - f = BernsteinPoly.from_lookup_table(k, x,, y, non_negative=True) + f = BernsteinPoly.from_lookup_table(k, x, y, non_negative=True) c = f.prior() c_expected = np.arange(np.prod(np.asarray(d))).reshape(d) + 1.0 self.assertEqual(c_expected.shape, c.shape) From b4bd8c964cf9e0756f65c082e5424e37574b612d Mon Sep 17 00:00:00 2001 From: "R. Quast" Date: Fri, 22 May 2026 20:12:44 +0200 Subject: [PATCH 03/15] Fix shape assertion and update comparison method --- test/b/test_b_jax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/b/test_b_jax.py b/test/b/test_b_jax.py index 0ae78a9..a8b90c8 100644 --- a/test/b/test_b_jax.py +++ b/test/b/test_b_jax.py @@ -180,8 +180,8 @@ def test_bernstein_poly(self): f = BernsteinPoly(c) y = f.eval(c, x) y_precalculated = np.asarray([19.8694, 32.0761, 19.6774]) - self.assertEqual(y_precalculated, y.shape) - self.assertTrue(jnp.allclose(y, y_precalculated)) + self.assertEqual(y_precalculated.shape, y.shape) + self.assertTrue(np.allclose(y, y_precalculated)) g = f.jac_p(c, x) self.assertEqual((3,) + d, g.shape) From f34b21748a7e5d468855107b03619a04956c8c64 Mon Sep 17 00:00:00 2001 From: "R. Quast" Date: Fri, 22 May 2026 20:16:21 +0200 Subject: [PATCH 04/15] Modify test parameters and assertions in test_b_jax Updated the shape parameter in test case and added assertion for non-negativity of coefficients. --- test/b/test_b_jax.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/b/test_b_jax.py b/test/b/test_b_jax.py index a8b90c8..d53ceff 100644 --- a/test/b/test_b_jax.py +++ b/test/b/test_b_jax.py @@ -192,7 +192,7 @@ def test_bernstein_poly(self): self.assertTrue(np.all(g > 0.0)) def test_from_lookup_table(self): - k = (4, 3, 2) + k = (2, 2, 2) d = tuple([k_ + 1 for k_ in k]) x = ( np.asarray([0.2718, 0.5772, 0.3141]), @@ -223,6 +223,7 @@ def test_from_lookup_table(self): c = f.prior() c_expected = np.arange(np.prod(np.asarray(d))).reshape(d) + 1.0 self.assertEqual(c_expected.shape, c.shape) + self.assertTrue(np.all(c > 0.0)) self.assertTrue(np.allclose(c, c_expected)) From 09fbec3f49229c12da73909480e70a36d1b98b1f Mon Sep 17 00:00:00 2001 From: "R. Quast" Date: Fri, 22 May 2026 20:34:32 +0200 Subject: [PATCH 05/15] Refactor test_from_lookup_table and b_solve tests Refactor tests to use BernsteinGrid evaluation and update assertions for non-negativity checks. --- test/b/test_b_jax.py | 42 ++++++++++++++++++------------------------ 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/test/b/test_b_jax.py b/test/b/test_b_jax.py index d53ceff..bfef043 100644 --- a/test/b/test_b_jax.py +++ b/test/b/test_b_jax.py @@ -194,37 +194,19 @@ def test_bernstein_poly(self): def test_from_lookup_table(self): k = (2, 2, 2) d = tuple([k_ + 1 for k_ in k]) + c = np.arange(np.prod(np.asarray(d))).reshape(d) + 1.0 x = ( np.asarray([0.2718, 0.5772, 0.3141]), np.asarray([0.5772, 0.3141, 0.2718]), np.asarray([0.3141, 0.2718, 0.5772]), ) - y = np.asarray( - [ - [ - [19.8694, 19.7848, 20.3956], - [17.5015, 17.4169, 18.0277], - [17.1208, 17.0362, 17.6470], - ], - [ - [34.5286, 34.4440, 35.0548], - [32.1607, 32.0761, 32.6869], - [31.7800, 31.6954, 32.3062], - ], - [ - [21.8998, 21.8152, 22.4260], - [19.5319, 19.4473, 20.0581], - [19.1512, 19.0666, 19.6774], - ], - ] - ) + y = BernsteinGrid(x).eval(c) f = BernsteinPoly.from_lookup_table(k, x, y, non_negative=True) - c = f.prior() - c_expected = np.arange(np.prod(np.asarray(d))).reshape(d) + 1.0 - self.assertEqual(c_expected.shape, c.shape) - self.assertTrue(np.all(c > 0.0)) - self.assertTrue(np.allclose(c, c_expected)) + b = f.prior() + self.assertEqual(c.shape, b.shape) + self.assertFalse(np.any(b < 0.0)) + self.assertTrue(np.allclose(b, c)) class BSolveTest(unittest.TestCase): @@ -241,6 +223,7 @@ def test_b_solve_0_2(self): c = b_solve((k,), (x,), y, non_negative=True) self.assertEqual((k + 1,), c.shape) + self.assertFalse(np.any(c < 0.0)) self.assertAlmostEqual(1.0, c[0].item()) self.assertAlmostEqual(0.0, c[1].item()) self.assertAlmostEqual(0.0, c[2].item()) @@ -253,6 +236,7 @@ def test_b_solve_1_2(self): c = b_solve((k,), (x,), y, non_negative=True) self.assertEqual((k + 1,), c.shape) + self.assertFalse(np.any(c < 0.0)) self.assertAlmostEqual(0.0, c[0].item()) self.assertAlmostEqual(1.0, c[1].item()) self.assertAlmostEqual(0.0, c[2].item()) @@ -265,6 +249,7 @@ def test_b_solve_2_2(self): c = b_solve((k,), (x,), y, non_negative=True) self.assertEqual((k + 1,), c.shape) + self.assertFalse(np.any(c < 0.0)) self.assertAlmostEqual(0.0, c[0].item()) self.assertAlmostEqual(0.0, c[1].item()) self.assertAlmostEqual(1.0, c[2].item()) @@ -279,6 +264,7 @@ def test_b_solve_0_0_2_2(self): c = b_solve((k, k), (x, x), y, non_negative=True) self.assertEqual((k + 1, k + 1), c.shape) + self.assertFalse(np.any(c < 0.0)) self.assertAlmostEqual(1.0, c[0, 0].item()) self.assertAlmostEqual(0.0, c[0, 1].item()) self.assertAlmostEqual(0.0, c[0, 2].item()) @@ -302,6 +288,7 @@ def test_b_solve_1_0_2_2(self): c = b_solve((k, k), (x, x), y, non_negative=True) self.assertEqual((k + 1, k + 1), c.shape) + self.assertFalse(np.any(c < 0.0)) self.assertAlmostEqual(0.0, c[0, 0].item()) self.assertAlmostEqual(0.0, c[0, 1].item()) self.assertAlmostEqual(0.0, c[0, 2].item()) @@ -322,6 +309,7 @@ def test_b_solve_2_0_2_2(self): c = b_solve((k, k), (x, x), y, non_negative=True) self.assertEqual((k + 1, k + 1), c.shape) + self.assertFalse(np.any(c < 0.0)) self.assertAlmostEqual(0.0, c[0, 0].item()) self.assertAlmostEqual(0.0, c[0, 1].item()) self.assertAlmostEqual(0.0, c[0, 2].item()) @@ -345,6 +333,7 @@ def test_b_solve_0_1_2_2(self): c = b_solve((k, k), (x, x), y, non_negative=True) self.assertEqual((k + 1, k + 1), c.shape) + self.assertFalse(np.any(c < 0.0)) self.assertAlmostEqual(0.0, c[0, 0].item()) self.assertAlmostEqual(1.0, c[0, 1].item()) self.assertAlmostEqual(0.0, c[0, 2].item()) @@ -369,6 +358,7 @@ def test_b_solve_1_1_2_2(self): c = b_solve((k, k), (x, x), y, non_negative=True) self.assertEqual((k + 1, k + 1), c.shape) + self.assertFalse(np.any(c < 0.0)) self.assertAlmostEqual(0.0, c[0, 0].item()) self.assertAlmostEqual(0.0, c[0, 1].item()) self.assertAlmostEqual(0.0, c[0, 2].item()) @@ -391,6 +381,7 @@ def test_b_solve_2_1_2_2(self): ) c = b_solve((k, k), (x, x), y, non_negative=True) + self.assertFalse(np.any(c < 0.0)) self.assertEqual((k + 1, k + 1), c.shape) self.assertAlmostEqual(0.0, c[0, 0].item()) self.assertAlmostEqual(0.0, c[0, 1].item()) @@ -411,6 +402,7 @@ def test_b_solve_0_2_2_2(self): ) c = b_solve((k, k), (x, x), y, non_negative=True) + self.assertFalse(np.any(c < 0.0)) self.assertEqual((k + 1, k + 1), c.shape) self.assertAlmostEqual(0.0, c[0, 0].item()) self.assertAlmostEqual(0.0, c[0, 1].item()) @@ -435,6 +427,7 @@ def test_b_solve_1_2_2_2(self): c = b_solve((k, k), (x, x), y, non_negative=True) self.assertEqual((k + 1, k + 1), c.shape) + self.assertFalse(np.any(c < 0.0)) self.assertAlmostEqual(0.0, c[0, 0].item()) self.assertAlmostEqual(0.0, c[0, 1].item()) self.assertAlmostEqual(0.0, c[0, 2].item()) @@ -453,6 +446,7 @@ def test_b_solve_2_2_2_2(self): c = b_solve((k, k), (x, x), y, non_negative=True) self.assertEqual((k + 1, k + 1), c.shape) + self.assertFalse(np.any(c < 0.0)) self.assertAlmostEqual(0.0, c[0, 0].item()) self.assertAlmostEqual(0.0, c[0, 1].item()) self.assertAlmostEqual(0.0, c[0, 2].item()) From 3269efd17e5db071afb471816a84fe41e3191f31 Mon Sep 17 00:00:00 2001 From: "R. Quast" Date: Fri, 22 May 2026 20:48:52 +0200 Subject: [PATCH 06/15] Enhance tests with additional almost equal assertions Added assertions to compare elements of two arrays with almost equal checks. --- test/b/test_b_jax.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/test/b/test_b_jax.py b/test/b/test_b_jax.py index bfef043..ed26a8b 100644 --- a/test/b/test_b_jax.py +++ b/test/b/test_b_jax.py @@ -206,8 +206,34 @@ def test_from_lookup_table(self): b = f.prior() self.assertEqual(c.shape, b.shape) self.assertFalse(np.any(b < 0.0)) - self.assertTrue(np.allclose(b, c)) - + self.assertAlmostEqual(c[0, 0, 0], b[0, 0, 0]) + self.assertAlmostEqual(c[0, 0, 1], b[0, 0, 1]) + self.assertAlmostEqual(c[0, 0, 2], b[0, 0, 2]) + self.assertAlmostEqual(c[0, 1, 0], b[0, 1, 0]) + self.assertAlmostEqual(c[0, 1, 1], b[0, 1, 1]) + self.assertAlmostEqual(c[0, 1, 2], b[0, 1, 2]) + self.assertAlmostEqual(c[0, 2, 0], b[0, 2, 0]) + self.assertAlmostEqual(c[0, 2, 1], b[0, 2, 1]) + self.assertAlmostEqual(c[0, 2, 2], b[0, 2, 2]) + self.assertAlmostEqual(c[1, 0, 0], b[1, 0, 0]) + self.assertAlmostEqual(c[1, 0, 1], b[1, 0, 1]) + self.assertAlmostEqual(c[1, 0, 2], b[1, 0, 2]) + self.assertAlmostEqual(c[1, 1, 0], b[1, 1, 0]) + self.assertAlmostEqual(c[1, 1, 1], b[1, 1, 1]) + self.assertAlmostEqual(c[1, 1, 2], b[1, 1, 2]) + self.assertAlmostEqual(c[1, 2, 0], b[1, 2, 0]) + self.assertAlmostEqual(c[1, 2, 1], b[1, 2, 1]) + self.assertAlmostEqual(c[1, 2, 2], b[1, 2, 2]) + self.assertAlmostEqual(c[2, 0, 0], b[2, 0, 0]) + self.assertAlmostEqual(c[2, 0, 1], b[2, 0, 1]) + self.assertAlmostEqual(c[2, 0, 2], b[2, 0, 2]) + self.assertAlmostEqual(c[2, 1, 0], b[2, 1, 0]) + self.assertAlmostEqual(c[2, 1, 1], b[2, 1, 1]) + self.assertAlmostEqual(c[2, 1, 2], b[2, 1, 2]) + self.assertAlmostEqual(c[2, 2, 0], b[2, 2, 0]) + self.assertAlmostEqual(c[2, 2, 1], b[2, 2, 1]) + self.assertAlmostEqual(c[2, 2, 2], b[2, 2, 2]) + class BSolveTest(unittest.TestCase): """ From 29c74ed7335d0689ce08d77da04afe25c7bcf39b Mon Sep 17 00:00:00 2001 From: "R. Quast" Date: Fri, 22 May 2026 21:02:45 +0200 Subject: [PATCH 07/15] Refactor BernsteinPoly test for prior evaluation --- test/b/test_b_jax.py | 33 ++------------------------------- 1 file changed, 2 insertions(+), 31 deletions(-) diff --git a/test/b/test_b_jax.py b/test/b/test_b_jax.py index ed26a8b..6c117db 100644 --- a/test/b/test_b_jax.py +++ b/test/b/test_b_jax.py @@ -202,37 +202,8 @@ def test_from_lookup_table(self): ) y = BernsteinGrid(x).eval(c) - f = BernsteinPoly.from_lookup_table(k, x, y, non_negative=True) - b = f.prior() - self.assertEqual(c.shape, b.shape) - self.assertFalse(np.any(b < 0.0)) - self.assertAlmostEqual(c[0, 0, 0], b[0, 0, 0]) - self.assertAlmostEqual(c[0, 0, 1], b[0, 0, 1]) - self.assertAlmostEqual(c[0, 0, 2], b[0, 0, 2]) - self.assertAlmostEqual(c[0, 1, 0], b[0, 1, 0]) - self.assertAlmostEqual(c[0, 1, 1], b[0, 1, 1]) - self.assertAlmostEqual(c[0, 1, 2], b[0, 1, 2]) - self.assertAlmostEqual(c[0, 2, 0], b[0, 2, 0]) - self.assertAlmostEqual(c[0, 2, 1], b[0, 2, 1]) - self.assertAlmostEqual(c[0, 2, 2], b[0, 2, 2]) - self.assertAlmostEqual(c[1, 0, 0], b[1, 0, 0]) - self.assertAlmostEqual(c[1, 0, 1], b[1, 0, 1]) - self.assertAlmostEqual(c[1, 0, 2], b[1, 0, 2]) - self.assertAlmostEqual(c[1, 1, 0], b[1, 1, 0]) - self.assertAlmostEqual(c[1, 1, 1], b[1, 1, 1]) - self.assertAlmostEqual(c[1, 1, 2], b[1, 1, 2]) - self.assertAlmostEqual(c[1, 2, 0], b[1, 2, 0]) - self.assertAlmostEqual(c[1, 2, 1], b[1, 2, 1]) - self.assertAlmostEqual(c[1, 2, 2], b[1, 2, 2]) - self.assertAlmostEqual(c[2, 0, 0], b[2, 0, 0]) - self.assertAlmostEqual(c[2, 0, 1], b[2, 0, 1]) - self.assertAlmostEqual(c[2, 0, 2], b[2, 0, 2]) - self.assertAlmostEqual(c[2, 1, 0], b[2, 1, 0]) - self.assertAlmostEqual(c[2, 1, 1], b[2, 1, 1]) - self.assertAlmostEqual(c[2, 1, 2], b[2, 1, 2]) - self.assertAlmostEqual(c[2, 2, 0], b[2, 2, 0]) - self.assertAlmostEqual(c[2, 2, 1], b[2, 2, 1]) - self.assertAlmostEqual(c[2, 2, 2], b[2, 2, 2]) + f = BernsteinPoly.from_lookup_table(k, x, y) + self.assertTrue(np.allclose(BernsteinGrid(x).eval(f.prior())), y) class BSolveTest(unittest.TestCase): From c84efa09ec573faaa6652efd0f90aa60a50c21d8 Mon Sep 17 00:00:00 2001 From: "R. Quast" Date: Fri, 22 May 2026 21:09:59 +0200 Subject: [PATCH 08/15] Fix assertion in BernsteinPoly test --- test/b/test_b_jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/b/test_b_jax.py b/test/b/test_b_jax.py index 6c117db..f79bccd 100644 --- a/test/b/test_b_jax.py +++ b/test/b/test_b_jax.py @@ -203,7 +203,7 @@ def test_from_lookup_table(self): y = BernsteinGrid(x).eval(c) f = BernsteinPoly.from_lookup_table(k, x, y) - self.assertTrue(np.allclose(BernsteinGrid(x).eval(f.prior())), y) + self.assertTrue(np.allclose(BernsteinGrid(x).eval(f.prior()), y)) class BSolveTest(unittest.TestCase): From 5ee091371018f1edfda2538af6b7ac9ce74c5501 Mon Sep 17 00:00:00 2001 From: "R. Quast" Date: Fri, 22 May 2026 21:21:43 +0200 Subject: [PATCH 09/15] Refactor assertions for BernsteinGrid evaluation --- test/b/test_b_jax.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/b/test_b_jax.py b/test/b/test_b_jax.py index f79bccd..ca88249 100644 --- a/test/b/test_b_jax.py +++ b/test/b/test_b_jax.py @@ -203,7 +203,10 @@ def test_from_lookup_table(self): y = BernsteinGrid(x).eval(c) f = BernsteinPoly.from_lookup_table(k, x, y) - self.assertTrue(np.allclose(BernsteinGrid(x).eval(f.prior()), y)) + z = BernsteinGrid(x).eval(f.prior()) + self.assertAlmostEquals(y[0, 0, 0], z[0, 0, 0]) + self.assertAlmostEquals(y[0, 0, 1], z[0, 0, 1]) + self.assertAlmostEquals(y[0, 0, 2], z[0, 0, 2]) class BSolveTest(unittest.TestCase): From 1cbdcdb66088927140e4548e9706c5f518531589 Mon Sep 17 00:00:00 2001 From: "R. Quast" Date: Fri, 22 May 2026 21:27:47 +0200 Subject: [PATCH 10/15] Fix assertAlmostEquals to assertAlmostEqual --- test/b/test_b_jax.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/b/test_b_jax.py b/test/b/test_b_jax.py index ca88249..4d87bc4 100644 --- a/test/b/test_b_jax.py +++ b/test/b/test_b_jax.py @@ -204,9 +204,9 @@ def test_from_lookup_table(self): f = BernsteinPoly.from_lookup_table(k, x, y) z = BernsteinGrid(x).eval(f.prior()) - self.assertAlmostEquals(y[0, 0, 0], z[0, 0, 0]) - self.assertAlmostEquals(y[0, 0, 1], z[0, 0, 1]) - self.assertAlmostEquals(y[0, 0, 2], z[0, 0, 2]) + self.assertAlmostEqual(y[0, 0, 0], z[0, 0, 0]) + self.assertAlmostEqual(y[0, 0, 1], z[0, 0, 1]) + self.assertAlmostEqual(y[0, 0, 2], z[0, 0, 2]) class BSolveTest(unittest.TestCase): From cb20339229688d36c3815a2abbe3d23857c7b4c8 Mon Sep 17 00:00:00 2001 From: "R. Quast" Date: Fri, 22 May 2026 21:47:31 +0200 Subject: [PATCH 11/15] Modify test parameters in test_b_jax.py --- test/b/test_b_jax.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/test/b/test_b_jax.py b/test/b/test_b_jax.py index 4d87bc4..3e5ee3d 100644 --- a/test/b/test_b_jax.py +++ b/test/b/test_b_jax.py @@ -192,21 +192,20 @@ def test_bernstein_poly(self): self.assertTrue(np.all(g > 0.0)) def test_from_lookup_table(self): - k = (2, 2, 2) + k = (3, 4, 2) d = tuple([k_ + 1 for k_ in k]) c = np.arange(np.prod(np.asarray(d))).reshape(d) + 1.0 x = ( - np.asarray([0.2718, 0.5772, 0.3141]), - np.asarray([0.5772, 0.3141, 0.2718]), - np.asarray([0.3141, 0.2718, 0.5772]), + np.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]), + np.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]), + np.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]), ) y = BernsteinGrid(x).eval(c) f = BernsteinPoly.from_lookup_table(k, x, y) - z = BernsteinGrid(x).eval(f.prior()) - self.assertAlmostEqual(y[0, 0, 0], z[0, 0, 0]) - self.assertAlmostEqual(y[0, 0, 1], z[0, 0, 1]) - self.assertAlmostEqual(y[0, 0, 2], z[0, 0, 2]) + b = f.prior() + self.assertEqual(c.shape, b.shape) + self.assertTrue(np.allclose(b, c)) class BSolveTest(unittest.TestCase): From 32ed86efbba62acbe2695b086fd7bb7200fc52a1 Mon Sep 17 00:00:00 2001 From: "R. Quast" Date: Fri, 22 May 2026 22:08:35 +0200 Subject: [PATCH 12/15] Fix: normalization of grid coordinates in jax.py --- uncertaintyx/b/jax.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/uncertaintyx/b/jax.py b/uncertaintyx/b/jax.py index 4ca8f38..00eaa3e 100644 --- a/uncertaintyx/b/jax.py +++ b/uncertaintyx/b/jax.py @@ -327,9 +327,10 @@ def __init__( :param a: The lower bounds of the grid coordinates. :param b: The upper bounds of the grid coordinates. """ + N = len(k) # noqa: : N806 a = _lower_bounds(a, x) b = _upper_bounds(b, x) - x = tuple(jnp.asarray((x_ - a) / (b - a)) for x_ in x) + x_ = tuple(jnp.asarray((x[i] - a[i]) / (b[i] - a[i])) for i in range(N)) def f(c: Array) -> Array: r""" @@ -420,9 +421,10 @@ def from_lookup_table( :param rtol: The relative tolerance for terminating the solver. :param max_steps: The maximum number of steps the solver can take. """ + N = len(k) # noqa: : N806 a = _lower_bounds(a, x) b = _upper_bounds(b, x) - x_ = tuple(jnp.asarray((x_ - a) / (b - a)) for x_ in x) + x_ = tuple(jnp.asarray((x[i] - a[i]) / (b[i] - a[i])) for i in range(N)) y_ = jnp.asarray(y) c_ = b_solve( k, From 17ff4f178270fc2599381552899858daaf4055b6 Mon Sep 17 00:00:00 2001 From: "R. Quast" Date: Fri, 22 May 2026 22:16:11 +0200 Subject: [PATCH 13/15] Fix: variable N to use length of x instead of k --- uncertaintyx/b/jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uncertaintyx/b/jax.py b/uncertaintyx/b/jax.py index 00eaa3e..0cdbc55 100644 --- a/uncertaintyx/b/jax.py +++ b/uncertaintyx/b/jax.py @@ -327,7 +327,7 @@ def __init__( :param a: The lower bounds of the grid coordinates. :param b: The upper bounds of the grid coordinates. """ - N = len(k) # noqa: : N806 + N = len(x) # noqa: : N806 a = _lower_bounds(a, x) b = _upper_bounds(b, x) x_ = tuple(jnp.asarray((x[i] - a[i]) / (b[i] - a[i])) for i in range(N)) From e01375a63c89f60c36083ed9bacb2cee1e0dc819 Mon Sep 17 00:00:00 2001 From: "R. Quast" Date: Fri, 22 May 2026 23:09:24 +0200 Subject: [PATCH 14/15] Refactor QR decomposition to use Gram matrices --- uncertaintyx/b/jax.py | 51 ++++++++++++++----------------------------- 1 file changed, 16 insertions(+), 35 deletions(-) diff --git a/uncertaintyx/b/jax.py b/uncertaintyx/b/jax.py index 0cdbc55..05660d5 100644 --- a/uncertaintyx/b/jax.py +++ b/uncertaintyx/b/jax.py @@ -5,8 +5,8 @@ from typing import Self import jax -import jax.lax.linalg as jla import jax.numpy as jnp +import jax.numpy.linalg as jli import numpy as np import optax import optimistix @@ -186,44 +186,35 @@ def b_solve( N = len(k) # noqa: N806 bases = [b_basis(k[i], x[i]) for i in range(N)] - facts = [jla.qr(B.T, full_matrices=False) for B in bases] # noqa: N806 - Q = [_[0] for _ in facts] # noqa: N806 - R = [_[1] for _ in facts] # noqa: N806 + grams = [jnp.dot(B, B.T) for B in bases] # noqa: N806 - # compute the right hand side of the triangular equation + # compute the right hand side of the normal equation rhs = y for i in range(N): - rhs = jnp.tensordot(rhs, Q[i], axes=(0, 0)) - # solve the triangular equation + B = bases[i] # noqa: N806 + rhs = jnp.tensordot(rhs, B, axes=(0, 1)) + # solve the normal equation c_unconstrained = rhs - if N > 1: - for i in range(N): - solve = jax.vmap( - lambda a, b: jla.triangular_solve(a, b, left_side=True), - in_axes=(None, i), - out_axes=i, - ) - c_unconstrained = solve(R[i], c_unconstrained) - else: - c_unconstrained = jla.triangular_solve( - R[0], c_unconstrained, left_side=True + for i in range(N): + G = grams[i] # noqa: N806 + c_unconstrained = jnp.tensordot( + c_unconstrained, jli.pinv(G), axes=(0, 1) ) def hvp(c: Array): """The Hessian-vector product.""" res = c for i in range(N): - res = jnp.tensordot(res, R[i], axes=(0, 1)) - for i in range(N): - res = jnp.tensordot(res, R[i], axes=(0, 0)) + G = grams[i] # noqa: N806 + res = jnp.tensordot(res, G, axes=(0, 1)) return res - def nnls(c: Array, rhs: Array): + def nnls(c: Array): """ Non-negative least-squares solver. - Applies a positive transformation and an L-BFGS - optimizer to ensure non-negativity. + Applies a positive transformation and an L-BFGS optimizer + to ensure non-negativity. """ def forward(u: Array) -> Array: @@ -250,10 +241,6 @@ def make_minimizer(): optax.lbfgs(), atol=atol, rtol=rtol, norm=optimistix.max_norm ) - # compute the right hand side of the normal equation - for i in range(N): - rhs = jnp.tensordot(rhs, R[i], axes=(0, 0)) - u = inverse(jnp.abs(c) + jnp.finfo(c.dtype).eps) optimum = optimistix.minimise( misfit, make_minimizer(), u, max_steps=max_steps, throw=False @@ -264,13 +251,7 @@ def make_minimizer(): nnls_needed = jnp.logical_and( non_negative, jnp.any(c_unconstrained < 0.0) ) - return jax.lax.cond( - nnls_needed, - nnls, - lambda c, _: c, - c_unconstrained, - rhs, - ) + return jax.lax.cond(nnls_needed, nnls, lambda c: c, c_unconstrained) def _lower_bounds( From 27745b43f9496de6f721aeff4ab272234f0362e7 Mon Sep 17 00:00:00 2001 From: "R. Quast" Date: Fri, 22 May 2026 23:26:41 +0200 Subject: [PATCH 15/15] Normalize x values in test_b_jax.py --- test/b/test_b_jax.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/test/b/test_b_jax.py b/test/b/test_b_jax.py index 3e5ee3d..0e27443 100644 --- a/test/b/test_b_jax.py +++ b/test/b/test_b_jax.py @@ -196,9 +196,9 @@ def test_from_lookup_table(self): d = tuple([k_ + 1 for k_ in k]) c = np.arange(np.prod(np.asarray(d))).reshape(d) + 1.0 x = ( - np.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]), - np.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]), - np.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]), + np.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]), + np.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]), + np.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]), ) y = BernsteinGrid(x).eval(c) @@ -217,7 +217,7 @@ class BSolveTest(unittest.TestCase): def test_b_solve_0_2(self): r"""Fit :math:`B_{0,2}(x)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = jnp.square(1.0 - x) c = b_solve((k,), (x,), y, non_negative=True) @@ -230,7 +230,7 @@ def test_b_solve_0_2(self): def test_b_solve_1_2(self): r"""Fit :math:`B_{1,2}(x)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = 2.0 * x * (1.0 - x) c = b_solve((k,), (x,), y, non_negative=True) @@ -243,7 +243,7 @@ def test_b_solve_1_2(self): def test_b_solve_2_2(self): r"""Fit :math:`B_{2,2}(x)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = jnp.square(x) c = b_solve((k,), (x,), y, non_negative=True) @@ -256,7 +256,7 @@ def test_b_solve_2_2(self): def test_b_solve_0_0_2_2(self): r"""Fit :math:`B_{(0,0),(2,2)}(x_0, x_1)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = jnp.square(1.0 - x[jnp.newaxis, :]) * jnp.square( 1.0 - x[:, jnp.newaxis] ) @@ -277,7 +277,7 @@ def test_b_solve_0_0_2_2(self): def test_b_solve_1_0_2_2(self): r"""Fit :math:`B_{(1,0),(2,2)}(x_0, x_1)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = ( 2.0 * x[:, jnp.newaxis] @@ -301,7 +301,7 @@ def test_b_solve_1_0_2_2(self): def test_b_solve_2_0_2_2(self): r"""Fit :math:`B_{(2,0),(2,2)}(x_0, x_1)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = jnp.square(x[:, jnp.newaxis]) * jnp.square( 1.0 - x[jnp.newaxis, :] ) @@ -322,7 +322,7 @@ def test_b_solve_2_0_2_2(self): def test_b_solve_0_1_2_2(self): r"""Fit :math:`B_{(0,1),(2,2)}(x_0, x_1)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = ( 2.0 * jnp.square(1.0 - x[:, jnp.newaxis]) @@ -346,7 +346,7 @@ def test_b_solve_0_1_2_2(self): def test_b_solve_1_1_2_2(self): r"""Fit :math:`B_{(1,1),(2,2)}(x_0, x_1)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = ( 4.0 * x[:, jnp.newaxis] @@ -371,7 +371,7 @@ def test_b_solve_1_1_2_2(self): def test_b_solve_2_1_2_2(self): r"""Fit :math:`B_{(2,1),(2,2)}(x_0, x_1)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = ( 2.0 * jnp.square(x[:, jnp.newaxis]) @@ -395,7 +395,7 @@ def test_b_solve_2_1_2_2(self): def test_b_solve_0_2_2_2(self): r"""Fit :math:`B_{(0,2),(2,2)}(x_0, x_1)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = jnp.square(1.0 - x[:, jnp.newaxis]) * jnp.square( x[jnp.newaxis, :] ) @@ -416,7 +416,7 @@ def test_b_solve_0_2_2_2(self): def test_b_solve_1_2_2_2(self): r"""Fit :math:`B_{(1,2),(2,2)}(x_0, x_1)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = ( 2.0 * x[:, jnp.newaxis] @@ -440,7 +440,7 @@ def test_b_solve_1_2_2_2(self): def test_b_solve_2_2_2_2(self): r"""Fit :math:`B_{(2,2),(2,2)}(x_0, x_1)`.""" k = 2 - x = jnp.asarray([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) + x = jnp.asarray([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) y = jnp.square(x[:, jnp.newaxis]) * jnp.square(x[jnp.newaxis, :]) c = b_solve((k, k), (x, x), y, non_negative=True)