I ran the following benchmark from here .
#!/usr/bin/env python3
import torch
def batched_dot_mul_sum(a, b):
'''Computes batched dot by multiplying and summing'''
return a.mul(b).sum(-1)
def batched_dot_bmm(a, b):
'''Computes batched dot by reducing to ``bmm``'''
a = a.reshape(-1, 1, a.shape[-1])
b = b.reshape(-1, b.shape[-1], 1)
return torch.bmm(a, b).flatten(-3)
# Input for benchmarking
x = torch.randn(10000, 64)
# Ensure that both functions compute the same output
assert batched_dot_mul_sum(x, x).allclose(batched_dot_bmm(x, x))
import timeit
t0 = timeit.Timer(
stmt='batched_dot_mul_sum(x, x)',
setup='from __main__ import batched_dot_mul_sum',
globals={'x': x})
t1 = timeit.Timer(
stmt='batched_dot_bmm(x, x)',
setup='from __main__ import batched_dot_bmm',
globals={'x': x})
print(f'mul_sum(x, x): {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(x, x): {t1.timeit(100) / 100 * 1e6:>5.1f} us')
I got
mul_sum(x, x): 1065.9 us
bmm(x, x): 134.5 us
on mac
and
mul_sum(x, x): 52.3 us
bmm(x, x): 120.1 us
on linux CPU
I'm seeing a huge performance difference, is this expected?
I first noticed this difference on a more serious program, and am trying to replicate it here.
John Zavialov answer goes over the general issues, I'll briefly list them here. This portion is basically going to be summarizing that answer and get into how to speed it up.
1.Optimization Progress: PyTorch's adaptation to the Apple Silicon architecture is still undergoing refinement and is not as mature as Linux's setup.
Architecture-Specific Tuning: PyTorch is setup for specific Architectures, which means that it may not have as solid performance on each system.
Instruction Set Variation: The instruction set architecture significantly impacts the execution efficiency of various operations with ARM-based systems(M1 Pros) being distinct then x86_84 (Linux) which can lead to big differences in performance
In order to fix this or speed this up. Mac has added new Metal Performance Shader. If you activate this for Pytorch on Mac you should see a performance boost. You can see installation instructions in the link, I put down a code example to test and activate it:
import torch
if torch.backends.mps.is_available():
mps_device = torch.device("mps")
x = torch.ones(1, device=mps_device)
print (x)
else:
print ("MPS device not found.")
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With