Over the last decade, the use of machine learning has grown astoundingly and seen widespread adoption across all industries. It has quickly become integrated into our lives through features such as chatbots, natural language processing, voice recognition, forecasting, and document classification, to name a few areas.
One reason behind machine learning's popularity is its ability to learn non-latent patterns and relationships from data without hard-coding fixed rules – helping achieve high levels of accuracy and performance across many applications. By the nature of these models, even though they are achieving high levels of precision, it can be difficult to fully understand how the predictions are made.
Moreover, for specific problems and domains, the ability to explain these predictions is a precursor to delivering excellent business value. We often see data scientists use traditional models (ex., linear regression) and avoid state-of-the-art methods due to their inherent complexity and lack of explainability, which leads to poor performance. The goal of explainable artificial intelligence (AI) is to address this problem and many others.
What is explainable AI?
We need machine learning models to make predictions and be transparent in their explanations of these predictions. Explainable AI, or XAI, is essential research to enable users and stakeholders to interpret and understand how a machine learning model is making its prediction. By implementing explainable AI, we can answer questions such as:
- Why does the model predict that result?
- What are the reasons for this prediction?
- What are the most vital contributors to the prediction?
- How does the model work?
By being able to answer these questions, we are not only able to have an interpretable explanation of a prediction for business users and stakeholders, but now we can start building trust in these models by being able to show how and why the predictions are made at a local (individual prediction) and global (all predictions) level. By understanding how predictions are made, we better know how the model should act when predicting the population or unseen data.
By seeing how machine learning models rank the importance of features and what governs their predictions, we can begin to see if bias has been introduced to the model at some point. Leveraging explainable AI can bring transparency, trust, safety, and bias detection to your machine-learning models.
How do explainable AI algorithms work?
There are many explainable AI algorithms, and while they all try to explain models, how they do it is slightly different by the algorithm. While there are many, I will review three different algorithms as examples of how they work:
- Shapley additive explanations
- Local interpretable model-agnostic explanations
- Integrated gradients
What is the shapley additive explanations (SHAP) algorithm?
SHAP is a trendy way to explain a model. SHAP tries to take a game theory approach to explain models by quantifying each feature's contribution to the prediction made by the model.
One way to think of this is with a sports analogy. Suppose you were a huge soccer fan and continuously watched your favorite team’s games by carefully studying when players were substituted throughout the game and how this impacted performances. In that case, you can soon deduce the performance of each player. SHAP will do this hard work for you by carefully seeing how predictions change as it explores all possible predictions; SHAP soon learns how all the features impact the model and assigns them a SHAP value.
One of the great advantages of SHAP is that by looking at all the possible cases, SHAP provides explanations for your model for an individual prediction (local level) as well as explanations for the entire feature over all possible values (global level). More importantly, these local and global explanations will coincide or provide a united explanation.
How does the local interpretable model-agnostic explanations (LIME) algorithm work?
Since SHAP has to go through all the permutations of the data to create its explanation, this can become computationally heavy and time-consuming. LIME addresses this issue. For any given instance trying to be predicted, LIME will create a sample of data points around the data point you are trying to predict and then get their predictions.
Next, weighting this sample by the proximity to the instance, LIME builds a linear regression model. It uses the coefficients from the model to determine the impact of the features on the prediction. It is vital to note that since this explanation is built only on a sample of data by the instance trying to be explained, LIME is not globally faithful. So, you can think of LIME as building sparse linear models around individual instances/predictions based on data points in its vicinity.
Understanding the integrated gradients algorithm
Another method for explaining machine learning models is integrated gradients. Integrated gradients aim to describe a model's prediction by explaining the difference in predicting an instance from a baseline or masked instance in terms of input features to the model. In other words, integrated gradients start with a completely empty baseline (all the features are gone; or zero) and slowly turn on the features one at a time, keeping track of the change in prediction. By doing this, they can identify:
- When the changes were the largest
- The direction of the changes
- The key features in a model
- The impact on the prediction
A simple example of this is to think in text classification. If we started with a sentence with no words, the probability of it being classified into group A would be zero since there are no words. However, as we slowly add words, we can track how the probability changes and thus identify which words impacted the prediction into group A.
Another example would be image classification. If we took an image and turned the brightness all the way down until there was nothing left and slowly started bringing the brightness back, we could see as certain pixels begin re-emerging how the prediction probability changes and thus determining what pixels are impacting the prediction. This method is well suited for deep learning models and offers much faster computations than SHAP values. However, the model must be differentiable (gradients must be present).
Explainable AI use cases
As machine learning models continue to grow in complexity, applications of explainable AI are limitless. This can be used in cost forecasting, such as how having a specific chronic illness affects the length of stay for a patient having an ER visit. Equally, this model could be used in classification; what is the effect of an interest rate on the likelihood of someone defaulting on their loan?
Additionally, explainable AI can be implemented on more than just tabular data. For example, in text classification, let's say you had a model to determine what department to send an email to for follow-up; what keywords were seen by the model to send to a certain department? In image classification, if you were trying to detect a pool in the backyard by aerial photography, what in the image did the model see to think there was a pool? While these were only a small number of examples, it can be seen that regardless of the type of problem, by implementing explainable AI, transparency, trust, and safety can be quickly increased for your model.
Which explainable AI algorithm is best?
The short answer is that it depends. Each algorithm has advantages and disadvantages; while all the algorithms try to reach the same goal of explaining your model, they all do it slightly differently. Thus, depending on your situation, one might work better.
For example, since LIME is focused on looking at data points around the point trying to be predicted rather than the whole data set, like SHAP, if resources are limited, or the prediction needs to be very quick, LIME could be the better alternative. Conversely, if you need a global explanation that is unified with the local prediction, then SHAP would suit your application. In short, the best algorithm for you can be determined only after truly understanding your data, model, and situation.
Take away
By leveraging explainable AI, we can have the same, if not better, explainability of machine learning compared to traditional statistical models but with much higher accuracy. Building explainable AI on top of a machine learning model allows the business to see how predictions are made and trust the machine learning model. This allows many businesses that need a high level of explainability to try new machine-learning algorithms instead of being trapped with traditional statistical models.
Suppose a business is already utilizing machine learning. In that case, explainable AI can offer the ability for data scientists to work alongside subject matter experts to review what is impacting the model and detect potential flaws. In addition, now businesses can better understand what a model will do when seeing new data and detect bias in the model. Thus, regardless of where you are in your AI journey, explainable AI can help you along the way.
Connect with a CGI expert to discuss how implementing AI and model explainability can help automate, forecast, and accelerate savings in your organization or explore our health services and solutions.