Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Modifying Pytorch Pretrained Model's Parameter in `forward()` Makes Training Slowing

Tags:

python

pytorch

I have two parameters, A and B, that I need to put to replace all the weight of the pre-trained model. So I want to utilize the forward calculation of the pre-trained model but not the weight.

I want to modify the weight of the model W = A + B, where A is a fixed tensor (not trainable), but B is a trainable parameter. So, in the end, my aim is to train B in the structure of the pre-trained model.

This is my attempt:

class Net(nn.Module):
    
    def __init__(self, pre_model, B):
    
        super(Net, self).__init__()
        self.B = B
        self.pre_model = copy.deepcopy(pre_model)
        for params in self.pre_model.parameters(): 
            params.requires_grad = False
    
    def forward(self, x, A):
        for i, params in enumerate(self.pre_model.parameters()):
            params.copy_(A[i].detach().clone()) # detached because I don't need to train A
            params.add_(self.B[i])      # I need to train B so no detach
            params.retain_grad()

        x = self.pre_model(x)
        return x

And this is how I calculate A and B:

b = []
A = []
for params in list(pre_model.parameters()):
    A.append(torch.rand_like(params))
    b_temp = nn.Parameter(torch.rand_like(params))
    b.append(b_temp.detach().clone())
B = nn.ParameterList(b)

I checked in the process, and B was already trained. But the problem is in every iteration, the training process keeps getting slower:

Epoch 1:
24%|██▍ | 47/196 [00:05<00:23, 6.44it/s]
57%|█████▋ | 111/196 [00:18<00:19, 4.28it/s]
96%|█████████▋| 189/196 [00:41<00:02, 2.90it/s]
Epoch 2:
6%|▌ | 11/196 [00:04<01:14, 2.50it/s]

I think I have detached all the parameters correctly, but I am not sure why it happened.

UPDATED:

credits to ptrblck from PyTorch Forum, you can run the minimal example code in Google Colab here. Or use the code below for the main iteration. You will see the training iteration keeps getting slower and slower.

from torch.cuda import synchronize
device = 'cuda'
pre_model = models.resnet18().to(device)
b = []
A = []
for params in list(pre_model.parameters()):
    A.append(torch.rand_like(params))
    b_temp = nn.Parameter(torch.rand_like(params))
    b.append(b_temp.detach().clone())
B = nn.ParameterList(b)

modelwithAB = Net(pre_model, B)
optimizer = torch.optim.Adam(modelwithAB.parameters(), lr=1e-3)

image = torch.randn(2, 3, 224, 224).to(device)
print(torch.cuda.memory_allocated()/1024**2)

for i in tqdm(range(300)):
    optimizer.zero_grad()
    out = modelwithAB(image, A)
    start = time.time()
    out.mean().backward()
    torch.cuda.synchronize()
    optimizer.step()
    if i%40==0:
        print("-", torch.cuda.memory_allocated()/1024**2, "-", time.time()-start)
like image 324
malioboro Avatar asked Nov 06 '25 12:11

malioboro


1 Answers

ptrblck: Based on your code snippet you are detaching A, which is the fixed tensor, while you are adding B to params potentially including its entire computation graph. Could you double check this, please?

Well, let's investigate! This code is main suspect:

def forward(self, x, A):
    for i, params in enumerate(self.pre_model.parameters()):
        params.copy_(A[i].detach().clone()) # detached because I don't need to train A
        params.add_(self.B[i])      # I need to train B so no detach
        params.retain_grad()

    x = self.pre_model(x)
    return x

We will vary iterations in yours for i in tqdm(range(300)) while watching graph size.

Check model body with torchinfo.summary

!pip install torchinfo

from torchinfo import summary

...
def forward(self, x, A=None):
if A is None:
    A=self.B
...

summary( modelwithAB, input_size=image.shape, depth=10)

for i in tqdm(range(2)):

