tldr: Memory usage of my implementation apparently grows with the number of samples passed through it, but there should be nothing in the network/sample feeding that cares about how many samples were passed so far.
When passing a large badge of high-dimensional data through a custom Keras model created through the functional API, I observe what I assume is a constant growth in GPU memory usage with growing number of observed instances. The following is a minimal example for the process of passing the samples through the network:
sequence_length = 100
batch_size = 128
env = gym.make("ShadowHand-v1")
_, _, joint = build_shadow_brain(env, bs=batch_size)
optimizer: tf.keras.optimizers.Optimizer = tf.keras.optimizers.SGD()
start_time = time.time()
for t in tqdm(range(sequence_length), disable=False):
sample_batch = (
tf.random.normal([batch_size, 1, 200, 200, 3]),
tf.random.normal([batch_size, 1, 48]),
tf.random.normal([batch_size, 1, 92]),
tf.random.normal([batch_size, 1, 7])
)
with tf.GradientTape() as tape:
out, v = joint(sample_batch)
loss = tf.reduce_mean(out - v)
grads = tape.gradient(loss, joint.trainable_variables)
optimizer.apply_gradients(zip(grads, joint.trainable_variables))
joint.reset_states()
print(f"Execution Time: {time.time() - start_time}")
I am aware of the fact that this is a large sample given the batch size, however what I would expect would be an instant OOM error if it were in fact too large for my GPU and I also assume that 6GB of VRAM actually suffice. That is because only after 33 instances the OOM error occurs, leading me to the suspicion that there is a growing usage of memory.
See in the following the Keras summary of my model:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
visual_input (InputLayer) [(32, None, 200, 200 0
__________________________________________________________________________________________________
proprioceptive_input (InputLaye [(32, None, 48)] 0
__________________________________________________________________________________________________
somatosensory_input (InputLayer [(32, None, 92)] 0
__________________________________________________________________________________________________
time_distributed (TimeDistribut (None, None, 64) 272032 visual_input[0][0]
__________________________________________________________________________________________________
time_distributed_1 (TimeDistrib (None, None, 8) 848 proprioceptive_input[0][0]
__________________________________________________________________________________________________
time_distributed_2 (TimeDistrib (None, None, 8) 3032 somatosensory_input[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate) (None, None, 80) 0 time_distributed[0][0]
time_distributed_1[0][0]
time_distributed_2[0][0]
__________________________________________________________________________________________________
time_distributed_3 (TimeDistrib (None, None, 48) 3888 concatenate[0][0]
__________________________________________________________________________________________________
time_distributed_4 (TimeDistrib (None, None, 48) 0 time_distributed_3[0][0]
__________________________________________________________________________________________________
time_distributed_5 (TimeDistrib (None, None, 32) 1568 time_distributed_4[0][0]
__________________________________________________________________________________________________
time_distributed_6 (TimeDistrib (None, None, 32) 0 time_distributed_5[0][0]
__________________________________________________________________________________________________
goal_input (InputLayer) [(32, None, 7)] 0
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (32, None, 39) 0 time_distributed_6[0][0]
goal_input[0][0]
__________________________________________________________________________________________________
lstm (LSTM) (32, 32) 9216 concatenate_1[0][0]
__________________________________________________________________________________________________
dense_10 (Dense) (32, 20) 660 lstm[0][0]
__________________________________________________________________________________________________
dense_11 (Dense) (32, 20) 660 lstm[0][0]
__________________________________________________________________________________________________
activation (Activation) (32, 20) 0 dense_10[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (32, 20) 0 dense_11[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (32, 40) 0 activation[0][0]
activation_1[0][0]
__________________________________________________________________________________________________
dense_12 (Dense) (32, 1) 33 lstm[0][0]
==================================================================================================
Total params: 291,937
Trainable params: 291,937
Non-trainable params: 0
__________________________________________________________________________________________________
As you can see there is an LSTM layer in this network. It should usually be stateful, however I already turned this off because I assumed the problem to somehow lie there. In fact I already tried the following, without eliminating the issue
and have now reached the end of my ideas concerning potential causes of the issue.
I have also forced the process onto the CPU and inspected the standard memory (the OOM does not happen here because I have a lot more RAM than VRAM). Interestingly the memory usage jumps up and down but has an upwards trend. For every instance, about 2GB memory are taken, but when freeing the memory before taking the next sample, only about 200MB of memory less than what was taken is released.
EDIT 1: As mentioned in the comments the issue might be the fact that calling the model on the input adds to the computation graph. However I cannot use joint.predict() because I need to calculate the Gradients.
EDIT 2: I monitored the growth in memory a little more closely and indeed what happens is that every iteration keeps some memory reserved as you can see here for the first 9 steps:
0: 8744054784
1: 8885506048
2: 9015111680
3: 9143611392
4: 9272619008
5: 9405591552
6: 9516531712
7: 9647988736
8: 9785032704
This was done with a batch size of 32. The size of one sample_batch is 256 * (200 * 200 * 3 + 48 + 92 + 7) * 32 = 984244224 bits (precision is float32) which more or less shows that indeed the problem must be that when passing the sample through the network, the sample is added to the graph because it is symbolic, as @MatiasValdenegro suggested. So I guess the question now boils down to "how to make a tensor non-symbolic" if that even is a thing.
Disclaimer: I know that you cannot reproduce the issue with the given code because there are missing components, but I cannot provide the full project's code here.
It took me a while but I have now solved the issue. As I have edited into the Question before: The issue is that the functional API of Keras seems to be adding each sample to the computation graph without removing the input we don't need anymore after the iteration. There seems to be no easy way of explicitly removing it, however the tf.function decorator can solve the issue.
Taking my code example from above, it can be applied as follows:
sequence_length = 100
batch_size = 256
env = gym.make("ShadowHand-v1")
_, _, joint = build_shadow_brain(env, bs=batch_size)
plot_model(joint, to_file="model.png")
optimizer: tf.keras.optimizers.Optimizer = tf.keras.optimizers.SGD()
@tf.function
def _train():
start_time = time.time()
for _ in tqdm(range(sequence_length), disable=False):
sample_batch = (tf.convert_to_tensor(tf.random.normal([batch_size, 4, 224, 224, 3])),
tf.convert_to_tensor(tf.random.normal([batch_size, 4, 48])),
tf.convert_to_tensor(tf.random.normal([batch_size, 4, 92])),
tf.convert_to_tensor(tf.random.normal([batch_size, 4, 7])))
with tf.GradientTape() as tape:
out, v = joint(sample_batch, training=True)
loss = tf.reduce_mean(out - v)
grads = tape.gradient(loss, joint.trainable_variables)
optimizer.apply_gradients(zip(grads, joint.trainable_variables))
print(f"Execution Time: {time.time() - start_time}")
_train()
That is, the training loop can be shipped in a function with the tf.function decorator. This means that the training will be executed in graph mode, and for some reason, this removes the issue, most likely because the graph will be dumped after the function ends. For more on tf.function see the TF2.0 Guide on the topic.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With