Text classification conflicts and hierarchical models
Pre-requisites
- We will use this dataset for our study, taken from this repository.
- Some experience with Python
Story
In my experience of building classifiers for commercial applications, the problems I face often apart from:
- Inadequate data.
- Noisy data.
- Incorrectly tagged data.
- Data with poor representation of the actual problem.
is classification labels with conflicting vocabulary. If you don't know what that means, here are a few examples:
# Examples:
| sentences | tag |
|--------------------------------------------------------------------|----------------------|
| No I didn't say play music | oos |
| I said play music | play_music |
| No I said play music | oos |
| What is my credit card balance | credit_card_balance |
| What is my credit card limit | credit_card_limit |
| I just paid from my credit card is my transaction confirmed? | payment_confirmation |
| Is my booking confirmed? I did my transaction with my credit card. | booking_confirmation |
If we run classifiers on such a dataset, we may, with a lot of effort get to a certain degree of precision but model's prediction will reflect the confusion in the dataset. Keeping the model aside, tagging such data-points has shown us more than 20% tagging errors as the aspect of general language the taggers understand is not enough to tag business specific labels. This leads to poor model accuracy, spending more time than necessary on tagging, which translates to frequent train-test-analyse-deploy cycles.
So how do we solve this?
Hierarchical models
Since the problem here is shared vocabulary, let's try to share the labels as well. A design angle also drops by in this case classifiers generally have a fallback label with triggers like:
- Predicted label doesn't cross the threshold. (Threshold too high)
- Couple of labels cross the threshold with confidence scores close by. (Threshold too low)
- The fallback label gets predicted.
Since we have data that is sharing vocabulary, it means that predicted labels will also share confidence. Making it hard to cross the threshold.
Intuition
Let's say we have labels present in our data, we will assess if they can be broadly classified into a set of labels where:
Let's assume there is just one degree of ascension, i.e. each label in the original data-set has only one generic label (many to one mapping). In this case we will train two models, one as usual with data as-is and another with the generalized labels.
Data analysis
We will explore the dataset linked above. We will try to see all the labels, come up with possible generalizations (degree=1).
# - Analyse tags
In [1]: import json
In [2]: import pydash as py_
In [3]: with open("voice_commands/data_full.json", "r") as f:
...: full_data = json.load(f)
In [4]: full_data.keys()
Out[4]: dict_keys(['oos_val', 'val', 'train', 'oos_test', 'test', 'oos_train'])
In [5]: training_data = full_data["train"] + full_data["oos_train"]
In [6]: label_groups = py_.group_by(training_data, lambda el: el[1])
In [7]: label_dist = {k: len(v) for k, v in label_groups.items()}
There are 151 labels, each with 100 data points (a very balanced dataset!). Since the list would make it really long to print, we will use tabulate
.
# - Pretty print a long list of labels. [horizontally scrollable!]
In [8]: from tabulate import tabulate
In [9]: labels = py_.chunk(list(label_dist.keys()), 6)
In [10]: print(tabulate(labels))
--------------------- -------------------- ------------------------- ------------------ ------------------- ---------------------- ---------------------- -----------------------
accept_reservations account_blocked alarm application_status apr are_you_a_bot balance bill_balance
bill_due book_flight book_hotel calculator calendar calendar_update calories cancel
cancel_reservation car_rental card_declined carry_on change_accent change_ai_name change_language change_speed
change_user_name change_volume confirm_reservation cook_time credit_limit credit_limit_change credit_score current_location
damaged_card date definition direct_deposit directions distance do_you_have_pets exchange_rate
expiration_date find_phone flight_status flip_coin food_last freeze_account fun_fact gas
gas_type goodbye greeting how_busy how_old_are_you improve_credit_score income ingredient_substitution
ingredients_list insurance insurance_change interest_rate international_fees international_visa jump_start last_maintenance
lost_luggage make_call maybe meal_suggestion meaning_of_life measurement_conversion meeting_schedule min_payment
mpg new_card next_holiday next_song no nutrition_info oil_change_how oil_change_when
oos order order_checks order_status pay_bill payday pin_change play_music
plug_type pto_balance pto_request pto_request_status pto_used recipe redeem_rewards reminder
reminder_update repeat replacement_card_duration report_fraud report_lost_card reset_settings restaurant_reservation restaurant_reviews
restaurant_suggestion rewards_balance roll_dice rollover_401k routing schedule_maintenance schedule_meeting share_location
shopping_list shopping_list_update smart_home spelling spending_history sync_device taxes tell_joke
text thank_you time timer timezone tire_change tire_pressure todo_list
todo_list_update traffic transactions transfer translate travel_alert travel_notification travel_suggestion
uber update_playlist user_name vaccines w2 weather what_are_your_hobbies what_can_i_ask_you
what_is_your_name what_song where_are_you_from whisper_mode who_do_you_work_for who_made_you yes
--------------------- -------------------- ------------------------- ------------------ ------------------- ---------------------- ---------------------- -----------------------
These labels suggest we have a very chatbot-like data. If you want, at this point you can pause reading and try generalizing these labels on your own.
Programmatic Assessment
There are 151 intents, getting a sense of them would require going through each label, revisit, sometimes even check what a few sentences look like. That's pretty slow. Let's use some tricks to get there faster. We will first save the vocabulary of our training data, i.e. every unique word and its frequency. Once that is done, we can check how correlated are these labels with one another.
# - Please don't use the functions here in production!
In [11]: from sklearn.feature_extraction.text import TfidfVectorizer
In [12]: tfidf = TfidfVectorizer(stop_words="english", lowercase=True)
In [13]: sentences, labels = zip(*training_data)
In [14]: labels = list(set(labels))
In [15]: tfm = tfidf.fit_transform(sentences)
In [16]: vocabulary_ = tfidf.vocabulary_
In [17]: len(vocabulary_)
Out[17]: 5092
In [18]: vocabulary_set = set(vocabulary_)
In [19]: def make_label_maps():
...: label_vocab_map = {}
...: for label, content in label_groups.items():
...: label_vocab_map[label] = set()
...: label_sentences, _ = zip(*content)
...: for word in vocabulary_set:
...: for sentence in label_sentences:
...: if word in sentence:
...: label_vocab_map[label].add(word)
...: return label_vocab_map
In [20]: label_vocab_map = make_label_maps()
In [21]: def create_label_correlation(correlation_threshold=10):
...: correlation = []
...: for i in tqdm(range(len(labels))):
...: alabel = labels[i]
...: for j in range(i, len(labels)):
...: otherlabel = labels[j]
...: if alabel != otherlabel:
...: score = len(label_vocab_map[alabel].intersection(label_vocab_map[otherlabel]))
...: if score > correlation_threshold:
...: correlation.append((alabel, otherlabel, score))
...: return correlation
In [22]: correlation = sorted(create_label_correlation(), key=lambda el: el[1], reverse=True)
In [23]: len(correlation)
Out[23]: 11325
pardon the murder of functional concepts and poor programming for the sake of speed
That's a lot of data points to scan through! Let's try that again, let's include ngram-range=(1,3)
in TfidfVectorizer(stop_words="english", lowercase=True, ngram_range=(1,3))
and pruning vocabulary off words less than 3 characters. Also, we will use correlation_threshold=40
create_label_correlation
.
# -
In [24]: correlation = sorted(create_label_correlation(40), key=lambda el: el[1], reverse=True)
100%|███████████████████████████████████████████████████████████| 151/151 [00:00<00:00, 4174.51it/s]
In [25]: len(correlation)
Out[26]: 127
That's better, let's see what we have here.
# -
In [27]: correlation_chunked = py_.chunk(correlation, 3)
In [28]: print(tabulate(correlation_chunked))
----------------------------------------------------- ------------------------------------------------------ ------------------------------------------------------
('recipe and ingredients_list', 85) ('confirm_reservation and restaurant_reservation', 80) ('car_rental and book_flight', 79)
('book_flight and book_hotel', 78) ('car_rental and book_hotel', 75) ('todo_list_update and todo_list', 71)
('directions and distance', 69) ('credit_limit and credit_limit_change', 68) ('accept_reservations and restaurant_reviews', 66)
('calendar and calendar_update', 64) ('rewards_balance and redeem_rewards', 64) ('nutrition_info and calories', 63)
('restaurant_reservation and cancel_reservation', 61) ('balance and interest_rate', 61) ('carry_on and book_flight', 59)
('freeze_account and account_blocked', 58) ('meal_suggestion and restaurant_suggestion', 58) ('order_checks and interest_rate', 57)
('recipe and cook_time', 56) ('international_visa and vaccines', 55) ('credit_limit and apr', 55)
('international_visa and plug_type', 54) ('routing and interest_rate', 54) ('rewards_balance and credit_limit', 54)
('confirm_reservation and cancel_reservation', 54) ('ingredients_list and cook_time', 54) ('routing and balance', 53)
('uber and distance', 53) ('order_checks and routing', 52) ('nutrition_info and recipe', 51)
('todo_list_update and calendar_update', 51) ('lost_luggage and book_flight', 51) ('update_playlist and play_music', 51)
('pto_request and book_hotel', 51) ('calories and recipe', 51) ('food_last and cook_time', 51)
('order_checks and pin_change', 50) ('international_visa and travel_alert', 50) ('international_visa and travel_notification', 50)
('nutrition_info and food_last', 50) ('lost_luggage and carry_on', 50) ('travel_suggestion and restaurant_suggestion', 50)
('book_hotel and restaurant_suggestion', 50) ('order_checks and balance', 49) ('how_busy and restaurant_reviews', 49)
('redeem_rewards and apr', 49) ('account_blocked and balance', 49) ('pay_bill and interest_rate', 49)
('expiration_date and application_status', 49) ('order_checks and direct_deposit', 48) ('report_fraud and transactions', 48)
('rewards_balance and apr', 48) ('directions and book_hotel', 48) ('nutrition_info and ingredients_list', 47)
('report_fraud and redeem_rewards', 47) ('pin_change and interest_rate', 47) ('rewards_balance and balance', 47)
----------------------------------------------------- ------------------------------------------------------ ------------------------------------------------------
We can see from label names that we have a reasonable confusion matter. Let's keep trying to solve for vocab with len(word) >= 3
.
In [29]: def group_correlations(correlation):
...: correlated_group = {}
...: for item in correlation:
...: if item[0] not in correlated_group and item[1] not in correlated_group:
...: correlated_group[item[0]] = [(item[0], item[2]), (item[1], item[2])]
...: elif item[0] in correlated_group:
...: correlated_group[item[0]].append((item[1], item[2]))
...: elif item[1] in correlated_group:
...: correlated_group[item[1]].append((item[0], item[2]))
...: return correlated_group
In [30]: correlated_group = group_correlations(correlation)
At this point for each intent we have easy to visualize overlaps but we can do better than this. So far we have only seen correlations, and not the differences which could be useful. Let's try to see the most common words in the vocabulary.
In [30]: def create_label_blacklist():
...: blacklist = []
...: for i in tqdm(range(len(labels))):
...: alabel = labels[i]
...: for j in range(i, len(labels)):
...: otherlabel = labels[j]
...: if alabel != otherlabel:
...: blacklist.extend(list(label_vocab_map[alabel].intersection(label_vocab_map[otherlabel])))
...: return blacklist
In [31]: blacklist_ctr = Counter(blacklist)
In [32]: from collections import Counter
In [33]: blacklist_ctr = Counter(blacklist)
In [34]: blacklist_ctr.most_common(10) Out[33]:
[('need', 10296),
('nee', 10296),
('tel', 8385),
('tell', 7260),
('ate', 7021),
('know', 6786),
('want', 6670),
('like', 6670),
('ill', 5995),
('let', 4950)]
Discard the top 500 words from the vocab.
# - `if x not in ...` works faster on sets and dicts.
In [34]: filter_vocab = set([w[0] for w in blacklist_ctr.most_common(500)])
In [35]: vocabulary = [word for word in vocabulary_ if len(word) >= 3 and word not in filter_vocab]
In [36]: vocabulary_set = set(vocabulary)
In [37]: label_vocab_map = make_label_maps()
We still have 41983 words in vocabulary but a lot of common words are removed. We could also do this using TfidfVectorizer
's max_df
param.
# Let's modify our correlation function to show the diff vocab as well
In [38]: correlation = sorted(create_label_correlation(10), key=lambda el: el[2], reverse=True)
In [39]: len(correlation)
Out[40]: 200
In [41]: def create_label_correlation(threshold=10):
...: correlation = []
...: for i in tqdm(range(len(labels))):
...: alabel = labels[i]
...: for j in range(i, len(labels)):
...: otherlabel = labels[j]
...: if alabel != otherlabel:
...: corr = len(label_vocab_map[alabel].intersection(label_vocab_map[otherlabel]))
...: diff = len(label_vocab_map[alabel].symmetric_difference(label_vocab_map[otherlabel]))
...: if corr > threshold:
...: correlation.append((alabel, otherlabel, corr, diff))
...: return correlation
In [42]: print(tabulate(py_.chunk(correlation, 3)))
------------------------------------------------------- ---------------------------------------------------------- ------------------------------------------------------------
('recipe', 'ingredients_list', 66, 408) ('confirm_reservation', 'restaurant_reservation', 60, 269) ('accept_reservations', 'restaurant_reviews', 46, 408)
('book_flight', 'book_hotel', 43, 524) ('recipe', 'cook_time', 42, 436) ('car_rental', 'book_flight', 41, 412)
('nutrition_info', 'calories', 40, 296) ('todo_list_update', 'todo_list', 40, 287) ('update_playlist', 'play_music', 40, 387)
('directions', 'distance', 38, 361) ('car_rental', 'book_hotel', 35, 436) ('uber', 'distance', 35, 344)
('calendar', 'calendar_update', 33, 289) ('restaurant_reservation', 'cancel_reservation', 33, 259) ('nutrition_info', 'food_last', 32, 284)
('ingredients_list', 'cook_time', 32, 358) ('carry_on', 'book_flight', 31, 426) ('pto_request', 'book_hotel', 31, 382)
('calories', 'recipe', 31, 464) ('confirm_reservation', 'cancel_reservation', 30, 280) ('calories', 'ingredients_list', 30, 368)
('credit_limit', 'credit_limit_change', 30, 159) ('freeze_account', 'account_blocked', 29, 152) ('calories', 'cook_time', 29, 350)
('meal_suggestion', 'restaurant_suggestion', 29, 331) ('nutrition_info', 'ingredients_list', 28, 334) ('accept_reservations', 'confirm_reservation', 28, 433)
('ingredient_substitution', 'calories', 28, 464) ('international_visa', 'vaccines', 27, 217) ('nutrition_info', 'recipe', 27, 434)
('food_last', 'cook_time', 27, 326) ('international_visa', 'plug_type', 26, 282) ('car_rental', 'restaurant_suggestion', 26, 374)
('uber', 'traffic', 26, 348) ('accept_reservations', 'restaurant_reservation', 26, 422) ('calories', 'food_last', 26, 334)
('international_visa', 'travel_alert', 25, 239) ('international_visa', 'travel_notification', 25, 250) ('lost_luggage', 'carry_on', 25, 303)
('directions', 'uber', 25, 367) ('book_flight', 'pto_request', 25, 382) ('how_busy', 'restaurant_reviews', 25, 331)
...
------------------------------------------------------- ---------------------------------------------------------- ------------------------------------------------------------
That looks decent. We have all these overlapping labels but all of them seem to have sufficient vocab to distinguish from each other on a one-on-one basis.
We can have something even more useful. What if we could see label-level vocabulary? We could know the set of words that only exist for the label.def
In [51]: def create_label_vocab():
...: label_vocab = {}
...: for i in tqdm(range(len(labels))):
...: alabel = labels[i]
...: label_vocab[alabel] = None
...: vocab = copy.deepcopy(label_vocab_map[alabel])
...: for j in range(i, len(labels)):
...: otherlabel = labels[j]
...: if alabel != otherlabel:
...: vocab = vocab.difference(label_vocab_map[otherlabel])
...: if vocab:
...: label_vocab[alabel] = vocab
...: return label_vocab
...:
In [52]: print(tabulate(py_.chunk(sorted([(k, len(v)) for k, v in label_vocab.items()], key=lambda el: el[1], reverse=True), 3)))
--------------------------------- ------------------------------- --------------------------------
('oos', 370) ('book_hotel', 254) ('ingredient_substitution', 232)
('accept_reservations', 226) ('update_playlist', 204) ('recipe', 201)
('play_music', 191) ('restaurant_suggestion', 187) ('calculator', 165)
('book_flight', 164) ('calendar_update', 154) ('ingredients_list', 154)
('traffic', 154) ('insurance_change', 150) ('calories', 150)
...
--------------------------------- ------------------------------- --------------------------------
We can see the set sizes for each of the 151 classes. Let's try a maximum overlap algorithm for our classification task. Here's what it looks like in my head:
- Break a sentence into a set of ngram tokens (n = [1 - 3]).
- Search each token in our filtered vocabulary.
- If a token is not found we will re-label it as OOV.
- Compare the set of tokens to our labels and take a note of the intersection.
- Select the maximum overlap of all labels.
In [53]: def get_sentence_ngrams(sentence):
...: return set([" ".join(token) for token in everygrams(sentence.split(" "), max_len=3)])
...:
In [54]: def label_missing_token_oov(tokens, vocab):
...: return set([token if token in vocab else "OOV" for token in tokens])
In [55]: def match_sentence_to_label(sentence, vocab_set, label_vocab):
...: overlaps = {}
...: tokens = get_sentence_ngrams(sentence)
...: tokens = label_missing_token_oov(tokens, vocab_set)
...: for label, vocab in label_vocab.items():
...: overlaps[label] = len(vocab.intersection(tokens))
...: return overlaps
In [56]: label_missing_token_oov(get_sentence_ngrams("peanut butter and jelly"), vocabulary_set)
Out[56]: {'OOV', 'jelly', 'peanut', 'peanut butter'}
In [57]: overlaps = match_sentence_to_label("peanut butter and jelly", vocabulary_set, label_vocab)
In [58]: sorted_overlap = sorted([(k, v) for k, v in overlaps.items() if v], key=lambda el: len(el[1]), reverse=True)
In [59]: sorted_overlap[:5]
Out[59]: [('shopping_list', {'peanut', 'peanut butter'}), ('calories', {'jelly'})]