From 5873bce962cfae5007925d376c9e46b6ff08da9a Mon Sep 17 00:00:00 2001
From: Spencer Sherwin <s.sherwin@imperial.ac.uk>
Date: Sat, 27 Aug 2022 22:01:21 +0000
Subject: [PATCH] feature/avx512 float

---
 .gitlab-ci.yml                               |   8 +
 .gitlab-ci/build-and-test.sh                 |   2 +
 CHANGELOG.md                                 |   5 +
 cmake/NektarSIMD.cmake                       |  18 +-
 library/LibUtilities/SimdLib/avx2.hpp        |   8 +-
 library/LibUtilities/SimdLib/avx512.hpp      | 485 +++++++++++++++++--
 library/LibUtilities/SimdLib/tinysimd.hpp    |   2 +-
 library/UnitTests/SIMD/TestSimdLibSingle.cpp |  30 +-
 8 files changed, 496 insertions(+), 62 deletions(-)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index beab4fc007..f4dee32500 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -203,6 +203,14 @@ bullseye-full_py3-avx2-build-and-test:
   tags:
     - avx2
 
+bullseye-full_py3-avx512-build-and-test:
+  <<: *build-and-test-template
+  needs: ["bullseye-full_py3-build-env"]
+  variables:
+    BUILD_SIMD: avx512
+  tags:
+    - avx512
+
 buster-default-build-and-test:
   <<: *build-and-test-template
   needs: ["buster-default-build-env"]
diff --git a/.gitlab-ci/build-and-test.sh b/.gitlab-ci/build-and-test.sh
index f0281b540b..70ebb03811 100644
--- a/.gitlab-ci/build-and-test.sh
+++ b/.gitlab-ci/build-and-test.sh
@@ -26,6 +26,8 @@ elif [[ $BUILD_TYPE == "full" ]] || [[ $BUILD_TYPE == "full_py3" ]]; then
         -DNEKTAR_ERROR_ON_WARNINGS=OFF"
     if [[ $BUILD_SIMD == "avx2" ]]; then
         BUILD_OPTS="$BUILD_OPTS -DNEKTAR_ENABLE_SIMD_AVX2:BOOL=ON"
+    elif [[ $BUILD_SIMD == "avx512" ]]; then
+        BUILD_OPTS="$BUILD_OPTS -DNEKTAR_ENABLE_SIMD_AVX512:BOOL=ON"
     fi
 fi
 
diff --git a/CHANGELOG.md b/CHANGELOG.md
index f1557b4ddc..4e25db9539 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,6 +1,11 @@
 Changelog
 =========
 
+v5.3
+------
+**Library**
+- Added float and restore avx512 back-end for SimdLib (!1387)
+
 v5.2.0
 ------
 **Library**
diff --git a/cmake/NektarSIMD.cmake b/cmake/NektarSIMD.cmake
index bbaccef8ea..8ce7231fd1 100644
--- a/cmake/NektarSIMD.cmake
+++ b/cmake/NektarSIMD.cmake
@@ -15,16 +15,16 @@ ENDIF()
 IF(_SYSTEM_PROCESSOR STREQUAL "x86_64")
     # OPTION(NEKTAR_ENABLE_SIMD_SSE2 "Enable sse2 vector types" OFF)
     OPTION(NEKTAR_ENABLE_SIMD_AVX2 "Enable avx2 vector types" OFF)
-    # OPTION(NEKTAR_ENABLE_SIMD_AVX512 "Enable avx512 vector types" OFF)
+    OPTION(NEKTAR_ENABLE_SIMD_AVX512 "Enable avx512 vector types" OFF)
     MARK_AS_ADVANCED(FORCE NEKTAR_ENABLE_SIMD_SSE2 NEKTAR_ENABLE_SIMD_AVX2 NEKTAR_ENABLE_SIMD_AVX512)
