-
Notifications
You must be signed in to change notification settings - Fork 26
Update QR tests to avoid element-wise comparisons #2785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
8cb6d29
6c894c9
f8729c4
26d213a
706b3e2
707c5a0
9154330
e9973e6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,70 @@ | ||||||||
| import numpy | ||||||||
|
|
||||||||
| from .helper import has_support_aspect64 | ||||||||
|
|
||||||||
|
|
||||||||
| def gram(x, xp): | ||||||||
| # Return Gram matrix: X^H @ X | ||||||||
| return xp.conjugate(x).swapaxes(-1, -2) @ x | ||||||||
|
|
||||||||
|
|
||||||||
| def get_R_from_raw(h, m, n, xp): | ||||||||
| # Get reduced R from NumPy-style raw QR: | ||||||||
| # R = triu((tril(h))^T), shape (..., k, n) | ||||||||
| k = min(m, n) | ||||||||
| rt = xp.tril(h) | ||||||||
| r = xp.swapaxes(rt, -1, -2) | ||||||||
| r = xp.triu(r[..., :m, :n]) | ||||||||
| return r[..., :k, :] | ||||||||
|
|
||||||||
|
|
||||||||
| def check_qr(a_np, a_xp, mode, xp): | ||||||||
| # QR is not unique: | ||||||||
| # element-wise comparison with NumPy may differ by sign/phase. | ||||||||
| # To verify correctness use mode-dependent functional checks: | ||||||||
| # complete/reduced: check decomposition Q @ R = A | ||||||||
| # raw/r: check invariant R^H @ R = A^H @ A | ||||||||
| if mode in ("complete", "reduced"): | ||||||||
| res = xp.linalg.qr(a_xp, mode) | ||||||||
| assert xp.allclose(res.Q @ res.R, a_xp, atol=1e-5) | ||||||||
|
|
||||||||
| # Since QR satisfies A = Q @ R with orthonormal Q (Q^H @ Q = I), | ||||||||
| # validate correctness via the invariant R^H @ R == A^H @ A | ||||||||
| # for raw/r modes | ||||||||
| elif mode == "raw": | ||||||||
| _, tau_np = numpy.linalg.qr(a_np, mode=mode) | ||||||||
| h_xp, tau_xp = xp.linalg.qr(a_xp, mode=mode) | ||||||||
|
|
||||||||
| m, n = a_np.shape[-2], a_np.shape[-1] | ||||||||
| Rraw_xp = get_R_from_raw(h_xp, m, n, xp) | ||||||||
|
|
||||||||
| # Use reduced QR as a reference: | ||||||||
| # reduced is validated via Q @ R == A | ||||||||
| exp_res = xp.linalg.qr(a_xp, mode="reduced") | ||||||||
| exp_r = exp_res.R | ||||||||
| assert xp.allclose(Rraw_xp, exp_r, atol=1e-4, rtol=1e-4) | ||||||||
|
|
||||||||
| exp_xp = gram(a_xp, xp) | ||||||||
|
|
||||||||
| # Compare R^H @ R == A^H @ A | ||||||||
| assert xp.allclose(gram(Rraw_xp, xp), exp_xp, atol=1e-4, rtol=1e-4) | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we use |
||||||||
|
|
||||||||
| assert tau_xp.shape == tau_np.shape | ||||||||
| if not has_support_aspect64(tau_xp.sycl_device): | ||||||||
| assert tau_xp.dtype.kind == tau_np.dtype.kind | ||||||||
| else: | ||||||||
| assert tau_xp.dtype == tau_np.dtype | ||||||||
|
|
||||||||
| else: # mode == "r" | ||||||||
| r_xp = xp.linalg.qr(a_xp, mode="r") | ||||||||
|
|
||||||||
| # Use reduced QR as a reference: | ||||||||
| # reduced is validated via Q @ R == A | ||||||||
| exp_res = xp.linalg.qr(a_xp, mode="reduced") | ||||||||
| exp_r = exp_res.R | ||||||||
|
Comment on lines
+63
to
+64
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| assert xp.allclose(r_xp, exp_r, atol=1e-4, rtol=1e-4) | ||||||||
|
|
||||||||
| exp_xp = gram(a_xp, xp) | ||||||||
|
|
||||||||
| # Compare R^H @ R == A^H @ A | ||||||||
| assert xp.allclose(gram(r_xp, xp), exp_xp, atol=1e-4, rtol=1e-4) | ||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ | |
| has_support_aspect64, | ||
| numpy_version, | ||
| ) | ||
| from .qr_helper import check_qr | ||
| from .third_party.cupy import testing | ||
|
|
||
|
|
||
|
|
@@ -3584,7 +3585,7 @@ def test_error(self): | |
|
|
||
|
|
||
| class TestQr: | ||
| @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) | ||
| @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) | ||
| @pytest.mark.parametrize( | ||
| "shape", | ||
| [ | ||
|
|
@@ -3610,60 +3611,27 @@ class TestQr: | |
| "(2, 2, 4)", | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize("mode", ["r", "raw", "complete", "reduced"]) | ||
| @pytest.mark.parametrize("mode", ["complete", "reduced", "r", "raw"]) | ||
| def test_qr(self, dtype, shape, mode): | ||
| a = generate_random_numpy_array(shape, dtype, seed_value=81) | ||
| ia = dpnp.array(a) | ||
|
|
||
| if mode == "r": | ||
| np_r = numpy.linalg.qr(a, mode) | ||
| dpnp_r = dpnp.linalg.qr(ia, mode) | ||
| else: | ||
| np_q, np_r = numpy.linalg.qr(a, mode) | ||
|
|
||
| # check decomposition | ||
| if mode in ("complete", "reduced"): | ||
| result = dpnp.linalg.qr(ia, mode) | ||
| dpnp_q, dpnp_r = result.Q, result.R | ||
| assert dpnp.allclose( | ||
| dpnp.matmul(dpnp_q, dpnp_r), ia, atol=1e-05 | ||
| ) | ||
| else: # mode=="raw" | ||
| dpnp_q, dpnp_r = dpnp.linalg.qr(ia, mode) | ||
| assert_dtype_allclose(dpnp_q, np_q, factor=24) | ||
| a = generate_random_numpy_array(shape, dtype, seed_value=None) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Was that intended to pass |
||
| ia = dpnp.array(a, dtype=dtype) | ||
|
|
||
| if mode in ("raw", "r"): | ||
| assert_dtype_allclose(dpnp_r, np_r, factor=24) | ||
| check_qr(a, ia, mode, dpnp) | ||
|
|
||
| @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) | ||
| @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) | ||
antonwolfy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| @pytest.mark.parametrize( | ||
| "shape", | ||
| [(32, 32), (8, 16, 16)], | ||
| ids=["(32, 32)", "(8, 16, 16)"], | ||
| ) | ||
| @pytest.mark.parametrize("mode", ["r", "raw", "complete", "reduced"]) | ||
| @pytest.mark.parametrize("mode", ["complete", "reduced", "r", "raw"]) | ||
| def test_qr_large(self, dtype, shape, mode): | ||
| a = generate_random_numpy_array(shape, dtype, seed_value=81) | ||
| ia = dpnp.array(a) | ||
|
|
||
| if mode == "r": | ||
| np_r = numpy.linalg.qr(a, mode) | ||
| dpnp_r = dpnp.linalg.qr(ia, mode) | ||
| else: | ||
| np_q, np_r = numpy.linalg.qr(a, mode) | ||
|
|
||
| # check decomposition | ||
| if mode in ("complete", "reduced"): | ||
| result = dpnp.linalg.qr(ia, mode) | ||
| dpnp_q, dpnp_r = result.Q, result.R | ||
| assert dpnp.allclose(dpnp.matmul(dpnp_q, dpnp_r), ia, atol=1e-5) | ||
| else: # mode=="raw" | ||
| dpnp_q, dpnp_r = dpnp.linalg.qr(ia, mode) | ||
| assert_allclose(dpnp_q, np_q, atol=1e-4) | ||
| if mode in ("raw", "r"): | ||
| assert_allclose(dpnp_r, np_r, atol=1e-4) | ||
| check_qr(a, ia, mode, dpnp) | ||
|
|
||
| @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) | ||
| @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) | ||
| @pytest.mark.parametrize( | ||
| "shape", | ||
| [(0, 0), (0, 2), (2, 0), (2, 0, 3), (2, 3, 0), (0, 2, 3)], | ||
|
|
@@ -3676,65 +3644,22 @@ def test_qr_large(self, dtype, shape, mode): | |
| "(0, 2, 3)", | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize("mode", ["r", "raw", "complete", "reduced"]) | ||
| @pytest.mark.parametrize("mode", ["complete", "reduced", "r", "raw"]) | ||
| def test_qr_empty(self, dtype, shape, mode): | ||
| a = numpy.empty(shape, dtype=dtype) | ||
| ia = dpnp.array(a) | ||
|
|
||
| if mode == "r": | ||
| np_r = numpy.linalg.qr(a, mode) | ||
| dpnp_r = dpnp.linalg.qr(ia, mode) | ||
| else: | ||
| np_q, np_r = numpy.linalg.qr(a, mode) | ||
|
|
||
| if mode in ("complete", "reduced"): | ||
| result = dpnp.linalg.qr(ia, mode) | ||
| dpnp_q, dpnp_r = result.Q, result.R | ||
| else: | ||
| dpnp_q, dpnp_r = dpnp.linalg.qr(ia, mode) | ||
|
|
||
| assert_dtype_allclose(dpnp_q, np_q) | ||
| check_qr(a, ia, mode, dpnp) | ||
|
|
||
| assert_dtype_allclose(dpnp_r, np_r) | ||
|
|
||
| @pytest.mark.parametrize("mode", ["r", "raw", "complete", "reduced"]) | ||
| @pytest.mark.parametrize("mode", ["complete", "reduced", "r", "raw"]) | ||
| def test_qr_strides(self, mode): | ||
| a = generate_random_numpy_array((5, 5)) | ||
| ia = dpnp.array(a) | ||
|
|
||
| # positive strides | ||
| if mode == "r": | ||
| np_r = numpy.linalg.qr(a[::2, ::2], mode) | ||
| dpnp_r = dpnp.linalg.qr(ia[::2, ::2], mode) | ||
| else: | ||
| np_q, np_r = numpy.linalg.qr(a[::2, ::2], mode) | ||
|
|
||
| if mode in ("complete", "reduced"): | ||
| result = dpnp.linalg.qr(ia[::2, ::2], mode) | ||
| dpnp_q, dpnp_r = result.Q, result.R | ||
| else: | ||
| dpnp_q, dpnp_r = dpnp.linalg.qr(ia[::2, ::2], mode) | ||
|
|
||
| assert_dtype_allclose(dpnp_q, np_q) | ||
|
|
||
| assert_dtype_allclose(dpnp_r, np_r) | ||
|
|
||
| check_qr(a[::2, ::2], ia[::2, ::2], mode, dpnp) | ||
| # negative strides | ||
| if mode == "r": | ||
| np_r = numpy.linalg.qr(a[::-2, ::-2], mode) | ||
| dpnp_r = dpnp.linalg.qr(ia[::-2, ::-2], mode) | ||
| else: | ||
| np_q, np_r = numpy.linalg.qr(a[::-2, ::-2], mode) | ||
|
|
||
| if mode in ("complete", "reduced"): | ||
| result = dpnp.linalg.qr(ia[::-2, ::-2], mode) | ||
| dpnp_q, dpnp_r = result.Q, result.R | ||
| else: | ||
| dpnp_q, dpnp_r = dpnp.linalg.qr(ia[::-2, ::-2], mode) | ||
|
|
||
| assert_dtype_allclose(dpnp_q, np_q) | ||
|
|
||
| assert_dtype_allclose(dpnp_r, np_r) | ||
| check_qr(a[::-2, ::-2], ia[::-2, ::-2], mode, dpnp) | ||
|
|
||
| def test_qr_errors(self): | ||
| a_dp = dpnp.array([[1, 2], [3, 5]], dtype="float32") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.