Becoming a Super Learner

What this tutorial is about

Super Learning, also known as Stacking, is an ensemble technique that was first introduced by Wolpert in 1992. Instead of selecting a model based on cross-validation performance, models are combined by a meta-learner to minimize the cross-validation error. It has also been shown by van der Laan et al. that the resulting Super Learner will perform at least as well as its best performing submodel (at least asymptotically).

Why is it important for Targeted Learning?

The short answer is that the consistency (convergence in probability) of the targeted estimator depends on the consistency of at least one of the nuisance estimands: $Q_0$ or $G_0$. By only using unrealistic models like linear models, we have little chance of satisfying the above criterion. Super Learning is a data driven way to leverage a diverse set of models and build the best performing estimator for both $Q_0$ or $G_0$.

The dataset

Let's consider the case where Y is categorical. In TMLE.jl, this could be useful to learn:

  • The propensity score
  • The outcome model when the outcome is binary

We will use the following moons dataset:

using MLJ

X, y = MLJ.make_moons(1000)

Defining a Super Learner in MLJ

In MLJ, a Super Learner can be defined using the Stack function. The three most important type of arguments for a Stack are:

  • metalearner: The metalearner to be used to combine the weak learner to be defined. Typically a generalized linear model.
  • resampling: The cross-validation scheme, by default, a 6-fold cross-validation. Since we are working with categorical

data it is a good idea to make sure the splits are balanced. We will thus use a StratifiedCV resampling strategy.

  • models...: A series of named MLJ models.

One important point is that MLJ does not provide any model by itself, juat the API, models have to be loaded from external compatible libraries. You can search for available models that match your data.

models(matching(X, y))
57-element Vector{NamedTuple{(:name, :package_name, :is_supervised, :abstract_type, :constructor, :deep_properties, :docstring, :fit_data_scitype, :human_name, :hyperparameter_ranges, :hyperparameter_types, :hyperparameters, :implemented_methods, :inverse_transform_scitype, :is_pure_julia, :is_wrapper, :iteration_parameter, :load_path, :package_license, :package_url, :package_uuid, :predict_scitype, :prediction_type, :reporting_operations, :reports_feature_importances, :supports_class_weights, :supports_online, :supports_training_losses, :supports_weights, :target_in_fit, :transform_scitype, :input_scitype, :target_scitype, :output_scitype)}}:
 (name = AdaBoostClassifier, package_name = MLJScikitLearnInterface, ... )
 (name = AdaBoostStumpClassifier, package_name = DecisionTree, ... )
 (name = BaggingClassifier, package_name = MLJScikitLearnInterface, ... )
 (name = BayesianLDA, package_name = MLJScikitLearnInterface, ... )
 (name = BayesianLDA, package_name = MultivariateStats, ... )
 (name = BayesianQDA, package_name = MLJScikitLearnInterface, ... )
 (name = BayesianSubspaceLDA, package_name = MultivariateStats, ... )
 (name = CatBoostClassifier, package_name = CatBoost, ... )
 (name = ConstantClassifier, package_name = MLJModels, ... )
 (name = DecisionTreeClassifier, package_name = BetaML, ... )
 ⋮
 (name = SGDClassifier, package_name = MLJScikitLearnInterface, ... )
 (name = SVC, package_name = LIBSVM, ... )
 (name = SVMClassifier, package_name = MLJScikitLearnInterface, ... )
 (name = SVMLinearClassifier, package_name = MLJScikitLearnInterface, ... )
 (name = SVMNuClassifier, package_name = MLJScikitLearnInterface, ... )
 (name = StableForestClassifier, package_name = SIRUS, ... )
 (name = StableRulesClassifier, package_name = SIRUS, ... )
 (name = SubspaceLDA, package_name = MultivariateStats, ... )
 (name = XGBoostClassifier, package_name = XGBoost, ... )
Stack limitation

The Stack cannot contain <:Deterministic models for classification.

Let's load a few packages providing models and build our first Stack:

using MLJXGBoostInterface
using MLJLinearModels
using NearestNeighborModels

resampling = StratifiedCV()
metalearner = LogisticClassifier()

stack = Stack(
    metalearner = metalearner,
    resampling  = resampling,
    lr          = LogisticClassifier(),
    knn         = KNNClassifier(K=3)
)
ProbabilisticStack(
  metalearner = LogisticClassifier(
        lambda = 2.220446049250313e-16, 
        gamma = 0.0, 
        penalty = :l2, 
        fit_intercept = true, 
        penalize_intercept = false, 
        scale_penalty_with_samples = true, 
        solver = nothing), 
  resampling = StratifiedCV(
        nfolds = 6, 
        shuffle = false, 
        rng = Random.TaskLocalRNG()), 
  measures = nothing, 
  cache = true, 
  acceleration = ComputationalResources.CPU1{Nothing}(nothing), 
  lr = LogisticClassifier(
        lambda = 2.220446049250313e-16, 
        gamma = 0.0, 
        penalty = :l2, 
        fit_intercept = true, 
        penalize_intercept = false, 
        scale_penalty_with_samples = true, 
        solver = nothing), 
  knn = KNNClassifier(
        K = 3, 
        algorithm = :kdtree, 
        metric = Distances.Euclidean(0.0), 
        leafsize = 10, 
        reorder = true, 
        weights = NearestNeighborModels.Uniform()))

