diff --git a/NEWS.md b/NEWS.md index 7ce82a4b..8f447d13 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,6 @@ # bayesplot (development version) +* Eliminate redundant data processing in `mcmc_areas_data()` by reusing the prepared MCMC array for both interval and density computation. * Validate equal chain lengths in `validate_df_with_chain()`, reject missing chain labels, and renumber data-frame chain labels internally when converting to arrays. diff --git a/R/mcmc-intervals.R b/R/mcmc-intervals.R index 1d505fdb..2dd4fd72 100644 --- a/R/mcmc-intervals.R +++ b/R/mcmc-intervals.R @@ -599,64 +599,13 @@ mcmc_intervals_data <- function(x, prob <- probs[1] prob_outer <- probs[2] - x <- prepare_mcmc_array(x, pars, regex_pars, transformations) - x <- merge_chains(x) - - data_long <- melt_mcmc(x) %>% + data_long <- melt_mcmc( + merge_chains(prepare_mcmc_array(x, pars, regex_pars, transformations)) + ) %>% dplyr::as_tibble() %>% rlang::set_names(tolower) - probs <- c(0.5 - prob_outer / 2, - 0.5 - prob / 2, - 0.5 + prob / 2, - 0.5 + prob_outer / 2) - - point_est <- match.arg(point_est) - m_func <- if (point_est == "mean") mean else median - - data <- data_long %>% - group_by(.data$parameter) %>% - summarise( - outer_width = prob_outer, - inner_width = prob, - point_est = point_est, - ll = unname(quantile(.data$value, probs[1])), - l = unname(quantile(.data$value, probs[2])), - m = m_func(.data$value), - h = unname(quantile(.data$value, probs[3])), - hh = unname(quantile(.data$value, probs[4])) - ) - - if (point_est == "none") { - data$m <- NULL - } - - color_by_rhat <- isTRUE(length(rhat) > 0) - - if (color_by_rhat) { - rhat <- drop_NAs_and_warn(new_rhat(rhat)) - - if (length(rhat) != nrow(data)) { - abort(paste( - "'rhat' has length", length(rhat), - "but 'x' has", nrow(data), "parameters." - )) - } - - rhat <- set_names(rhat, data$parameter) - - rhat_tbl <- rhat %>% - mcmc_rhat_data() %>% - select(all_of("parameter"), - rhat_value = "value", - rhat_rating = "rating", - rhat_description = "description") %>% - mutate(parameter = factor(.data$parameter, levels(data$parameter))) - - data <- dplyr::inner_join(data, rhat_tbl, by = "parameter") - } - - data + compute_intervals(data_long, prob, prob_outer, point_est, rhat) } @@ -691,10 +640,6 @@ mcmc_areas_data <- function(x, point_est <- match.arg(point_est) temp_point_est <- if (point_est == "none") "median" else point_est - intervals <- mcmc_intervals_data(x, pars, regex_pars, transformations, - prob = probs[1], prob_outer = probs[2], - point_est = temp_point_est, rhat = rhat) - x <- prepare_mcmc_array(x, pars, regex_pars, transformations) x <- merge_chains(x) @@ -702,6 +647,10 @@ mcmc_areas_data <- function(x, dplyr::as_tibble() %>% rlang::set_names(tolower) + intervals <- compute_intervals(data_long, prob = probs[1], + prob_outer = probs[2], + point_est = temp_point_est, rhat = rhat) + # Compute the density intervals data_inner <- data_long %>% compute_column_density( @@ -901,3 +850,61 @@ check_interval_widths <- function(prob, prob_outer) { } sort(c(prob, prob_outer)) } + +# Internal helper shared by mcmc_intervals_data() and mcmc_areas_data() +compute_intervals <- function(data_long, prob, prob_outer, + point_est = c("median", "mean", "none"), + rhat = numeric()) { + + probs <- c(0.5 - prob_outer / 2, + 0.5 - prob / 2, + 0.5 + prob / 2, + 0.5 + prob_outer / 2) + + point_est <- match.arg(point_est) + m_func <- if (point_est == "mean") mean else median + + data <- data_long %>% + group_by(.data$parameter) %>% + summarise( + outer_width = prob_outer, + inner_width = prob, + point_est = point_est, + ll = unname(quantile(.data$value, probs[1])), + l = unname(quantile(.data$value, probs[2])), + m = m_func(.data$value), + h = unname(quantile(.data$value, probs[3])), + hh = unname(quantile(.data$value, probs[4])) + ) + + if (point_est == "none") { + data$m <- NULL + } + + color_by_rhat <- isTRUE(length(rhat) > 0) + + if (color_by_rhat) { + rhat <- drop_NAs_and_warn(new_rhat(rhat)) + + if (length(rhat) != nrow(data)) { + abort(paste( + "'rhat' has length", length(rhat), + "but 'x' has", nrow(data), "parameters." + )) + } + + rhat <- set_names(rhat, data$parameter) + + rhat_tbl <- rhat %>% + mcmc_rhat_data() %>% + select(all_of("parameter"), + rhat_value = "value", + rhat_rating = "rating", + rhat_description = "description") %>% + mutate(parameter = factor(.data$parameter, levels(data$parameter))) + + data <- dplyr::inner_join(data, rhat_tbl, by = "parameter") + } + + data +}