Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

how to write a generator for Keras fit_generator with a state?

I am trying to feed a large dataset to a keras model. The dataset does not fit into memory. It is currently stored as a serie of hd5f files

I want to train my model using

model.fit_generator(my_gen, steps_per_epoch=30, epochs=10, verbose=1)

However, in all the examples I could find online, my_gen was used only to perform data augmentation on a already loaded dataset. For example

def generator(features, labels, batch_size):

 # Create empty arrays to contain batch of features and labels#

 batch_features = np.zeros((batch_size, 64, 64, 3))
 batch_labels = np.zeros((batch_size,1))

 while True:
   for i in range(batch_size):
     # choose random index in features
     index= random.choice(len(features),1)
     batch_features[i] = some_processing(features[index])
     batch_labels[i] = labels[index]
   yield batch_features, batch_labels

In my case, it needs to be something like

def generator(features, labels, batch_size):    
 while True:
   for i in range(batch_size):
     # choose random index in features
     index= # SELECT THE NEXT FILE
     batch_features[i] = some_processing(features[files[index]])
     batch_labels[i] = labels[file[index]]
   yield batch_features, batch_labels

How do I keep track of the files which were already read in previous batch?

like image 678
00__00__00 Avatar asked Nov 26 '25 13:11

00__00__00


1 Answers

From the keras doc

generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. [...]

This means you can write a class inheriting from keras.utils.sequence

class ProductSequence(keras.utils.Sequence):
    def __init__(self):
        pass

    def __len__(self):
        pass

    def __getitem__(self, idx):
        pass

__init__ ist to init the class. __len__ should return the number of batches per epoch. Keras will use thisto know which index can be passed to __getitem__. __getitem__ will then return the batch data depending on the index. A simple example can be found here

With this approach you can simpy have an internal class object in which you save which files are already read.

like image 182
dennis-w Avatar answered Nov 28 '25 01:11

dennis-w



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!