I am interesting in experimenting with the recently published Invariant Risk Minimization (IRM) technique (see https://arxiv.org/abs/1907.02893). In short, this method is designed to identify direct causes of an output by looking for invariant correlations across datasets from different environments. Each environment consists of data with the same set of variables, but with different noise or interventions applied to one or more of the variables. The authors give an example of classifying photos of cows vs. camels, in which a classifier may learn to predict based on background (green for cow and brown (sand) for camel), because background is confounded with visual properties that are more causally responsible for something being a cow or a camel.
I am specifically interested in experimenting with the technique on observational healthcare data, since that data naturally comes from different environments and should apply to additional environments. IRM is specifically designed to maximize out of distribution performance.
My question is whether it is feasible at this point to use fastai for this. It appears that this would require a custom pipeline. The basic pytorch example of the training loop is shown below:
def compute_penalty ( losses , dummy_w ): g1 = grad( losses[0::2].mean() , dummy_w, create_graph = True ) # 0::2 starts with first element and returns every other element in losses g2 = grad( losses[1::2].mean() , dummy_w, create_graph = True ) # 1::2 starts with second element and returns every other after that return ( g1 * g2 ).sum() for iteration in range(50000): error = 0 penalty = 0 for x_e, y_e in environments: p = torch.randperm(len(x_e)) # p is used to randomly permute the input cases for x_e, in order to split x_e into two minibatches for compute_penalty error_e = mse( x_e[p] @ phi * dummy_w , y_e[p ]) penalty += compute_penalty( error_e, dummy_w ) error += error_e.mean() opt.zero_grad() (error + lambda*penalty ).backward() opt.step ()
Two aspects of this look a bit tricky for implementing in fastai. The first is how the penalty and error is calculated and summed for each environment. The second is the use of the gradients for the regularization term in the loss function (done by
compute_penalty) along with randomly splitting data from an environment to compute the penalty. This term controls the tradeoff between minimizing sum of errors across environments vs. minimizing variation in performance across environments. I’m also not entirely sure of how to best handle batch size in this training loop.
I’ve found posts on how to do custom workflows in fastai, but most are now old. Does this seem like something that is relatively easy to do, or even worth it at this time? Can you point me to the best examples on how I might do it? The trade-offs are the time needed to implement in fastai vs. the time to code a more complete pipeline and utilities (e.g., learning rate finder) in straight pytorch.