Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TF2.0 Memory Leak From Applying Keras Model to Symbolic Tensor

tldr: Memory usage of my implementation apparently grows with the number of samples passed through it, but there should be nothing in the network/sample feeding that cares about how many samples were passed so far.


When passing a large badge of high-dimensional data through a custom Keras model created through the functional API, I observe what I assume is a constant growth in GPU memory usage with growing number of observed instances. The following is a minimal example for the process of passing the samples through the network:

sequence_length = 100
batch_size = 128

env = gym.make("ShadowHand-v1")
_, _, joint = build_shadow_brain(env, bs=batch_size)
optimizer: tf.keras.optimizers.Optimizer = tf.keras.optimizers.SGD()

start_time = time.time()
for t in tqdm(range(sequence_length), disable=False):
    sample_batch = (
        tf.random.normal([batch_size, 1, 200, 200, 3]),
        tf.random.normal([batch_size, 1, 48]),
        tf.random.normal([batch_size, 1, 92]),
        tf.random.normal([batch_size, 1, 7])
    )

    with tf.GradientTape() as tape:
        out, v = joint(sample_batch)
        loss = tf.reduce_mean(out - v)

    grads = tape.gradient(loss, joint.trainable_variables)
    optimizer.apply_gradients(zip(grads, joint.trainable_variables))
    joint.reset_states()

print(f"Execution Time: {time.time() - start_time}")

I am aware of the fact that this is a large sample given the batch size, however what I would expect would be an instant OOM error if it were in fact too large for my GPU and I also assume that 6GB of VRAM actually suffice. That is because only after 33 instances the OOM error occurs, leading me to the suspicion that there is a growing usage of memory.

See in the following the Keras summary of my model:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
visual_input (InputLayer)       [(32, None, 200, 200 0                                            
__________________________________________________________________________________________________
proprioceptive_input (InputLaye [(32, None, 48)]     0                                            
__________________________________________________________________________________________________
somatosensory_input (InputLayer [(32, None, 92)]     0                                            
__________________________________________________________________________________________________
time_distributed (TimeDistribut (None, None, 64)     272032      visual_input[0][0]               
__________________________________________________________________________________________________
time_distributed_1 (TimeDistrib (None, None, 8)      848         proprioceptive_input[0][0]       
__________________________________________________________________________________________________
time_distributed_2 (TimeDistrib (None, None, 8)      3032        somatosensory_input[0][0]        
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, None, 80)     0           time_distributed[0][0]           
                                                                 time_distributed_1[0][0]         
                                                                 time_distributed_2[0][0]         
__________________________________________________________________________________________________
time_distributed_3 (TimeDistrib (None, None, 48)     3888        concatenate[0][0]                
__________________________________________________________________________________________________
time_distributed_4 (TimeDistrib (None, None, 48)     0           time_distributed_3[0][0]         
__________________________________________________________________________________________________
time_distributed_5 (TimeDistrib (None, None, 32)     1568        time_distributed_4[0][0]         
__________________________________________________________________________________________________
time_distributed_6 (TimeDistrib (None, None, 32)     0           time_distributed_5[0][0]         
__________________________________________________________________________________________________
goal_input (InputLayer)         [(32, None, 7)]      0                                            
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (32, None, 39)       0           time_distributed_6[0][0]         
                                                                 goal_input[0][0]                 
__________________________________________________________________________________________________
lstm (LSTM)                     (32, 32)             9216        concatenate_1[0][0]              
__________________________________________________________________________________________________
dense_10 (Dense)                (32, 20)             660         lstm[0][0]                       
__________________________________________________________________________________________________
dense_11 (Dense)                (32, 20)             660         lstm[0][0]                       
__________________________________________________________________________________________________
activation (Activation)         (32, 20)             0           dense_10[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (32, 20)             0           dense_11[0][0]                   
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (32, 40)             0           activation[0][0]                 
                                                                 activation_1[0][0]               
__________________________________________________________________________________________________
dense_12 (Dense)                (32, 1)              33          lstm[0][0]                       
==================================================================================================
Total params: 291,937
Trainable params: 291,937
Non-trainable params: 0
__________________________________________________________________________________________________

As you can see there is an LSTM layer in this network. It should usually be stateful, however I already turned this off because I assumed the problem to somehow lie there. In fact I already tried the following, without eliminating the issue

  • Turn of statefulness
  • Entirely removing the LSTM
  • not calculating any gradients
  • rebuilding the model after every instance

and have now reached the end of my ideas concerning potential causes of the issue.

I have also forced the process onto the CPU and inspected the standard memory (the OOM does not happen here because I have a lot more RAM than VRAM). Interestingly the memory usage jumps up and down but has an upwards trend. For every instance, about 2GB memory are taken, but when freeing the memory before taking the next sample, only about 200MB of memory less than what was taken is released.

EDIT 1: As mentioned in the comments the issue might be the fact that calling the model on the input adds to the computation graph. However I cannot use joint.predict() because I need to calculate the Gradients.

EDIT 2: I monitored the growth in memory a little more closely and indeed what happens is that every iteration keeps some memory reserved as you can see here for the first 9 steps:

0: 8744054784
1: 8885506048
2: 9015111680
3: 9143611392
4: 9272619008
5: 9405591552
6: 9516531712
7: 9647988736
8: 9785032704

This was done with a batch size of 32. The size of one sample_batch is 256 * (200 * 200 * 3 + 48 + 92 + 7) * 32 = 984244224 bits (precision is float32) which more or less shows that indeed the problem must be that when passing the sample through the network, the sample is added to the graph because it is symbolic, as @MatiasValdenegro suggested. So I guess the question now boils down to "how to make a tensor non-symbolic" if that even is a thing.

Disclaimer: I know that you cannot reproduce the issue with the given code because there are missing components, but I cannot provide the full project's code here.

like image 826
weidler Avatar asked Oct 17 '25 03:10

weidler


1 Answers

It took me a while but I have now solved the issue. As I have edited into the Question before: The issue is that the functional API of Keras seems to be adding each sample to the computation graph without removing the input we don't need anymore after the iteration. There seems to be no easy way of explicitly removing it, however the tf.function decorator can solve the issue.

Taking my code example from above, it can be applied as follows:

sequence_length = 100
batch_size = 256

env = gym.make("ShadowHand-v1")
_, _, joint = build_shadow_brain(env, bs=batch_size)
plot_model(joint, to_file="model.png")
optimizer: tf.keras.optimizers.Optimizer = tf.keras.optimizers.SGD()

@tf.function
def _train():
    start_time = time.time()

    for _ in tqdm(range(sequence_length), disable=False):
        sample_batch = (tf.convert_to_tensor(tf.random.normal([batch_size, 4, 224, 224, 3])),
                        tf.convert_to_tensor(tf.random.normal([batch_size, 4, 48])),
                        tf.convert_to_tensor(tf.random.normal([batch_size, 4, 92])),
                        tf.convert_to_tensor(tf.random.normal([batch_size, 4, 7])))

        with tf.GradientTape() as tape:
            out, v = joint(sample_batch, training=True)
            loss = tf.reduce_mean(out - v)

        grads = tape.gradient(loss, joint.trainable_variables)
        optimizer.apply_gradients(zip(grads, joint.trainable_variables))

    print(f"Execution Time: {time.time() - start_time}")

_train()

That is, the training loop can be shipped in a function with the tf.function decorator. This means that the training will be executed in graph mode, and for some reason, this removes the issue, most likely because the graph will be dumped after the function ends. For more on tf.function see the TF2.0 Guide on the topic.

like image 92
weidler Avatar answered Oct 18 '25 20:10

weidler



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!