Skip to contents

This vignette guides you through the new functionality that turns mrIML into a graphical network model (GNMs). There is also new functions in MrIML 2.0 that can be used in non-GN models, such as for single taxon SDMs, epi models, gene environment association (GEA) studies and for landscape genetics. Advances in Tidymodels are in the core of the functions as are exciting advances in interpretable machine learning (e.g., for quantifying interactions quickly). The capability to capture uncertainty via bootstraps and spatial patterns are also a feature of this MRIML module. The examples are from the microbiome world but of course generalizable to any multi-response problem.

Currently, the models are optimized and tested on presence/absence data. Models suitable for abundance data are being added in the future.

First, we need to load up some extra functionality and some data. The first example is coinfection data from New Caledonian Zosterops species. A single continuous covariate is also included (scale.prop.zos), which reflects the relative abundance of Zosterops species among different sample sites.

Y <- dplyr::select(Bird.parasites,
                   -scale.prop.zos) %>% 
    dplyr::select(sort(names(.)))#response variables eg. SNPs, pathogens, species....
X <- dplyr::select(Bird.parasites,
                   scale.prop.zos) # feature set

X1 <- Y %>%
  dplyr::select(sort(names(.)))

X1_fact <- X1 %>%
  mutate_all(as.factor) %>% 
  mutate_all(~ifelse(. == 0, "absent", "present"))

Note that the inclusion of ‘X1’ converts MrIML into a JSDM by directly adding of the presence/absence patterns of the other taxa into the model (as well as environmental/host covariates). Note, the order of X1 and Y need to match.

Setting up the models

Built into the MrIML architecture (and the big advantage with Tidymodels) is the capability to change the underlying model easily. We are going to set up two models to compare: a random forest model (RF) and a logistic regression (lm). MrIML takes advantage of multi-core processing so we set that up here to run on 5 cores. These steps are the same as in MrIML 1.0.

model_rf <-rand_forest(trees = 100,
                       mode = "classification",
                       mtry = tune(),
                       min_n = tune()) %>% #100 trees are set for brevity. Aim to start with 1000
          set_engine("randomForest")

model_lm <- #model used to generate yhat
  logistic_reg() %>%
  set_engine("glm") %>%
  set_mode("classification") #just for your response

cl <- parallel::makeCluster(5)
plan(cluster, workers=cl)

Running the models

Aside from adding JSDM functionality with the X1 call, we have also enabled MrIML to tune hyperparameters using the very efficient ‘racing’ option (see Kuhn (2014)). In brief, this racing option take a small subsample of parameters and eliminates parameter combinations that do not improve fit using a repeated measure ANOVA model. Turning racing to ‘FALSE’ goes back to a grid search if you want to manually set the tuning grid size. For the logistic regression there are no parameters to tune so set to ‘FALSE’.

#random forest
yhats_rf <- mrIMLpredicts(X=X, Y=Y,
                          X1=X1,
                          Model=model_rf,
                          balance_data='no',
                          mode='classification',
                          seed = sample.int(1e8, 1),
                          morans=F,
                          prop=0.7, k=5, racing=T)
#>   |                                                                              |                                                                      |   0%  |                                                                              |==================                                                    |  25%  |                                                                              |===================================                                   |  50%  |                                                                              |====================================================                  |  75%  |                                                                              |======================================================================| 100%

#linear model
yhats_lm <- mrIMLpredicts(X=X,Y=Y,
                            X1=X1_fact,
                            Model=model_lm , 
                            balance_data='no',
                            mode='classification',
                            seed = sample.int(1e8, 1),
                            prop=0.6, racing=F, k=5)
#>   |                                                                              |                                                                      |   0%  |                                                                              |==================                                                    |  25%
#> Warning: No tuning parameters have been detected, performance will be evaluated
#> using the resamples with no tuning. Did you want to [tune()] parameters?
#>   |                                                                              |===================================                                   |  50%
#> Warning: No tuning parameters have been detected, performance will be evaluated
#> using the resamples with no tuning. Did you want to [tune()] parameters?
#>   |                                                                              |====================================================                  |  75%
#> Warning: No tuning parameters have been detected, performance will be evaluated
#> using the resamples with no tuning. Did you want to [tune()] parameters?
#>   |                                                                              |======================================================================| 100%
#> Warning: No tuning parameters have been detected, performance will be evaluated
#> using the resamples with no tuning. Did you want to [tune()] parameters?

