In 2020, ~5.8 million Americans were living with Alzheimer’s and that number is projected to increase to 14 million by 2050. Currently, 1 in 5 Alzheimer’s patients are misdiagnosed, and the ability to diagnose correctly at an early stage is vital to prevent irreversible brain damage. The mental decline visibly associated with Alzheimer’s is typically observed in later stages and may be harder to detect in MRI in early stages.1
This is why machine learning, and specifically image classification, should be leveraged for the classification and prediction of Alzheimer’s severity, in order to reduce inaccurate misdiagnosis and potentially diagnose at an earlier stage.
In my Deep Learning and Image Recognition course, I worked on a team project using image augmentation for class balancing combined with a custom convolutional neural network (CNN) to classify the severity of Alzheimer’s MRI imaging. When we compared our custom model to the state-of-the-art model VGG16 through transfer learning, we found that our custom CNN performed better with a test accuracy of 74.87% vs. VGG16s test accuracy of 24.64%.
The dataset we used was found on Kaggle and consists of 6,400 MRI images (128 x 128), broken into 4 categories:
- Mild Demented
- Very Mild Demented
- Moderate Demented
- Non Demented
We conducted exploratory analysis on the dataset and found that the classes were imbalanced; meaning that we had more images of some classes versus others.
This was an issue because if our model trained on the imbalanced data, the results could be biased towards the classes that represented most of the dataset, the model would not generalize well to other datasets and would have bad accuracy as a result.
In order to balance our dataset, we used a method called image augmentation.
Image augmentation is a process to create and add new images to a dataset by altering existing images. There are many ways to augment an image including flipping, rotation, and scaling. We used the imgaug library and found after some trial and error, that the augmentations that produced the best results were cropping, linear contrast, Gaussian blur, and additive Gaussian noise.
It makes sense that augmentations that do not alter the appearance of the images too much produce the best results since it is harder to classify when there is more differentiation between images. For Alzheimer’s MRI imaging, it is especially important that the images are not altered too much since the accuracy of the classification is important for diagnosis.
Convolutional Neural Network (CNN)
A convolutional neural network (CNN) is a type of artificial neural network that is able to process pixel data and therefore images for recognition and classification tasks.
We built a custom CNN with 2 convolutional layers, “relu” activation functions, maxpooling for feature extraction and a kernel initializer for initializing weights in the model. Our final layer was a dense layer with a softmax activation function for predicting the probabilities of our classes.
For our optimizer, we used the ADAM optimizer since it works well for this kind of classification problem, and we used the categorical cross-entropy loss function since this is commonly used in multi-class classification tasks.
Using our balanced dataset with the augmented images, we were able to achieve a test accuracy of 74.87%.
We also ran our model on our unbalanced dataset for comparison to see if image augmentation was having an impact on the accuracy of our model. With the unbalanced data, we achieved a test accuracy of 95%.
So, given that we were changing what the images looked like, image augmentation does impact the model performance, and it would be worth exploring other ways of balancing MRI image data.
VGG16 is a CNN with 16 layers that has been pretrained on the ImageNet database. We decided to compare our custom CNN to VGG16 since this model tends to get test accuracies of 90% and above on datasets like MRI imaging.
Since VGG16 is a pre-built model, we had to add additional layers in order to tailer the model to predict our 4 Alzheimer’s classes. This process of using an existing model and it’s knowledge for another task is referred to as transfer learning.
We used the same loss function and the Adam optimizer, but here decided to set the learning rate to 0.01.
When we ran this model on our augmented images, we got a test accuracy of 24.64%. Looking at the validation accuracy across epochs, the accuracy kept bouncing around between 0.2484 and 0.2531:
We believe that the model was not tuned enough and therefore did not achieving optimal performance. We tried changing parameters like the epochs, batch sizes, and the learning rate to try and improve performace. This is why the VGG16 model we ended up using has a learning rate of 0.01, since this was the best model we were able to achieve.
We also ran this model on the unbalanced dataset and got a test accuracy of 51.25% in which we observed the same validation accuracy for each epoch, re-emphasizing the need of hyperparameterization within the model. Additionally, the unbalanced dataset still performed better then the augmented dataset, which is consistent with what we saw with our custom CNN model.
To test out our custom CNN live with different MRI images, we hosted the model on a Streamlit app where you can simply upload an Alzheimer’s MRI image to see the models class prediction.
The above demo shows that when we drop in the images of mild and very mild Azheimer’s, our custom CNN is able to accurately predict their severity classes. However, since our model is not perfect (and only 74.87% accurate) the last 2 examples of moderate and non-demented Alzheimer’s are not correctly classified.
This shows the need to do more work and improve upon CNNs for MRI imaging to correctly classify Alzheimer’s severity.
Findings & Future Considerations
We found that our custom CNN performed significantly better on the Alzheimer’s MRI imaging than the state-of-the-art VGG16 model. We believed this happen because the model was stuck in a local minima and requires hyperparameter tuning to improve performance.
Moving forward, we think that other forms of balancing image data should be explored since image augmentation lowered the testing accuracy. It would be best if more image data could be collected to remove the need to balance data, and further, remove the need for human labelling of data to create a fully hands free process from data generation to MRI classification.
Deep learning and image recognition could be the future of Alzheimer’s diagnosis, but needs to achieve the best possible test accuracies and performance for that to be the case.
1 Centers for Disease Control and Prevention. (2020, October 26). What is alzheimer’s disease? Centers for Disease Control and Prevention. Retrieved May 10, 2022, from https://www.cdc.gov/aging/aginginfo/alzheimers.htm#:~:text=In%202020%2C%20as%20many%20as,were%20living%20with%20Alzhei%20mer’s%20disease.&text=Younger%20people%20may%20get%20Alzheimer’s,14%20million%20people%20by%202060.