Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Input_shape for build method in TensorFlow custom layer with multiple inputs

I have to design a neural network that takes two inputs X_1 and X_2. The layer transforms them to fixed-size vectors(10D) and then sums them in the following manner

class my_lyr(tf.keras.layers.Layer):
    def __init__(self):
        pass
    def call(self, X_1, X_2):
        return X_1 @ self.w1 + X_2 @ self.w2  

However, I need to know the input shape of X_1 and X_2 before I initialize w1 and w2. I'm not sure how can I declare w2 in build.

def build(self, input_shape):
    self.w1 = self.add_weight('w1', shape=[input_shape[-1],10])
    // self.w2 = ?????

I want to know how to build methods are usually written in such cases.

like image 840
Lawhatre Avatar asked Jan 30 '26 16:01

Lawhatre


1 Answers

If you've two input of such layer, then you can simply initialize your weights something like as follows

import tensorflow as tf 
from tensorflow import keras 

class Linear(keras.layers.Layer):
    def __init__(self, units=32):
        super(Linear, self).__init__()
        self.units = units

    def build(self, input_shape):
        self.wa = self.add_weight(
            shape=(input_shape[0][-1], self.units),
            initializer="random_normal",
            trainable=True,
        )

        self.wb = self.add_weight(
            shape=(input_shape[1][-1], self.units),
            initializer="random_normal",
            trainable=True,
        )

    def call(self, inputs):
        return tf.matmul(inputs[0], self.wa) + tf.matmul(inputs[1], self.wb)

Passing inputs

x = tf.random.normal(shape=(2,2))
linear_layer = Linear(32)
linear_layer([x, x])
<tf.Tensor: shape=(2, 32), dtype=float32, numpy=
array([[-0.08829461, -0.01605312, -0.04368614, -0.08116315, -0.01521384,
         0.01132785,  0.10704445, -0.10873697, -0.0525714 ,  0.07684848,
         0.04586978,  0.01315852,  0.01369547,  0.07404792,  0.10313608,
        -0.10851607,  0.04091477, -0.01723676, -0.0326797 ,  0.03598418,
        -0.11335816, -0.10044714,  0.13555384,  0.01689356,  0.02631954,
         0.08226107, -0.08765724, -0.05981663,  0.00531629,  0.02930426,
         0.04155847,  0.05339598],
       [ 0.20617458, -0.05936547,  0.01735754, -0.06575315,  0.10090968,
        -0.07796012, -0.1956767 , -0.03406558,  0.18604615, -0.03547171,
         0.02784208,  0.0471364 , -0.10712875, -0.07869454, -0.19457275,
         0.13593757, -0.14659101,  0.0384632 ,  0.02344182, -0.03861775,
         0.08948556,  0.09225713, -0.17395493,  0.10021958, -0.09210777,
        -0.09865301,  0.2536609 , -0.02547608,  0.02885125, -0.01271547,
        -0.10340843, -0.0338558 ]], dtype=float32)>
like image 174
M.Innat Avatar answered Feb 01 '26 05:02

M.Innat