Feasibility of Implementing Invariant Risk Minimization

I am interesting in experimenting with the recently published Invariant Risk Minimization (IRM) technique (see https://arxiv.org/abs/1907.02893). In short, this method is designed to identify direct causes of an output by looking for invariant correlations across datasets from different environments. Each environment consists of data with the same set of variables, but with different noise or interventions applied to one or more of the variables. The authors give an example of classifying photos of cows vs. camels, in which a classifier may learn to predict based on background (green for cow and brown (sand) for camel), because background is confounded with visual properties that are more causally responsible for something being a cow or a camel.

I am specifically interested in experimenting with the technique on observational healthcare data, since that data naturally comes from different environments and should apply to additional environments. IRM is specifically designed to maximize out of distribution performance.

My question is whether it is feasible at this point to use fastai for this. It appears that this would require a custom pipeline. The basic pytorch example of the training loop is shown below:

def compute_penalty ( losses , dummy_w ):
    g1 = grad( losses[0::2].mean() , dummy_w, create_graph = True )[0] # 0::2 starts with first element and returns every other element in losses
    g2 = grad( losses[1::2].mean() , dummy_w, create_graph = True )[0] # 1::2 starts with second element and returns every other after that
    return ( g1 * g2 ).sum()

for iteration in range(50000):
    error = 0
    penalty = 0
    for x_e, y_e in environments: 
        p = torch.randperm(len(x_e)) # p is used to randomly permute the input cases for x_e, in order to split x_e into two minibatches for compute_penalty
        error_e = mse( x_e[p] @ phi * dummy_w , y_e[p ])
        penalty += compute_penalty( error_e, dummy_w )
        error += error_e.mean()
        
    opt.zero_grad() 
    (error + lambda*penalty ).backward()
    opt.step ()

Two aspects of this look a bit tricky for implementing in fastai. The first is how the penalty and error is calculated and summed for each environment. The second is the use of the gradients for the regularization term in the loss function (done by compute_penalty) along with randomly splitting data from an environment to compute the penalty. This term controls the tradeoff between minimizing sum of errors across environments vs. minimizing variation in performance across environments. I’m also not entirely sure of how to best handle batch size in this training loop.

I’ve found posts on how to do custom workflows in fastai, but most are now old. Does this seem like something that is relatively easy to do, or even worth it at this time? Can you point me to the best examples on how I might do it? The trade-offs are the time needed to implement in fastai vs. the time to code a more complete pipeline and utilities (e.g., learning rate finder) in straight pytorch.

1 Like

I don’t see what in this implementation can’t be written in a standard Callback. Your model would be the part x_e @ phi * dummy_w , y_e (taking a lists of xes and a list of yes, returning a list or error_e) then your callback would:

  • do the randperms in on_batch_begin
  • compute the penalty and sum it with all of the errors in on_batckward_begin

and the rest will be done by the fastai library

Thanks for the pointers. I think one thing I didn’t make entirely clear is that each environment in the loop above is a complete dataset. For example:

def example_1(n=10000 , env=1, env_id=1):
    x = np.random.random(n) * env
    y = x + np.random.random(n) * env
    z = y + np.random.random(n)
    # print(x[0], z[0],z)
    df = pd.DataFrame({'env_id': env_id, 'x':x, 'y': y, 'z':z})
    return df

df = pd.concat([example_1(env=.1, env_id=1), example_1(env=.6, env_id=2), example_1(env=1, env_id=3)])

Here, I’ve added a variable env_id to the dataframe to indicate the environment. I need to be able to calculate and sum error_e and penalty for each environment in the inner loop of the training. fastai’s datablock appears to treat the dataset as one monolithic set, other than the split for a validation set. Looking at the callback docs, the only approach I see to do this is to us on_batch_begin to set the batch size to the number of rows in each environment, using num_batch to index the env_id. In the simple example above, each dataset has the same number of rows, but that will not be true in most real-world examples. In addition, for more complex datasets, such as images, each environment will need to be broken in to batches, so overloading batch size is just a temporary fix.