Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use class weights in Keras for image segmentation

I am trying to segment medical images using a version of U-Net implemented with Keras. The inputs of my network are 3D images and the outputs are two one-hot-encoded 3D segmentation maps. I know that my dataset is very imbalanced (there is not so much to segment) and therefore I want to use class weights for my loss function (currently binary_crossentropy). With the class weights, I hope the model will give more attention to the small stuff it has to segment.

If you know the imbalance of your database, you can pass the parameter class_weight to model.fit(). Does this also work with my use case?

like image 724
Jan Willem Avatar asked Oct 30 '25 22:10

Jan Willem


1 Answers

With the help of the above mentioned github issue I managed to solve the problem for my particular use case. I want to share the solution with you anyway. An extra hurdle was the fact I am using a custom generator for my data. A simplified version of this class is the following code:

import numpy as np
import keras

class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, batch_size=2, dim=(144,144,144), n_classes=2):
        'Initialization'
        self.dim = dim
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.n_classes = n_classes
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        x, y = self.__data_generation(list_IDs_temp)

        return x, y

    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, 1)
        # Initialization
        x = np.empty((self.batch_size, *self.dim, 1))
        y = np.empty((self.batch_size, *self.dim, 1))

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            # Load dataset
            data = np.load('data/' + ID + '.npy')

            # Store x and y
            x[i,] = data[:, :, :, 0]  # Image
            y[i,] = data[:, :, :, 1]  # Mask

        # One-hot-encoding
        y = keras.utils.to_categorical(y, num_classes=self.n_classes)

        return x, y

Actually a few lines of code did the trick. With an extra input argument class_weights to my generator, a line to convert the class weights to sample weights for each individual batch in the __getitem__() method, and also a return of the sample weights in the same method, I solved the issue. The class weights are inputted as list with the following structure: class_weights = [weight_class_0, weight_class_1]. My basic generator class now looks like this (I have marked changes with a comment):

import numpy as np
import keras

class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, class_weights, batch_size=2, dim=(144,144,144), 
                 n_classes=2):
        'Initialization'
        self.dim = dim
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.n_classes = n_classes
        self.class_weights = class_weights  # CLASS WEIGHTS FIX
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        x, y = self.__data_generation(list_IDs_temp)

        # Compute sample weights CLASS WEIGHTS FIX
        sample_weights = np.take(np.array(self.class_weights), np.round(y[:, :, :, :, 1]).astype('int'))

        return x, y, sample weights  # CLASS WEIGHTS FIX

    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, 1)
        # Initialization
        x = np.empty((self.batch_size, *self.dim, 1))
        y = np.empty((self.batch_size, *self.dim, 1))

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            # Load dataset
            data = np.load('data/' + ID + '.npy')

            # Store x and y
            x[i,] = data[:, :, :, 0]  # Image
            y[i,] = data[:, :, :, 1]  # Mask

        # One-hot-encoding
        y = keras.utils.to_categorical(y, num_classes=self.n_classes)

        return x, y

It might seem a bit like a magic one-liner, but what sample_weights = np.take(np.array(self.class_weights), np.round(y[:, :, :, :, 1]).astype('int')) does is the following: It takes the y-values belonging to the not so common class, in my case the one to segment, and gives each pixel in this 3D image a sample weight. This sample weight is either the class weight for the common class or the uncommon class, depending on which class the pixel is belonging too.

The output of this generator class can be then used in the model.fit() method of the Keras model as long as sample_weight_mode="temporal" is passed to model.compile().

like image 119
Jan Willem Avatar answered Nov 01 '25 12:11

Jan Willem



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!