Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to log model graph to tensorboard when using combination to Functional API and tf.GradientTape() to train in Tensorflow 2.0?

Can some please guide me on how to log model graph to tensorboard when I am using Keras Functional API or the model sub calling API to create the model and tf.GradientTape() to train the model?

# Get the model.
inputs = keras.Input(shape=(784,), name='digits')
x = layers.Dense(64, activation='relu', name='dense_1')(inputs)
x = layers.Dense(64, activation='relu', name='dense_2')(x)
outputs = layers.Dense(10, activation='softmax', name='predictions')(x)
model = keras.Model(inputs=inputs, outputs=outputs)


optimizer = keras.optimizers.SGD(learning_rate=1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy()


batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Iterate over epochs.
epochs = 3
for epoch in range(epochs):
    print('Start of epoch %d' % (epoch,))

    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):   

    with tf.GradientTape() as tape:
        logits = model(x_batch_train)
        loss_value = loss_fn(y_batch_train, logits)

    grads = tape.gradient(loss_value, model.trainable_weights)


    optimizer.apply_gradients(zip(grads, model.trainable_weights))

    if step % 200 == 0:
        print('Training loss (for one batch) at step %s: %s' % (step, float(loss_value)))
        print('Seen so far: %s samples' % ((step + 1) * 64))

Where should I insert the tensorboard logging for the model graph?

like image 680
prateek agrawal Avatar asked Sep 18 '25 04:09

prateek agrawal


1 Answers

The best way to do this (TF 2.3.0, TB 2.3.0) is to use tf.summary and pass the model through a wrapper function with the @tf.function decorator.

In your case, to export the model graph to TensorBoard for inspection:

inputs = keras.Input(shape=(784,), name='digits')
x = layers.Dense(64, activation='relu', name='dense_1')(inputs)
x = layers.Dense(64, activation='relu', name='dense_2')(x)
outputs = layers.Dense(10, activation='softmax', name='predictions')(x)
model = keras.Model(inputs=inputs, outputs=outputs)


# Solution Code:
writer = tf.summary.create_file_writer('./logs/')

@tf.function
def init_model(data, model):
    model(data)

tf.summary.trace_on(graph=True)
init_model(tf.zeros((1,784), model)

with writer.as_default():
    tf.summary.trace_export('name', step=0)

Looking at the logs files in TensorBoard you should see a model graph as such. enter image description here

like image 132
Cobes Avatar answered Sep 20 '25 18:09

Cobes