diff --git a/NEWS.md b/NEWS.md index 3783ee32..7ce82a4b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,8 @@ # bayesplot (development version) +* Validate equal chain lengths in `validate_df_with_chain()`, reject missing + chain labels, and renumber data-frame chain labels internally when converting + to arrays. * Added unit tests for previously untested edge cases in `param_range()`, `param_glue()`, and `tidyselect_parameters()` (no-match, partial-match, and negation behavior). * Bumped minimum version for `rstantools` from `>= 1.5.0` to `>= 2.0.0` . * Use `rlang::warn()` and `rlang::inform()` for selected PPC user messages instead of base `warning()` and `message()`. diff --git a/R/helpers-mcmc.R b/R/helpers-mcmc.R index 41e2c4ee..8d794d1a 100644 --- a/R/helpers-mcmc.R +++ b/R/helpers-mcmc.R @@ -210,6 +210,13 @@ validate_df_with_chain <- function(x) { x$chain <- NULL } x$Chain <- as.integer(x$Chain) + if (anyNA(x$Chain)) { + abort("Chain values must not be NA.") + } + rows_per_chain <- table(x$Chain) + if (length(unique(rows_per_chain)) != 1) { + abort("All chains must have the same number of iterations.") + } x } @@ -218,11 +225,14 @@ validate_df_with_chain <- function(x) { df_with_chain2array <- function(x) { x <- validate_df_with_chain(x) chain <- x$Chain + # Renumber arbitrary chain labels to the contiguous 1:N indices used internally. + chain <- match(chain, sort(unique(chain))) n_chain <- length(unique(chain)) a <- x[, !colnames(x) %in% "Chain", drop = FALSE] parnames <- colnames(a) a <- as.matrix(a) - x <- array(NA, dim = c(ceiling(nrow(a) / n_chain), n_chain, ncol(a))) + n_iter <- nrow(a) %/% n_chain + x <- array(NA, dim = c(n_iter, n_chain, ncol(a))) for (j in seq_len(n_chain)) { x[, j, ] <- a[chain == j,, drop=FALSE] } diff --git a/R/mcmc-overview.R b/R/mcmc-overview.R index 40368641..59601c23 100644 --- a/R/mcmc-overview.R +++ b/R/mcmc-overview.R @@ -24,7 +24,9 @@ #' frame with one column per parameter (if only a single chain or all chains #' have already been merged), or a data frame with one column per parameter plus #' an additional column `"Chain"` that contains the chain number (an integer) -#' corresponding to each row in the data frame. +#' corresponding to each row in the data frame. When a `"Chain"` column is +#' supplied, each chain must have the same number of iterations. Chain labels +#' are used to identify groups and are renumbered internally to `1:N`. #' * __draws__: Any of the `draws` formats supported by the #' \pkg{posterior} package. #' diff --git a/man/MCMC-overview.Rd b/man/MCMC-overview.Rd index d426d30f..eed5117f 100644 --- a/man/MCMC-overview.Rd +++ b/man/MCMC-overview.Rd @@ -25,7 +25,9 @@ already be merged (stacked). frame with one column per parameter (if only a single chain or all chains have already been merged), or a data frame with one column per parameter plus an additional column \code{"Chain"} that contains the chain number (an integer) -corresponding to each row in the data frame. +corresponding to each row in the data frame. When a \code{"Chain"} column is +supplied, each chain must have the same number of iterations. Chain labels +are used to identify groups and are renumbered internally to \code{1:N}. \item \strong{draws}: Any of the \code{draws} formats supported by the \pkg{posterior} package. } diff --git a/tests/testthat/test-helpers-mcmc.R b/tests/testthat/test-helpers-mcmc.R index 63a532f3..a116be73 100644 --- a/tests/testthat/test-helpers-mcmc.R +++ b/tests/testthat/test-helpers-mcmc.R @@ -106,6 +106,14 @@ test_that("validate_df_with_chain works", { tbl <- tibble::tibble(parameter=rnorm(n=40), Chain=rep(1:4, each=10)) a <- validate_df_with_chain(tbl) expect_type(a$Chain, "integer") + + missing_chain_df <- data.frame( + Chain = c(1L, 1L, NA_integer_, NA_integer_), + V1 = rnorm(4), + V2 = rnorm(4) + ) + expect_error(validate_df_with_chain(missing_chain_df), + "Chain values must not be NA") }) test_that("df_with_chain2array works", { @@ -113,6 +121,28 @@ test_that("df_with_chain2array works", { expect_mcmc_array(a) expect_error(df_with_chain2array(dframe), "is_df_with_chain") + + # Unequal chain lengths should error via validate_df_with_chain + unequal_df <- data.frame( + Chain = c(1L, 1L, 1L, 1L, 2L, 2L, 2L), + V1 = rnorm(7), + V2 = rnorm(7) + ) + expect_error(validate_df_with_chain(unequal_df), + "All chains must have the same number of iterations") + expect_error(df_with_chain2array(unequal_df), + "All chains must have the same number of iterations") + + renumbered_df <- data.frame( + Chain = c(2L, 2L, 3L, 3L), + V1 = 1:4, + V2 = 5:8 + ) + a <- df_with_chain2array(renumbered_df) + expect_equal(dim(a), c(2, 2, 2)) + expect_identical(unname(a[, 1, "V1"]), c(1L, 2L)) + expect_identical(unname(a[, 2, "V1"]), c(3L, 4L)) + expect_identical(as.character(dimnames(a)$Chain), c("1", "2")) }) @@ -305,6 +335,7 @@ test_that("diagnostic_factor.rhat works", { ) expect_identical(levels(r), c("low", "ok", "high")) }) + test_that("diagnostic_factor.neff_ratio works", { ratios <- new_neff_ratio(c(low = 0.05, low = 0.01, ok = 0.2, ok = 0.49,