Deciphering Explainable AI, with GPT-2

Using advanced text transformers with the ELI5 library

Two weeks ago, I wrote about classifying text using only SciKit-Learn, and used Explainable AI library ‘ELI5’ to visualize how I was sorting Tweets into good and bad. This post revisits that with one significant change: upgrading the word tokenization / vectorization step.

Quick refresher: machine learning systems use multi-dimensional arrays of numbers. In this context, these arrays can be called vectors or tensors. It’s fairly easy to convert images into machine-readable data — start with a two-dimensional grid of pixels, and add more dimensions for colors.

For text data, it’s more complicated. Previously we used the SciKit-Learn method of counting words (which ELI5 directly supports). In the past few years, there has been a generation of projects using pre-trained word vectors (includingword2vec, GloVe, and FastText), which have learned word associations in dozens of languages on a massive scale.
Now researchers are competing to build a new generation of word transformers, including OpenAI’s GPT-2, Google’s BERT and XLNet, and Facebook’s XLM. I don’t have a solid understanding of what makes these so much better, but do check out this illustrated post. Most are English-only, but you can find other language content, especially for Chinese and German.

It’s a lot of technical options and details, but to sum it up, the word-vectorization step and the machine learning model step are independent of each other. When we choose better word-vector models, the text is transformed into vectors with more dimensions and more meaningful values, which makes the model more accurate.

GPT-2

PyTorch-Transformers is a uniform interface for several of the latest text transformers. It downloads the models and pre-trained vectors for us.

Let’s pass the text through GPT-2 (medium size), capture the word vectors directly instead of using PyTorch, and recreate X and Y data for SciKit-Learn models:

Tweets have multiple words, and the model returns a 1024-length array for each word. SciKit expects a one-dimensional array, so how can we fix this? You’ll see that I created one average_word_vector for each Tweet — a common practice which might be hard to wrap your head around (what is the average word in this sentence?) but several tutorials and papers followed this practice.

This classifier jumped in accuracy (particularly for LogisticRegressionCV) compared to any of my previous runs with only CountVectorizer.

ELI5 and GPT-2

We previously used ELI5 to highlight words with colors, showing words’ contribution toward a ‘good’ or ‘toxic’ score. That was relatively easy because we used text utilities which are built into SciKit-Learn and specifically supported by ELI5.

GPT-2 is a whole other system, but luckily ELI5 has a solution:

If a library is not supported by eli5 directly, or the text processing pipeline is too complex for eli5, eli5 can still help — it provides an implementation of LIME (Ribeiro et al., 2016) algorithm which allows to explain predictions of arbitrary classifiers, including text classifiers.

I went to make a journal demo for this post. Thegpt2-medium model was large enough that NextJournal appeared to prevent the download, so I downgraded to small. Then I noticed that ELI5 LIME / TextExplainer expects data to flow from a pipeline from SciKit-Learn, meaning I should reorganize my first clumsy script.

full embed as this was too many lines for a proper embed, I guess?

It takes minutes for ELI5’sTextExplainerto try out multiple tweaks to the sentence. The slowdown is majorly tied up in vectorization, so there are likely things I could do here to speed it up.

A sample of text mutations sent from TextExplainer to the vectorizer

Then we get some numbers (probability and score) to evaluate how good the prediction is, and highlighted words:

A different post:

Just like in the previous post, we see words (such as ‘fake’ and ‘news’) which are contributing to the toxic / ‘known weird’ classification, and some words (such as ‘dictators’) which count against toxic classification. There are additional parameters for TextExplainer which also test word order, etc.

Takeaways

  • ELI5 continued to be a good resource for Explainable AI.
  • NextJournal is the real MVP for downloading GPT-2 much faster than I can in Vietnam https://nextjournal.com/mapmeld/eli5-and-gpt-2
  • SciKit-Learn pipelines were a little annoying at first, especially making my own class to wrap vectorization code, but it made my code more logical and structured.
  • PyTorch-Transformers is an awesome module which made it relatively simple to drop in this advanced pre-trained model.

Nomadic web developer and mapmaker.