I am trying to do research on batch normalization, and had to make some modifications for the pytorch BN code. I dig into the pytorch code and got stuck with torch.nn.functional.batch_norm, which references torch.batch_norm.
The problem is that torch.batch_norm cannot be further found in the torch library. Is there any way I can find the source code of this built-in function and re-implement it? Thanks!
It's there, but it's not defined in Python. They're defined in C++ in the aten/ directories.
For CPU, the implementation (one of them, it depends on whether or not the input is contiguous) is here: https://github.com/pytorch/pytorch/blob/420b37f3c67950ed93cd8aa7a12e673fcfc5567b/aten/src/ATen/native/Normalization.cpp#L61-L126
For CUDA, the implementation is here: https://github.com/pytorch/pytorch/blob/7aae51cdedcbf0df5a7a8bf50a947237ac4b3ee8/aten/src/ATen/native/cudnn/BatchNorm.cpp#L52-L143
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With