#8 | Cross-Validation Experiment
Cross-validation isn’t estimating the error of the specific model we have trained. So says Cross-Validation: what does it estimate and how well does it do it? which is cited in the early chapters of Causal Inference and Discovery in Python.
Instead it is estimating the average error of the models created by the training procedure across all training datasets of the same size.
Let’s call our training datasets XY_i
where i
runs from 1
up to n
. This notation captures that training data has both independent (X
) and dependent (Y
) variables.
(This would be a nice time to be using something like Mathjax.)
Let population_error_XY_i
be the population error of the model trained on the data XY_i
.
Call the average population error across all models mean_population_error_XY
. In reality we never know the population error of a model but in this case we can calculate this since we are using a fabricated population dataset.
Let the cross-validation estimate of a model trained on data XY_i
be cv_error_XY_i
.
The paper shows that cv_error_XY_i
will be closer to mean_population_error_XY
than to population_error_XY_i
(they use mean squared error).
Let’s see if we get a similar result. The paper focuses mostly on the linear case so it would be nice to look at something different like a non-linear relaionship here.
Oh, since we’ve started comparing, it’s worth calling out that this little experiment will be much much less thorough than the paper.
Fabricated data
Here’s some code to create a simple dataset using polars.
It has three independent variables (x1
, x2
, x3
) from different distributions and a dependent variable which is a multiplicative combination of the these along with some noise e
. A RELU function is added for some extra non-linearity.
Modelling
Let’s try to predict y
from x1
, x2
and x3
using a decision tree. We’re not trying to train a good model, rather we are trying to understand the cross-validation eatimate of our models performance. This means we can arbitrarily pick the models parameters, in this case tree depth, and keep it fixed.
The code below has functions to sample a single training dataset, fit a tree, compute the cross-validation estimate of error and calculate the population error. Mean squared error is used for error calculations.
There is a bit of a dance switching between polars and numpy arrays when passing data to scikit-learn. I think there is a better way to do this and I need to investigate that when I get a chance.
Since we are fabricating a population dataset of fixed size we can sample from it to get training datasets and then evaluate the performance of the trained model on the entire population.
The Experiment
Let’s put it all together. We set the population size to 1,000,009 (yes 9…a typo but I’m sticking by it!). We build 1,000 trees each trained on a sample of 1,000 data points from the population dataset.
The only parameter we set when building the tree is the max depth, arbitrarily choosing 10. It might overfit or it might not, that’s not important for this experiment.
We use 10 fold cross validation.
We also include a paired t-test to give a little comfort that there are enough trees. I think the assumptions of a paired t-test are met but I’ve tried not to over think that.
Here is df_results
.
shape: (1, 3)
┌────────────────┬───────────────┬───────────────┐
│ delta_cv_mean_ ┆ delta_cv_popu ┆ paired_t_test │
│ population_err ┆ lation_error_ ┆ _p_value │
│ or_squared ┆ squared ┆ --- │
│ --- ┆ --- ┆ f64 │
│ f64 ┆ f64 ┆ │
╞════════════════╪═══════════════╪═══════════════╡
│ 18963.842164 ┆ 21313.904408 ┆ 0.002268 │
└────────────────┴───────────────┴───────────────┘
Hurray! The cross-validation estimate is closer to the mean of all the population errors and there is a very low p value from the t-test.
I’d be curious to see what happens as we move towards leave-one-out cross validation but that’s not for now.