TensorFlow Tutorial: A Guide to Retraining Object Detection Models
I was inspired to document this TensorFlow tutorial after developing the SIMI project; an object recognition app for the visually impaired. This time around I wanted to spend my week retraining the object detection model and writing up a guide so that other developers can do the same thing.
TensorFlow’s object detection technology can provide huge opportunities for app developers and brands alike to use a range of tools for different purposes. The use of mobile devices only furthers this potential as people have access to incredibly powerful computers and only have to search as far as their pockets to find it.
In this tutorial, I’ll cover the steps you need to take while retraining object detection models in TensorFlow, including a breakdown of each stage which covers different approaches such as using existing models and data, as well as linking out to helpful resources that provide more detail on steps not everyone will be taking.
Steps in Retraining Object Detection Models with TensorFlow:
In TensorFlow’s GitHub repository you can find a large variety of pre-trained models for various machine learning tasks, and one excellent resource is their object detection API. The object detection API doesn’t make it too tough to train your own object detection model to fit your requirements.
Whether you need a high-speed model to work on live stream, high-frames-per-second (fps) applications, or high-accuracy desktop models, the API makes it possible to train and export the model.
To start off, make sure you have TensorFlow installed on your computer (how to install TensorFlow). Next, we have to clone and install the object detection API on our PC. Installing the object detection API is simple, you just need to clone the TensorFlow Models directory or you can always download the zip file for the TensorFlow Models on GitHub.
Then we save everything under our main folder: ‘The Pip Model’, before opening a terminal and moving to the research folder by typing:
cd ~/pipModel/models/research/ protoc object_detection/protos/*.proto --python_out=. export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
Before we get started, let’s create a folder named TensorFlow on our PC, and from now on everything we download will be stored in this root folder. Ideally, create this file inside your main user folder (e.g. where the Desktop, Documents, Downloads, and Movies files are stored).
In order to train your custom object detection class, you have to create (collect) and label (tag) your own data set.
The dataset should contain all the objects you want to detect. Ideally, a dataset contains at least 200 images of each object in question – but this set is only for the trainer dataset because unfortunately, you also need a test dataset which should be 30 percent of the trained dataset…
So in total, we need approximately 260 images.
This is where the hard work starts: you need to search the web and manually download images of the object you want to detect (?). You’ll want to make sure the images are no larger than 1280×720 pixels for two reasons. Firstly, all the images will be resized to 300×300 during the training process, and secondly because of the lack of storage space (you’ll need 250MB of free disk space, and that’s for objects only).
I just wish there was an automatic way to download all the images! Here’s the simplest way to do this:
Now that we have acquired all the images, the next step is labelling them.
This process, unfortunately, cannot be automated so we have to label all the images manually. We need to draw a bounding box around the object, to let the system know that this “thing” inside the box is the actual object we want to learn/detect. In this case, you can use Rectlabelor LabelImg to label all the images.
LabelImg is an excellent open-source (and free) software that makes the labelling process much easier. It will save individual xml labels for each image you have processed.
Create an ‘annotations’ folder within the ‘TensorFlow’ folder, and save all the xml files into that folder.
Finally, after labelling the images we need to create the TFRecords. We need train.records (which is the result of the training images) and test.records (which is the result of the test images).
Convert the labels to the TFRecord format.
When training models with TensorFlow using TFRecord, files help optimise your data feed. We can generate a TFRecord file using the code: xml_to_csv.py.
After we created the required input file for the API, we now can continue the training of our model.
There are many models in the TensorFlow API you can use depending on your needs. If you want a high-speed model that can work on detecting video feed at high fps, the single shot detection (SSD) network is the best.
Some other object detection networks detect objects by sliding different sized boxes across the image and running the classifier many times on different sections. As you can imagine this is very resource-consuming.
The SSD network determines all bounding box probabilities in one go, hence it is a vastly faster model. However, with single shot detection, you gain speed but lose accuracy. In our tutorial, we will use the MobileNet model, which is designed to be used in mobile applications.
We’ve already configured the .config file for SSD MobileNet and included it in the GitHub repository for this post, named ssd_mobilenet_v1_pets.config.
We could train the entire SSD MobileNet model on our own data from scratch, but that would require thousands of training images and roughly 4-7 days’ worth of training time.
So we might as well use an existing model which is already trained on a large dataset and replaces the last layer, which has the classes/objects from the trained model, with our own classes/objects. By doing that, we can use all the feature detectors trained in that model to detect our new classes/objects.
A key thing in this step is to stop the training once our loss is consistently inferior to 1 or you can wait until it finishes. To stop TensorFlow training, simply press ctrl+c (on Mac). To train, we simply run the `train.py` file in the object detection API directory pointing to our data. So let’s move all train.record and test.record into a new folder called ‘data’.
First, download the ssd_mobilenet_v1_coco_11_06_2017. Then unzip that to the Downloads, do not save this file in TensorFlow. Also, change the path on ssd_mobilenet_v1_pets.config, to point to the model.ckpt file of the ssd_mobilenet_v1_coco file we downloaded before.
You may have several files with the same format, but with different checkpoint numbers. Thus, we will use the latest ckpt -####(e.g., model.cptk-102) from our data directory and execute the following command:
Open it up and import a new project using the directory from the TensorFlow repo we downloaded/cloned, called “Android”, it should be stored in ~/tensorflow/tensorflow/examples/android path.
Then, change the nativeBuildSystem variable of the Gradle build to none.
Change the model’s configuration on DetectorActivity.java and move the two files in the assets folder.
private static final String TF_OD_API_MODEL_FILE = "file:///android_asset/frozen_inference_graph.pb";
private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/labels.txt";
In the file called “labels.txt”, the first line stays blank (remember when I said the first line is reserved), and in the second line write the label of your object (in my case, it was Pip).
And that’s it! Happy detection!
Meet SIMI: The Object Recognition App for the Visually Impaired:
Take a look at the SIMI project that inspired this tutorial, the object detection model was set-up to recognise a range of different and unique objects from plant plots to people, laptops, books, bicycles and many, many more. Including voice interactions and emergency contacts, the app utilises TensorFlow object detection technology to improve the lives of those living with visual impairments or disabilities.
We dedicate 15% of our team's time to explore emerging technology and work on projects they're passionate about. So far we've developed a Jenga game in augmented reality, an app that mimics the human eyes and an interactive map that tacks natural disasters in real-time.