TensorFlow.js

Greetings! This is Jun and I’m a front-end developer at LINE. Recently the front-end development scene is booming with new technologies and it’s becoming difficult to catch up with all those new things. Personally I’m onto machine learning (ML). As part of today’s topic as “implementing machine learning in front-end web development”, let me share my little experiment on machine learning, using TensorFlow.js.

Getting started

I used TensorFlow.js to build a simple classification model for browsers. Before I go into the details, I’ll briefly introduce about TensorFlow and classification.

About TensorFlow

TensorFlow is an open-source machine learning framework. TensorFlow offers well-abstracted models and functions that are frequently used in machine learning, and this allows application programmers to easily implement machine learning. In mid-2018, the TensorFlow project released TensorFlow.js, a JavaScript port of TensorFlow. Recent trend is to deploy machine learning models on the client side, like Apple’s Core ML or Google’s ML Kit. TensorFlow.js allows an ML model to run on the browsers so that user experience (UX) can be more delicately managed and individually tailored contents can be provided.

About classification

In machine learning, classification is a supervised learning approach in which the data input is classified into a number of relevant classes. The case in point is filtering spam. The distribution diagram below plots the number of typos on the x-axis and the number of special characters on the y-axis to distinguish between spam (O) and ham (X).

You can see here that those with a higher number of typos and special characters are more likely to be spam. Using a machine learning model, you can determine a relevant class for the given data by learning these patterns. For example, you can assume that the data marked as a yellow triangle in the following diagram to be a spam.

Building color recommendation system

Now is the time to share my experiment. I implemented a color recommendation system, which makes a recommendation to the users based on their feedback on favorite colors.

How to implement

For data input, a 3-dimensional vector is used as the feature are RGB values. I chose logistic regression as a classification model, especially since this system runs on a web browser and it is difficult to collect a large set of data. Logistic regression model tends to better deliver results with smaller sample sizes and low variance. In addition, I used Sigmoid function as an activation function and binary crossentropy as a loss function. I went with stochastic gradient descent (SGD) as an optimizer as the fitting process should take place per user action. This can be implemented with TensorFlow.js as follows:

const model = tf.sequential({
  layers: [
    tf.layers.dense({ inputShape: [3], units: 1, activation: 'sigmoid' })
  ]
})
 
model.compile({
  optimizer: tf.train.sgd(1), // learning rate = 1
  loss: 'binaryCrossentropy'
})

As the input value has each of RGB values, inputShape is set as [3]. Others are just as I explained above. With the following code, this model can start learning with a training data set.

await model.fit(x, y, {
  batchSize: 1,
  epochs: 3
})

Tensor is assigned to x and y respectively. The batchSize variable determines the size of input, and here I set it as 1 so that data can be fed one by one. The input is normalized for SGD algorithm to work better. You can make prediction for new data with the following code:

model.predict(newExample)

Demo

Let me show you how it works on the front end. I prepared a Web app as shown below. When the user clicks the like button, that color will be recognized as positive. I didn’t show a dislike button as users usually don’t give a feedback on what they don’t like. Instead, I coded it to recognize a certain color as negative when a user doesn’t click the like button and scrolls away from that color. I used the Intersection Observer API for implementation.

When user feedback is collected either from the like button or from scroll down, the following data are collected. It is shown as a 3-dimensional vector as input represents each of RGB values.

This distribution shows that the users clicked on the like button for the different shades of red. You can see here that positive data (marked in blue “O”) and negative data (marked in orange “X”) tend to cluster together respectively. If the model is well trained, it will show this pattern.

When the model.predict() function is called, it returns a value between 0 and 1, where closer to 1 means positive. When the threshold of 0.5 was applied to the new colors for prediction, the filtering results showed different shades of the color red as below, indicating the model was doing its job right.

Challenges

This example showed positive results, but there were still some challenges. At first, I implemented a shallow neural network (NN) model with one hidden layer. However I tried, the generalization error rate was high. Probably the sample data size of 50 wasn’t big enough.

{
  layers: [
    tf.layers.dense({ inputShape: [3], units: 3, activation: 'sigmoid' }),
    tf.layers.dense({ units: 1, activation: 'sigmoid' })
  ]
}

When I added more training data and increased the batchSize, it worked fine. As it is difficult to secure a large training data set on the front end, it is one of the challenges to be tackled. If I choose to download the weights trained on the server, it will still be a problem as more layers or units mean huge weights, increasing the download size significantly. I need to find a more efficient solution for this issue. If you have any suggestion or ideas, please contact me here!

Closing

This is all I’ve prepared for this posting. How did you like it? I’m still a novice at machine learning so there is a lot of things to learn. I’ll be carrying on with my pursuit of machine learning, and I’d really appreciate it if you share your opinions and ideas for improvement. Personally I think machine learning on the client side is difficult with a lot of constraints, but there is still an upside with direct integration with UX. There are already libraries available on many client platforms including iOS, Android and Unity with actual implementation cases on the increase. For your reference, I posted the code on GitHub. Bye for now!