library(tidyverse)
library(arrow)
library(tidymodels)
library(bonsai)
library(finetune)
library(modeltime)
library(timetk)
library(dtwclust)
library(kableExtra)
library(tictoc)
library(geobr)
library(DT)
library(sf)
source("../functions.R")
Multivariate clustering, climate model
This notebooks aims to reproduce the methodology of the paper submitted to the SBD2023 conference, implementing the global and subset modelling with a multivariate approach.
This methodology aims to compare the performance of models trained with data from all municipalities time-series (global models) and models trained with subset of municipalities time-series (subset models).
Those subsets were created by a clustering algorithm considering the climate time-series.
Packages
Load data
<- read_parquet(file = data_dir("bundled_data/tdengue.parquet")) %>%
tdengue select(mun, date, starts_with(c("cases", "tmax", "tmin", "prec"))) %>%
drop_na()
NA values are created when the lagged variables were calculated. The rows containing those NA values are dropped due machine learning regressors constraints.
Cases, maximum temperature, minimum temperature, precipitation variables are loaded, and also their time-lagged variables (from 1 to 6 weeks).
glimpse(tdengue)
Rows: 329,982
Columns: 33
$ mun <chr> "110002", "110002", "110002", "110002", "110002", "11000…
$ date <date> 2011-01-16, 2011-01-23, 2011-01-23, 2011-01-30, 2011-01…
$ cases <dbl> 0.07830184, 0.66762618, 0.66762618, 2.43559920, 2.435599…
$ cases_raw <dbl> 1, 2, 2, 5, 5, 0, 0, 1, 1, 2, 2, 1, 1, 0, 0, 2, 2, 1, 1,…
$ cases_cum_raw <dbl> 10, 12, 14, 19, 24, 24, 24, 25, 26, 28, 30, 31, 32, 32, …
$ cases_cum <dbl> -1.419910, -1.410558, -1.401206, -1.377827, -1.354447, -…
$ cases_lag1 <dbl> 0.07830184, 0.07830184, 0.66762618, 0.66762618, 2.435599…
$ cases_lag2 <dbl> 0.66762618, 0.07830184, 0.07830184, 0.66762618, 0.667626…
$ cases_lag3 <dbl> 0.66762618, 0.66762618, 0.07830184, 0.07830184, 0.667626…
$ cases_lag4 <dbl> 0.66762618, 0.66762618, 0.66762618, 0.07830184, 0.078301…
$ cases_lag5 <dbl> 0.66762618, 0.66762618, 0.66762618, 0.66762618, 0.078301…
$ cases_lag6 <dbl> -0.51102251, 0.66762618, 0.66762618, 0.66762618, 0.66762…
$ tmax <dbl> 0.52569297, 0.75229054, 0.75229054, 0.83110709, 0.831107…
$ tmax_lag1 <dbl> 0.52569297, 0.52569297, 0.75229054, 0.75229054, 0.831107…
$ tmax_lag2 <dbl> 0.39269005, 0.52569297, 0.52569297, 0.75229054, 0.752290…
$ tmax_lag3 <dbl> 0.39269005, 0.39269005, 0.52569297, 0.52569297, 0.752290…
$ tmax_lag4 <dbl> 0.29909540, 0.39269005, 0.39269005, 0.52569297, 0.525692…
$ tmax_lag5 <dbl> 0.29909540, 0.29909540, 0.39269005, 0.39269005, 0.525692…
$ tmax_lag6 <dbl> -14.51841512, 0.29909540, 0.29909540, 0.39269005, 0.3926…
$ tmin <dbl> 1.0212567, 0.9886102, 0.9886102, 1.0072654, 1.0072654, 1…
$ tmin_lag1 <dbl> 1.0212567, 1.0212567, 0.9886102, 0.9886102, 1.0072654, 1…
$ tmin_lag2 <dbl> 1.0585670, 1.0212567, 1.0212567, 0.9886102, 0.9886102, 1…
$ tmin_lag3 <dbl> 1.0585670, 1.0585670, 1.0212567, 1.0212567, 0.9886102, 0…
$ tmin_lag4 <dbl> 0.9932740, 1.0585670, 1.0585670, 1.0212567, 1.0212567, 0…
$ tmin_lag5 <dbl> 0.9932740, 0.9932740, 1.0585670, 1.0585670, 1.0212567, 1…
$ tmin_lag6 <dbl> -8.6560935, 0.9932740, 0.9932740, 1.0585670, 1.0585670, …
$ prec <dbl> 1.5673247, 1.1822375, 1.1822375, 0.9472381, 0.9472381, 1…
$ prec_lag1 <dbl> 1.5673247, 1.5673247, 1.1822375, 1.1822375, 0.9472381, 0…
$ prec_lag2 <dbl> 1.6101183, 1.5673247, 1.5673247, 1.1822375, 1.1822375, 0…
$ prec_lag3 <dbl> 1.6101183, 1.6101183, 1.5673247, 1.5673247, 1.1822375, 1…
$ prec_lag4 <dbl> 1.5969131, 1.6101183, 1.6101183, 1.5673247, 1.5673247, 1…
$ prec_lag5 <dbl> 1.5969131, 1.5969131, 1.6101183, 1.6101183, 1.5673247, 1…
$ prec_lag6 <dbl> -1.9551115, 1.5969131, 1.5969131, 1.6101183, 1.6101183, …
Clustering
Here we load the results from this clustering notebook.
<- readRDS("../dengue-cluster/m_clim_cluster_ids.rds") %>%
clust_res st_drop_geometry() %>%
select(mun = code_muni, group)
table(clust_res$group)
1 2 3 4 5
127 97 25 41 43
Join clustering results with bundled dataset.
<- left_join(tdengue, clust_res, by = "mun") %>%
tdengue relocate(group, .after = mun)
Check for NAs.
table(is.na(tdengue$group))
FALSE
329982
Train and test split
Split the data into training and testing. The function time_series_split
handles the time series, not shuffling them, and considering the panel data format, as depicted in the message about overlapping timestamps detected.
The last two years data will be used as the training set.
<- tdengue %>%
tdengue_split time_series_split(
date_var = date,
assess = 52*2,
cumulative = TRUE
)
Data is not ordered by the 'date_var'. Resamples will be arranged by `date`.
Overlapping Timestamps Detected. Processing overlapping time series together using sliding windows.
tdengue_split
<Analysis/Assess/Total>
<261123/68859/329982>
K-folds
The training set will be split into k folds.
<- training(tdengue_split) %>%
tdengue_split_folds vfold_cv(v = 5)
Recipes
The global and subset models training specification are called recipes. The procedure bellow creates a list of those recipes.
<- list() recipes_list
Global
The global training recipe uses data from all municipalities for training the models.
The date and group variables are removed prior training
The municipality identification variable is treated as an Id variable, taking no place as a predictor in the training process
<- recipe(cases ~ ., data = training(tdengue_split)) %>%
recipe_global step_rm(date, group) %>%
update_role(mun, new_role = "id variable")
<- append(recipes_list, list(global = recipe_global))
recipes_list
rm(recipe_global)
Groups
For each group created by the clustering process, a specific training recipe will be created. For this, the first step is to filter rows from the training set, keeping only the rows belonging to the group in the loop
The date and group variables are removed prior to training
The municipality identification variable is treated as an Id variable, taking no place as a predictor in the training process
for(g in unique(tdengue$group)){
<- recipe(cases ~ ., data = training(tdengue_split)) %>%
tmp step_filter(group == !!g) %>%
step_rm(date, group) %>%
update_role(mun, new_role = "id variable")
<- list(tmp)
tmp <- setNames(tmp, paste0("g", g))
tmp
<- append(recipes_list, tmp)
recipes_list
rm(tmp)
}
Regressors specification
Random forest
A Random Forest specification using the ranger
engine. The trees
and min_n
hyperparameters will be tuned.
<- rand_forest(
rf_spec trees = tune(),
min_n = tune()
%>%
) set_engine("ranger") %>%
set_mode("regression")
Workflow set
This step creates a workflow set, combining the training recipes and regressors specifications.
<- workflow_set(
all_workflows preproc = recipes_list,
models = list(rf = rf_spec),
cross = TRUE
)
Tune
This step tunes the training hyperparameters of each workflow.
::registerDoParallel()
doParallel
tic()
<-
race_results %>%
all_workflows workflow_map(
"tune_race_anova",
seed = 345,
resamples = tdengue_split_folds,
grid = 10,
control = control_race(parallel_over = "everything"),
verbose = TRUE
)
i 1 of 6 tuning: global_rf
✔ 1 of 6 tuning: global_rf (2h 12m 49.1s)
i 2 of 6 tuning: g1_rf
✔ 2 of 6 tuning: g1_rf (31m 25.2s)
i 3 of 6 tuning: g2_rf
✔ 3 of 6 tuning: g2_rf (31m 8.5s)
i 4 of 6 tuning: g3_rf
✔ 4 of 6 tuning: g3_rf (3m 45.8s)
i 5 of 6 tuning: g4_rf
✔ 5 of 6 tuning: g4_rf (8m 5.6s)
i 6 of 6 tuning: g5_rf
✔ 6 of 6 tuning: g5_rf (8m 13.9s)
toc()
12931.669 sec elapsed
Fit
Each workflow will be trained using the tuned hyperparameters, considering the RMSE metric as reference.
This procedure creates a list of trained models, containing the fit results and a list of the municipalities used on the training of each workflow.
The global workflow is trained with data from all municipalities and the subsets workflows are trained using the respective municipalities list given by the cluster algorithm.
tic()
<- list()
trained_models for(w in unique(race_results$wflow_id)){
<- race_results %>%
best_tune extract_workflow_set_result(w) %>%
select_best("rmse")
<- race_results %>%
final_fit extract_workflow(w) %>%
finalize_workflow(best_tune) %>%
fit(training(tdengue_split))
<- extract_mold(final_fit)
mold <- mold$extras$roles$`id variable` %>%
train_ids distinct() %>%
pull() %>%
as.character()
<- list(
final_fit list(
"final_fit" = final_fit,
"train_ids" = train_ids
)
)
<- setNames(final_fit, paste0(w))
final_fit
<- append(trained_models, final_fit)
trained_models
}toc()
5065.569 sec elapsed
Accuracy
After training each workflow, the accuracy of the models are obtained applying the fitted models on the testing set.
For the global model, all municipalities are using for testing. For the subsets models, only data from the subset’s municipalities are considered for testing.
The RMSE metric is obtained for each workflow and municipality.
<- tibble()
models_accuracy for(t in 1:length(trained_models)){
<- modeltime_table(trained_models[[t]][[1]])
model_tbl <- testing(tdengue_split) %>%
testing_set filter(mun %in% trained_models[[t]][[2]])
<- model_tbl %>%
calib_tbl modeltime_calibrate(
new_data = testing_set,
id = "mun"
)
<- calib_tbl %>%
res modeltime_accuracy(
acc_by_id = TRUE,
metric_set = metric_set(rmse, mape)
)
$.model_id <- word(names(trained_models[t]), 1, sep = "_")
res
<- bind_rows(models_accuracy, res)
models_accuracy }
ℹ We have detected a possible intermittent series, you can change the default metric set to the extended_forecast_accuracy_metric_set() containing the MAAPE metric, which is more appropriate for this type of series.
ℹ We have detected a possible intermittent series, you can change the default metric set to the extended_forecast_accuracy_metric_set() containing the MAAPE metric, which is more appropriate for this type of series.
This plot presents the RMSE distribution across the workflows.
ggplot(data = models_accuracy, aes(x = .model_id, y = rmse, fill = .model_desc)) +
geom_boxplot()
Breakdown
<- lookup_muni(code_muni = "all") %>%
mun_names mutate(code_muni = substr(code_muni, 0, 6)) %>%
mutate(name_muni = paste0(name_muni, ", ", abbrev_state)) %>%
select(code_muni, name_muni)
Using year 2010
%>%
models_accuracy left_join(mun_names, by = c("mun" = "code_muni")) %>%
select(.model_id, .model_desc, name_muni, rmse) %>%
mutate(rmse = round(rmse, 2)) %>%
arrange(.model_id, .model_desc, -rmse) %>%
datatable(filter = "top")
%>%
models_accuracy left_join(mun_names, by = c("mun" = "code_muni")) %>%
select(.model_id, .model_desc, name_muni, rmse) %>%
mutate(rmse = round(rmse, 2)) %>%
group_by(.model_desc) %>%
mutate(.model_id = case_when(
!= "global" ~ "cluster",
.model_id .default = .model_id
%>%
)) pivot_wider(names_from = .model_id, values_from = rmse) %>%
mutate(dif = round(global - cluster, 2)) %>%
ungroup() %>%
datatable(filter = "top")
%>%
models_accuracy left_join(mun_names, by = c("mun" = "code_muni")) %>%
select(.model_id, .model_desc, name_muni, rmse) %>%
group_by(.model_desc) %>%
mutate(.model_id = case_when(
!= "global" ~ "cluster",
.model_id .default = .model_id
%>%
)) pivot_wider(names_from = .model_id, values_from = rmse) %>%
mutate(dif = round(global - cluster, 2)) %>%
arrange(.model_desc, dif) %>%
ggplot(aes(x = global, y = cluster, fill = .model_desc, color = dif)) +
geom_point(size = 2, alpha = .3) +
::scale_color_viridis(option = "inferno") +
viridistheme_bw() +
labs(x = "Global model", y = "Subset models", title = "RMSE error obtained with global and subset training strategies")
Session info
sessionInfo()
R version 4.3.2 (2023-10-31)
Platform: x86_64-conda-linux-gnu (64-bit)
Running under: CentOS Linux 7 (Core)
Matrix products: default
BLAS/LAPACK: /home/raphaelfs/miniconda3/envs/quarto/lib/libopenblasp-r0.3.25.so; LAPACK version 3.11.0
Random number generation:
RNG: L'Ecuyer-CMRG
Normal: Inversion
Sample: Rejection
locale:
[1] LC_CTYPE=pt_BR.UTF-8 LC_NUMERIC=C
[3] LC_TIME=pt_BR.UTF-8 LC_COLLATE=pt_BR.UTF-8
[5] LC_MONETARY=pt_BR.UTF-8 LC_MESSAGES=pt_BR.UTF-8
[7] LC_PAPER=pt_BR.UTF-8 LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C
[11] LC_MEASUREMENT=pt_BR.UTF-8 LC_IDENTIFICATION=C
time zone: America/Sao_Paulo
tzcode source: system (glibc)
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] rlang_1.1.1 ranger_0.16.0 sf_1.0-14 DT_0.28
[5] geobr_1.8.1 tictoc_1.2 kableExtra_1.3.4 dtwclust_5.5.12
[9] dtw_1.23-1 proxy_0.4-27 timetk_2.8.2 modeltime_1.2.5
[13] finetune_1.1.0 bonsai_0.2.1 yardstick_1.2.0 workflowsets_1.0.0
[17] workflows_1.1.3 tune_1.1.2 rsample_1.2.0 recipes_1.0.6
[21] parsnip_1.1.0 modeldata_1.1.0 infer_1.0.4 dials_1.2.0
[25] scales_1.2.1 broom_1.0.4 tidymodels_1.0.0 arrow_12.0.0
[29] lubridate_1.9.2 forcats_1.0.0 stringr_1.5.0 dplyr_1.1.2
[33] purrr_1.0.1 readr_2.1.4 tidyr_1.3.0 tibble_3.2.1
[37] ggplot2_3.4.2 tidyverse_2.0.0
loaded via a namespace (and not attached):
[1] rstudioapi_0.14 jsonlite_1.8.5 magrittr_2.0.3
[4] modeltools_0.2-23 farver_2.1.1 nloptr_2.0.3
[7] rmarkdown_2.22 vctrs_0.6.3 minqa_1.2.6
[10] webshot_0.5.4 htmltools_0.5.5 curl_5.0.2
[13] sass_0.4.6 parallelly_1.36.0 StanHeaders_2.26.26
[16] bslib_0.4.2 KernSmooth_2.23-21 htmlwidgets_1.6.2
[19] plyr_1.8.8 cachem_1.0.8 zoo_1.8-12
[22] mime_0.12 lifecycle_1.0.3 iterators_1.0.14
[25] pkgconfig_2.0.3 Matrix_1.5-4.1 R6_2.5.1
[28] fastmap_1.1.1 future_1.32.0 shiny_1.7.4
[31] clue_0.3-64 digest_0.6.31 colorspace_2.1-0
[34] furrr_0.3.1 RSpectra_0.16-1 crosstalk_1.2.0
[37] labeling_0.4.2 fansi_1.0.4 timechange_0.2.0
[40] httr_1.4.6 compiler_4.3.2 doParallel_1.0.17
[43] bit64_4.0.5 withr_2.5.0 backports_1.4.1
[46] viridis_0.6.3 DBI_1.1.3 MASS_7.3-60
[49] lava_1.7.2.1 classInt_0.4-9 units_0.8-2
[52] tools_4.3.2 httpuv_1.6.11 flexclust_1.4-1
[55] future.apply_1.11.0 nnet_7.3-19 glue_1.6.2
[58] nlme_3.1-162 promises_1.2.0.1 grid_4.3.2
[61] cluster_2.1.4 reshape2_1.4.4 generics_0.1.3
[64] gtable_0.3.3 tzdb_0.4.0 class_7.3-22
[67] data.table_1.14.8 hms_1.1.3 xml2_1.3.4
[70] utf8_1.2.3 ggrepel_0.9.3 foreach_1.5.2
[73] pillar_1.9.0 later_1.3.1 splines_4.3.2
[76] lhs_1.1.6 lattice_0.21-8 survival_3.5-5
[79] bit_4.0.5 tidyselect_1.2.0 knitr_1.43
[82] gridExtra_2.3 svglite_2.1.1 stats4_4.3.2
[85] xfun_0.39 hardhat_1.3.0 timeDate_4022.108
[88] stringi_1.7.12 boot_1.3-28.1 DiceDesign_1.9
[91] yaml_2.3.7 evaluate_0.21 codetools_0.2-19
[94] cli_3.6.1 RcppParallel_5.1.7 rpart_4.1.19
[97] xtable_1.8-4 systemfonts_1.0.4 jquerylib_0.1.4
[100] munsell_0.5.0 Rcpp_1.0.10 globals_0.16.2
[103] parallel_4.3.2 ellipsis_0.3.2 gower_1.0.1
[106] assertthat_0.2.1 prettyunits_1.1.1 lme4_1.1-35.1
[109] GPfit_1.0-8 listenv_0.9.0 viridisLite_0.4.2
[112] ipred_0.9-13 e1071_1.7-13 xts_0.13.1
[115] prodlim_2019.11.13 rvest_1.0.3 shinyjs_2.1.0