Feature importance in deep learning

FYI, here’s how I got at least the model-agnostic KernelExplainer from shap to run in a notebook without errors on a tabular learner (model/data on gpu with both categorical and continuous variables):

# learn = tabular_learner(...)
# learn.fit_one_cycle(...)

import numpy as np
import pandas as pd
import shap
shap.initjs()

def pred(data):
    device = learn.data.device
    cat_cols = len(learn.data.train_ds.x.cat_names)
    cont_cols = len(learn.data.train_ds.x.cont_names)
    x_cat = torch.from_numpy(data[:, :cat_cols]).to(device, torch.int64)
    x_cont = torch.from_numpy(data[:, -cont_cols:]).to(device, torch.float32)
    pred_proba = learn.model(x_cat, x_cont).detach().to('cpu').numpy()
    return pred_proba

def shap_data(data):    
    X_train, y_train = data.one_batch(denorm=False, cpu=False)
    X_test, y_test = data.one_batch(denorm=False, cpu=False)
    cols = data.train_ds.x.col_names
    X_train = pd.DataFrame(np.concatenate([v.to('cpu').numpy() for v in X_train], axis=1), columns=cols)
    X_test = pd.DataFrame(np.concatenate([v.to('cpu').numpy() for v in X_test], axis=1), columns=cols)
    return X_train, X_test

X_train, X_test = shap_data(learn.data)
e = shap.KernelExplainer(pred, X_train)
shap_values = e.shap_values(X_test, nsamples=100, l1_reg=False)
shap.force_plot(e.expected_value[0], shap_values[0], X_test)

shap.summary_plot(shap_values, X_test, plot_type="bar")

This grabs two batches from the training set as X_train and X_test for shap.

2 Likes

Sadly google colab doesn’t support the Javascript library it looks like :frowning:

have you seen this: https://github.com/slundberg/shap/issues/279#issuecomment-427240107. the js should work, just needs to be initialised in every cell that produces a visual output.

1 Like

Hi all. Has ANYONE gottten SHAP DeepExplainer to work with FastAI Tabular DataBlock? It seems the formats expected by SHAP are PyTorch primitives and different of course than FastAI wrappers. SHAP seems to be a wonderful approach for some interpretability of the NN. Or is there any upcoming extension to FastAI to provide this sort of functionality. It seems fairly easy to do with PyTorch and Keras as well but if anyone has this working with FastAI please let me know. Thanks!

I’ve ported shap to fastai2: https://github.com/muellerzr/fastinference

6 Likes

Was the subject of SHAP assumptions about the feature independence discussed?
In tabular data it is quite an issue.

Have any of you read papers behind GradientExplainer and/or DeepExplainer and come to clear conclusion that we can use them without the independence assumption?

In Some Baselines for other Tabular Datasets with fastai2 we slit the topic to model interpretation, so I am bringing the conversation here, to make it easier to find and maybe more people will express their thoughts on the subject.

This is a response to the question How to conduct feature importance without assumption of feature independence?

We can extract FI without assuming feature independence with the attention from TabNet, SHAP explainers which do not use interventions (which might be GradientExplainer - sorry I haven’t checked that yet), @nestorDemeure’s idea or in general methods learning the feature importance during the training. In the future maybe the next update of SHAP will be resistant to the problem, because in the last paper (1911.11888) they describe improvements, but it’s not clear to me.

Any other thoughts?

1 Like

@hubert.misztela I guess the one question I still have is:

Do we still have this independence issue with our permutation importance (with the raw values not SHAP) I’d assume we would but I just want to be sure. Because if so then our FI tells us absolutely nothing then, no?

If features are correlated, then permutation importance can give biased results. In Interpretable Machine Learning, Christoph Molnar mentions discusses this in the feature importance chapter, particularly in the disadvantages section.

Unless otherwise stated, I’d expect the assumption of feature independence to be a requirement in any method that involves holding some features constant while modifying other features. TreeSHAP doesn’t always have this assumption, although it looks like certain output types do require setting feature_dependence=”independent”.

In regression analysis, feature independence (or in statistics terms: a lack of multicollinearity between independent variables, predictors, or covariates) is usually a required assumption as we are interested in interpreting the coefficients of the covariates. There are multiple methods for detecting multicollinearity, which we could use to check on our data.

Depending on the circumstances, multicollinearity isn’t always a problem. For example, through feature engineering or domain knowledge we might have a model whose inputs include age and age_squared, which by definition will be correlated with each other. In regression analysis we would always interpret the two coefficients together and never independently. For our tabular neural networks, we’d want to do the same, so perhaps we’d modify permutation importance to always permutate age and age_squared together. Likewise if we have a interaction term or other combinations of features.

Small amounts of multicollinearity between features we want to be independent might not be completely problematic [1]. The real world is messy, and practitioners don’t always have ideal data. Unfortunately, there are no hard and fast rules on what counts as acceptable multicollinearity, but various rules of thumb. An example, if we are modeling children’s health age, weight, and height are probably going to be correlated with each other. But if the correlation isn’t too large, we can still look at their feature importance assuming we are careful in our reporting and interpretation, and if we recognize and acknowledge that our results might be biased. Or, depending on the method used, we could treat them as control variables and limit our analysis to other features.

