Augment Breast Cancer Diagnosis with Image Classification and LLMs

Olalekan Elesin
6 min readApr 1, 2023

It’s important to note that this post should not be interpreted as a definitive diagnosis, but informative on what is possible with recent advances in the AI space. This post was inspired by a recent loss of someone close to breast cancer.

In this blog post, using the RSNA Screening Mammography Breast Cancer Detection available on Kaggle, we will explore how to use of image classification models combined with Large Language Models, to aid in the detection and diagnosis of breast cancer. Specifically, we will fine-tune a pretrained Tensorflow model, with Amazon SageMaker JumpStart, on the mammography dataset mentioned above. The inference results from the deployed model is then passed OpenAI gpt-3.5-turbo, which generates a detailed explanation with natural language.

Below is an example inference from the model:

{
'biopsy': 0.131018966,
'cancer': 0.12599,
'density_A': 0.0577527769,
'density_B': 0.309944034,
'density_C': 0.23029092,
'density_D': 0.088826865,
'invasive': 0.0561763197
}

Response generated by OpenAI gpt-3.5-turbo :

The machine learning classification model you used to predict cancer likelihood returned the following probabilities:

biopsy: 0.131 — This suggests that there is a 13.1% chance that the biopsy results indicate the presence of cancer.

cancer: 0.126 — This indicates an overall likelihood of cancer based on all the analyzed features. The probability returned by the model suggests that there is a 12.6% chance that the patient has cancer.

density_A: 0.058 — This suggests that there is a 5.8% chance that the density of tissue in area A is indicative of cancer.

density_B: 0.31 — This indicates a 31% likelihood that the density of tissue in area B is indicative of cancer.

density_C: 0.23 — This indicates a 23% likelihood that the density of tissue in area C is indicative of cancer.

density_D: 0.089 — This suggests that there is an 8.9% chance that the density of tissue in area D is indicative of cancer.

invasive: 0.056 — This suggests that there is a 5.6% chance that the cancer is invasive, meaning that it has spread beyond the initial site of the tumor.

It’s important to note that these probabilities are not a definitive diagnosis and should be used in conjunction with other clinical data and diagnostic tests.

With the results upfront, it is time to get into woods.

The Work

If you are interested in reproducing this work, you can download the Jupyter Notebook from my AWS Samples Github repository following the link below:

Hypothesis

The goal was not only to predict likelihood of breast cancer from an image. The leading hypothesis for this work was:

Given a mammography, can we identify likely indicators of breast cancer and present the information in an human-understandable way to physicians to augment their diagnosis process?

The hypothesis above sets the stage for the data preparation.

Data Preparation

To get started, download the RSNA Screening Mammography Breast Cancer Detection Kaggle competition dataset to your Amazon SageMaker Notebook instance:

!pip install sagemaker ipywidgets pydicom --upgrade --quiet

The mammography images are stored as DICOM format and located in the train_images/ and test_images/ folders. Along with the image data, we have the train.csv and test.csv files which contain anonymized patient metadata for the images.

— —

We prepare our dataset with the assumption that we want want to identify likely indicators of breast cancer.

Steps:

  1. Select image classes from available columns in the train.csv file. The columns selected are: density, invasive, biopsy, BIRADS, and cancer . See definitions of these columns below:
- `density` - A rating for how dense the breast tissue is, with A being the least dense and D being the most dense. Extremely dense tissue can make diagnosis more difficult.
- `invasive` - If the breast is positive for cancer, whether or not the cancer proved to be invasive.
- `BIRADS` - 0 if the breast required follow-up, 1 if the breast was rated as negative for cancer, and 2 if the breast was rated as normal.
- `cancer` - Whether or not the breast was positive for malignant cancer.

2. Create a balanced dataset across the 4 classes mentioned above to avoid bias towards a specific class with highest number of images:

3. Convert .dcm to .png and upload to Amazon S3: The images available from the RSNA Screening Mammography Breast Cancer Detection available on Kaggle are dcm file formats. The Amazon SageMaker JumpStart model selected requires the image dataset to be in png or jpeg and saved to an Amazon S3 Bucket accessible to the SageMaker training instances.

The data uploaded to the Amazon S3 bucket for model training was split into train and validation using the split-folders Python library.

Model Training and Deployment

Details on the model training setup can be found in the GitHub project repository. The pre-trained model, imagenet , was fine-tuned on the custom dataset uploaded to S3 with the following parameters:

Using spot instances, I saved up to 64% on training costs.

2023-03-18 09:19:54 Completed - Training job completed
ProfilerReport-1679130750: NoIssuesFound
Training seconds: 307
Billable seconds: 108
Managed Spot Training savings: 64.8%

Interpreting Model Training Metrics

...
Setting weights to model with maximum val_accuracy at epoch 2/20:
- loss: 1.8219252824783325
- accuracy: 0.24945294857025146
- top_5_accuracy: 0.9037199020385742
- val_loss: 1.803653359413147
- val_accuracy: 0.27731093764305115
- val_top_5_accuracy: 0.9327731132507324
...

During training, the model was evaluated on both the training data and a separate validation dataset. At epoch 2/20, the model achieved its maximum validation accuracy, which means that it performed the best on the validation data compared to any other epoch.

The output provides several performance metrics for the model at this epoch. The loss and accuracy metrics are related to how well the model performed on the training data, while the val_loss and val_accuracy metrics represent the performance on the validation data.

The top_5_accuracy and val_top_5_accuracy metrics represent how often the model’s prediction was within the top 5 most likely predictions. A high top 5 accuracy indicates that the model’s predictions were generally close to the correct answer, even if they were not exact.

Overall, the output indicates that the model achieved a validation accuracy of 0.277 and a top 5 accuracy of 0.933 at epoch 2/20.

Inference

Once the model is deployed on Amazon SageMaker Endpoint, the predictions generated are class probabilities. However, these class probabilities are not human readable, especially for doctors. To make this useful, we pass the class predictions to the OpenAI gpt-3.5-turbo to generate an explanation:

Conclusion

Breast cancer is the second most common cancer among women worldwide, accounting for approximately 11.7% of all cancer cases. It is estimated that 1 in 8 women will develop breast cancer at some point in their lives. Early detection is crucial for successful treatment, and mammography is currently the gold standard for breast cancer screening.

With 93% top_5_accuracy model accuracy, the predictions were generally close to the correct answer, even if they were not exact. This can augment in near real time information available to oncologists immediately after mammogram scans are run.

You can reach me via email, follow me on Twitter or connect with me on LinkedIn.

--

--

Olalekan Elesin

Enterprise technologist with experience across technical leadership, architecture, cloud, machine learning, big-data and other cool stuff.