Learnings with the Vector API
Viswanathan, Sandhya
sandhya.viswanathan at intel.com
Tue Feb 2 00:48:53 UTC 2021
Hi Ludovic,
Thanks a lot for these observations.
Regarding the floating point multiply and subtract, I think we can generate "fms" today for both vector api and Math.fma.
The backend can recognize the IR patterns for va.fma(vb, vc.neg()), Math.fma(a, b, -c) as in the following patch:
http://cr.openjdk.java.net/~sviswanathan/vectorIntrinsics/Fms/webrev.00/
Best Regards,
Sandhya
-----Original Message-----
From: panama-dev <panama-dev-retn at openjdk.java.net> On Behalf Of Ludovic Henry
Sent: Wednesday, January 27, 2021 10:13 PM
To: panama-dev at openjdk.java.net
Subject: Learnings with the Vector API
Hello,
As part of my exploration of the Vector API, I've run into the following issues. I'm then turning to you to figure out what eventual limitation I'm hitting in Hotspot and that can be trivially alleviated, and whether it is worth in your opinion to implement missing APIs and optimizations.
For context, I've been using two BLAS (Basic Linear Algebra Subprograms) implementations: one using the Vector API [1], and one using plain Java 8/11 APIs and code [2]. You'll note that the code is laid out very similarly, both to make it easier to transpose the algorithm from the Java 8/11 implementation to the Vector API one by hand, but also in the hope to present enough information to the JIT that it can easily autovectorize the Java 8/11 implementation.
You can also find the appropriate JMH benchmarks at [3].
# `a * b + c` isn't optimized to `fma(a, b, c)` on supported hardware
Even though major platforms support FMA instructions (x86 and ARM at least), there is currently no optimization in Hotspot to automatically transform `a * b + c` into `fma(a, b, c)`. Hotspot _is_ aware of FMA as there is the UseFMA global check, and `Math.fma` is intrinsified in the correct instruction. But there seems to be no "auto-fmaization".
This code pattern is particularely common in Linear Algebra workloads (as you can attest by the number of `Math.fma` calls in [2] even though there is still room for more). And using a single FMA instruction over the two MUL+ADD instructions brings a 15-20% speedup on the specific workloads I've been looking at.
The only significant behaviour change I can observe is the increase in precision since you don't compound floating precision losses on two instructions (MUL+ADD) because you only do one instruction (FMA).
# No `Math.dot` operation
The dot product is an extremely common operation in Linear Algebra. It is, for example, the most time consuming operations in matrix-matrix multiplications which are at the core of current ML and AI workloads.
In both the Vector API and the Java 8/11 implementations, I've laid out near optimal unrolling of the different loops for the matrix-matrix multiply case, but I've been hitting limitations in Hotspot register allocator leading to uncessecary spilling. Having a dedicated `Math.dot` method (or equivalent like a LinearAlgebra.dot) taking two "matrices of arbitrary sizes" would allow to intrinsify this very specific operation. (By matrices of arbitrary sizes, the intrinsic should be optimized for matrices of size 3xk for A, and kx4 for B, with C += A * B.)
The maximum speed I'm observing on my machine with the Vector API implementation is in the order of 15 Gflops/seconds, while the OpenBLAS implementation (which uses handwritten assembly) runs in the order of 25 Gflops/seconds. The main difference between the code generated by Hotspot in both cases is the size of the inner-loop submatrix (3x3 for the Vector API implementation, 3x4 for OpenBLAS), and the unrolling of the inner-loop (none for the Vector API implementation, and more than 4 for OpenBLAS). Both the augmentation of size in the inner-loop submatrix and the unrolling alleviates the current bottleneck at the microarchitecture level which is the high ratio of scalar instructions over vector instructions.
An intrinsic would allow to handroll a hardware-optimized dot-product, leading to performance equivalent to native libraries, and allowing any Java application doing ML and AI to accelerate some of their core operations in pure Java.
# Uncessary spilling in register allocator
As alluded to in the previous point, I'm unable to reach peak performance when doing matrices dot-product because Hotspot generates uncessary register spilling.
Let's look at the specific piece of code:
```
DoubleVector vsum00 = DoubleVector.zero(DMAX); DoubleVector vsum01 = DoubleVector.zero(DMAX); DoubleVector vsum02 = DoubleVector.zero(DMAX); DoubleVector vsum03 = DoubleVector.zero(DMAX); DoubleVector vsum10 = DoubleVector.zero(DMAX); DoubleVector vsum11 = DoubleVector.zero(DMAX); DoubleVector vsum12 = DoubleVector.zero(DMAX); DoubleVector vsum13 = DoubleVector.zero(DMAX); DoubleVector vsum20 = DoubleVector.zero(DMAX); DoubleVector vsum21 = DoubleVector.zero(DMAX); DoubleVector vsum22 = DoubleVector.zero(DMAX); DoubleVector vsum23 = DoubleVector.zero(DMAX); for (; i < loopBound(ie, Ti * DMAX.length()); i += Ti * DMAX.length()) {
DoubleVector va00 = DoubleVector.fromArray(DMAX, a, offseta + (i + 0 * DMAX.length()) + (row + 0) * lda);
DoubleVector va01 = DoubleVector.fromArray(DMAX, a, offseta + (i + 0 * DMAX.length()) + (row + 1) * lda);
DoubleVector va02 = DoubleVector.fromArray(DMAX, a, offseta + (i + 0 * DMAX.length()) + (row + 2) * lda);
DoubleVector vb00 = DoubleVector.fromArray(DMAX, b, offsetb + (i + 0 * DMAX.length()) + (col + 0) * ldb);
vsum00 = va00.fma(vb00, vsum00);
vsum10 = va01.fma(vb00, vsum10);
vsum20 = va02.fma(vb00, vsum20);
DoubleVector vb01 = DoubleVector.fromArray(DMAX, b, offsetb + (i + 0 * DMAX.length()) + (col + 1) * ldb);
vsum01 = va00.fma(vb01, vsum01);
vsum11 = va01.fma(vb01, vsum11);
vsum21 = va02.fma(vb01, vsum21);
DoubleVector vb02 = DoubleVector.fromArray(DMAX, b, offsetb + (i + 0 * DMAX.length()) + (col + 2) * ldb);
vsum02 = va00.fma(vb02, vsum02);
vsum12 = va01.fma(vb02, vsum12);
vsum22 = va02.fma(vb02, vsum22);
DoubleVector vb03 = DoubleVector.fromArray(DMAX, b, offsetb + (i + 0 * DMAX.length()) + (col + 3) * ldb);
vsum03 = va00.fma(vb03, vsum03);
vsum13 = va01.fma(vb03, vsum13);
vsum23 = va02.fma(vb03, vsum23);
}
```
There is theoretically no need to spill any of the loaded vectors since the number of live register at any of the vector operations is at a maximum of 16 (12 for vsum, 3 for va, and 1 for vb). However, Hotspot does generate spilling for some vector registers, leading to suboptimal performance (even though it's in L1, it's still slowing down computation due to the larger number of instructions and the latency cost of L1).
To ensure optimal performance, I've fallen back to a 3x3 inner submatrix (vsum[0-2][0-2]) instead of the above 3x4 inner submatrix (vsum[0-2][0-3]).
Having a more conservative spilling would allow to both use the higher-performance 3x4 submatrix in this case, but also to unroll the loop more aggressively, reducing the overhead of "loop management" (incrementing and checking the indexes).
# Auto-vectorization not kicking in on trivial cases
Even though the algorithms laid out in the Java 8/11 implementation are trivially reused and manually-vectorized in the Vector API implementation, Hotspot isn't able to auto-vectorize the Java 8/11 implementation.
I haven't figure out the underlying reason why Hotspot isn't able to (lack of information? Corner case?) but I thought useful to raise it here to get your take on the matter.
# No `Math.fms` operation
Similarly to `Math.dot`, FMS (Fused Multiply-Substract, `a * b - c`) isn't trivially available from Java. It isn't as commonly used as FMA but still has use-cases in some Linear Algebra operations. There is also a corresponding instruction on major architectures (I've checked x86 and ARM).
It would be technically trivial to add support for it, but would require an API change to add `Math.fms`.
Thank you and I'm very happy to answer any further question,
Ludovic
[1] https://github.com/luhenry/netlib/blob/master/blas/src/main/java/dev/ludovic/netlib/blas/VectorizedBLAS.java
[2] https://github.com/luhenry/netlib/blob/master/blas/src/main/java/dev/ludovic/netlib/blas/JavaBLAS.java
[3] https://github.com/luhenry/netlib/tree/master/benchmarks
More information about the panama-dev
mailing list