From a statistical practitioner’s perspective, if you want to interpret the feature importance of tabular neural networks I’d recommend this non-exhaustive list:

  1. Start by plotting a pairwise plot and correlation matrix of all the data. This is more of an eyeball test for collinearity, as it can only reveal pairwise correlation, not multicollinearity.
  2. Normalize the data. Data normalization can remove certain types of collinearity. Keep in mind that domain knowledge might suggest something other than straight normalization. For example, when working with economics data the natural log of income often more useful for interpretation than normalized income.
  3. Run at least one multicollinearity test. Preferably multiple. A non-exhaustive list of options includes variance inflation factor (VIF), the Farrar–Glauber test, perturbing the data, and conditional number test. Of these, only VIF appears to have a python implementation in statsmodels, the rest have R packages. Be careful with VIF in statsmodels, as it appears by default it doesn’t include a constant term so you’ll need to add another column to your data filled with ones.
  4. Remember that some forms of multicollinearity are not deal breakers if handled correctly. This would depend on what type collinearity and what type of feature importance analysis being applied.
  5. Even if all the statistical tests look good, there could still be undetected multicollinearity. So always be careful when presenting results.

Keep in mind that even with completely independent features, there are other factors that could bias feature interpretation. Some examples include omitted-variable bias and dealing with repeated measurements from a longitudinal study (measuring patients over time) or measurements made on clusters of related items (studying students in schools).

Any feature importance package, or addon to fast.ai, should clearly mention the assumption of feature independence if required.

Let me know if you have any questions or corrections.


  1. Somewhere there is probably a statistics theorist disagreeing with this statement :blush: ↩︎

6 Likes

Thank you both, this has helped me understand why TabNet is such a big thing (and a better understanding on the bigger issues with a FCNN for tabular interpretibility ). I appreciate the thorough thought into both, it’s given me much to think about :slight_smile:

Here is a idea I suggested in the other thread:

Given a trained model f(x0,...,xn) , we might be able to learn the feature importance by simulating the attention.

For each feature xi , we can associate an uninformative value mi (the mean or the mean of the embedding). We can then create a dampened feature xi' = ai*xi + (1-ai)*mi with the sum of all ai equal to 1 (this can be enforced by a sofmax).

We now maximize the accuracy of f(x0',...,xn') by optimizing the ai , these are our importances.
(that’s a way to ask to network to make a decision : which features can be dropped and which features bring information to the plate ?)

There is one hidden parameter here which is the fact that the ai sum to one : I see no particular reason to use 1 and not another number between 1 and n-1 (given n features).
Let’s call this parameter k, in a way it is what we think is the number of important features (my gut feeling would be that sqrt(n) should be a good default).

I think that another way to deal with k would be to sum each weight over all possible values of k from 1 to n-1.

@hubert.misztela To answer you questions, I have not implemented it (I had the idea as I was typing it and I don’t have enough time in the short term to build a prototype) and this would indeed be about feature importance and not sample specific quantities.
I believe that, as it is written, it would focus on features important for most samples and not features rarely important.
It might even be possible to add self attention to select different features for different samples. We would then have all the interpretability properties of an attention based model but for arbitrary tabular models.

@Bwarner’s warnings are important but do note, however, that the vast majority of methods suppose feature independence (which is false more often than not) and that they work well enough nevertheless.

(I work with people doing exactly that when analysing simulation codes)

1 Like

Hi there,
I think you might get some idea in This website below can be usefull for your question.

I had the same issue before. I got some aspect here.

Good luck.

Moved this here as it’s more relevent:

On the topic of attention, this was recently published:

And the authors claim it’s competitive :slight_smile: (The model itself is in SAN.py :wink: )

3 Likes

I just extracted @muellerzr’s permutation feature importance snippet into a dedicated repository.

Doing so I found a nicely documented book on the state of the art of interpretability and machine learning (which details the theory, pro, con, and references implementations of various methods) :

5 Likes

Is there a library support partial dependence plot for neural network? (e.g. integration with PyTorch)

I’d check out chapter 9 of the fastbook. It describes the general methodology for partial dependence and then you should be able to just integrate it in similar to how we have done for the feature importance (IE make all values in column equal to A from n to n+t where n is the starting value and t is the ending value. It can be done fairly straightforward from there. (Of course then you have the question of is it “truly” representative, same argument as feature importance in this thread, so take it with a grain of salt)

I understand it is a general method (model agnostic), but seems most library heavily integrate it with tree models. I do not see people are using it for neural network, there are generally too much focus on SHAP IMO.

The work with SHAP seems very interesting, but the link gives me error 404. Has the repo been moved or renamed?

It’s now in my fastinference library :slight_smile:

Docs: muellerzr.github.io/fastinference

3 Likes

You can look into my version of partial dependence implementation for fastai2 https://github.com/Pak911/fastai2-tabular-interpretation

5 Likes