Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why is the java vector API so slow compared to scalar?

I recently decided to play around with Java's new incubated vector API, to see how fast it can get. I implemented two fairly simple methods, one for parsing an int and one for finding the index of a character in a string. In both cases, my vectorized methods were incredibly slow compared to their scalar equivalents.

Here's my code:

public class SIMDParse {

private static IntVector mul = IntVector.fromArray(
        IntVector.SPECIES_512,
        new int[] {0, 0, 0, 0, 0, 0, 1000000000, 100000000, 10000000, 1000000, 100000, 10000, 1000, 100, 10, 1},
        0
);
private static byte zeroChar = (byte) '0';
private static int width = IntVector.SPECIES_512.length();
private static byte[] filler;

static {
    filler = new byte[16];
    for (int i = 0; i < 16; i++) {
        filler[i] = zeroChar;
    }
}

public static int parseInt(String str) {
    boolean negative = str.charAt(0) == '-';
    byte[] bytes = str.getBytes(StandardCharsets.UTF_8);
    if (negative) {
        bytes[0] = zeroChar;
    }
    bytes = ensureSize(bytes, width);
    ByteVector vec = ByteVector.fromArray(ByteVector.SPECIES_128, bytes, 0);
    vec = vec.sub(zeroChar);
    IntVector ints = (IntVector) vec.castShape(IntVector.SPECIES_512, 0);
    ints = ints.mul(mul);
    return ints.reduceLanes(VectorOperators.ADD) * (negative ? -1 : 1);
}

public static byte[] ensureSize(byte[] arr, int per) {
    int mod = arr.length % per;
    if (mod == 0) {
        return arr;
    }
    int length = arr.length - (mod);
    length += per;
    byte[] newArr = new byte[length];
    System.arraycopy(arr, 0, newArr, per - mod, arr.length);
    System.arraycopy(filler, 0, newArr, 0, per - mod);
    return newArr;
}

public static byte[] ensureSize2(byte[] arr, int per) {
    int mod = arr.length % per;
    if (mod == 0) {
        return arr;
    }
    int length = arr.length - (mod);
    length += per;
    byte[] newArr = new byte[length];
    System.arraycopy(arr, 0, newArr, 0, arr.length);
    return newArr;
}

public static int indexOf(String s, char c) {
    byte[] b = s.getBytes(StandardCharsets.UTF_8);
    int width = ByteVector.SPECIES_MAX.length();
    byte bChar = (byte) c;
    b = ensureSize2(b, width);
    for (int i = 0; i < b.length; i += width) {
        ByteVector vec = ByteVector.fromArray(ByteVector.SPECIES_MAX, b, i);
        int pos = vec.compare(VectorOperators.EQ, bChar).firstTrue();
        if (pos != width) {
            return pos + i;
        }
    }
    return -1;
}

}

I fully expected my int parsing to be slower, since it won't ever be handling more than the vector size can hold (an int can never be more than 10 digits long).

By my bechmarks, parsing 123 as an int 10k times took 3081 microseconds for Integer.parseInt, and 80601 microseconds for my implementation. Searching for 'a' in a very long string ("____".repeat(4000) + "a" + "----".repeat(193)) took 7709 microseconds to String#indexOf's 7.

Why is it so unbelievably slow? I thought the entire point of SIMD is that it's faster than the scalar equivalents for tasks like these.

like image 470
Redempt Avatar asked Oct 29 '25 14:10

Redempt


2 Answers

You picked something SIMD is not great at (string->int), and something that JVMs are very good at optimizing out of loops. And you made an implementation with a bunch of extra copying work if the inputs aren't exact multiples of the vector width.


I'm assuming your times are totals (for 10k repeats each), not a per-call average.

7 us is impossibly fast for that.

"____".repeat(4000) is 16k chars (32k bytes) before the 'a', which I assume is what you're searching for. Even a well-tuned / unrolled wmemchr (aka indexOf) running at 2x 32-byte vectors per clock cycle, on a 4GHz CPU, would take 1250 us for 10k reps. (32000B / (64B/c) * 10000 reps / 4000 MHz), assuming that 32kB string stayed hot in 32KiB L1d cache.

I'd hope and expect a JVM would either call the native wmemchr or use something equally efficient for a commonly-used core library function like String#indexOf. For example, glibc's AVX2 memchr is pretty well-tuned with loop unrolling. (Java 8 and earlier strings are actually UTF-16, but C on Linux wchar_t is 4 bytes wide, unlike Windows, so JVMs would need their own implementation.)

Built-in String indexOf is also something the JIT "knows about". It's apparently able to hoist it out of loops when it can see that you're using the same string repeatedly as input. (But then what's it doing for the rest of those 7 us? I guess doing a not-quite-so-great memchr and then doing an empty 10k iteration loop at 1/clock could take about 7 microseconds, especially if your CPU isn't as fast as 4GHz.)

See Idiomatic way of performance evaluation? - if doubling the repeat-count to 20k doesn't double the time, your benchmark is broken and not measuring what you think it does.


But you say 7 us is the per-iteration time? That would be improbably slow, except for a non-optimized first pass. So probably a sign of faulty benchmarking methodolody, like lack of warm-up runs.