-    # IF (NEKTAR_ENABLE_SIMD_AVX512)
-    #     MESSAGE(STATUS "Enabling avx512, you might need to clear CMAKE_CXX_FLAGS or add the appriopriate flags")
-    #     ADD_DEFINITIONS(-DNEKTAR_ENABLE_SIMD_AVX512)
-    #     SET(CMAKE_CXX_FLAGS "-mavx512f -mfma" CACHE STRING
-    #         "Flags used by the CXX compiler during all build types.")
-    #     SET(NEKTAR_ENABLE_SIMD_AVX2 "ON" FORCE)
-    #     SET(NEKTAR_ENABLE_SIMD_SSE2 "ON" FORCE)
-    # ENDIF()
+    IF (NEKTAR_ENABLE_SIMD_AVX512)
+        MESSAGE(STATUS "Enabling avx512, you might need to clear CMAKE_CXX_FLAGS or add the appriopriate flags")
+        ADD_DEFINITIONS(-DNEKTAR_ENABLE_SIMD_AVX512)
+        SET(CMAKE_CXX_FLAGS "-mavx512f -mfma" CACHE STRING
+            "Flags used by the CXX compiler during all build types.")
+        SET(NEKTAR_ENABLE_SIMD_AVX2 "ON" FORCE)
+        SET(NEKTAR_ENABLE_SIMD_SSE2 "ON" FORCE)
+    ENDIF()
     IF (NEKTAR_ENABLE_SIMD_AVX2)
         MESSAGE(STATUS "Enabling avx2, you might need to clear CMAKE_CXX_FLAGS or add the appriopriate flags")
         ADD_DEFINITIONS(-DNEKTAR_ENABLE_SIMD_AVX2)
diff --git a/library/LibUtilities/SimdLib/avx2.hpp b/library/LibUtilities/SimdLib/avx2.hpp
index 116b286a03..ad1204288d 100644
--- a/library/LibUtilities/SimdLib/avx2.hpp
+++ b/library/LibUtilities/SimdLib/avx2.hpp
@@ -63,8 +63,8 @@ template <typename scalarType, int width = 0> struct avx2
 #if defined(__AVX2__) && defined(NEKTAR_ENABLE_SIMD_AVX2)
 
 // forward declaration of concrete types
-template <typename T> struct avx2Int8;
 template <typename T> struct avx2Long4;
+template <typename T> struct avx2Int8;
 struct avx2Double4;
 struct avx2Float8;
 struct avx2Mask4;