===============================================================================================
Layer (type:depth-idx)                        Output Shape              Param #
===============================================================================================
Net                                           [2, 1000]                 11,689,512
├─ResNet: 1-1                                 [2, 1000]                 --
│    └─Conv2d: 2-1                            [2, 64, 112, 112]         9,408
│    └─BatchNorm2d: 2-2                       [2, 64, 112, 112]         128
│    └─ReLU: 2-3                              [2, 64, 112, 112]         --    
...
===============================================================================================
Total params: 23,379,024
Trainable params: 23,379,024
Non-trainable params: 0
Total mult-adds (G): 3.63
===============================================================================================
Input size (MB): 1.20
Forward/backward pass size (MB): 79.49
Params size (MB): 46.76
Estimated Total Size (MB): 127.46
===============================================================================================

for i in tqdm(range(30)):

===============================================================================================
Total params: 23,379,024
Trainable params: 23,379,024
Non-trainable params: 0
Total mult-adds (G): 3.63
===============================================================================================
Input size (MB): 1.20
Forward/backward pass size (MB): 79.49
Params size (MB): 46.76
Estimated Total Size (MB): 127.46
===============================================================================================

Well, looks like all right here, need to go deeper. Now let's check gradient computation graph, we will count all backward nodes of different kind.

xxx = dict()
def add_nodes(var, consumer =None, grad =None ):
    if hasattr(var, 'next_functions'):
        try:
            grads = var(grad)
            grads = grads if isinstance(grads, tuple) else [grads]
            if ( not hasattr(grads, '__iter__') ):
                grads = [grads]
        except:
            grads = map (  (lambda x:None), var.next_functions )
        for i, (u, grad) in enumerate(zip(var.next_functions, grads)):
            #print(i,type( var ).__name__)
            xxx[type( var ).__name__] = xxx.get(type( var ).__name__,0) +1
            for uu in u:
                add_nodes(uu, var, grad)

add_nodes( modelwithAB(image, A).grad_fn )
print(xxx)

1:
{'AddmmBackward0': 3, 'AddBackward0': 10214, 'CopyBackwards': 9704, 'ReshapeAliasBackward0': 1, 'MeanBackward1': 1, 'ReluBackward0': 766, 'CudnnBatchNormBackward0': 2424, 'ConvolutionBackward0': 2424, 'MaxPool2DWithIndicesBackward0': 256, 'TBackward0': 1}

2:
{'AddmmBackward0': 3, 'AddBackward0': 15066, 'CopyBackwards': 14556, 'ReshapeAliasBackward0': 1, 'MeanBackward1': 1, 'ReluBackward0': 766, 'CudnnBatchNormBackward0': 2424, 'ConvolutionBackward0': 2424, 'MaxPool2DWithIndicesBackward0': 256, 'TBackward0': 1}

3:
{'AddmmBackward0': 3, 'AddBackward0': 19918, 'CopyBackwards': 19408, 'ReshapeAliasBackward0': 1, 'MeanBackward1': 1, 'ReluBackward0': 766, 'CudnnBatchNormBackward0': 2424, 'ConvolutionBackward0': 2424, 'MaxPool2DWithIndicesBackward0': 256, 'TBackward0': 1}

30:
{'AddmmBackward0': 3, 'AddBackward0': 150922, 'CopyBackwards': 150412, 'ReshapeAliasBackward0': 1, 'MeanBackward1': 1, 'ReluBackward0': 766, 'CudnnBatchNormBackward0': 2424, 'ConvolutionBackward0': 2424, 'MaxPool2DWithIndicesBackward0': 256, 'TBackward0': 1}

Bingo. Count of AddBackward0 and CopyBackwards nodes rizes every trainng step. So, whetever you expected to acheive, you can't do it with such parameter manipulation. Can't suggest a fix, because as well as ptrblck I am not sure what you are up to. Why you don't happy with standart aproach - train head with fixed trained backbone?

like image 101
Alexey Birukov Avatar answered Nov 08 '25 10:11

Alexey Birukov



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!