I need to extract weights, bias and at least the type of activation function from a trained NN in pytorch.
I know that to extract the weights and biases the command is:
model.parameters()
but I can't figure out how to extract also the activation function used on the layers.Here is my network
class NetWithODE(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output, sampling_interval, scaler_features):
super(NetWithODE, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
self.predict = torch.nn.Linear(n_hidden, n_output) # output layer
self.sampling_interval = sampling_interval
self.device = torch.device("cpu")
self.dtype = torch.float
self.scaler_features = scaler_features
def forward(self, x):
x0 = x.clone().requires_grad_(True)
# activation function for hidden layer
x = F.relu(self.hidden(x))
# linear output, here r should be the output
r = self.predict(x)
# Now the r enters the integrator
x = self.integrate(r, x0)
return x
def integrate(self, r, x0):
# RK4 steps per interval
M = 4
DT = self.sampling_interval / M
X = x0
for j in range(M):
k1 = self.ode(X, r)
k2 = self.ode(X + DT / 2 * k1, r)
k3 = self.ode(X + DT / 2 * k2, r)
k4 = self.ode(X + DT * k3, r)
X = X + DT / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
return X
def ode(self, x0, r):
qF = r[0, 0]
qA = r[0, 1]
qP = r[0, 2]
mu = r[0, 3]
FRU = x0[0, 0]
AMC = x0[0, 1]
PHB = x0[0, 2]
TBM = x0[0, 3]
fFRU = qF * TBM
fAMC = qA * TBM
fPHB = qP - mu * PHB
fTBM = mu * TBM
return torch.stack((fFRU, fAMC, fPHB, fTBM), 0)
if I run the command
print(model)
I get
NetWithODE(
(hidden): Linear(in_features=4, out_features=10, bias=True)
(predict): Linear(in_features=10, out_features=4, bias=True)
)
But where can I get the activation function (in this case Relu)?
I have pytorch 1.4.
There are two ways of adding operations to the network graph: lowlevel functional way and more advanced object way. You need latter to make your structure observable, In first case is just calling (not exactly, but...) a function without storing info about it. So, instead of
def forward(self, x):
...
x = F.relu(self.hidden(x))
it must be something like
def __init__(...):
...
self.myFirstRelu= torch.nn.ReLU()
def forward(self, x):
...
x1 = self.hidden(x)
x2 = self.myFirstRelu(x1)
Anyway a mix of two theese ways is generally bad idea, although even torchvision
models have such inconsistiencies: models.inception_v3
not register the poolings for example >:-( (EDIT: it is fixed in june 2020, thanks, mitmul!).
UPD: - Thanks, that works, now if I print I see ReLU(). But this seems to only print the function in the same order they are defined in the init. Is there a way to get the associations between layers and activation functions? For example I want to know which activation was applyed to layer 1, which to layer 2 end so on...
There is no uniform way, but here is some tricks: object way:
-just init them in order
-use torch.nn.Sequential
-hook callbacks on nodes like that -
def hook( m, i, o):
print( m._get_name() )
for ( mo ) in model.modules():
mo.register_forward_hook(hook)
functional and object way:
-make use of internal model graph, builded on forward pass, as torchviz
do (https://github.com/szagoruyko/pytorchviz/blob/master/torchviz/dot.py), or just use plot generated by said torchviz
.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With