Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Create combined numpy array without explicit loop/comprehension

I have a long NumPy array that I need to organize in a "pattern" like so:

import numpy as np

# long array:
a = np.array(
    [
        1.3,
        -1.8,
        0.3,
        11.4,
        # ...
    ]
)


def pattern(x: float):
    return np.array(
        [
            [x, 0, 0],
            [0, x, 0],
            [0, 0, x],
            [+x, +x, +x],
            [-x, -x, -x],
        ]
    )


out = np.array([pattern(x) for x in a])
print(out.shape)
(4, 5, 3)

I'm wondering if there's a way to construct out without the explicit loop over a.

Any ideas?

like image 893
Nico Schlömer Avatar asked Sep 03 '25 17:09

Nico Schlömer


2 Answers

Replace the pattern function by an array of -1/0/1 values and broadcast a multiplication:

pattern = np.array([[ 1,  0,  0],
                    [ 0,  1,  0],
                    [ 0,  0,  1],
                    [+1, +1, +1],
                    [-1, -1, -1]])

out = a[:,None,None] * pattern[None]

Variant with einsum:

out = np.einsum('i,jk->ijk', a, pattern)

Another approach with select (less efficient):

m1 = np.array([[ True, False, False],
               [False,  True, False],
               [False, False,  True],
               [ True,  True,  True],
               [False, False, False]])
m2 = np.array([[False, False, False],
               [False, False, False],
               [False, False, False],
               [False, False, False],
               [ True,  True,  True]])

x = a[:,None,None]
out = np.select([m1, m2], [x, -x], 0)

Output:

array([[[  1.3,   0. ,   0. ],
        [  0. ,   1.3,   0. ],
        [  0. ,   0. ,   1.3],
        [  1.3,   1.3,   1.3],
        [ -1.3,  -1.3,  -1.3]],

       [[ -1.8,  -0. ,  -0. ],
        [ -0. ,  -1.8,  -0. ],
        [ -0. ,  -0. ,  -1.8],
        [ -1.8,  -1.8,  -1.8],
        [  1.8,   1.8,   1.8]],

       [[  0.3,   0. ,   0. ],
        [  0. ,   0.3,   0. ],
        [  0. ,   0. ,   0.3],
        [  0.3,   0.3,   0.3],
        [ -0.3,  -0.3,  -0.3]],

       [[ 11.4,   0. ,   0. ],
        [  0. ,  11.4,   0. ],
        [  0. ,   0. ,  11.4],
        [ 11.4,  11.4,  11.4],
        [-11.4, -11.4, -11.4]]])

Timings:

# list comprehension
14.8 µs ± 1.26 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

# broadcasting
1.87 µs ± 85.4 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

# einsum
3.07 µs ± 184 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

# select
42.2 µs ± 2.32 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
like image 74
mozway Avatar answered Sep 05 '25 07:09

mozway


Maybe you can try np.kron + np.reshape, e.g.,

pat = np.array(
        [
            [1, 0, 0],
            [0, 1, 0],
            [0, 0, 1],
            [1, 1, 1],
            [-1, -1, -1],
        ])

np.kron(a.reshape(-1, 1), pat).reshape(-1, *np.shape(pat))

and you will obtain

array([[[  1.3,   0. ,   0. ],
        [  0. ,   1.3,   0. ],
        [  0. ,   0. ,   1.3],
        [  1.3,   1.3,   1.3],
        [ -1.3,  -1.3,  -1.3]],

       [[ -1.8,  -0. ,  -0. ],
        [ -0. ,  -1.8,  -0. ],
        [ -0. ,  -0. ,  -1.8],
        [ -1.8,  -1.8,  -1.8],
        [  1.8,   1.8,   1.8]],

       [[  0.3,   0. ,   0. ],
        [  0. ,   0.3,   0. ],
        [  0. ,   0. ,   0.3],
        [  0.3,   0.3,   0.3],
        [ -0.3,  -0.3,  -0.3]],

       [[ 11.4,   0. ,   0. ],
        [  0. ,  11.4,   0. ],
        [  0. ,   0. ,  11.4],
        [ 11.4,  11.4,  11.4],
        [-11.4, -11.4, -11.4]]])
like image 31
ThomasIsCoding Avatar answered Sep 05 '25 08:09

ThomasIsCoding