Fruits classification using CNN and PyTorch

Praful Thangappa
3 min readJun 29, 2020
Photo by Pixabay from Pexels

Why is it important to eat fruit?

Eating fruit provides health benefits — people who eat more fruits and vegetables as part of an overall healthy diet are likely to have a reduced risk of some chronic diseases. Fruits provide nutrients vital for health and maintenance of your body.

A growing body of research proves that fruits and vegetables are critical to promoting good health. In fact, fruits and vegetables should be the foundation of a healthy diet .Fruits and vegetables are packed with essential vitamins, minerals, fiber, and disease-fighting phytochemicals.

Here, I have build an Image Classifier using CNN model which classifies 131 different types taken from dataset Fruits-360.

Preparing the Dataset

The dataset consists of training data and test data with 67,692 images in training data and 22,688 images in test data.

Each data folder has 131 classes of fruits in them.

We can visualize the data using sns:

Next we split our data in the training folder into training and validation set. The reason for doing this is because training set is the sample of data used to fit the model. Whereas validation dataset provides an unbiased evaluation of a model fit on the training dataset while tuning model hyperparameters.

We split the data in (0.75:0.25) ratio for training and validation data.

Data Augmentation

Data augmentation is the process of increasing the amount and diversity of data. We do not collect new data, rather we transform the already present data. You could read more about augmentation on data-augmentation-for-deep-learning.

Next step is to load the image data into data loader to get batches of images. Let’s have a look at batch of training images after augmentation:

Transfer Learning

Transfer learning is a popular method in computer vision because it allows us to build accurate models in a time-saving way.

In computer vision, transfer learning is usually expressed through the use of pre-trained models. A pre-trained model is a model that was trained on a large benchmark dataset to solve a problem similar to the one that we want to solve.

ResNet34 was used to train this image classifier. ResNet was the state of the art in computer vision in 2015 and is still hugely popular. PyTorch lets you easily build ResNet models; it provides several pre-trained ResNet architectures and lets you build your own ResNet architectures.

You can learn more about Resnets here.

Training the Model

Following snippet was used to training the model. We could freeze some layers during training the model. Since all the parameters in the model have requires_grad = True, it means that all the parameters are learnable and will update on training the model. Had it been set to False for a any specific param, that parameter’s weight would not update on training the model.

Evaluating the model

So after initializing the model, we need to how well it is performing on validation dataset.

Since we have 131 classes, it is obvious to get a poor accuracy score while evaluating it. Let’s train the model on training data and check how well it performs on the validation set. Before we could freeze some layers in the model and train it.

Now lets unfreeze and train some more:

Also, one thing to note is normalization in such type of image works poorly. The accuracy was extremely low as it was harder to differentiate the color of each fruits in some examples.

Using nn.Dropout() and nn.BatchNorm1d() in the final layer of the model did not help much either.

Now, let’s check the performance on some test data images.

It is showing extremely good results for random images in the test dataset.

We could also see how well the model performs on test dataset.

Test accuracy is close enough to the validation accuracy, so the model isn’t overfitting either.

This Project was done as a part of PyTorch-zero-to-GANs course on https://jovian.ml/ and freecodecamp.

--

--