From e9b5d6cb673fa05cbd54f331a8bc77a0bf0406d6 Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Fri, 20 Mar 2026 15:00:30 +0530 Subject: [PATCH 1/3] Fix df_with_chain2array() silently recycling data with unequal chain lengths --- NEWS.md | 1 + R/helpers-mcmc.R | 7 ++++++- tests/testthat/test-helpers-mcmc.R | 9 +++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 3783ee321..48018fea3 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,6 @@ # bayesplot (development version) +* Fix `df_with_chain2array()` silently recycling data when chains have unequal iterations. * Added unit tests for previously untested edge cases in `param_range()`, `param_glue()`, and `tidyselect_parameters()` (no-match, partial-match, and negation behavior). * Bumped minimum version for `rstantools` from `>= 1.5.0` to `>= 2.0.0` . * Use `rlang::warn()` and `rlang::inform()` for selected PPC user messages instead of base `warning()` and `message()`. diff --git a/R/helpers-mcmc.R b/R/helpers-mcmc.R index 41e2c4ee8..255ed7d6f 100644 --- a/R/helpers-mcmc.R +++ b/R/helpers-mcmc.R @@ -222,7 +222,12 @@ df_with_chain2array <- function(x) { a <- x[, !colnames(x) %in% "Chain", drop = FALSE] parnames <- colnames(a) a <- as.matrix(a) - x <- array(NA, dim = c(ceiling(nrow(a) / n_chain), n_chain, ncol(a))) + rows_per_chain <- table(chain) + if (length(unique(rows_per_chain)) != 1) { + abort("All chains must have the same number of iterations.") + } + n_iter <- as.integer(rows_per_chain[[1]]) + x <- array(NA, dim = c(n_iter, n_chain, ncol(a))) for (j in seq_len(n_chain)) { x[, j, ] <- a[chain == j,, drop=FALSE] } diff --git a/tests/testthat/test-helpers-mcmc.R b/tests/testthat/test-helpers-mcmc.R index 63a532f35..5db4df523 100644 --- a/tests/testthat/test-helpers-mcmc.R +++ b/tests/testthat/test-helpers-mcmc.R @@ -113,6 +113,15 @@ test_that("df_with_chain2array works", { expect_mcmc_array(a) expect_error(df_with_chain2array(dframe), "is_df_with_chain") + + # Unequal chain lengths should error, not silently recycle + unequal_df <- data.frame( + Chain = c(1L, 1L, 1L, 1L, 2L, 2L, 2L), + V1 = rnorm(7), + V2 = rnorm(7) + ) + expect_error(df_with_chain2array(unequal_df), + "All chains must have the same number of iterations") }) From 834284aea21dfbcb2c7bfa9f8f6ececba56b78e4 Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Fri, 20 Mar 2026 22:08:10 +0530 Subject: [PATCH 2/3] Move unequal chain length check to validate_df_with_chain() --- NEWS.md | 2 +- R/helpers-mcmc.R | 10 +++++----- tests/testthat/test-helpers-mcmc.R | 4 +++- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/NEWS.md b/NEWS.md index 48018fea3..2a7649d18 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,6 +1,6 @@ # bayesplot (development version) -* Fix `df_with_chain2array()` silently recycling data when chains have unequal iterations. +* Validate equal chain lengths in `validate_df_with_chain()`. * Added unit tests for previously untested edge cases in `param_range()`, `param_glue()`, and `tidyselect_parameters()` (no-match, partial-match, and negation behavior). * Bumped minimum version for `rstantools` from `>= 1.5.0` to `>= 2.0.0` . * Use `rlang::warn()` and `rlang::inform()` for selected PPC user messages instead of base `warning()` and `message()`. diff --git a/R/helpers-mcmc.R b/R/helpers-mcmc.R index 255ed7d6f..fc615c8d0 100644 --- a/R/helpers-mcmc.R +++ b/R/helpers-mcmc.R @@ -210,6 +210,10 @@ validate_df_with_chain <- function(x) { x$chain <- NULL } x$Chain <- as.integer(x$Chain) + rows_per_chain <- table(x$Chain) + if (length(unique(rows_per_chain)) != 1) { + abort("All chains must have the same number of iterations.") + } x } @@ -222,11 +226,7 @@ df_with_chain2array <- function(x) { a <- x[, !colnames(x) %in% "Chain", drop = FALSE] parnames <- colnames(a) a <- as.matrix(a) - rows_per_chain <- table(chain) - if (length(unique(rows_per_chain)) != 1) { - abort("All chains must have the same number of iterations.") - } - n_iter <- as.integer(rows_per_chain[[1]]) + n_iter <- nrow(a) %/% n_chain x <- array(NA, dim = c(n_iter, n_chain, ncol(a))) for (j in seq_len(n_chain)) { x[, j, ] <- a[chain == j,, drop=FALSE] diff --git a/tests/testthat/test-helpers-mcmc.R b/tests/testthat/test-helpers-mcmc.R index 5db4df523..5b5b84932 100644 --- a/tests/testthat/test-helpers-mcmc.R +++ b/tests/testthat/test-helpers-mcmc.R @@ -114,12 +114,14 @@ test_that("df_with_chain2array works", { expect_error(df_with_chain2array(dframe), "is_df_with_chain") - # Unequal chain lengths should error, not silently recycle + # Unequal chain lengths should error via validate_df_with_chain unequal_df <- data.frame( Chain = c(1L, 1L, 1L, 1L, 2L, 2L, 2L), V1 = rnorm(7), V2 = rnorm(7) ) + expect_error(validate_df_with_chain(unequal_df), + "All chains must have the same number of iterations") expect_error(df_with_chain2array(unequal_df), "All chains must have the same number of iterations") }) From d1b7c774c7b9d9f7440fb42f3009349ed865f784 Mon Sep 17 00:00:00 2001 From: jgabry Date: Fri, 20 Mar 2026 11:44:00 -0600 Subject: [PATCH 3/3] Handle invalid/misaligned Chain labels in df_with_chain2array() --- NEWS.md | 4 +++- R/helpers-mcmc.R | 5 +++++ R/mcmc-overview.R | 4 +++- man/MCMC-overview.Rd | 4 +++- tests/testthat/test-helpers-mcmc.R | 20 ++++++++++++++++++++ 5 files changed, 34 insertions(+), 3 deletions(-) diff --git a/NEWS.md b/NEWS.md index 2a7649d18..7ce82a4b1 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,6 +1,8 @@ # bayesplot (development version) -* Validate equal chain lengths in `validate_df_with_chain()`. +* Validate equal chain lengths in `validate_df_with_chain()`, reject missing + chain labels, and renumber data-frame chain labels internally when converting + to arrays. * Added unit tests for previously untested edge cases in `param_range()`, `param_glue()`, and `tidyselect_parameters()` (no-match, partial-match, and negation behavior). * Bumped minimum version for `rstantools` from `>= 1.5.0` to `>= 2.0.0` . * Use `rlang::warn()` and `rlang::inform()` for selected PPC user messages instead of base `warning()` and `message()`. diff --git a/R/helpers-mcmc.R b/R/helpers-mcmc.R index fc615c8d0..8d794d1a4 100644 --- a/R/helpers-mcmc.R +++ b/R/helpers-mcmc.R @@ -210,6 +210,9 @@ validate_df_with_chain <- function(x) { x$chain <- NULL } x$Chain <- as.integer(x$Chain) + if (anyNA(x$Chain)) { + abort("Chain values must not be NA.") + } rows_per_chain <- table(x$Chain) if (length(unique(rows_per_chain)) != 1) { abort("All chains must have the same number of iterations.") @@ -222,6 +225,8 @@ validate_df_with_chain <- function(x) { df_with_chain2array <- function(x) { x <- validate_df_with_chain(x) chain <- x$Chain + # Renumber arbitrary chain labels to the contiguous 1:N indices used internally. + chain <- match(chain, sort(unique(chain))) n_chain <- length(unique(chain)) a <- x[, !colnames(x) %in% "Chain", drop = FALSE] parnames <- colnames(a) diff --git a/R/mcmc-overview.R b/R/mcmc-overview.R index 403686412..59601c23c 100644 --- a/R/mcmc-overview.R +++ b/R/mcmc-overview.R @@ -24,7 +24,9 @@ #' frame with one column per parameter (if only a single chain or all chains #' have already been merged), or a data frame with one column per parameter plus #' an additional column `"Chain"` that contains the chain number (an integer) -#' corresponding to each row in the data frame. +#' corresponding to each row in the data frame. When a `"Chain"` column is +#' supplied, each chain must have the same number of iterations. Chain labels +#' are used to identify groups and are renumbered internally to `1:N`. #' * __draws__: Any of the `draws` formats supported by the #' \pkg{posterior} package. #' diff --git a/man/MCMC-overview.Rd b/man/MCMC-overview.Rd index d426d30f3..eed5117f6 100644 --- a/man/MCMC-overview.Rd +++ b/man/MCMC-overview.Rd @@ -25,7 +25,9 @@ already be merged (stacked). frame with one column per parameter (if only a single chain or all chains have already been merged), or a data frame with one column per parameter plus an additional column \code{"Chain"} that contains the chain number (an integer) -corresponding to each row in the data frame. +corresponding to each row in the data frame. When a \code{"Chain"} column is +supplied, each chain must have the same number of iterations. Chain labels +are used to identify groups and are renumbered internally to \code{1:N}. \item \strong{draws}: Any of the \code{draws} formats supported by the \pkg{posterior} package. } diff --git a/tests/testthat/test-helpers-mcmc.R b/tests/testthat/test-helpers-mcmc.R index 5b5b84932..a116be738 100644 --- a/tests/testthat/test-helpers-mcmc.R +++ b/tests/testthat/test-helpers-mcmc.R @@ -106,6 +106,14 @@ test_that("validate_df_with_chain works", { tbl <- tibble::tibble(parameter=rnorm(n=40), Chain=rep(1:4, each=10)) a <- validate_df_with_chain(tbl) expect_type(a$Chain, "integer") + + missing_chain_df <- data.frame( + Chain = c(1L, 1L, NA_integer_, NA_integer_), + V1 = rnorm(4), + V2 = rnorm(4) + ) + expect_error(validate_df_with_chain(missing_chain_df), + "Chain values must not be NA") }) test_that("df_with_chain2array works", { @@ -124,6 +132,17 @@ test_that("df_with_chain2array works", { "All chains must have the same number of iterations") expect_error(df_with_chain2array(unequal_df), "All chains must have the same number of iterations") + + renumbered_df <- data.frame( + Chain = c(2L, 2L, 3L, 3L), + V1 = 1:4, + V2 = 5:8 + ) + a <- df_with_chain2array(renumbered_df) + expect_equal(dim(a), c(2, 2, 2)) + expect_identical(unname(a[, 1, "V1"]), c(1L, 2L)) + expect_identical(unname(a[, 2, "V1"]), c(3L, 4L)) + expect_identical(as.character(dimnames(a)$Chain), c("1", "2")) }) @@ -316,6 +335,7 @@ test_that("diagnostic_factor.rhat works", { ) expect_identical(levels(r), c("low", "ok", "high")) }) + test_that("diagnostic_factor.neff_ratio works", { ratios <- new_neff_ratio(c(low = 0.05, low = 0.01, ok = 0.2, ok = 0.49,