This Stack only contains 2 different models: a logistic classifier and a KNN classifier. A Stack is just like any MLJ model, it can be wrapped in a machine and fitted:

mach = machine(stack, X, y)
fit!(mach, verbosity=0)
trained Machine; does not cache data
  model: ProbabilisticStack(metalearner = LogisticClassifier(lambda = 2.220446049250313e-16, …), …)
  args: 
    1:	Source @083 ⏎ ScientificTypesBase.Table{AbstractVector{ScientificTypesBase.Continuous}}
    2:	Source @701 ⏎ AbstractVector{ScientificTypesBase.Multiclass{2}}

Or evaluated. Because the Stack contains a cross-validation procedure, this will result in two nested levels of resampling.

evaluate!(mach, measure=log_loss, resampling=resampling)
PerformanceEvaluation object with these fields:
  model, measure, operation,
  measurement, per_fold, per_observation,
  fitted_params_per_fold, report_per_fold,
  train_test_rows, resampling, repeats
Extract:
┌──────────────────────┬───────────┬─────────────┐
│ measure              │ operation │ measurement │
├──────────────────────┼───────────┼─────────────┤
│ LogLoss(             │ predict   │ 2.22e-16    │
│   tol = 2.22045e-16) │           │             │
└──────────────────────┴───────────┴─────────────┘
┌──────────────────────────────────────────────────────────────┬─────────┐
│ per_fold                                                     │ 1.96*SE │
├──────────────────────────────────────────────────────────────┼─────────┤
│ [2.22e-16, 2.22e-16, 2.22e-16, 2.22e-16, 2.22e-16, 2.22e-16] │ 0.0     │
└──────────────────────────────────────────────────────────────┴─────────┘

A more advanced Stack

What are good Stack members? Virtually anything, provided they are MLJ models. Here are a few examples:

  • You can use the stack to "select" model hyper-parameters. e.g. KNNClassifier(K=3) or KNNClassifier(K=2)?
  • You can also use self-tuning models. Note that because

these models resort to cross-validation, fitting the stack will result in two nested levels of sample-splitting.

The following self-tuned XGBoost will vary some hyperparameters in an internal sample-splitting procedure in order to optimize the Log-Loss. It will then be combined with the rest of the models in the Stack's own sample-splitting procedure. Finally, evaluation is performed in an outer sample-split.

xgboost = XGBoostClassifier(tree_method="hist")
self_tuning_xgboost = TunedModel(
    model = xgboost,
    resampling = resampling,
    tuning = Grid(goal=20),
    range = [
        range(xgboost, :max_depth, lower=3, upper=7),
        range(xgboost, :lambda, lower=1e-5, upper=10, scale=:log)
        ],
    measure = log_loss,
    cache=false
)

stack = Stack(
    metalearner         = metalearner,
    resampling          = resampling,
    self_tuning_xgboost = self_tuning_xgboost,
    lr                  = LogisticClassifier(),
    knn_2               = KNNClassifier(K=2),
    knn_3               = KNNClassifier(K=3),
    cache               = false
)

mach = machine(stack, X, y, cache=false)
evaluate!(mach, measure=log_loss, resampling=resampling)
PerformanceEvaluation object with these fields:
  model, measure, operation,
  measurement, per_fold, per_observation,
  fitted_params_per_fold, report_per_fold,
  train_test_rows, resampling, repeats
Extract:
┌──────────────────────┬───────────┬─────────────┐
│ measure              │ operation │ measurement │
├──────────────────────┼───────────┼─────────────┤
│ LogLoss(             │ predict   │ 2.22e-16    │
│   tol = 2.22045e-16) │           │             │
└──────────────────────┴───────────┴─────────────┘
┌──────────────────────────────────────────────────────────────┬─────────┐
│ per_fold                                                     │ 1.96*SE │
├──────────────────────────────────────────────────────────────┼─────────┤
│ [2.22e-16, 2.22e-16, 2.22e-16, 2.22e-16, 2.22e-16, 2.22e-16] │ 0.0     │
└──────────────────────────────────────────────────────────────┴─────────┘

Diagnostic

Optionally, one can also investigate how sucessful the weak learners were in the Stack's internal cross-validation. This is done by specifying the measures keyword argument.

Here we look at both the Log-Loss and the AUC.

stack.measures = [log_loss, auc]
fit!(mach, verbosity=0)
report(mach).cv_report
(self_tuning_xgboost = PerformanceEvaluation(0.0137, 1.0),
 lr = PerformanceEvaluation(0.159, 0.985),
 knn_2 = PerformanceEvaluation(2.22e-16, 1.0),
 knn_3 = PerformanceEvaluation(2.22e-16, 1.0),)

One can look at the fitted parameters for the metalearner as well:

fitted_params(mach).metalearner
(classes = CategoricalArrays.CategoricalValue{Int64, UInt32}[0, 1],
 coefs = [:x1 => -245.66831606621056, :x2 => 245.66831656454454, :x3 => -200.67016078091012, :x4 => 200.67016078091018, :x5 => -249.9999999999445, :x6 => 249.9999999999445, :x7 => -249.9999999999445, :x8 => 249.9999999999445],
 intercept = 0.0,)

This page was generated using Literate.jl.