Model Misspecification & Double Robustness

In this example we illustrate the double robustness property of TMLE in the classical backdoor adjustment setting for the Average Treatment Effect.

Let's consider the following simple data generating process:

\[\begin{aligned} W \sim \mathcal{N}(0, 1) \\ T \sim \mathcal{B}(\frac{1}{1 + e^{-(0.3 - 0.5 \cdot W)}}) \\ Y \sim \mathcal{N}(e^{1 - 10 \cdot T + W}, 1) \end{aligned}\]

using TMLE
using MLJ
using Distributions
using StableRNGs
using LogExpFunctions
using MLJGLMInterface
using DataFrames
using CairoMakie


μY(T, W) = exp.(1 .- 10T .+ 1W)

function generate_data(;n = 1000, rng = StableRNG(123))
    W  = rand(rng, Normal(), n)
    μT = logistic.(0.3 .- 0.5W)
    T  = float(rand(rng, n) .< μT)
    ϵ = rand(rng, Normal(), n)
    Y  = μY(T, W) .+ ϵ
    Y₁ = μY(ones(n), W) .+ ϵ
    Y₀ = μY(zeros(n), W) .+ ϵ
    return DataFrame(
        W = W,
        T = T,
        Tcat = categorical(T),
        Y = Y,
        Y₁ = Y₁,
        Y₀ = Y₀
    )
end

function plotY(data)
    fig = Figure()
    ax = Axis(fig[1, 1], xlabel="W", ylabel="Y")
    for (key, group) in pairs(groupby(data, :T))
        scatter!(ax, group.Y, group.W, label=string("T=",key.T))
    end
    axislegend()
    return fig
end

data = generate_data(;n = 1000, rng = StableRNG(123))
plotY(data)
Example block output

Y is thus a non linear function of T and W. Despite the simplicity of the example, it is difficult to find a closed form solution for the true Average Causal Effect. However, since we know the generating process, we can approximate it using a Monte-Carlo approximation. In the next two sections, we compare Linear inference and TMLE and see how well they cover this Monte-Carlo approximation.

Estimation using a Linear model

We first propose to estimate the effect size using the classic linear inference method. Because our model does not contain the data generating process (and is hence mis-specified), there is no guarantee that the true effect size will be covered by our confidence interval. In fact, as the sample size grows, the confidence interval will inevitably shrink and fail to cover the ground truth. This can be seen from the following animation:

function linear_inference(data)
    mach = machine(LinearRegressor(), data[!, [:W, :T]], data.Y)
    fit!(mach, verbosity=0)
    coeftable = report(mach).coef_table
    Trow = findfirst(x -> x == "T", coeftable.rownms)
    coef = coeftable.cols[1][Trow]
    lb = coeftable.cols[end - 1][Trow]
    ub = coeftable.cols[end][Trow]
    return (coef, lb, ub)
end

function repeat_inference(inference_method; n=1000, K=100)
    estimates = Vector{Float64}(undef, K)
    errors = Vector{Float64}(undef, K)
    mcestimates = Vector{Float64}(undef, K)
    for k in 1:K
        data = generate_data(;n=n, rng=StableRNG(k))
        est, lb, ub = inference_method(data)
        estimates[k] = est
        errors[k] = ub - est
        mcestimates[k] = mean(data.Y₁ .- data.Y₀)
    end
    return estimates, errors, mcestimates
end

function plot_coverage(inference_method; n=1000, K=100)
    fig = Figure()
    title = Observable(string("N=", n))
    ax = Axis(fig[1, 1], xlabel="Repetition", ylabel="Estimate size", title=title)
    ks = 1:K
    estimates, errors, mcestimates = repeat_inference(inference_method; n=n, K=K)
    estimates = Observable(estimates)
    errors = Observable(errors)
    mcestimates = Observable(mcestimates)
    errorbars!(ax, ks, estimates, errors, color=:red, whiskerwidth = 10)
    scatter!(ax, ks, estimates, color=:red, label=replace(string(inference_method), "_" => " "))
    scatter!(ax, ks, mcestimates, label="Monte Carlo estimate")
    axislegend()
    return fig, estimates, errors, mcestimates, title
end

function update_observables!(estimates, errors, mcestimates, title, inference_method; n=1000, K=100)
    newestimates, newerrors, newmcestimates = repeat_inference(inference_method; n=n, K=K)
    estimates[] = newestimates
    errors[] = newerrors
    mcestimates[] = newmcestimates
    title[] = string("N=", n)
end

function make_animation(inference_method)
    Ns = [10_000, 25_000, 50_000, 75_000, 100_000, 250_000, 500_000]
    fig, estimates, errors, mcestimates, title = plot_coverage(inference_method; n=10_000, K=100)
    record(fig, "$(inference_method).gif", Ns; framerate = 1) do n
        update_observables!(estimates, errors, mcestimates, title, inference_method; n=n, K=100)
    end
end

make_animation(linear_inference)
"linear_inference.gif"

Linear Inference

Estimation using TMLE

To solve this issue, we will now use TMLE to estimate the Average Treatment Effect. We will keep the mis-specified linear model to estimate E[Y|T,W] but will estimate p(T|W) with a logistic regression which turns out to be the true generating model in this case. Because TMLE is double robust we see that we now have full coverage of the ground truth.

function tmle_inference(data)
    Ψ = ATE(
        outcome=:Y,
        treatment_values=(Tcat=(case=1.0, control=0.0),),
        treatment_confounders=(Tcat=[:W],)
    )
    models = Dict(
        :Y    => with_encoder(LinearRegressor()),
        :Tcat => with_encoder(LinearBinaryClassifier())
    )
    tmle = TMLEE(models=models)
    result, _ = tmle(Ψ, data; verbosity=0)
    lb, ub = confint(OneSampleTTest(result))
    return (TMLE.estimate(result), lb, ub)
end

make_animation(tmle_inference)
"tmle_inference.gif"

TMLE Inference


This page was generated using Literate.jl.