Deep neural networks (DNN) are becoming increasingly sophisticated especially in areas like computer vision, natural language understanding, and speech recognition. With the fast maturity of DNN platforms, availability of data, and computing, DNNs are going to be increasingly prevalent.
Limitations of DNNs
One of the limitations of DNNs is that they are a black box and difficult to interpret. Since Neural Networks involve passing the data through multiple layers of non-linear transformations, it is impossible for a human to follow or understand the logic behind the predictions. But it is necessary, in a business context to understand how the prediction works eg. to avoid racial bias or discrimination. It is also necessary to avoid typical pitfalls with DNNs like adversarial examples in classification and mode collapse in generative modeling. The Research Paper mentions such examples – https://arxiv.org/ftp/arxiv/papers/1903/1903.07282.pdf
Model Interpretability is an area of paramount importance for Deep Neural Networks since these models can achieve high accuracy but at the expense of high abstraction (i.e. accuracy vs interpretability problem). But this does not necessarily make the model trustworthy. A model that cannot be trusted will not be used by practitioners in their field.
By interpretation we mean functional understanding as well as the inner workings or algorithmic understanding and captum helps us understand on these lines.
Using Captum to understand how DNNs work
Captum is an open source, extensible library for model interpretability built on PyTorch.Here we share a technique for interpreting models by using Captum.Captum can be applied to interpret Deep Learning models built using PyTorch only.
Model Interpretability using Captum can be done at 3 different levels:
1. Primary Attribution:
It evaluates the contribution of each input feature to the output of the model. Primary Attribution makes use of Algorithms like Integrated Gradients and Deep Shift for model interpretation.
2. Layer Attribution:
It evaluates the contribution of each neuron in a given layer to the output of the model. Layer Attribution makes use of Algorithms as Layer Conductance and Layer Gradients for a layer level model interpretation.
3. Neuron Attribution:
It evaluates the contribution of each input feature on the activation of a particular hidden neuron. Neuron Attribution makes use of Algorithms as Neuron Conductance and Neuron Gradients for a Neuron Level Attribution.
In this blog, we will demonstrate how to interpret the models at all the 3 levels using a case study.
Prediction of the onset of Diabetes in Indian women with Feed-Forward Neural network model
For this article we have sourced a dataset from PIMA Indian Diabetes database, to predict the chances of Indian women developing diabetes using PyTorch. We have built a Feed-Forward Neural Networks Model using the PyTorch Framework for the prediction.
The dataset sample looks like below
Sample dataset for the onset of diabetes in Indian women
The Dataset contains 768 observations and 8 independent features. The dataset is split into 70:30 training and testing data respectively.
The features used as an input to the model were :
- Pregnancies – The number of pregnancies the woman had.
- Glucose Level – Plasma glucose concentration over 2 hours in an oral glucose tolerance test
- BloodPressure – Diastolic Blood Pressure(mm/Hg)
- SkinThickness – Triceps Skin Fold Thickness(mm)
- Insulin – 2-Hour serum insulin (mu U/ml)
- BMI – Body mass index (weight in kg/(height in m)2)
DiabetesPedigreeFunction – Diabetes pedigree function (a function which scores likelihood of diabetes based on family history)
- Age. – Age of the Woman in years
The Dependent variable or Outcome is to predict whether a patient can develop diabetes. There are only two classes in the variable(Yes – 1 /No – 0).
We have defined the Neural Network Architecture as below:
A feed-forward Neural Network with 4 layers – one input layer, two hidden layers and an output layer. At the end of each layer, a sigmoid activation function is applied. In the output layer, a softmax classifier is applied.
The Neural Network Architecture for Diabetes Prediction
The Neural Network architecture is coded as below:
The model was then made to run on 200 epochs and the output of which is shown below:
The Accuracy of the model on the Testing dataset is 67% as shown below:
After the Training and Testing data split, for the model to predict on the test data, it needs to be converted into a tensor-ready format. Hence, we converted the testing features to testing tensors.
Interpretation of the prediction using Captum
Now that we have built the model, the next step in the process is to understand how the Neural Networks can be more interpretable. Hence we look at the attributions of the Neural Networks that we have built.
There are 3 attributions as mentioned earlier: Primary, Feature and Neuron Attribution.
Primary Attribution: The first main component of Captum is understanding the Primary attributions.It evaluates the contribution of each input feature to the output of the model For this purpose, we apply Integrated Gradients Algorithm to arrive at the primary attributions.
In the code snippet below, we apply the Integrated Gradients Algorithm to arrive at the important attributes at the feature level and visualize the output.
Visualization of the Output:
Thus, from the plot above, we infer that Blood Pressure and Glucose were the major contributors to the diabetes prediction. The positive values in the weight imply that the higher the value of the Glucose level, the more are the chances of Diabetes occurrence. The negative value of the Blood Pressure implies that if the BP level drops, there are higher chances of Diabetes Prevalence.
Layer attributions allow us to understand the importance of all the neurons in the output of a particular layer.
To use Layer Conductance, we create a LayerConductance object passing in the model as well as the module (layer) whose output we would like to understand. In this case, we choose the output of the first hidden layer.
The code below is used to implement the Layer Conductance Algorithm at the First Hidden Layer Level (net.sigmoid 1 is the output of the hidden layer) and visualize the output.
We can infer from the above plot that Neuron 3 learns the substantial features. Neurons 0 and 2 are almost similar in terms of its ability to interpret substantial features.
This allows us to understand what parts of the input contribute to activating a particular input neuron. For this example, we will apply Neuron Conductance, which divides the neuron’s total conductance value into the contribution from each individual input feature.
To use Neuron Conductance, we create a NeuronConductance object, analogously to Conductance, passing in the model as well as the module (layer) whose output we would like to understand, in this case, the output of the first hidden layer, as before.
Code for interpreting Neuron 0:
The code snippet below is on the application of the Neuron Conductance Algorithm to interpret Neuron 0’s weights and attribution on the first hidden layer
From the data above, it appears that the primary input feature used by neuron 0 is Blood Pressure and Glucose, with limited importance for all other features.
Code for interpreting Neuron 3:
The code snippet below is on the application of the Neuron Conductance Algorithm to interpret Neuron 3’’s weights and attribution on the first hidden layer
Both these neurons ( 0 and 3) learn substantial features from the model as Glucose and Blood Pressure.
In this blog, we used a Healthcare AI use-case to demonstrate how to interpret Neural Networks Models using Captum PyTorch Models. Model Interpretability is an area of research and currently Captum supports only PyTorch Models. Complex Neural Networks can now be interpreted using this library. Once the Model Interpretability Analysis is done, if the attributions don’t explain the feature importance – the architecture can be changed – Number of Neurons for a layer in the existing architecture can be changed or new layers can be added and the Model Interpretability Analysis can be done.