Hi can someone improve this code ? The code is about Adaptive Median Filter. When working on large image the code is so slow.
import numpy as np
def padding(img,pad):
padded_img = np.zeros((img.shape[0]+2*pad,img.shape[1]+2*pad))
padded_img[pad:-pad,pad:-pad] = img
return padded_img
def AdaptiveMedianFilter(img,s=3,sMax=7):
if len(img.shape) == 3:
raise Exception ("Single channel image only")
H,W = img.shape
a = sMax//2
padded_img = padding(img,a)
f_img = np.zeros(padded_img.shape)
for i in range(a,H+a+1):
for j in range(a,W+a+1):
value = Lvl_A(padded_img,i,j,s,sMax)
f_img[i,j] = value
return f_img[a:-a,a:-a]
def Lvl_A(mat,x,y,s,sMax):
window = mat[x-(s//2):x+(s//2)+1,y-(s//2):y+(s//2)+1]
Zmin = np.min(window)
Zmed = np.median(window)
Zmax = np.max(window)
A1 = Zmed - Zmin
A2 = Zmed - Zmax
if A1 > 0 and A2 < 0:
return Lvl_B(window)
else:
s += 2
if s <= sMax:
return Lvl_A(mat,x,y,s,sMax)
else:
return Zmed
def Lvl_B(window):
h,w = window.shape
Zmin = np.min(window)
Zmed = np.median(window)
Zmax = np.max(window)
Zxy = window[h//2,w//2]
B1 = Zxy - Zmin
B2 = Zxy - Zmax
if B1 > 0 and B2 < 0 :
return Zxy
else:
return Zmed
Is there any way to improve this code ? For example using vectorized sliding window ? I dont know how to use what numpy function. Ps: For boundary checking its using padding so it dont have to check for out of bounds.
The numba's njit is perfect for such kind of computation. Mixed with the parallel=True+prange it can be much faster. Moreover, you can pass the minimum, maximum and median values to Lvl_B rather than recomputing them as @CrisLuengo pointed out.
Here is the modified code:
import numpy as np
from numba import njit,prange
@njit
def padding(img,pad):
padded_img = np.zeros((img.shape[0]+2*pad,img.shape[1]+2*pad))
padded_img[pad:-pad,pad:-pad] = img
return padded_img
@njit(parallel=True)
def AdaptiveMedianFilter(img,s=3,sMax=7):
if len(img.shape) == 3:
raise Exception ("Single channel image only")
H,W = img.shape
a = sMax//2
padded_img = padding(img,a)
f_img = np.zeros(padded_img.shape)
for i in prange(a,H+a+1):
for j in range(a,W+a+1):
value = Lvl_A(padded_img,i,j,s,sMax)
f_img[i,j] = value
return f_img[a:-a,a:-a]
@njit
def Lvl_A(mat,x,y,s,sMax):
window = mat[x-(s//2):x+(s//2)+1,y-(s//2):y+(s//2)+1]
Zmin = np.min(window)
Zmed = np.median(window)
Zmax = np.max(window)
A1 = Zmed - Zmin
A2 = Zmed - Zmax
if A1 > 0 and A2 < 0:
return Lvl_B(window, Zmin, Zmed, Zmax)
else:
s += 2
if s <= sMax:
return Lvl_A(mat,x,y,s,sMax)
else:
return Zmed
@njit
def Lvl_B(window, Zmin, Zmed, Zmax):
h,w = window.shape
Zxy = window[h//2,w//2]
B1 = Zxy - Zmin
B2 = Zxy - Zmax
if B1 > 0 and B2 < 0 :
return Zxy
else:
return Zmed
This code is 500 times faster on my machine with a 256x256 random image.
Note that the first call will not be much faster due to the (included) compilation time.
Note also that the computation can be even faster by not recomputing the min/max/median for each value as the sliding windows share many values (see the paper constant time median filtering (Perreault et al, 2007)).
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With