Comparing performance

It’s important to compare if there are any advantages to using a random forest approach. Interpreation would be easier overall if logistic regression gave similar predictive performance.

ModelPerf_rf <- mrIMLperformance(yhats_rf,
                                 Model=model_rf,
                                 Y=Y,
                                 mode='classification')

ModelPerf_rf[[1]] #across all parasites
#>       response  model_name           roc_AUC               mcc
#> 1   Hkillangoi rand_forest 0.886292016806723 0.365745522166111
#> 2  Hzosteropis rand_forest 0.941444372153546 0.718241208825566
#> 3 Microfilaria rand_forest 0.928861788617886 0.437856895997608
#> 4         Plas rand_forest 0.864678899082569 0.496552216869595
#>         sensitivity               ppv       specificity         prevalence
#> 1 0.983193277310924              0.25 0.906976744186047  0.115812917594655
#> 2 0.933962264150943 0.793103448275862 0.942857142857143  0.265033407572383
#> 3 0.983739837398374 0.333333333333333 0.937984496124031 0.0979955456570156
#> 4 0.926605504587156 0.538461538461538 0.893805309734513  0.195991091314031
ModelPerf_rf[[2]] #overall
#> [1] 0.504599

ModelPerf_lm <- mrIMLperformance(yhats_lm,
                                 Model=model_lm,
                                 Y=Y, mode='classification')

ModelPerf_lm[[1]]
#>       response   model_name           roc_AUC               mcc
#> 1   Hkillangoi logistic_reg 0.803326474622771              <NA>
#> 2  Hzosteropis logistic_reg 0.815781274341798   0.4040224605406
#> 3 Microfilaria logistic_reg  0.91820987654321 0.433340830644394
#> 4         Plas logistic_reg 0.777072192513369 0.240820814553862
#>         sensitivity               ppv       specificity         prevalence
#> 1                 1                 0               0.9  0.115812917594655
#> 2 0.946564885496183  0.36734693877551               0.8  0.265033407572383
#> 3 0.981481481481482 0.333333333333333 0.929824561403509 0.0979955456570156
#> 4 0.963235294117647 0.181818181818182 0.784431137724551  0.195991091314031
ModelPerf_lm[[2]]
#> [1] 0.269546

plots <- mrPerformancePlot(ModelPerf1=ModelPerf_lm,
                           ModelPerf2=ModelPerf_rf,
                           mod_names=c('linear_reg','rand_forest'),
                           mode='classification' ) 
plots
#> [[1]]

#> 
#> [[2]]

#> 
#> [[3]]
#> # A tibble: 4 × 5
#>   response     outlier logistic_reg rand_forest diff_mod1_2
#>   <chr>        <lgl>          <dbl>       <dbl>       <dbl>
#> 1 Hkillangoi   NA             0           0.366     0.366  
#> 2 Hzosteropis  NA             0.404       0.718     0.314  
#> 3 Microfilaria NA             0.433       0.438     0.00452
#> 4 Plas         NA             0.241       0.497     0.256

If we just look at the overall AUC values, it looks like model performance is quite similar (0.86 for the rf and 0.81 for the lm). However, when we look a the Mathew’s correlation coefficient (MCC) for each taxa in the lm model we see that for H.killangoi and Plas (Plasmodium) the values are much lower (e.g. 0.02, basically just a guess) compared to 0.31 for rf). Remember the classes are imbalanced so AUC tends to be an inflated measure. This is evidence that non-linear relationships can make a difference to prediction overall. We will interrogate the rf model further.

We will now ask if including putative associations between taxa improves model performance overall or in the relationship between the parasites or is host relative abundance enough.