@@ -539,7 +539,7 @@ inline avx2Double4 log(avx2Double4 in)
 }
 
 inline void load_interleave(
-    const double *in, size_t dataLen,
+    const double *in, std::uint32_t dataLen,
     std::vector<avx2Double4, allocator<avx2Double4>> &out)
 {
     alignas(avx2Double4::alignment)
@@ -574,8 +574,8 @@ inline void load_interleave(
 }
 
 inline void deinterleave_store(
-    const std::vector<avx2Double4, allocator<avx2Double4>> &in, size_t dataLen,
-    double *out)
+    const std::vector<avx2Double4, allocator<avx2Double4>> &in,
+    std::uint32_t dataLen, double *out)
 {
     alignas(avx2Double4::alignment)
         size_t tmp[avx2Double4::width] = {0, dataLen, 2 * dataLen, 3 * dataLen};
diff --git a/library/LibUtilities/SimdLib/avx512.hpp b/library/LibUtilities/SimdLib/avx512.hpp
index d4cb052862..398ea6192d 100644
--- a/library/LibUtilities/SimdLib/avx512.hpp
+++ b/library/LibUtilities/SimdLib/avx512.hpp
@@ -52,7 +52,7 @@ namespace tinysimd
 namespace abi
 {
 
-template <typename scalarType> struct avx512
+template <typename scalarType, int width = 0> struct avx512
 {
     using type = void;
 };
@@ -63,8 +63,11 @@ template <typename scalarType> struct avx512
 
 // forward declaration of concrete types
 template <typename T> struct avx512Long8;
+template <typename T> struct avx512Int16;
 struct avx512Double8;
-struct avx512Mask;
+struct avx512Float16;
+struct avx512Mask8;
+struct avx512Mask16;
 
 namespace abi
 {
@@ -74,6 +77,10 @@ template <> struct avx512<double>
 {
     using type = avx512Double8;
 };
+template <> struct avx512<float>
+{
+    using type = avx512Float16;
+};
 template <> struct avx512<std::int64_t>
 {
     using type = avx512Long8<std::int64_t>;
@@ -82,14 +89,140 @@ template <> struct avx512<std::uint64_t>
 {
     using type = avx512Long8<std::uint64_t>;
 };
-template <> struct avx512<bool>
+template <> struct avx512<std::int32_t>
+{
+    using type = avx512Int16<std::int32_t>;
+};
+template <> struct avx512<std::uint32_t>
+{
+    using type = avx512Int16<std::uint32_t>;
+};
+template <> struct avx512<bool, 8>
+{
+    using type = avx512Mask8;
+};
+template <> struct avx512<bool, 16>
 {
-    using type = avx512Mask;
+    using type = avx512Mask16;
 };
 
 } // namespace abi
 
-// concrete types, could add enable if to allow only unsigned long and long...
+// concrete types
+
+// could add enable if to allow only unsigned long and long...
+template <typename T> struct avx512Int16
+{
+    static_assert(std::is_integral<T>::value && sizeof(T) == 4,
+                  "4 bytes Integral required.");
+
+    static constexpr unsigned int width     = 16;
+    static constexpr unsigned int alignment = 64;
+
+    using scalarType  = T;
+    using vectorType  = __m512i;
+    using scalarArray = scalarType[width];
+
+    // storage
+    vectorType _data;
+
+    // ctors
+    inline avx512Int16()                       = default;
+    inline avx512Int16(const avx512Int16 &rhs) = default;
+    inline avx512Int16(const vectorType &rhs) : _data(rhs)
+    {
+    }
+    inline avx512Int16(const scalarType rhs)
+    {
+        _data = _mm512_set1_epi32(rhs);
+    }
+    explicit inline avx512Int16(scalarArray &rhs)
+    {
+        _data = _mm512_load_epi32(rhs);
+    }
+
+    // store
+    inline void store(scalarType *p) const
+    {
+        _mm512_store_epi32(p, _data);
+    }
+
+    template <class flag,
+              typename std::enable_if<is_requiring_alignment<flag>::value &&
+                                          !is_streaming<flag>::value,
+                                      bool>::type = 0>
+    inline void store(scalarType *p, flag) const
+    {
+        _mm512_store_epi32(p, _data);
+    }
+
+    template <class flag,
+              typename std::enable_if<!is_requiring_alignment<flag>::value,
+                                      bool>::type = 0>
+    inline void store(scalarType *p, flag) const
+    {
+        _mm512_storeu_epi32(p, _data);
+    }
+
+    inline void load(const scalarType *p)
+    {
+        _data = _mm512_load_epi32(p);
+    }
+
+    template <class flag,
+              typename std::enable_if<is_requiring_alignment<flag>::value &&
+                                          !is_streaming<flag>::value,
+                                      bool>::type = 0>
+    inline void load(const scalarType *p, flag)
+    {
+        _data = _mm512_load_epi32(p);
+    }
+
+    template <class flag,
+              typename std::enable_if<!is_requiring_alignment<flag>::value,
+                                      bool>::type = 0>
+    inline void load(const scalarType *p, flag)
+    {
+        // even though the intel intrisic manual lists
+        // __m512i _mm512_loadu_epi64 (void const* mem_addr)
+        // it is not implemented in some compilers (gcc)
+        // since this is a bitwise load with no extension
+        // the following instruction is equivalent
+        _data = _mm512_loadu_si512(p);
+    }
+
+    inline void broadcast(const scalarType rhs)
+    {
+        _data = _mm512_set1_epi32(rhs);
+    }
+
+    // subscript
+    // subscript operators are convienient but expensive
+    // should not be used in optimized kernels
+    inline scalarType operator[](size_t i) const
+    {
+        alignas(alignment) scalarArray tmp;
+        store(tmp, is_aligned);
+        return tmp[i];
+    }
+};
+
+template <typename T>
+inline avx512Int16<T> operator+(avx512Int16<T> lhs, avx512Int16<T> rhs)
+{
+    return _mm512_add_epi32(lhs._data, rhs._data);
+}
+
+template <
+    typename T, typename U,
+    typename = typename std::enable_if<std::is_arithmetic<U>::value>::type>
+inline avx512Int16<T> operator+(avx512Int16<T> lhs, U rhs)
+{
+    return _mm512_add_epi32(lhs._data, _mm512_set1_epi32(rhs));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
 template <typename T> struct avx512Long8
 {
     static_assert(std::is_integral<T>::value && sizeof(T) == 8,
@@ -207,9 +340,10 @@ struct avx512Double8
     static constexpr unsigned int width     = 8;
     static constexpr unsigned int alignment = 64;
 
-    using scalarType  = double;
-    using vectorType  = __m512d;
-    using scalarArray = scalarType[width];
+    using scalarType      = double;
+    using scalarIndexType = std::uint64_t;
+    using vectorType      = __m512d;
+    using scalarArray     = scalarType[width];
 
     // storage
     vectorType _data;
@@ -401,15 +535,16 @@ inline avx512Double8 log(avx512Double8 in)
 }
 
 inline void load_interleave(
-    const double *in, size_t dataLen,
+    const double *in, std::uint32_t dataLen,
     std::vector<avx512Double8, allocator<avx512Double8>> &out)
 {
 
-    alignas(avx512Double8::alignment) size_t tmp[avx512Double8::width] = {
-        0,           dataLen,     2 * dataLen, 3 * dataLen,
-        4 * dataLen, 5 * dataLen, 6 * dataLen, 7 * dataLen};
+    alignas(avx512Double8::alignment)
+        avx512Double8::scalarIndexType tmp[avx512Double8::width] = {
+            0,           dataLen,     2 * dataLen, 3 * dataLen,
+            4 * dataLen, 5 * dataLen, 6 * dataLen, 7 * dataLen};
 
-    using index_t = avx512Long8<size_t>;
+    using index_t = avx512Long8<avx512Double8::scalarIndexType>;
     index_t index0(tmp);
     index_t index1 = index0 + 1;
     index_t index2 = index0 + 2;
@@ -440,39 +575,278 @@ inline void load_interleave(
 
 inline void deinterleave_store(
     const std::vector<avx512Double8, allocator<avx512Double8>> &in,
-    size_t dataLen, double *out)
+    std::uint32_t dataLen, double *out)
 {
     // size_t nBlocks = dataLen / 4;
 
-    alignas(avx512Double8::alignment) size_t tmp[avx512Double8::width] = {
-        0,           dataLen,     2 * dataLen, 3 * dataLen,
-        4 * dataLen, 5 * dataLen, 6 * dataLen, 7 * dataLen};
-    using index_t = avx512Long8<size_t>;
+    alignas(avx512Double8::alignment)
+        avx512Double8::scalarIndexType tmp[avx512Double8::width] = {
+            0,           dataLen,     2 * dataLen, 3 * dataLen,
+            4 * dataLen, 5 * dataLen, 6 * dataLen, 7 * dataLen};
+    using index_t = avx512Long8<avx512Double8::scalarIndexType>;
     index_t index0(tmp);
-    // index_t index1 = index0 + 1;
-    // index_t index2 = index0 + 2;
-    // index_t index3 = index0 + 3;
+    for (size_t i = 0; i < dataLen; ++i)
+    {
+        in[i].scatter(out, index0);
+        index0 = index0 + 1;
+    }
+}
 
-    // // 4x unrolled loop
-    // for (size_t i = 0; i < nBlocks; ++i)
-    // {
-    //     in[i].scatter(out, index0);
-    //     in[i+1].scatter(out, index1);
-    //     in[i+2].scatter(out, index2);
-    //     in[i+3].scatter(out, index3);
-    //     index0 = index0 + 4;
-    //     index1 = index1 + 4;
-    //     index2 = index2 + 4;
-    //     index3 = index3 + 4;
-    // }
+////////////////////////////////////////////////////////////////////////////////
 
-    // // spillover loop
-    // for (size_t i = 4 * nBlocks; i < dataLen; ++i)
-    // {
-    //     in[i].scatter(out, index0);
-    //     index0 = index0 + 1;
-    // }
+struct avx512Float16
+{
+    static constexpr unsigned int width     = 16;
+    static constexpr unsigned int alignment = 64;
+
+    using scalarType      = float;
+    using scalarIndexType = std::uint32_t;
+    using vectorType      = __m512;
+    using scalarArray     = scalarType[width];
+
+    // storage
+    vectorType _data;
+
+    // ctors
+    inline avx512Float16()                         = default;
+    inline avx512Float16(const avx512Float16 &rhs) = default;
+    inline avx512Float16(const vectorType &rhs) : _data(rhs)
+    {
+    }
+    inline avx512Float16(const scalarType rhs)
+    {
+        _data = _mm512_set1_ps(rhs);
+    }
+
+    // store
+    inline void store(scalarType *p) const
+    {
+        _mm512_store_ps(p, _data);
+    }
+
+    template <class flag,
+              typename std::enable_if<is_requiring_alignment<flag>::value &&
+                                          !is_streaming<flag>::value,
+                                      bool>::type = 0>
+    inline void store(scalarType *p, flag) const
+    {
+        _mm512_store_ps(p, _data);
+    }
+
+    template <class flag,
+              typename std::enable_if<!is_requiring_alignment<flag>::value,
+                                      bool>::type = 0>
+    inline void store(scalarType *p, flag) const
+    {
+        _mm512_storeu_ps(p, _data);
+    }
+
+    template <class flag, typename std::enable_if<is_streaming<flag>::value,
+                                                  bool>::type = 0>
+    inline void store(scalarType *p, flag) const
+    {
+        _mm512_stream_ps(p, _data);
+    }
+
+    // load packed
+    inline void load(const scalarType *p)
+    {
+        _data = _mm512_load_ps(p);
+    }
+
+    template <class flag,
+              typename std::enable_if<is_requiring_alignment<flag>::value,
+                                      bool>::type = 0>
+    inline void load(const scalarType *p, flag)
+    {
+        _data = _mm512_load_ps(p);
+    }
+
+    template <class flag,
+              typename std::enable_if<!is_requiring_alignment<flag>::value,
+                                      bool>::type = 0>
+    inline void load(const scalarType *p, flag)
+    {
+        _data = _mm512_loadu_ps(p);
+    }
+
+    // broadcast
+    inline void broadcast(const scalarType rhs)
+    {
+        _data = _mm512_set1_ps(rhs);
+    }
 
+    // gather/scatter
+    template <typename T>
+    inline void gather(scalarType const *p, const avx512Int16<T> &indices)
+    {
+        _data = _mm512_i32gather_ps(indices._data, p, sizeof(scalarType));
+    }
+
+    template <typename T>
+    inline void scatter(scalarType *out, const avx512Int16<T> &indices) const
+    {
+        _mm512_i32scatter_ps(out, indices._data, _data, sizeof(scalarType));
+    }
+
+    // fma
+    // this = this + a * b
+    inline void fma(const avx512Float16 &a, const avx512Float16 &b)
+    {
+        _data = _mm512_fmadd_ps(a._data, b._data, _data);
+    }
+
+    // subscript
+    // subscript operators are convienient but expensive
+    // should not be used in optimized kernels
+    inline scalarType operator[](size_t i) const
+    {
+        alignas(alignment) scalarArray tmp;
+        store(tmp, is_aligned);
+        return tmp[i];
+    }
+
+    inline scalarType &operator[](size_t i)
+    {
+        scalarType *tmp = reinterpret_cast<scalarType *>(&_data);
+        return tmp[i];
+    }
+
+    // unary ops
+    inline void operator+=(avx512Float16 rhs)
+    {
+        _data = _mm512_add_ps(_data, rhs._data);
+    }
+
+    inline void operator-=(avx512Float16 rhs)
+    {
+        _data = _mm512_sub_ps(_data, rhs._data);
+    }
+
+    inline void operator*=(avx512Float16 rhs)
+    {
+        _data = _mm512_mul_ps(_data, rhs._data);
+    }
+
+    inline void operator/=(avx512Float16 rhs)
+    {
+        _data = _mm512_div_ps(_data, rhs._data);
+    }
+};
+
+inline avx512Float16 operator+(avx512Float16 lhs, avx512Float16 rhs)
+{
+    return _mm512_add_ps(lhs._data, rhs._data);
+}
+
+inline avx512Float16 operator-(avx512Float16 lhs, avx512Float16 rhs)
+{
+    return _mm512_sub_ps(lhs._data, rhs._data);
+}
+
+inline avx512Float16 operator*(avx512Float16 lhs, avx512Float16 rhs)
+{
+    return _mm512_mul_ps(lhs._data, rhs._data);
+}
+
+inline avx512Float16 operator/(avx512Float16 lhs, avx512Float16 rhs)
+{
+    return _mm512_div_ps(lhs._data, rhs._data);
+}
+
+inline avx512Float16 sqrt(avx512Float16 in)
+{
+    return _mm512_sqrt_ps(in._data);
+}
+
+inline avx512Float16 abs(avx512Float16 in)
+{
+    return _mm512_abs_ps(in._data);
+}
+
+inline avx512Float16 log(avx512Float16 in)
+{
+#if defined(TINYSIMD_HAS_SVML)
+    return _mm512_log_ps(in._data);
+#else
+    // there is no avx512 log intrinsic
+    // this is a dreadful implementation and is simply a stop gap measure
+    alignas(avx512Float16::alignment) avx512Float16::scalarArray tmp;
+    in.store(tmp);
+    tmp[0]  = std::log(tmp[0]);
+    tmp[1]  = std::log(tmp[1]);
+    tmp[2]  = std::log(tmp[2]);
+    tmp[3]  = std::log(tmp[3]);
+    tmp[4]  = std::log(tmp[4]);
+    tmp[5]  = std::log(tmp[5]);
+    tmp[6]  = std::log(tmp[6]);
+    tmp[7]  = std::log(tmp[7]);
+    tmp[8]  = std::log(tmp[8]);
+    tmp[9]  = std::log(tmp[9]);
+    tmp[10] = std::log(tmp[10]);
+    tmp[11] = std::log(tmp[11]);
+    tmp[12] = std::log(tmp[12]);
+    tmp[13] = std::log(tmp[13]);
+    tmp[14] = std::log(tmp[14]);
+    tmp[15] = std::log(tmp[15]);
+    avx512Float16 ret;
+    ret.load(tmp);
+    return ret;
+#endif
+}
+
+inline void load_interleave(
+    const float *in, std::uint32_t dataLen,
+    std::vector<avx512Float16, allocator<avx512Float16>> &out)
+{
+
+    alignas(avx512Float16::alignment)
+        avx512Float16::scalarIndexType tmp[avx512Float16::width] = {
+            0,           dataLen,     2 * dataLen, 3 * dataLen,
+            4 * dataLen, 5 * dataLen, 6 * dataLen, 7 * dataLen};
+
+    using index_t = avx512Int16<avx512Float16::scalarIndexType>;
+    index_t index0(tmp);
+    index_t index1 = index0 + 1;
+    index_t index2 = index0 + 2;
+    index_t index3 = index0 + 3;
+
+    // 4x unrolled loop
+    constexpr uint16_t unrl = 4;
+    size_t nBlocks          = dataLen / unrl;
+    for (size_t i = 0; i < nBlocks; ++i)
+    {
+        out[unrl * i + 0].gather(in, index0);
+        out[unrl * i + 1].gather(in, index1);
+        out[unrl * i + 2].gather(in, index2);
+        out[unrl * i + 3].gather(in, index3);
+        index0 = index0 + unrl;
+        index1 = index1 + unrl;
+        index2 = index2 + unrl;
+        index3 = index3 + unrl;
+    }
+
+    // spillover loop
+    for (size_t i = unrl * nBlocks; i < dataLen; ++i)
+    {
+        out[i].gather(in, index0);
+        index0 = index0 + 1;
+    }
+}
+
+inline void deinterleave_store(
+    const std::vector<avx512Float16, allocator<avx512Float16>> &in,
+    std::uint32_t dataLen, float *out)
+{
+    // size_t nBlocks = dataLen / 4;
+
+    alignas(avx512Float16::alignment)
+        avx512Float16::scalarIndexType tmp[avx512Float16::width] = {
+            0,           dataLen,     2 * dataLen, 3 * dataLen,
+            4 * dataLen, 5 * dataLen, 6 * dataLen, 7 * dataLen};
+    using index_t = avx512Int16<avx512Float16::scalarIndexType>;
+
+    index_t index0(tmp);
     for (size_t i = 0; i < dataLen; ++i)
     {
         in[i].scatter(out, index0);
@@ -489,7 +863,7 @@ inline void deinterleave_store(
 //
 // VERY LIMITED SUPPORT...just enough to make cubic eos work...
 //
-struct avx512Mask : avx512Long8<std::uint64_t>
+struct avx512Mask8 : avx512Long8<std::uint64_t>
 {
     // bring in ctors
     using avx512Long8::avx512Long8;
@@ -498,20 +872,43 @@ struct avx512Mask : avx512Long8<std::uint64_t>
     static constexpr scalarType false_v = 0;
 };
 
-inline avx512Mask operator>(avx512Double8 lhs, avx512Double8 rhs)
+inline avx512Mask8 operator>(avx512Double8 lhs, avx512Double8 rhs)
 {
     __mmask8 mask = _mm512_cmp_pd_mask(lhs._data, rhs._data, _CMP_GT_OQ);
-    return _mm512_maskz_set1_epi64(mask, avx512Mask::true_v);
+    return _mm512_maskz_set1_epi64(mask, avx512Mask8::true_v);
 }
 
-inline bool operator&&(avx512Mask lhs, bool rhs)
+inline bool operator&&(avx512Mask8 lhs, bool rhs)
 {
-    __m512i val_true = _mm512_set1_epi64(avx512Mask::true_v);
+    __m512i val_true = _mm512_set1_epi64(avx512Mask8::true_v);
     __mmask8 mask    = _mm512_test_epi64_mask(lhs._data, val_true);
     unsigned int tmp = _cvtmask16_u32(mask);
     return tmp && rhs;
 }
 
+struct avx512Mask16 : avx512Int16<std::uint32_t>
+{
+    // bring in ctors
+    using avx512Int16::avx512Int16;
+
+    static constexpr scalarType true_v  = -1;
+    static constexpr scalarType false_v = 0;
+};
+
+inline avx512Mask16 operator>(avx512Float16 lhs, avx512Float16 rhs)
+{
+    __mmask16 mask = _mm512_cmp_ps_mask(lhs._data, rhs._data, _CMP_GT_OQ);
+    return _mm512_maskz_set1_epi32(mask, avx512Mask16::true_v);
+}
+
+inline bool operator&&(avx512Mask16 lhs, bool rhs)
+{
+    __m512i val_true = _mm512_set1_epi32(avx512Mask16::true_v);
+    __mmask16 mask   = _mm512_test_epi32_mask(lhs._data, val_true);
+    unsigned int tmp = _cvtmask16_u32(mask);
+    return tmp && rhs;
+}
+
 #endif // defined(__avx512__)
 
 } // namespace tinysimd
diff --git a/library/LibUtilities/SimdLib/tinysimd.hpp b/library/LibUtilities/SimdLib/tinysimd.hpp
index 1629825868..aef55456f3 100644
--- a/library/LibUtilities/SimdLib/tinysimd.hpp
+++ b/library/LibUtilities/SimdLib/tinysimd.hpp
@@ -65,7 +65,7 @@ namespace abi
 template <typename T, int width> struct default_abi
 {
     using type = typename first_not_void_of<
-        typename sve<T>::type, typename avx512<T>::type,
+        typename sve<T>::type, typename avx512<T, width>::type,
         typename avx2<T, width>::type, typename sse2<T>::type,
         typename scalar<T>::type>::type;
 
diff --git a/library/UnitTests/SIMD/TestSimdLibSingle.cpp b/library/UnitTests/SIMD/TestSimdLibSingle.cpp
index 758b09a2d9..83ef995951 100644
--- a/library/UnitTests/SIMD/TestSimdLibSingle.cpp
+++ b/library/UnitTests/SIMD/TestSimdLibSingle.cpp
@@ -118,8 +118,8 @@ BOOST_AUTO_TEST_CASE(SimdLibSingle_width_alignment)
     // std::int32_t aka (usually) int (avx2int8)
     width     = simd<std::int32_t>::width;
     alignment = simd<std::int32_t>::alignment;
-    BOOST_CHECK_EQUAL(width, 8);
-    BOOST_CHECK_EQUAL(alignment, 32);
+    BOOST_CHECK_EQUAL(width, NUM_LANES_32BITS);
+    BOOST_CHECK_EQUAL(alignment, 64);
     // float
     width     = simd<float>::width;
     alignment = simd<float>::alignment;
@@ -344,12 +344,23 @@ BOOST_AUTO_TEST_CASE(SimdLibFloat_gather32)
         aindex[6] = 20;
         aindex[7] = 23;
     }
+    if (vec_t::width > 8)
+    {
+        aindex[8]  = 24;
+        aindex[9]  = 28;
+        aindex[10] = 33;
+        aindex[11] = 40;
+        aindex[12] = 41;
+        aindex[13] = 45;
+        aindex[14] = 60;
+        aindex[15] = 61;
+    }
 
     // load index
     aindexvec.load(aindex.data(), is_not_aligned);
 
     // create and fill scalar array
-    constexpr size_t scalarArraySize = 32;
+    constexpr size_t scalarArraySize = 64;
     std::array<float, scalarArraySize> ascalararr;
     for (size_t i = 0; i < scalarArraySize; ++i)
     {
@@ -390,12 +401,23 @@ BOOST_AUTO_TEST_CASE(SimdLibFloat_scatter32)
         aindex[6] = 20;
         aindex[7] = 30;
     }
+    if (vec_t::width > 8)
+    {
+        aindex[8]  = 31;
+        aindex[9]  = 32;
+        aindex[10] = 35;
+        aindex[11] = 40;
+        aindex[12] = 41;
+        aindex[13] = 45;
+        aindex[14] = 60;
+        aindex[15] = 61;
+    }
 
     // load index
     aindexvec.load(aindex.data(), is_not_aligned);
 
     // create scalar array
-    constexpr size_t scalarArraySize = 32;
+    constexpr size_t scalarArraySize = 64;
     std::array<float, scalarArraySize> ascalararr;
 
     // fill vector
-- 
GitLab