Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to visualize nested `tf.keras.Model (SubClassed API)` GAN model with plot_model?

Models implemented as subclasses of keras. Model can generally not be visualized with plot_model. There is a workaround as described here. However, it only applies to simple models. As soon as a model is enclosed by another model, the nestings will not be resolved.

I am looking for a way to resolve nested models implemented as subclasses of the keras. Model. As an example, I have created a minimal GAN model:

import keras
from keras import layers
from tensorflow.python.keras.utils.vis_utils import plot_model


class BaseModel(keras.Model):
    def __init__(self, *args, **kwargs):
        super(BaseModel, self).__init__(*args, **kwargs)

    def call(self, inputs, training=None, mask=None):
        super(BaseModel, self).call(inputs=inputs, training=training, mask=mask)

    def get_config(self):
        super(BaseModel, self).get_config()

    def build_graph(self, raw_shape):
        """ Plot models that subclass `keras.Model`

        Adapted from https://stackoverflow.com/questions/61427583/how-do-i-plot-a-keras-tensorflow-subclassing-api-model

        :param raw_shape: Shape tuple not containing the batch_size
        :return:
        """
        x = keras.Input(shape=raw_shape)
        return keras.Model(inputs=[x], outputs=self.call(x))


class GANModel(BaseModel):
    def __init__(self, generator, discriminator):
        super(GANModel, self).__init__()
        self.generator = generator
        self.discriminator = discriminator

    def call(self, input_tensor, training=False, mask=None):
        x = self.generator(input_tensor)
        x = self.discriminator(x)
        return x


class DiscriminatorModel(BaseModel):
    def __init__(self, name="Critic"):
        super(DiscriminatorModel, self).__init__(name=name)
        self.l1 = layers.Conv2D(64, 2, activation=layers.ReLU())
        self.flat = layers.Flatten()
        self.dense = layers.Dense(1)

    def call(self, inputs, training=False, mask=None):
        x = self.l1(inputs, training=training)
        x = self.flat(x)
        x = self.dense(x, training=training)
        return x


class GeneratorModel(BaseModel):
    def __init__(self, name="Generator"):
        super(GeneratorModel, self).__init__(name=name)
        self.dense = layers.Dense(128, activation=layers.ReLU())
        self.reshape = layers.Reshape((7, 7, 128))
        self.out = layers.Conv2D(1, (7, 7), activation='tanh', padding="same")

    def call(self, inputs, training=False, mask=None):
        x = self.dense(inputs, training=training)
        x = self.reshape(x)
        x = self.out(x, training=training)
        return x


g = GeneratorModel()
d = DiscriminatorModel()

plot_model(g.build_graph((7, 7, 1)), to_file="generator_model.png",
           expand_nested=True, show_shapes=True)

gan = GANModel(generator=g, discriminator=d)
plot_model(gan.build_graph((7, 7, 1)), to_file="gan_model.png", 
           expand_nested=True, show_shapes=True)

Edit

Using the functional keras API I get the desired result (see here). The nested models are correctly resolved within the GAN model.

from keras import Model, layers, optimizers
from tensorflow.python.keras.utils.vis_utils import plot_model


def get_generator(input_dim):
    initial = layers.Input(shape=input_dim)

    x = layers.Dense(128, activation=layers.ReLU())(initial)
    x = layers.Reshape((7, 7, 128))(x)
    x = layers.Conv2D(1, (7, 7), activation='tanh', padding="same")(x)

    return Model(inputs=initial, outputs=x, name="Generator")


def get_discriminator(input_dim):
    initial = layers.Input(shape=input_dim)

    x = layers.Conv2D(64, 2, activation=layers.ReLU())(initial)
    x = layers.Flatten()(x)
    x = layers.Dense(1)(x)

    return Model(inputs=initial, outputs=x, name="Discriminator")

def get_gan(input_dim, latent_dim):
    initial = layers.Input(shape=input_dim)

    x = get_generator(input_dim)(initial)
    x = get_discriminator(latent_dim)(x)

    return Model(inputs=initial, outputs=x, name="GAN")



m = get_generator((7, 7, 1))
m.compile(optimizer=optimizers.Adam())

plot_model(m, expand_nested=True, show_shapes=True, to_file="generator_model_functional.png")

gan = get_gan((7, 7, 1), (7, 7, 1))
plot_model(gan, expand_nested=True, show_shapes=True, to_file="gan_model_functional.png")
like image 612
Molitoris Avatar asked Oct 27 '25 06:10

Molitoris


1 Answers

Whenever you pass each generator and discriminator to GANModel, they act like an encompassed child layer consisting of n times layers. So, if you plot only the generator model by the GANModel instances, it will show as follows (same goes to discriminator) unlike plots while using them separately.

The fact is while we pass data at this point using the call() method of GANModel, the input passes implicitly all internal layers (generator, discriminator) according to its design. Here I will show you two workaround for this to get your desired plot.

enter image description here


Option 1

I believe you probably guess the method. In the GANModel model, we will pass the input very explicitly to each internal layer of those child layers (generator, discriminator).

class GANModel(BaseModel):
    def __init__(self, generator, discriminator):
        super(GANModel, self).__init__()
        self.generator = generator
        self.discriminator = discriminator

    def call(self, input_tensor, training=False, mask=None):
        x = input_tensor

        for gen_lyr in self.generator.layers:
            print(gen_lyr) # checking 
            x = gen_lyr(x)

        for disc_lyr in self.discriminator.layers:
            print(disc_lyr) # checking 
            x = disc_lyr(x)

        return x

If you plot now, you will get

# All Internal Layers of self.generator, self.discriminator
<tensorflow.python.keras.layers.core.Dense object at 0x7f2a472a3710>
<tensorflow.python.keras.layers.core.Reshape object at 0x7f2a461e8f50>
<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7f2a44591f90>
<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7f2a47317290>
<tensorflow.python.keras.layers.core.Flatten object at 0x7f2a47317ed0>
<tensorflow.python.keras.layers.core.Dense object at 0x7f2a57f42910>

enter image description here


Option 2

I think it's a bit ugly approach. First, we take each internal layer and build a Sequential model with them. Then use .build to create its input layer. BOOM.

gan = GANModel(generator=g, discriminator=d)

all_layer = []
for layer in gan.layers: 
    all_layer.extend(layer.layers)

gan_plot = tf.keras.models.Sequential(all_layer)
gan_plot.build((None,7,7,1))
list(all_layer)

[<tensorflow.python.keras.layers.core.Dense at 0x7f2a461ab390>,
 <tensorflow.python.keras.layers.core.Reshape at 0x7f2a46156110>,
 <tensorflow.python.keras.layers.convolutional.Conv2D at 0x7f2a461fedd0>,
 <tensorflow.python.keras.layers.convolutional.Conv2D at 0x7f2a461500d0>,
 <tensorflow.python.keras.layers.core.Flatten at 0x7f2a4613ea10>,
 <tensorflow.python.keras.layers.core.Dense at 0x7f2a462cae10>]
tf.keras.utils.plot_model(gan_plot, expand_nested=True, show_shapes=True)
like image 116
M.Innat Avatar answered Oct 28 '25 19:10

M.Innat