yhats_rf_noAssoc <- mrIMLpredicts(X=X,
                                  Y=Y,
                                  X1=NULL, #no associations for this one
                                  Model=model_rf,
                                  balance_data='no',
                                  mode='classification',
                                  seed = sample.int(1e8, 1),
                                  prop=0.7,
                                  k=5,
                                  racing=T)
#>   |                                                                              |                                                                      |   0%  |                                                                              |==================                                                    |  25%  |                                                                              |===================================                                   |  50%  |                                                                              |====================================================                  |  75%  |                                                                              |======================================================================| 100%

ModelPerf_rf_noAssoc <- mrIMLperformance(yhats_rf_noAssoc,
                                         Model=model_rf,
                                         Y=Y,
                                         mode='classification')

ModelPerf_rf_noAssoc[[1]]
#>       response  model_name           roc_AUC               mcc
#> 1   Hkillangoi rand_forest  0.61633428300095              <NA>
#> 2  Hzosteropis rand_forest 0.787623248572911 0.494420277562475
#> 3 Microfilaria rand_forest 0.715175953079179 0.371709945158283
#> 4         Plas rand_forest  0.78360768175583 0.460220713518774
#>         sensitivity               ppv       specificity         prevalence
#> 1                 1                 0 0.866666666666667  0.115812917594655
#> 2 0.936170212765957  0.48780487804878 0.807339449541284  0.265033407572383
#> 3 0.983870967741935 0.272727272727273 0.938461538461538 0.0979955456570156
#> 4 0.898148148148148 0.555555555555556 0.889908256880734  0.195991091314031
ModelPerf_rf[[1]] #performance including associations
#>       response  model_name           roc_AUC               mcc
#> 1   Hkillangoi rand_forest 0.886292016806723 0.365745522166111
#> 2  Hzosteropis rand_forest 0.941444372153546 0.718241208825566
#> 3 Microfilaria rand_forest 0.928861788617886 0.437856895997608
#> 4         Plas rand_forest 0.864678899082569 0.496552216869595
#>         sensitivity               ppv       specificity         prevalence
#> 1 0.983193277310924              0.25 0.906976744186047  0.115812917594655
#> 2 0.933962264150943 0.793103448275862 0.942857142857143  0.265033407572383
#> 3 0.983739837398374 0.333333333333333 0.937984496124031 0.0979955456570156
#> 4 0.926605504587156 0.538461538461538 0.893805309734513  0.195991091314031

You can see that overall including associations improved model performance overall but particularly in predicting H.killangoi and Microfilaria. Using MCC to compare models is problematic as for the association free model its not defined for these taxa (NA, probably as there are no false negatives for the low prevalence taxa in the association-free model). Positive predictive value (PPV) is useful in this case and shows that without associations we can’t predict the occurrence of these taxa (PPV=0 for both). Including associations increases PPV to ~0.2 - not great as 80% of our positive predictions for these taxa are false.

Downsampling

Including associations makes a difference, but how can we do better in predicting our two rarer taxa? Up sampling is possible but in this case we’ll try down sampling to see if correcting for class imbalance improves our model fit.

yhats_rf_downSamp <- mrIMLpredicts(X=X,
                                   Y=Y,
                                   X1=X1,
                                   Model=model_rf,
                                   balance_data='down', #down sampling
                                   mode='classification',
                                   seed = sample.int(1e8, 1),
                                   prop=0.75,
                                   k=5,
                                   racing=T)
#>   |                                                                              |                                                                      |   0%  |                                                                              |==================                                                    |  25%  |                                                                              |===================================                                   |  50%  |                                                                              |====================================================                  |  75%  |                                                                              |======================================================================| 100%

ModelPerf_rf_downSamp <- mrIMLperformance(yhats_rf_downSamp,
                                          Model=model_rf,
                                          Y=Y,
                                          mode='classification')

