library(tidyverse)
library(arrow)
library(knitr)
library(lubridate)
library(tidymodels)
library(finetune)
library(bonsai)
library(tictoc)
library(vip)
Dengue case classification
by symptoms and clinical condition
The objective of this notebook is to predict a dengue suspected case based on its symptoms, clinical conditions and other patient related variables.
This trained model may be used to reclassify suspected dengue cases with inconclusive diagnosis.
Packages
Data
Data prior to 2016 does not have patient symptoms and clinical conditions.
Cases classified as inconclusive are discarded for model training.
# Data sources
<- c(
files_list "../dengue-data/parquets/dengue_2016.parquet",
"../dengue-data/parquets/dengue_2017.parquet",
"../dengue-data/parquets/dengue_2018.parquet",
"../dengue-data/parquets/dengue_2019.parquet",
"../dengue-data/parquets/dengue_2020.parquet",
"../dengue-data/parquets/dengue_2021.parquet"
)
# Independent variables
<- c("FEBRE", "MIALGIA", "CEFALEIA",
x_vars "EXANTEMA", "VOMITO", "NAUSEA",
"DOR_COSTAS", "CONJUNTVIT",
"ARTRITE", "ARTRALGIA", "PETEQUIA_N",
"LEUCOPENIA", "LACO", "DOR_RETRO",
"DIABETES", "HEMATOLOG", "HEPATOPAT",
"HEPATOPAT", "RENAL", "HIPERTENSA",
"ACIDO_PEPT", "AUTO_IMUNE")
# Prepare data
<- arrow::open_dataset(sources = files_list) %>%
dengue # Select variables
select(all_of(c("CLASSI_FIN", "COMUNINF", "IDADEanos", "DT_SIN_PRI", x_vars))) %>%
# Filter out "Inconclusivo" cases
filter(CLASSI_FIN != "Inconclusivo") %>%
# Collect data from parquet files
collect() %>%
# Prepare variables
mutate(CLASSI_FIN = case_when(
!= "Descartado" ~ TRUE,
CLASSI_FIN .default = FALSE
%>%
)) mutate(CLASSI_FIN = as.factor(CLASSI_FIN)) %>%
mutate(DT_SIN_PRI = as_date(DT_SIN_PRI)) %>%
mutate(COMUNINF = as.factor(COMUNINF)) %>%
mutate_at(.vars = x_vars, .funs = ~ . == "Sim")
# Smaller dataset for tests
<- sample_n(dengue, 500000) dengue
Modeling
Train and test dataset split
- Proportion between training and test: 3/4
set.seed(123)
<- dengue %>%
dengue_split initial_split(prop = 3/4, strata = CLASSI_FIN)
<- training(dengue_split)
dengue_train <- testing(dengue_split)
dengue_test
set.seed(234)
<- vfold_cv(dengue_train, strata = CLASSI_FIN) dengue_folds
Recipes
<-
dengue_rec_1 recipe(CLASSI_FIN ~ . , data = dengue_train) %>%
step_rm(COMUNINF) %>%
step_rm(DT_SIN_PRI) %>%
#step_date(DT_SIN_PRI, features = c("month", "week", "semester", "quarter"), keep_original_cols = FALSE) %>%
step_integer(all_predictors())
Specifications
XGB
<-
xgb_spec boost_tree(
trees = tune(),
min_n = tune(),
mtry = tune()
%>%
) set_engine("xgboost") %>%
set_mode("classification")
Decision tree
<-
cart_spec decision_tree(
cost_complexity = tune(),
min_n = tune()
%>%
) set_engine("rpart") %>%
set_mode("classification")
Workflows
<-
all_workflows workflow_set(
preproc = list(recipe_1 = dengue_rec_1),
models = list(xgb = xgb_spec, cart = cart_spec)
)
Tuning
::registerDoParallel()
doParallel
<- control_race(parallel_over = "everything")
race_ctrl
tic()
<-
race_results %>%
all_workflows workflow_map(
"tune_race_anova",
seed = 345,
resamples = dengue_folds,
grid = 10,
control = race_ctrl
)
i Creating pre-processing data to finalize unknown parameter: mtry
toc()
1651.906 sec elapsed
Race metrics
<- rank_results(race_results, rank_metric = "roc_auc")
train_rank_results
train_rank_results
# A tibble: 4 × 9
wflow_id .config .metric mean std_err n preprocessor model rank
<chr> <chr> <chr> <dbl> <dbl> <int> <chr> <chr> <int>
1 recipe_1_xgb Preprocess… accura… 0.620 7.72e-4 10 recipe boos… 1
2 recipe_1_xgb Preprocess… roc_auc 0.633 1.10e-3 10 recipe boos… 1
3 recipe_1_cart Preprocess… accura… 0.615 7.98e-4 10 recipe deci… 2
4 recipe_1_cart Preprocess… roc_auc 0.622 1.42e-3 10 recipe deci… 2
autoplot(race_results, metric = "roc_auc")
Last fit
<- train_rank_results %>%
selection_train arrange(-mean) %>%
pull(wflow_id) %>%
first()
selection_train
[1] "recipe_1_xgb"
<- race_results %>%
best_results extract_workflow_set_result(selection_train) %>%
select_best("accuracy")
best_results
# A tibble: 1 × 4
mtry trees min_n .config
<int> <int> <int> <chr>
1 3 174 8 Preprocessor1_Model02
<- race_results %>%
last_fit extract_workflow(selection_train) %>%
finalize_workflow(best_results) %>%
last_fit(dengue_split)
Evaluate on test
collect_metrics(last_fit)
# A tibble: 2 × 4
.metric .estimator .estimate .config
<chr> <chr> <dbl> <chr>
1 accuracy binary 0.620 Preprocessor1_Model1
2 roc_auc binary 0.635 Preprocessor1_Model1
collect_predictions(last_fit) %>%
conf_mat(CLASSI_FIN, .pred_class)
Truth
Prediction FALSE TRUE
FALSE 16696 10861
TRUE 36661 60782
%>%
last_fit extract_fit_engine() %>%
vip()
Session info
sessionInfo()
R version 4.1.2 (2021-11-01)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 22.04.2 LTS
Matrix products: default
BLAS: /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.10.0
LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.10.0
locale:
[1] LC_CTYPE=pt_BR.UTF-8 LC_NUMERIC=C
[3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
[5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
[7] LC_PAPER=en_US.UTF-8 LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] rpart_4.1.16 rlang_1.1.1 xgboost_1.7.5.1 vip_0.3.2
[5] tictoc_1.2 bonsai_0.2.1 finetune_1.1.0 yardstick_1.2.0
[9] workflowsets_1.0.1 workflows_1.1.3 tune_1.1.1 rsample_1.1.1
[13] recipes_1.0.6 parsnip_1.1.0 modeldata_1.1.0 infer_1.0.4
[17] dials_1.2.0 scales_1.2.1 broom_1.0.5 tidymodels_1.1.0
[21] knitr_1.43 arrow_12.0.1 lubridate_1.9.2 forcats_1.0.0
[25] stringr_1.5.0 dplyr_1.1.2 purrr_1.0.1 readr_2.1.4
[29] tidyr_1.3.0 tibble_3.2.1 ggplot2_3.4.2 tidyverse_2.0.0
loaded via a namespace (and not attached):
[1] nlme_3.1-155 bit64_4.0.5 doParallel_1.0.17
[4] DiceDesign_1.9 tools_4.1.2 backports_1.4.1
[7] utf8_1.2.3 R6_2.5.1 colorspace_2.1-0
[10] nnet_7.3-17 withr_2.5.0 gridExtra_2.3
[13] tidyselect_1.2.0 bit_4.0.5 compiler_4.1.2
[16] cli_3.6.1 labeling_0.4.2 digest_0.6.33
[19] minqa_1.2.5 rmarkdown_2.23 pkgconfig_2.0.3
[22] htmltools_0.5.5 lme4_1.1-34 parallelly_1.36.0
[25] lhs_1.1.6 fastmap_1.1.1 htmlwidgets_1.6.2
[28] rstudioapi_0.15.0 farver_2.1.1 generics_0.1.3
[31] jsonlite_1.8.7 magrittr_2.0.3 Matrix_1.6-0
[34] Rcpp_1.0.11 munsell_0.5.0 fansi_1.0.4
[37] GPfit_1.0-8 lifecycle_1.0.3 furrr_0.3.1
[40] stringi_1.7.12 yaml_2.3.7 MASS_7.3-55
[43] grid_4.1.2 parallel_4.1.2 listenv_0.9.0
[46] lattice_0.20-45 splines_4.1.2 hms_1.1.3
[49] pillar_1.9.0 boot_1.3-28 future.apply_1.11.0
[52] codetools_0.2-18 glue_1.6.2 evaluate_0.21
[55] data.table_1.14.8 nloptr_2.0.3 vctrs_0.6.3
[58] tzdb_0.4.0 foreach_1.5.2 gtable_0.3.3
[61] future_1.33.0 assertthat_0.2.1 xfun_0.39
[64] gower_1.0.1 prodlim_2023.03.31 class_7.3-20
[67] survival_3.2-13 timeDate_4022.108 iterators_1.0.14
[70] hardhat_1.3.0 lava_1.7.2.1 timechange_0.2.0
[73] globals_0.16.2 ellipsis_0.3.2 ipred_0.9-14