On 18th May 2022, PyTorch announced support for GPU-accelerated PyTorch training on Mac.
I followed the following process to set up PyTorch on my Macbook Air M1 (using miniconda).
conda create -n torch-nightly python=3.8
$ conda activate torch-nightly
$ pip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
I am trying to execute a script from Udacity's Deep Learning Course available here.
The script moves the models to GPU using the following code:
G.cuda()
D.cuda()
However, this will not work on M1 chips, since there is no CUDA.
If we want to move models to M1 GPU and our tensors to M1 GPU, and train entirely on M1 GPU, what should we be doing?
If Relevant: G and D are Discriminator and Generators for GAN's.
class Discriminator(nn.Module):
def __init__(self, conv_dim=32):
super(Discriminator, self).__init__()
self.conv_dim = conv_dim
# complete init function
self.cv1 = conv(in_channels=3, out_channels=conv_dim, kernel_size=4, stride=2, padding=1, batch_norm=False) # 32*32*3 -> 16*16*32
self.cv2 = conv(in_channels=conv_dim, out_channels=conv_dim*2, kernel_size=4, stride=2, padding=1, batch_norm=True) # 16*16*32 -> 8*8*64
self.cv3 = conv(in_channels=conv_dim*2, out_channels=conv_dim*4, kernel_size=4, stride=2, padding=1, batch_norm=True) # 8*8*64 -> 4*4*128
self.fc1 = nn.Linear(in_features = 4*4*conv_dim*4, out_features = 1, bias=True)
def forward(self, x):
# complete forward function
out = F.leaky_relu(self.cv1(x), 0.2)
out = F.leaky_relu(self.cv2(x), 0.2)
out = F.leaky_relu(self.cv3(x), 0.2)
out = out.view(-1, 4*4*conv_dim*4)
out = self.fc1(out)
return out
D = Discriminator(conv_dim)
class Generator(nn.Module):
def __init__(self, z_size, conv_dim=32):
super(Generator, self).__init__()
self.conv_dim = conv_dim
self.z_size = z_size
# complete init function
self.fc1 = nn.Linear(in_features = z_size, out_features = 4*4*conv_dim*4)
self.dc1 = deconv(in_channels = conv_dim*4, out_channels = conv_dim*2, kernel_size=4, stride=2, padding=1, batch_norm=True)
self.dc2 = deconv(in_channels = conv_dim*2, out_channels = conv_dim, kernel_size=4, stride=2, padding=1, batch_norm=True)
self.dc3 = deconv(in_channels = conv_dim, out_channels = 3, kernel_size=4, stride=2, padding=1, batch_norm=False)
def forward(self, x):
# complete forward function
x = self.fc1(x)
x = x.view(-1, conv_dim*4, 4, 4)
x = F.relu(self.dc1(x))
x = F.relu(self.dc2(x))
x = F.tanh(self.dc3(x))
return x
G = Generator(z_size=z_size, conv_dim=conv_dim)
This is what I used:
if torch.backends.mps.is_available():
mps_device = torch.device("mps")
G.to(mps_device)
D.to(mps_device)
Similarly for all tensors that I want to move to M1 GPU, I used:
tensor_ = tensor_(mps_device)
Some operations are ot yet implemented using MPS, and we might need to set a few environment variables to use CPU fall back instead: One error that I faced during executing the script was
# NotImplementedError: The operator 'aten::_slow_conv2d_forward' is not current implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
To solve it I set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1
conda env config vars set PYTORCH_ENABLE_MPS_FALLBACK=1
conda activate <test-env>
References:
I'd like to add to the answer above by specifying that we should make sure we're using the native Python arm64 version (3.9.x) for M1 while installing the mps build. If you're on conda do:
import platform
print(platform.platform())
to check whether x86 or arm64 is being used. The two errors I encountered were:
RuntimeError: Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, ve, ort, mlc, xla, lazy, vulkan, meta, hpu device type at start of device string: mps` and `AttributeError: module 'torch.backends' has no attribute 'mps'
This is because even though I had installed the required Pytorch versions, I was still running Python x86.
To fix these, do:
That works for me, although pytorch on MPS is still extremely new and buggy. Hope it gets better soon.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With