I have a long NumPy array that I need to organize in a "pattern" like so:
import numpy as np
# long array:
a = np.array(
[
1.3,
-1.8,
0.3,
11.4,
# ...
]
)
def pattern(x: float):
return np.array(
[
[x, 0, 0],
[0, x, 0],
[0, 0, x],
[+x, +x, +x],
[-x, -x, -x],
]
)
out = np.array([pattern(x) for x in a])
print(out.shape)
(4, 5, 3)
I'm wondering if there's a way to construct out
without the explicit loop over a
.
Any ideas?
Replace the pattern
function by an array of -1/0/1 values and broadcast a multiplication:
pattern = np.array([[ 1, 0, 0],
[ 0, 1, 0],
[ 0, 0, 1],
[+1, +1, +1],
[-1, -1, -1]])
out = a[:,None,None] * pattern[None]
Variant with einsum
:
out = np.einsum('i,jk->ijk', a, pattern)
Another approach with select
(less efficient):
m1 = np.array([[ True, False, False],
[False, True, False],
[False, False, True],
[ True, True, True],
[False, False, False]])
m2 = np.array([[False, False, False],
[False, False, False],
[False, False, False],
[False, False, False],
[ True, True, True]])
x = a[:,None,None]
out = np.select([m1, m2], [x, -x], 0)
Output:
array([[[ 1.3, 0. , 0. ],
[ 0. , 1.3, 0. ],
[ 0. , 0. , 1.3],
[ 1.3, 1.3, 1.3],
[ -1.3, -1.3, -1.3]],
[[ -1.8, -0. , -0. ],
[ -0. , -1.8, -0. ],
[ -0. , -0. , -1.8],
[ -1.8, -1.8, -1.8],
[ 1.8, 1.8, 1.8]],
[[ 0.3, 0. , 0. ],
[ 0. , 0.3, 0. ],
[ 0. , 0. , 0.3],
[ 0.3, 0.3, 0.3],
[ -0.3, -0.3, -0.3]],
[[ 11.4, 0. , 0. ],
[ 0. , 11.4, 0. ],
[ 0. , 0. , 11.4],
[ 11.4, 11.4, 11.4],
[-11.4, -11.4, -11.4]]])
Timings:
# list comprehension
14.8 µs ± 1.26 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# broadcasting
1.87 µs ± 85.4 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
# einsum
3.07 µs ± 184 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# select
42.2 µs ± 2.32 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Maybe you can try np.kron
+ np.reshape
, e.g.,
pat = np.array(
[
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[1, 1, 1],
[-1, -1, -1],
])
np.kron(a.reshape(-1, 1), pat).reshape(-1, *np.shape(pat))
and you will obtain
array([[[ 1.3, 0. , 0. ],
[ 0. , 1.3, 0. ],
[ 0. , 0. , 1.3],
[ 1.3, 1.3, 1.3],
[ -1.3, -1.3, -1.3]],
[[ -1.8, -0. , -0. ],
[ -0. , -1.8, -0. ],
[ -0. , -0. , -1.8],
[ -1.8, -1.8, -1.8],
[ 1.8, 1.8, 1.8]],
[[ 0.3, 0. , 0. ],
[ 0. , 0.3, 0. ],
[ 0. , 0. , 0.3],
[ 0.3, 0.3, 0.3],
[ -0.3, -0.3, -0.3]],
[[ 11.4, 0. , 0. ],
[ 0. , 11.4, 0. ],
[ 0. , 0. , 11.4],
[ 11.4, 11.4, 11.4],
[-11.4, -11.4, -11.4]]])
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With