ModelPerf_rf_downSamp[[1]]
#>       response  model_name           roc_AUC               mcc
#> 1   Hkillangoi rand_forest 0.893041237113402 0.489868155711274
#> 2  Hzosteropis rand_forest 0.915567765567766 0.645516132718224
#> 3 Microfilaria rand_forest  0.96504854368932 0.474958292677348
#> 4         Plas rand_forest  0.83719806763285 0.400395061824689
#>         sensitivity               ppv       specificity         prevalence
#> 1 0.814432989690722            0.8125 0.963414634146341  0.115812917594655
#> 2 0.897435897435897 0.742857142857143 0.886075949367089  0.265033407572383
#> 3 0.766990291262136                 1                 1 0.0979955456570156
#> 4               0.8 0.652173913043478               0.9  0.195991091314031

Look at those PPV values now - much better. Our false positive rate is down to ~ < 15% overall. Now that we are happy with the performance of the model, now we can interrogate further.

Interpreting the model

In many cases, like this data set, community or microbiome data tend to be small in size. When we apply stochastic machine learning algorithms to such data, it can lead to challenges. For instance, the importance of variables may vary substantially when we create multiple models using the same data and algorithm. To handle this variability and better understand prediction uncertainty, MrIML 2.0 has functionality to capture uncertainty in our tuned model using bootstraps. Additionally, this approach helps us estimate how variables affect the response, and these estimates align with the results obtained from traditional linear regression models (see Cook et al., 2021).

MrIML2.0 makes it easy to get bootstrap estimates for a variety of interpretable machine learning tools, and uses these estimates to construct marginalized co-occurrence networks. First, lets do the bootstrapping and calculate variable importance.

cl <- parallel::makeCluster(5) #can increase the number of cores as needed.
plan(cluster, workers=cl)

#do bootstraps.
bs_malaria <- mrBootstrap(yhats=yhats_rf,
                          Y=Y, 
                          num_bootstrap = 10,
                          downsample = TRUE,
                          mode='classification') 
#>   |                                                                              |                                                                      |   0%  |                                                                              |==================                                                    |  25%  |                                                                              |===================================                                   |  50%  |                                                                              |====================================================                  |  75%  |                                                                              |======================================================================| 100%
#make sure downsample=TRUE as this did improve performance
#just 10 bootstraps to keep this short. We suggest using more for a final analysis (100 is reasonable but depends on how big the data is)

#up to here -not working properly
bs_impVI <- mrvip(
  mrBootstrap_obj = bs_malaria,
  yhats = yhats_rf_downSamp,
  X = X,
  X1 = X1,
  Y = Y,
  mode = 'classification',
  threshold = 0.0,
  global_top_var = 10,
  local_top_var = 5,
  taxa = NULL,
  ModelPerf = ModelPerf_rf_downSamp)
#> [1] "here"

bs_impVI[[3]]  #importance plot. There are plenty of other insights possible

#the 'global_top_var' provides a limit to how many predictors are included in the community-wide plot. 'local_top_var provides a limit to the number of individual taxa plots. The threshold excludes plotting individual importance plots for taxa not well predicted by the model.

You can see that host abundance is the most important predictor of this parasite community (followed by H.zosteropsis) . However, the second figure shows that there is important variability. For example, H.zosteropsis is the most important predictor for the occurence of Microfilaria, and host abundance (shortened to sc..) is less important.

Bootstrap partial dependence plots

To look at the the relationship between each variable and community structure, MrIML 2.0 has a convenient wrapper to plot bootsrapped partial dependencies for a taxa of interest.

pds <- mrPD_bootstrap(mrBootstrap_obj=bs_malaria,
                      vi_obj=bs_impVI,
                      X,
                      Y,
                      target='Plas',
                      global_top_var=5)

These plots show that, for example, the presence of Microfilaria greatly increass the probability of observing Plasmodium (from ~0.32 to 0.58 while holding all other variables at their mean value). Note that these are marginal relationships (i.e. isolating the effect of just each predictor). When host abundance is high the probability of detecting Plasmodium descreases non-linearly (a threshold of around ~0). The other parasites have less effect.

If we want to explore the effect of host abundance overall we can use the ‘mr_Covar’ function.

