Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to create sub network reference in pytorch?

The simple explanation for what I aim to do is the following

Given network O with structure of

---------           ----------            ----------            ----------            ----------
| Input | -> (x) -> | Blob A | -> (xa) -> | Blob B | -> (xb) -> | Blob C | -> (xc) -> | Output |
---------           ----------            ----------            ----------            ----------

I want to create a sub network to calculate a noise loss function for Blob C. The operation is given input xb with original output xc, and pass xb + noise through Blob C again to get xc'. Then the mse_loss is compute between xc and xc'

I have tried creating nn.Sequential from the original model. But I am not sure that it created a new deep copy or a reference.

If I have missed anything, please comment

Thank you

like image 487
Andrew Lister Avatar asked Oct 19 '25 07:10

Andrew Lister


1 Answers

So after some testing, I found out that if the layer references is kept (like in some variable), and then create a new model using nn.Sequential with that layer architecture, the new model will share the same layer reference. So when the original network is updated, the new model is also updated.

The code I used to test my hypothesis is the following

class TestNN(nn.Module):
    def __init__(self):
        super(TestNN, self).__init__()

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(64, 3, kernel_size=3, padding=1)
        self.relu3 = nn.ReLU()

    def forward(self, x):
        in_x = x
        h = self.relu1(self.conv1(in_x))
        h = self.relu2(self.conv2(h))
        h = self.relu3(self.conv3(h))
        return h

net = TestNN()

testInput = torch.from_numpy(np.random.rand(1, 3, 3, 3)).float()
target = torch.from_numpy(np.random.rand(1, 3, 3, 3)).float()

criterion = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)

def subnetwork(model, start_layer_idx, end_layer_idx):
    subnetwork = nn.Sequential()
    for idx, layer in enumerate(list(model)[start_layer_idx: end_layer_idx+1]):
        subnetwork.add_module("layer_{}".format(idx), layer)
    return subnetwork

start = subnetwork(net.children(), 0, 1)
middle = subnetwork(net.children(), 2, 3)
end = subnetwork(net.children(), 4, 5)

print(end(middle(start(testInput))))
print(net(testInput))

for idx in range(5):
    net.zero_grad()
    out = net(testInput)
    loss = criterion(out, target)

    print("[{}] {:4f}".format(idx, loss))
    loss.backward()
    optimizer.step()

print(end(middle(start(testInput))))
print(net(testInput))

The output before and after the training is the same. So I concluded that my hypothesis is correct.

To finish my objective, I created a 'transparent' loss like from this tutorial.

class NoiseLoss(nn.Module):

    def __init__(self, subnet, noise_count = 20, noise_range=0.3):
        super(NoiseLoss, self).__init__()
        self.net = subnet
        self.noise_count = noise_count
        self.noise_range = noise_range

    def add_noise(self, x):
        b, c, h, w = x.size()
        noise = torch.zeros(c, h, w)
        for i in range(self.noise_count):
            row, col = rng.randint(0, h-1), rng.randint(0, w-1)
            for j in range(c):
                noise[j,row,col] = (2*(rng.random()%self.noise_range)) - self.noise_range
        noise = noise.float()
        xp = x.clone()
        for b_idx in range(b):
            xp[b_idx,:,:,:] = xp[b_idx,:,:,:] + noise
        return xp

    def forward(self, x):
        self.loss = F.mse_loss(x, self.add_noise(x))
        print(self.loss)
        return x

noise_losses = []
testLoss = NoiseLoss(subnetwork(net.children(), 2, 3))
middle.add_module('noise_loss_test', testLoss)
noise_losses.append(testLoss)

and modify my loop to

    ...
    print("[{}] {:4f}".format(idx, loss))
    for nl in noise_losses:
        loss += nl.loss
    loss.backward(retain_graph=True)
    ...

If I miss something, please leave a comment

like image 154
Andrew Lister Avatar answered Oct 20 '25 19:10

Andrew Lister



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!