Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time

I am trying to create a CNN model using RandomSearch but its very slow and pops this error tensorflow:Callback method on_train_batch_end is slow compared to the batch time I am running my code in google colab with hardware acceleration set on gpu this is my code

def model_builder(hp):
    model=Sequential([
        Conv2D(filters=hp.Int('conv_1_filter',min_value=32,max_value=128,step=32),
               kernel_size=hp.Int('conv_1_filter',min_value=2,max_value=3,step=1),
               activation='relu',
               padding='same',
               input_shape=(200,200,3)),
        MaxPooling2D(pool_size=(2,2),strides=(2,2)),
        
        Conv2D(filters=hp.Int('conv_2_filter',min_value=32,max_value=128,step=32),
               kernel_size=hp.Int('conv_2_filter',min_value=2,max_value=3,step=1),
               padding='same',
               activation='relu'),
        MaxPooling2D(pool_size=(2,2),strides=(2,2)),
        
        Flatten(),
        
        Dense(units=hp.Int('dense_1_units',min_value=32,max_value=512,step=128),
              activation='relu'),
        
        Dense(units=10,
              activation='softmax')
               
    ])
    
    model.compile(optimizer=Adam(hp.Choice('learning_rate',values=[1e-1,1e-3,3e-2])),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

then RandomSearch and Fit

tuner=RandomSearch(model_builder,
                   objective='val_accuracy',
                   max_trials=2,
                   directory='projects',
                   project_name='Hypercars CNN'
                  )
tuner.search(X_train,Y_train,epochs=2,validation_split=0.2)
like image 368
Mohamed Abdullah Avatar asked Jan 27 '26 07:01

Mohamed Abdullah


2 Answers

This is caused when other operations which run at the end of each batch consumes more time than the batch itself. It could be that you have really small batches i.e. any operation that is slower in comparison to your original batches.

Increasing the batch size should solve this. Alternatively, you can use_multiprocessing = True in model.fit() and select the appropriate number of workers to generate your training batches more efficiently - but this only works for datasets that use a generator or keras.utils.Sequence.

Two threads talking about this issue:

  1. Thread 1
  2. Thread 2
like image 133
yudhiesh Avatar answered Jan 30 '26 07:01

yudhiesh


use_multiprocessing = True can work in removing that warning but another warning pops up relating using multiprocessing in tf2.

like image 25
Shruti Mathew Avatar answered Jan 30 '26 06:01

Shruti Mathew



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!