covar <- mr_Covar(yhats=yhats_rf_downSamp,
                  X=X,
                  X1=X1,
                  Y=Y,
                  var='scale.prop.zos',
                  sdthresh =0.01) 

#sdthrsh just plots taxa responding the most.

Note that this isn’t bootstrapped now - each line represents a taxon in this case. The second plot shows the community-wide change in occurrence probabilities across host abundance. Note the the occurrence probabilities of all taxa drop an intermediate levels of host abundance (0.75-1.25).

Co-occurrence network

We can utilise all the bootstrapped partial dependence estimates (pds) to construct a co-occurrence network. We show how this object can be converted to an igraph object and plotted. This is a directed network and edges are scaled by the standard deviation of the marginal change in prediction. Red are positive associations (the predicted occurrence of the taxa increases with the presence of the other) and blue are negative (the predicted occurrence of the taxa decreases with the presence of the other).

assoc_net<- mrCoOccurNet_bootstrap (mrPD_obj=pds,
                                    Y=Y)

assoc_net_filtered <-  assoc_net %>% 
  filter(mean_strength > 0.1)
#based on our simulations the following rule of thumb for associations. Any association  < 0.05  for mean strength is included.

#convert to igraph
g <- graph_from_data_frame(assoc_net_filtered,
                           directed=TRUE,
                           vertices=names(Y)) #matching Y data

E(g)$Value <- assoc_net_filtered$mean_strength

E(g)$Color <- ifelse(assoc_net_filtered$direction == "negative", "blue", "red")

# Convert the igraph object to a ggplot object with NMDS layout
gg <- ggnetwork(g)

# Plot the graph
ggplot(gg, aes(x = x, y = y, xend = xend, yend = yend)) +
  geom_edges(aes(color = Color, linewidth = (Value)), 
             curvature = 0.2,
             arrow = arrow(length = unit(5, "pt"),
                           type = "closed")) + 
  geom_nodes(color = "gray", size = degree(g, mode = "out")/2)+
  scale_color_identity() +
  theme_void() +
  theme(legend.position = "none")  +
  geom_nodelabel_repel(aes(label = name),
                       box.padding = unit(0.5, "lines"),
                       data = gg,
                       size=2,
                       segment.colour = "black",
                       colour = "white", fill = "grey36")

## 1 way, 2 way and 3 way interactions

Finally, we can quantify the importance of interactions overall as well as one and two-way interaction importance using the same bootstrap approach.See https://github.com/mayer79/hstats for more details about the method.

int_ <- mrInteractions(yhats=yhats_rf,
                       X,
                       Y,
                       num_bootstrap=10,
                       feature = 'Plas',
                       top.int=10)
#10 bootstraps to keep it short. top int focusses on the 10 top interactions (all of them in this case).

int_[[1]] # overall plot

int_[[2]] # individual plot for the response of choice 

int_[[3]] #two way plot

The first plot shows that interactions account for on average 27% (bootstrap interval 23-34%) of variation in predictions for H.killangoi and less for the other taxa. The second plot shows that interactions involving host abundance impacts predictions of Plasmodium the most (but H.zosteropis is also important). This trend is similar community-wide. The next plot shows that the interaction between Haemoproteus species is the most important 2-way interaction for Plasmodium but this isn’t true community wide as host abudance and H.zosteropis is the strongest interaction overall. Taken together, we can see that interactions between taxa are mediated by host abundance.

Finally, we can explore specific interactions in more detail using 2D partial dependence plots. In case we choose the one of the more important 2-way interactions impacting the probability of detetcting Plasmodium.

fl <- mrFlashlight(yhats=yhats_rf_downSamp,
                   X=cbind(X, Y),
                   Y=Y,
                   response = "single",
                   index=4,
                   mode='classification') #index=4 selects Plasmodium

plot(light_profile2d(fl,c("scale.prop.zos","Hzosteropis")))+
  theme_bw()

So if H.zosteropsis is not present and the relative abundance of Zosterops species is low the probability of observing Plasmodium is high (~>0.7).