If IndexOf was checking one char at a time, 16k * 0.25 ns/char would be 4000 nanoseconds, or 4 microseconds on a 4GHz CPU. 7 us is in that ballpark of checking 1 char per cycle, which is pathetically slow on a modern x86. I think it's unlikely that mainstream JVMs would use such a slow implementation once they were done JIT optimizing.


Your manual SIMD indexOf is very unlikely to get optimized out of a loop. It makes a copy of the whole array every time, if the size isn't an exact multiple of the vector width!! (In ensureSize2). The normal technique is to fall back to scalar for the last size % width elements, which is obviously much better for large arrays. Or even better, do an unaligned load that ends at the end of the array (if the total size is >= vector width) for something where overlap with previous work is not a problem.

A decent memchr on modern x86 (using an algorithm like your indexOf without unrolling) should go at about 1 vector (16/32/64 bytes) per maybe 1.5 clock cycles, with data hot in L1d cache, without loop unrolling or anything. (Checking both the vector compare and the pointer bound as possible loop exit conditions takes extra asm instructions vs. a simple strlen, but see this answer for some microbenchmarks of a simple hand-written strlen that assumes aligned buffers). Probably your indexOf loops bottlenecks on front-end throughput on a CPU like Skylake, with its pipeline width of 4 uops/clock.

So let's guess that your implementation takes 1.5 cycles per 16 byte vector, if perhaps you're on a CPU without AVX2? You didn't say.

16kB / 16B = 1000 vectors. At 1 vector per 1.5 clocks, that's 1500 cycles. On a 3GHz machine, 1500 cycles takes 500 ns = 0.5 us per call, or 5000 us per 10k reps. But since 16194 bytes isn't a multiple of 16, you're also copying the whole thing every call, so that costs some more time, and could plausibly account for your 7709 us total time.


What SIMD is good for

for tasks like these.

No, "horizontal" stuff like ints.reduceLanes is something SIMD is generally slow at. And even with something like How to implement atoi using SIMD? using x86 pmaddwd to multiply and add pairs horizontally, it's still a lot of work.

Note that to make the elements wide enough to multiply by place-values without overflow, you have to unpack, which costs some shuffling. ints.reduceLanes takes about log2(elements) shuffle/add steps, and if you're starting with 512-bit AVX-512 vectors of int, the first 2 of those shuffles are lane-crossing, 3 cycle latency (https://agner.org/optimize/). (Or if your machine doesn't even have AVX2, then a 512-bit integer vector is actually 4x 128-bit vectors. And you had to do separate work to unpack each part. But at least the reduction will be cheap, just vertical adds until you get down to a single 128-bit vector.)

like image 101
Peter Cordes Avatar answered Oct 31 '25 05:10

Peter Cordes


This is an updated version of the post I originally put up on this subject on 29-Jan-2022. I'm very much obliged to @JimmyB and @rapaio for pointing out some major flaws there, and think I've addressed them now. Also, I'm in a position to compare Java 19 with Java 20.

I found this post because I thought I had hit something strange with the Vector perfomance for something that ostensibly it should be ideal for - multiplying two double arrays. As updated (after correcting errors) the main routine is this (for the Vector case):

  static private void doVector(int iteration, double[] input1, double[] input2, double[] output) {
    long start = System.nanoTime();
    for (int i = 0; i < SPECIES.loopBound(ARRAY_LENGTH); i += SPECIES.length()) {
      DoubleVector va = DoubleVector.fromArray(SPECIES, input1, i);
      DoubleVector vb = DoubleVector.fromArray(SPECIES, input2, i);
      va.mul(vb).intoArray(output, i);
    }
    long finish = System.nanoTime();
    System.out.println("  vector (intoArray) duration (ns)\t" + iteration + "\t" + (finish - start));
  }

and this (for the scalar case):

  static private void doScalar(int iteration, double[] input1, double[] input2, double[] output) {
    long start = System.nanoTime();
    for (int i = 0; i < ARRAY_LENGTH; ++i) {
      output[i] = input1[i] * input2[i];
    }
    long finish = System.nanoTime();
    System.out.println("  scalar duration (ns)\t" + iteration + "\t" + (finish - start));
  }

For testing I'm using an array length of 65536 (random numbers) and 1024 iterations. The species length comes out at 4 on my machine (CPU is Intel i7-7700HQ at 2.8 GHz).

The timings take a while to settle down. Firstly, the first 2 scalar, and 6 Vector iterations are considerably (approx 10 times) slower than the later ones, and secondly, the first 180 or so iterations are very erratic in timing. There are occasional spikes thereafter (presumably to do with other stuff going on on the machine/jvm)

If I compare the median times (using median rather than mean to avoid outliers, and ignoring the first 200 iterations to remove the startup effects), in Java 19 the Vector method is roughly 8% faster than the scalar method, and in Java 20 roughly 9% faster. Which all means, I'm still not seeing "four times faster". I wonder whether the optimizer is distributing the scalar calculations... (or maybe I'm doing something else wrong this time :-) )

like image 23
Tim V Avatar answered Oct 31 '25 05:10

Tim V