tfds.disable_progress_bar()

Neural Network SMS Text Classifier

Note

The training time for this problem is painfully slow!

Here's a good tutorial on RNN and TextVectorization that you can follow to solve this problem. However, make sure to follow it closely. I waste a lot of time by skipping these 2 lines, thinking that it's just an optimization. It's not. It actually reshape the data sets. Without it, the model can't fit.

train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

Total time for this: 3.5h

Problem description

Copied and modified from this Google Colab

In this challenge, you need to create a machine learning model that will classify SMS messages as either "ham" or "spam". A "ham" message is a normal message sent by a friend. A "spam" message is an advertisement, or a message sent by a company.

You should create a function called predict_message that takes a message string as an argument and returns a list. The first element in the list should be a number between zero and one that indicates the likeliness of "ham" (0) or " spam" (1). The second element in the list should be the word "ham" or "spam", depending on which is most likely.

For this challenge, you will use the SMS Spam Collection dataset. The dataset has already been grouped into train data and test data.

The first two cells import the libraries and data. The final cell tests your model and function. Add your code in between these cells.

Solution

Get data files

Import libraries

Prepare data

Train data frame

Test data frame

Prepare labels for training

This step is important. It's not only about batching the data sets, but also reshape it to make it works when fitting the model. Otherwise, we will get incompatible layers errors: expecting ndim=3, got ndim=2.

Create a TextVectorization layer for our model

Let's show the vocabulary that our vectorizer has learned.

Create the model

Train the model against our data sets.

Plot the accuracy and loss metrics

Create a helper function to plot

Plot the graphs

Test