From d7bd32ce3cc424dd8c1574a061a87a14feff90d4 Mon Sep 17 00:00:00 2001
From: Jacques Xing <jacques.xing@kcl.ac.uk>
Date: Tue, 30 Jan 2024 06:38:49 +0000
Subject: [PATCH] Tidy-up CUDA implementation of IProductWRTBase and add CUDA
 kernels with additional parallelism

---
 Operators/AssmbScatr/AssmbScatrCUDA.hpp       |    6 +-
 .../AssmbScatr/AssmbScatrCUDAKernels.cuh      |   96 +-
 Operators/AssmbScatr/AssmbScatrStdMat.hpp     |    4 +-
 Operators/DirBndCond/DirBndCondCUDA.hpp       |    8 +-
 .../DirBndCond/DirBndCondCUDAKernels.cuh      |   48 +-
 Operators/DirBndCond/DirBndCondStdMat.hpp     |    4 +-
 Operators/Helmholtz/HelmholtzCUDAKernels.cuh  |   27 +-
 .../IProductWRTBase/IProductWRTBaseCUDA.cu    |    5 +-
 .../IProductWRTBase/IProductWRTBaseCUDA.hpp   |  234 +-
 .../IProductWRTBaseCUDAKernels.cuh            | 2233 +++++++++++++----
 .../IProductWRTBaseMatFree.hpp                |    5 +-
 .../IProductWRTBase/IProductWRTBaseStdMat.hpp |    5 +-
 .../IProductWRTBase/IProductWRTBaseSumFac.hpp |    6 +-
 Operators/NeuBndCond/NeuBndCondCUDA.hpp       |    8 +-
 .../NeuBndCond/NeuBndCondCUDAKernels.cuh      |   38 +-
 Operators/NeuBndCond/NeuBndCondStdMat.hpp     |    4 +-
 16 files changed, 2115 insertions(+), 616 deletions(-)

diff --git a/Operators/AssmbScatr/AssmbScatrCUDA.hpp b/Operators/AssmbScatr/AssmbScatrCUDA.hpp
index 944e178d..5f97ae4e 100644
--- a/Operators/AssmbScatr/AssmbScatrCUDA.hpp
+++ b/Operators/AssmbScatr/AssmbScatrCUDA.hpp
@@ -1,13 +1,13 @@
 #pragma once
 
+#include <MultiRegions/AssemblyMap/AssemblyMapCG.h>
+#include <MultiRegions/ContField.h>
+
 #include "MemoryRegionCUDA.hpp"
 #include "Operators/AssmbScatr/AssmbScatrCUDAKernels.cuh"
 #include "Operators/OperatorAssmbScatr.hpp"
 #include "Operators/OperatorHelper.cuh"
 
-#include <MultiRegions/AssemblyMap/AssemblyMapCG.h>
-#include <MultiRegions/ContField.h>
-
 using namespace Nektar;
 using namespace Nektar::MultiRegions;
 using namespace Nektar::Operators;
diff --git a/Operators/AssmbScatr/AssmbScatrCUDAKernels.cuh b/Operators/AssmbScatr/AssmbScatrCUDAKernels.cuh
index 5112cf96..a510cce8 100644
--- a/Operators/AssmbScatr/AssmbScatrCUDAKernels.cuh
+++ b/Operators/AssmbScatr/AssmbScatrCUDAKernels.cuh
@@ -1,19 +1,24 @@
+#pragma once
+
 namespace Nektar::Operators::detail
 {
 
 template <typename TData>
-__global__ void AssembleKernel(const size_t ncoeff, const size_t nelmt,
-                               const size_t offset, const int *assmbptr,
-                               const TData *signptr, const TData *inptr,
-                               TData *outptr)
+__global__ void AssembleKernel(const unsigned int ncoeff,
+                               const unsigned int nelmt,
+                               const unsigned int offset,
+                               const int *__restrict assmbptr,
+                               const TData *__restrict signptr,
+                               const TData *__restrict inptr,
+                               TData *__restrict outptr)
 {
-    size_t e = blockDim.x * blockIdx.x + threadIdx.x;
+    unsigned int e = blockDim.x * blockIdx.x + threadIdx.x;
 
     while (e < nelmt)
     {
-        size_t index = offset + e * ncoeff;
+        unsigned int index = offset + e * ncoeff;
 
-        for (size_t i = 0; i < ncoeff; i++)
+        for (unsigned int i = 0; i < ncoeff; i++)
         {
             atomicAdd(outptr + assmbptr[index + i],
                       signptr[index + i] * inptr[index + i]);
@@ -23,18 +28,20 @@ __global__ void AssembleKernel(const size_t ncoeff, const size_t nelmt,
 }
 
 template <typename TData>
-__global__ void AssembleKernel(const size_t ncoeff, const size_t nelmt,
-                               const size_t offset, const int *assmbptr,
-                               const TData sign, const TData *inptr,
-                               TData *outptr)
+__global__ void AssembleKernel(const unsigned int ncoeff,
+                               const unsigned int nelmt,
+                               const unsigned int offset,
+                               const int *__restrict assmbptr, const TData sign,
+                               const TData *__restrict inptr,
+                               TData *__restrict outptr)
 {
-    size_t e = blockDim.x * blockIdx.x + threadIdx.x;
+    unsigned int e = blockDim.x * blockIdx.x + threadIdx.x;
 
     while (e < nelmt)
     {
-        size_t index = offset + e * ncoeff;
+        unsigned int index = offset + e * ncoeff;
 
-        for (size_t i = 0; i < ncoeff; i++)
+        for (unsigned int i = 0; i < ncoeff; i++)
         {
             atomicAdd(outptr + assmbptr[index + i], sign * inptr[index + i]);
         }
@@ -43,17 +50,20 @@ __global__ void AssembleKernel(const size_t ncoeff, const size_t nelmt,
 }
 
 template <typename TData>
-__global__ void AssembleKernel(const size_t ncoeff, const size_t nelmt,
-                               const size_t offset, const int *assmbptr,
-                               const TData *inptr, TData *outptr)
+__global__ void AssembleKernel(const unsigned int ncoeff,
+                               const unsigned int nelmt,
+                               const unsigned int offset,
+                               const int *__restrict assmbptr,
+                               const TData *__restrict inptr,
+                               TData *__restrict outptr)
 {
-    size_t e = blockDim.x * blockIdx.x + threadIdx.x;
+    unsigned int e = blockDim.x * blockIdx.x + threadIdx.x;
 
     while (e < nelmt)
     {
-        size_t index = offset + e * ncoeff;
+        unsigned int index = offset + e * ncoeff;
 
-        for (size_t i = 0; i < ncoeff; i++)
+        for (unsigned int i = 0; i < ncoeff; i++)
         {
             atomicAdd(outptr + assmbptr[index + i], inptr[index + i]);
         }
@@ -62,18 +72,21 @@ __global__ void AssembleKernel(const size_t ncoeff, const size_t nelmt,
 }
 
 template <typename TData>
-__global__ void GlobalToLocalKernel(const size_t ncoeff, const size_t nelmt,
-                                    const size_t offset, const int *assmbptr,
-                                    const TData *signptr, const TData *inptr,
-                                    TData *outptr)
+__global__ void GlobalToLocalKernel(const unsigned int ncoeff,
+                                    const unsigned int nelmt,
+                                    const unsigned int offset,
+                                    const int *__restrict assmbptr,
+                                    const TData *__restrict signptr,
+                                    const TData *__restrict inptr,
+                                    TData *__restrict outptr)
 {
-    size_t e = blockDim.x * blockIdx.x + threadIdx.x;
+    unsigned int e = blockDim.x * blockIdx.x + threadIdx.x;
 
     while (e < nelmt)
     {
-        size_t index = offset + e * ncoeff;
+        unsigned int index = offset + e * ncoeff;
 
-        for (size_t i = 0; i < ncoeff; i++)
+        for (unsigned int i = 0; i < ncoeff; i++)
         {
             outptr[index + i] = signptr[index + i] * inptr[assmbptr[index + i]];
         }
@@ -82,18 +95,18 @@ __global__ void GlobalToLocalKernel(const size_t ncoeff, const size_t nelmt,
 }
 
 template <typename TData>
-__global__ void GlobalToLocalKernel(const size_t ncoeff, const size_t nelmt,
-                                    const size_t offset, const int *assmbptr,
-                                    const TData sign, const TData *inptr,
-                                    TData *outptr)
+__global__ void GlobalToLocalKernel(
+    const unsigned int ncoeff, const unsigned int nelmt,
+    const unsigned int offset, const int *__restrict assmbptr, const TData sign,
+    const TData *__restrict inptr, TData *__restrict outptr)
 {
-    size_t e = blockDim.x * blockIdx.x + threadIdx.x;
+    unsigned int e = blockDim.x * blockIdx.x + threadIdx.x;
 
     while (e < nelmt)
     {
-        size_t index = offset + e * ncoeff;
+        unsigned int index = offset + e * ncoeff;
 
-        for (size_t i = 0; i < ncoeff; i++)
+        for (unsigned int i = 0; i < ncoeff; i++)
         {
             outptr[index + i] = sign * inptr[assmbptr[index + i]];
         }
@@ -102,17 +115,20 @@ __global__ void GlobalToLocalKernel(const size_t ncoeff, const size_t nelmt,
 }
 
 template <typename TData>
-__global__ void GlobalToLocalKernel(const size_t ncoeff, const size_t nelmt,
-                                    const size_t offset, const int *assmbptr,
-                                    const TData *inptr, TData *outptr)
+__global__ void GlobalToLocalKernel(const unsigned int ncoeff,
+                                    const unsigned int nelmt,
+                                    const unsigned int offset,
+                                    const int *__restrict assmbptr,
+                                    const TData *__restrict inptr,
+                                    TData *__restrict outptr)
 {
-    size_t e = blockDim.x * blockIdx.x + threadIdx.x;
+    unsigned int e = blockDim.x * blockIdx.x + threadIdx.x;
 
     while (e < nelmt)
     {
-        size_t index = offset + e * ncoeff;
+        unsigned int index = offset + e * ncoeff;
 
-        for (size_t i = 0; i < ncoeff; i++)
+        for (unsigned int i = 0; i < ncoeff; i++)
         {
             outptr[index + i] = inptr[assmbptr[index + i]];
         }
diff --git a/Operators/AssmbScatr/AssmbScatrStdMat.hpp b/Operators/AssmbScatr/AssmbScatrStdMat.hpp
index 7f51eb94..70bdd2d5 100644
--- a/Operators/AssmbScatr/AssmbScatrStdMat.hpp
+++ b/Operators/AssmbScatr/AssmbScatrStdMat.hpp
@@ -1,10 +1,10 @@
 #pragma once
 
-#include "Operators/OperatorAssmbScatr.hpp"
-
 #include <MultiRegions/AssemblyMap/AssemblyMapCG.h>
 #include <MultiRegions/ContField.h>
 
+#include "Operators/OperatorAssmbScatr.hpp"
+
 using namespace Nektar;
 using namespace Nektar::MultiRegions;
 using namespace Nektar::Operators;
diff --git a/Operators/DirBndCond/DirBndCondCUDA.hpp b/Operators/DirBndCond/DirBndCondCUDA.hpp
index a57d3ed4..6cf11a9b 100644
--- a/Operators/DirBndCond/DirBndCondCUDA.hpp
+++ b/Operators/DirBndCond/DirBndCondCUDA.hpp
@@ -1,14 +1,14 @@
 #pragma once
 
+#include <MultiRegions/AssemblyMap/AssemblyMapCG.h>
+#include <MultiRegions/ContField.h>
+#include <SpatialDomains/Conditions.h>
+
 #include "MemoryRegionCUDA.hpp"
 #include "Operators/DirBndCond/DirBndCondCUDAKernels.cuh"
 #include "Operators/OperatorDirBndCond.hpp"
 #include "Operators/OperatorHelper.cuh"
 
-#include <MultiRegions/AssemblyMap/AssemblyMapCG.h>
-#include <MultiRegions/ContField.h>
-#include <SpatialDomains/Conditions.h>
-
 using namespace Nektar;
 using namespace Nektar::MultiRegions;
 using namespace Nektar::SpatialDomains;
diff --git a/Operators/DirBndCond/DirBndCondCUDAKernels.cuh b/Operators/DirBndCond/DirBndCondCUDAKernels.cuh
index be9ac0a9..167d659d 100644
--- a/Operators/DirBndCond/DirBndCondCUDAKernels.cuh
+++ b/Operators/DirBndCond/DirBndCondCUDAKernels.cuh
@@ -1,3 +1,5 @@
+#pragma once
+
 #include <SpatialDomains/Conditions.h>
 
 using namespace Nektar;
@@ -7,20 +9,21 @@ namespace Nektar::Operators::detail
 {
 
 template <typename TData>
-__global__ void DirBndCondKernel(const size_t nsize, const int *offsetptr,
-                                 const BoundaryConditionType *bctypeptr,
-                                 const int *ncoeffptr, const int *mapptr,
-                                 const TData *inptr, TData *outptr)
+__global__ void DirBndCondKernel(
+    const unsigned int nsize, const int *__restrict offsetptr,
+    const BoundaryConditionType *__restrict bctypeptr,
+    const int *__restrict ncoeffptr, const int *__restrict mapptr,
+    const TData *__restrict inptr, TData *__restrict outptr)
 {
-    size_t i = blockDim.x * blockIdx.x + threadIdx.x;
+    unsigned int i = blockDim.x * blockIdx.x + threadIdx.x;
 
     while (i < nsize)
     {
         if (bctypeptr[i] == eDirichlet)
         {
-            size_t offset = offsetptr[i];
-            size_t ncoeff = ncoeffptr[i];
-            for (size_t j = 0; j < ncoeff; j++)
+            unsigned int offset = offsetptr[i];
+            unsigned int ncoeff = ncoeffptr[i];
+            for (unsigned int j = 0; j < ncoeff; j++)
             {
                 outptr[mapptr[offset + j]] = inptr[offset + j];
             }
@@ -30,21 +33,22 @@ __global__ void DirBndCondKernel(const size_t nsize, const int *offsetptr,
 }
 
 template <typename TData>
-__global__ void DirBndCondKernel(const size_t nsize, const int *offsetptr,
-                                 const BoundaryConditionType *bctypeptr,
-                                 const int *ncoeffptr, const TData *signptr,
-                                 const int *mapptr, const TData *inptr,
-                                 TData *outptr)
+__global__ void DirBndCondKernel(
+    const unsigned int nsize, const int *__restrict offsetptr,
+    const BoundaryConditionType *__restrict bctypeptr,
+    const int *__restrict ncoeffptr, const TData *__restrict signptr,
+    const int *__restrict mapptr, const TData *__restrict inptr,
+    TData *__restrict outptr)
 {
-    size_t i = blockDim.x * blockIdx.x + threadIdx.x;
+    unsigned int i = blockDim.x * blockIdx.x + threadIdx.x;
 
     while (i < nsize)
     {
         if (bctypeptr[i] == eDirichlet)
         {
-            size_t offset = offsetptr[i];
-            size_t ncoeff = ncoeffptr[i];
-            for (size_t j = 0; j < ncoeff; j++)
+            unsigned int offset = offsetptr[i];
+            unsigned int ncoeff = ncoeffptr[i];
+            for (unsigned int j = 0; j < ncoeff; j++)
             {
                 outptr[mapptr[offset + j]] =
                     signptr[offset + j] * inptr[offset + j];
@@ -55,11 +59,13 @@ __global__ void DirBndCondKernel(const size_t nsize, const int *offsetptr,
 }
 
 template <typename TData>
-__global__ void LocalDirBndCondKernel(const size_t nsize, const int *id0ptr,
-                                      const int *id1ptr, const TData *signptr,
-                                      TData *outptr)
+__global__ void LocalDirBndCondKernel(const unsigned int nsize,
+                                      const int *__restrict id0ptr,
+                                      const int *__restrict id1ptr,
+                                      const TData *__restrict signptr,
+                                      TData *__restrict outptr)
 {
-    size_t i = blockDim.x * blockIdx.x + threadIdx.x;
+    unsigned int i = blockDim.x * blockIdx.x + threadIdx.x;
 
     while (i < nsize)
     {
diff --git a/Operators/DirBndCond/DirBndCondStdMat.hpp b/Operators/DirBndCond/DirBndCondStdMat.hpp
index 1a5a1de3..5674e8fb 100644
--- a/Operators/DirBndCond/DirBndCondStdMat.hpp
+++ b/Operators/DirBndCond/DirBndCondStdMat.hpp
@@ -1,10 +1,10 @@
 #pragma once
 
-#include "Operators/OperatorDirBndCond.hpp"
-
 #include <MultiRegions/AssemblyMap/AssemblyMapCG.h>
 #include <MultiRegions/ContField.h>
 
+#include "Operators/OperatorDirBndCond.hpp"
+
 using namespace Nektar;
 using namespace Nektar::MultiRegions;
 
diff --git a/Operators/Helmholtz/HelmholtzCUDAKernels.cuh b/Operators/Helmholtz/HelmholtzCUDAKernels.cuh
index 65b37905..001e7d6f 100644
--- a/Operators/Helmholtz/HelmholtzCUDAKernels.cuh
+++ b/Operators/Helmholtz/HelmholtzCUDAKernels.cuh
@@ -1,27 +1,31 @@
+#pragma once
+
 namespace Nektar::Operators::detail
 {
 
 template <typename TData>
-__global__ void DiffusionCoeff1DKernel(const size_t nsize,
+__global__ void DiffusionCoeff1DKernel(const unsigned int nsize,
                                        const TData *diffCoeff, TData *deriv0)
 {
-    size_t i = blockDim.x * blockIdx.x + threadIdx.x;
+    unsigned int i = blockDim.x * blockIdx.x + threadIdx.x;
 
     while (i < nsize)
     {
         deriv0[i] *= diffCoeff[0];
+
         i += blockDim.x * gridDim.x;
     }
 }
 
 template <typename TData>
-__global__ void DiffusionCoeff2DKernel(const size_t nsize,
+__global__ void DiffusionCoeff2DKernel(const unsigned int nsize,
                                        const TData *diffCoeff, TData *deriv0,
                                        TData *deriv1)
 {
     __shared__ TData s_diffCoeff[4];
 
-    size_t ind = threadIdx.x;
+    // Copy to shared memory.
+    unsigned int ind = threadIdx.x;
     if (ind < 4)
     {
         s_diffCoeff[ind] = diffCoeff[ind];
@@ -29,7 +33,7 @@ __global__ void DiffusionCoeff2DKernel(const size_t nsize,
 
     __syncthreads();
 
-    size_t i = blockDim.x * blockIdx.x + threadIdx.x;
+    unsigned int i = blockDim.x * blockIdx.x + threadIdx.x;
 
     while (i < nsize)
     {
@@ -37,18 +41,20 @@ __global__ void DiffusionCoeff2DKernel(const size_t nsize,
 
         deriv0[i] = s_diffCoeff[0] * deriv[0] + s_diffCoeff[1] * deriv[1];
         deriv1[i] = s_diffCoeff[2] * deriv[0] + s_diffCoeff[3] * deriv[1];
+
         i += blockDim.x * gridDim.x;
     }
 }
 
 template <typename TData>
-__global__ void DiffusionCoeff3DKernel(const size_t nsize, TData *diffCoeff,
-                                       TData *deriv0, TData *deriv1,
-                                       TData *deriv2)
+__global__ void DiffusionCoeff3DKernel(const unsigned int nsize,
+                                       TData *diffCoeff, TData *deriv0,
+                                       TData *deriv1, TData *deriv2)
 {
     __shared__ TData s_diffCoeff[9];
 
-    size_t ind = threadIdx.x;
+    // Copy to shared memory.
+    unsigned int ind = threadIdx.x;
     if (ind < 9)
     {
         s_diffCoeff[ind] = diffCoeff[ind];
@@ -56,7 +62,7 @@ __global__ void DiffusionCoeff3DKernel(const size_t nsize, TData *diffCoeff,
 
     __syncthreads();
 
-    size_t i = blockDim.x * blockIdx.x + threadIdx.x;
+    unsigned int i = blockDim.x * blockIdx.x + threadIdx.x;
 
     while (i < nsize)
     {
@@ -68,6 +74,7 @@ __global__ void DiffusionCoeff3DKernel(const size_t nsize, TData *diffCoeff,
                     s_diffCoeff[5] * deriv[2];
         deriv2[i] = s_diffCoeff[6] * deriv[0] + s_diffCoeff[7] * deriv[1] +
                     s_diffCoeff[8] * deriv[2];
+
         i += blockDim.x * gridDim.x;
     }
 }
diff --git a/Operators/IProductWRTBase/IProductWRTBaseCUDA.cu b/Operators/IProductWRTBase/IProductWRTBaseCUDA.cu
index 5e7286bf..b914b1af 100644
--- a/Operators/IProductWRTBase/IProductWRTBaseCUDA.cu
+++ b/Operators/IProductWRTBase/IProductWRTBaseCUDA.cu
@@ -2,9 +2,12 @@
 
 namespace Nektar::Operators::detail
 {
+
+// Add different IProductWRTBase implementations to the factory.
 template <>
 std::string OperatorIProductWRTBaseImpl<double, ImplCUDA>::className =
     GetOperatorFactory<double>().RegisterCreatorFunction(
         "IProductWRTBaseCUDA",
         OperatorIProductWRTBaseImpl<double, ImplCUDA>::instantiate, "...");
-}
+
+} // namespace Nektar::Operators::detail
diff --git a/Operators/IProductWRTBase/IProductWRTBaseCUDA.hpp b/Operators/IProductWRTBase/IProductWRTBaseCUDA.hpp
index dc2632c0..d738ce65 100644
--- a/Operators/IProductWRTBase/IProductWRTBaseCUDA.hpp
+++ b/Operators/IProductWRTBase/IProductWRTBaseCUDA.hpp
@@ -5,40 +5,42 @@
 #include "Operators/OperatorHelper.cuh"
 #include "Operators/OperatorIProductWRTBase.hpp"
 
+#define FLAG_QP false
+
 namespace Nektar::Operators::detail
 {
 
 template <typename TData, bool SCALE = false, bool APPEND = false,
           bool DEFORMED = false>
-void IProductWRTBase1DKernel(const size_t gridSize, const size_t blockSize,
-                             const size_t nm0, const size_t nq0,
-                             const size_t nElmts, const TData *basis0,
+void IProductWRTBase1DKernel(const unsigned int gridSize,
+                             const unsigned int blockSize,
+                             const unsigned int nm0, const unsigned int nq0,
+                             const unsigned int nElmts, const TData *basis0,
                              const TData *w0, const TData *jac, const TData *in,
                              TData *out, TData scale = 1.0);
 
 template <typename TData, bool SCALE = false, bool APPEND = false,
           bool DEFORMED = false>
-void IProductWRTBase2DKernel(const size_t gridSize, const size_t blockSize,
+void IProductWRTBase2DKernel(const unsigned int gridSize,
+                             const unsigned int blockSize,
                              LibUtilities::ShapeType shapetype,
-                             const size_t nm0, const size_t nm1,
-                             const size_t nq0, const size_t nq1,
-                             const size_t nElmts, const bool correct,
+                             const unsigned int nm0, const unsigned int nm1,
+                             const unsigned int nq0, const unsigned int nq1,
+                             const unsigned int nElmts, const bool correct,
                              const TData *basis0, const TData *basis1,
                              const TData *w0, const TData *w1, const TData *jac,
                              const TData *in, TData *out, TData scale = 1.0);
 
 template <typename TData, bool SCALE = false, bool APPEND = false,
           bool DEFORMED = false>
-void IProductWRTBase3DKernel(const size_t gridSize, const size_t blockSize,
-                             LibUtilities::ShapeType shapetype,
-                             const size_t nm0, const size_t nm1,
-                             const size_t nm2, const size_t nq0,
-                             const size_t nq1, const size_t nq2,
-                             const size_t nElmts, const bool correct,
-                             const TData *basis0, const TData *basis1,
-                             const TData *basis2, const TData *w0,
-                             const TData *w1, const TData *w2, const TData *jac,
-                             const TData *in, TData *out, TData scale = 1.0);
+void IProductWRTBase3DKernel(
+    const unsigned int gridSize, const unsigned int blockSize,
+    LibUtilities::ShapeType shapetype, const unsigned int nm0,
+    const unsigned int nm1, const unsigned int nm2, const unsigned int nq0,
+    const unsigned int nq1, const unsigned int nq2, const unsigned int nElmts,
+    const bool correct, const TData *basis0, const TData *basis1,
+    const TData *basis2, const TData *w0, const TData *w1, const TData *w2,
+    const TData *jac, const TData *in, TData *out, TData scale = 1.0);
 
 // IProductWRTBase implementation
 template <typename TData>
@@ -270,91 +272,193 @@ private:
     std::map<std::vector<LibUtilities::BasisKey>, std::vector<TData *>>
         m_weight;
     TData *m_jac;
+    size_t m_gridSize  = 1024;
     size_t m_blockSize = 32;
-    size_t m_gridSize;
 };
 
 template <typename TData, bool SCALE, bool APPEND, bool DEFORMED>
-void IProductWRTBase1DKernel(const size_t gridSize, const size_t blockSize,
-                             const size_t nm0, const size_t nq0,
-                             const size_t nElmts, const TData *basis0,
+void IProductWRTBase1DKernel(const unsigned int gridSize,
+                             const unsigned int blockSize,
+                             const unsigned int nm0, const unsigned int nq0,
+                             const unsigned int nElmts, const TData *basis0,
                              const TData *w0, const TData *jac, const TData *in,
                              TData *out, TData scale)
 {
-    IProductWRTBaseSegKernel<TData, SCALE, APPEND, DEFORMED>
-        <<<gridSize, blockSize>>>(nm0, nq0, nElmts, basis0, w0, jac, in, out,
-                                  scale);
+    if (!FLAG_QP)
+    {
+        unsigned int nshared = sizeof(TData) * (nm0 * nq0 + nq0);
+        IProductWRTBaseSegKernel<TData, SCALE, APPEND, DEFORMED>
+            <<<gridSize, blockSize, nshared>>>(nm0, nq0, nElmts, basis0, w0,
+                                               jac, in, out, scale);
+    }
+    else
+    {
+        unsigned int nshared = sizeof(TData) * (nq0);
+        IProductWRTBaseSegKernel<TData, SCALE, APPEND, DEFORMED>
+            <<<gridSize, dim3(32), nshared>>>(nm0, nq0, nElmts, basis0, w0, jac,
+                                              in, out, scale);
+    }
 }
 
 template <typename TData, bool SCALE, bool APPEND, bool DEFORMED>
-void IProductWRTBase2DKernel(const size_t gridSize, const size_t blockSize,
+void IProductWRTBase2DKernel(const unsigned int gridSize,
+                             const unsigned int blockSize,
                              LibUtilities::ShapeType shapetype,
-                             const size_t nm0, const size_t nm1,
-                             const size_t nq0, const size_t nq1,
-                             const size_t nElmts, const bool correct,
+                             const unsigned int nm0, const unsigned int nm1,
+                             const unsigned int nq0, const unsigned int nq1,
+                             const unsigned int nElmts, const bool correct,
                              const TData *basis0, const TData *basis1,
                              const TData *w0, const TData *w1, const TData *jac,
                              const TData *in, TData *out, TData scale)
 {
     if (shapetype == LibUtilities::Quad)
     {
-        IProductWRTBaseQuadKernel<TData, SCALE, APPEND, DEFORMED>
-            <<<gridSize, blockSize>>>(nm0, nm1, nq0, nq1, nElmts, basis0,
-                                      basis1, w0, w1, jac, in, out, scale);
+        unsigned int nmTot =
+            LibUtilities::StdQuadData::getNumberOfCoefficients(nm0, nm1);
+        if (!FLAG_QP)
+        {
+            unsigned int nshared =
+                sizeof(TData) * (nm0 * nq0 + nm1 * nq1 + nq0 + nq1);
+            IProductWRTBaseQuadKernel<TData, SCALE, APPEND, DEFORMED>
+                <<<gridSize, blockSize, nshared>>>(nm0, nm1, nmTot, nq0, nq1,
+                                                   nElmts, basis0, basis1, w0,
+                                                   w1, jac, in, out, scale);
+        }
+        else
+        {
+            unsigned int nshared = sizeof(TData) * (nq0 * nq1 + nm0 * nq1);
+            IProductWRTBaseQuadKernel_QP<TData, SCALE, APPEND, DEFORMED>
+                <<<gridSize, dim3(8, 8), nshared>>>(nm0, nm1, nmTot, nq0, nq1,
+                                                    nElmts, basis0, basis1, w0,
+                                                    w1, jac, in, out, scale);
+        }
     }
     else if (shapetype == LibUtilities::Tri)
     {
-        size_t nmTot =
+        unsigned int nmTot =
             LibUtilities::StdTriData::getNumberOfCoefficients(nm0, nm1);
-        IProductWRTBaseTriKernel<TData, SCALE, APPEND, DEFORMED>
-            <<<gridSize, blockSize>>>(nm0, nm1, nmTot, nq0, nq1, nElmts,
-                                      correct, basis0, basis1, w0, w1, jac, in,
-                                      out, scale);
+        if (!FLAG_QP)
+        {
+            unsigned int nshared = sizeof(TData) * (nq0 + nq1);
+            IProductWRTBaseTriKernel<TData, SCALE, APPEND, DEFORMED>
+                <<<gridSize, blockSize, nshared>>>(
+                    nm0, nm1, nmTot, nq0, nq1, nElmts, correct, basis0, basis1,
+                    w0, w1, jac, in, out, scale);
+        }
+        else
+        {
+            unsigned int nshared = sizeof(TData) * (nq0 * nq1 + nm0 * nq1 + 1);
+            IProductWRTBaseTriKernel_QP<TData, SCALE, APPEND, DEFORMED>
+                <<<gridSize, dim3(8, 8), nshared>>>(
+                    nm0, nm1, nmTot, nq0, nq1, nElmts, correct, basis0, basis1,
+                    w0, w1, jac, in, out, scale);
+        }
     }
 }
 
 template <typename TData, bool SCALE, bool APPEND, bool DEFORMED>
 void IProductWRTBase3DKernel(
-    const size_t gridSize, const size_t blockSize,
-    LibUtilities::ShapeType shapetype, const size_t nm0, const size_t nm1,
-    const size_t nm2, const size_t nq0, const size_t nq1, const size_t nq2,
-    const size_t nElmts, const bool correct, const TData *basis0,
-    const TData *basis1, const TData *basis2, const TData *w0, const TData *w1,
-    const TData *w2, const TData *jac, const TData *in, TData *out, TData scale)
+    const unsigned int gridSize, const unsigned int blockSize,
+    LibUtilities::ShapeType shapetype, const unsigned int nm0,
+    const unsigned int nm1, const unsigned int nm2, const unsigned int nq0,
+    const unsigned int nq1, const unsigned int nq2, const unsigned int nElmts,
+    const bool correct, const TData *basis0, const TData *basis1,
+    const TData *basis2, const TData *w0, const TData *w1, const TData *w2,
+    const TData *jac, const TData *in, TData *out, TData scale)
 {
     if (shapetype == LibUtilities::Hex)
     {
-        IProductWRTBaseHexKernel<TData, SCALE, APPEND, DEFORMED>
-            <<<gridSize, blockSize>>>(nm0, nm1, nm2, nq0, nq1, nq2, nElmts,
-                                      basis0, basis1, basis2, w0, w1, w2, jac,
-                                      in, out, scale);
+        unsigned int nmTot =
+            LibUtilities::StdHexData::getNumberOfCoefficients(nm0, nm1, nm2);
+        if (!FLAG_QP)
+        {
+            unsigned int nshared =
+                sizeof(TData) *
+                (nm0 * nq0 + nm1 * nq1 + nm2 * nq2 + nq0 + nq1 + nq2);
+            IProductWRTBaseHexKernel<TData, SCALE, APPEND, DEFORMED>
+                <<<gridSize, blockSize, nshared>>>(
+                    nm0, nm1, nm2, nmTot, nq0, nq1, nq2, nElmts, basis0, basis1,
+                    basis2, w0, w1, w2, jac, in, out, scale);
+        }
+        else
+        {
+            unsigned int nshared =
+                sizeof(TData) *
+                (nq0 * nq1 * nq2 + nm0 * nq1 * nq2 + nm0 * nm1 * nq2);
+            IProductWRTBaseHexKernel_QP<TData, SCALE, APPEND, DEFORMED>
+                <<<gridSize, dim3(4, 4, 4), nshared>>>(
+                    nm0, nm1, nm2, nmTot, nq0, nq1, nq2, nElmts, basis0, basis1,
+                    basis2, w0, w1, w2, jac, in, out, scale);
+        }
     }
     else if (shapetype == LibUtilities::Tet)
     {
-        size_t nmTot =
+        unsigned int nmTot =
             LibUtilities::StdTetData::getNumberOfCoefficients(nm0, nm1, nm2);
-        IProductWRTBaseTetKernel<TData, SCALE, APPEND, DEFORMED>
-            <<<gridSize, blockSize>>>(nm0, nm1, nm2, nmTot, nq0, nq1, nq2,
-                                      nElmts, correct, basis0, basis1, basis2,
-                                      w0, w1, w2, jac, in, out, scale);
-    }
-    else if (shapetype == LibUtilities::Pyr)
-    {
-        size_t nmTot =
-            LibUtilities::StdPyrData::getNumberOfCoefficients(nm0, nm1, nm2);
-        IProductWRTBasePyrKernel<TData, SCALE, APPEND, DEFORMED>
-            <<<gridSize, blockSize>>>(nm0, nm1, nm2, nmTot, nq0, nq1, nq2,
-                                      nElmts, correct, basis0, basis1, basis2,
-                                      w0, w1, w2, jac, in, out, scale);
+        if (!FLAG_QP)
+        {
+            unsigned int nshared = sizeof(TData) * (nq0 + nq1 + nq2);
+            IProductWRTBaseTetKernel<TData, SCALE, APPEND, DEFORMED>
+                <<<gridSize, blockSize, nshared>>>(
+                    nm0, nm1, nm2, nmTot, nq0, nq1, nq2, nElmts, correct,
+                    basis0, basis1, basis2, w0, w1, w2, jac, in, out, scale);
+        }
+        else
+        {
+            unsigned int nshared =
+                sizeof(TData) * (nq0 * nq1 * nq2 + nm0 * nq1 * nq2 +
+                                 ((2 * nm1 - nm0 + 1) * nm0 / 2) * nq2 + nm2);
+            IProductWRTBaseTetKernel_QP<TData, SCALE, APPEND, DEFORMED>
+                <<<gridSize, dim3(4, 4, 4), nshared>>>(
+                    nm0, nm1, nm2, nmTot, nq0, nq1, nq2, nElmts, correct,
+                    basis0, basis1, basis2, w0, w1, w2, jac, in, out, scale);
+        }
     }
     else if (shapetype == LibUtilities::Prism)
     {
-        size_t nmTot =
+        unsigned int nmTot =
             LibUtilities::StdPrismData::getNumberOfCoefficients(nm0, nm1, nm2);
-        IProductWRTBasePrismKernel<TData, SCALE, APPEND, DEFORMED>
-            <<<gridSize, blockSize>>>(nm0, nm1, nm2, nmTot, nq0, nq1, nq2,
-                                      nElmts, correct, basis0, basis1, basis2,
-                                      w0, w1, w2, jac, in, out, scale);
+        if (!FLAG_QP)
+        {
+            unsigned int nshared = sizeof(TData) * (nq0 + nq1 + nq2);
+            IProductWRTBasePrismKernel<TData, SCALE, APPEND, DEFORMED>
+                <<<gridSize, blockSize, nshared>>>(
+                    nm0, nm1, nm2, nmTot, nq0, nq1, nq2, nElmts, correct,
+                    basis0, basis1, basis2, w0, w1, w2, jac, in, out, scale);
+        }
+        else
+        {
+            unsigned int nshared =
+                sizeof(TData) *
+                (nq0 * nq1 * nq2 + nm0 * nq1 * nq2 + nm0 * nm1 * nq2 + nm1);
+            IProductWRTBasePrismKernel_QP<TData, SCALE, APPEND, DEFORMED>
+                <<<gridSize, dim3(4, 4, 4), nshared>>>(
+                    nm0, nm1, nm2, nmTot, nq0, nq1, nq2, nElmts, correct,
+                    basis0, basis1, basis2, w0, w1, w2, jac, in, out, scale);
+        }
+    }
+    else if (shapetype == LibUtilities::Pyr)
+    {
+        unsigned int nmTot =
+            LibUtilities::StdPyrData::getNumberOfCoefficients(nm0, nm1, nm2);
+        if (!FLAG_QP)
+        {
+            unsigned int nshared = sizeof(TData) * (nq0 + nq1 + nq2);
+            IProductWRTBasePyrKernel<TData, SCALE, APPEND, DEFORMED>
+                <<<gridSize, blockSize, nshared>>>(
+                    nm0, nm1, nm2, nmTot, nq0, nq1, nq2, nElmts, correct,
+                    basis0, basis1, basis2, w0, w1, w2, jac, in, out, scale);
+        }
+        else
+        {
+            unsigned int nshared =
+                sizeof(TData) *
+                (nq0 * nq1 * nq2 + nm0 * nq1 * nq2 + nm0 * nm1 * nq2 + 1);
+            IProductWRTBasePyrKernel_QP<TData, SCALE, APPEND, DEFORMED>
+                <<<gridSize, dim3(4, 4, 4), nshared>>>(
+                    nm0, nm1, nm2, nmTot, nq0, nq1, nq2, nElmts, correct,
+                    basis0, basis1, basis2, w0, w1, w2, jac, in, out, scale);
+        }
     }
 }
 
diff --git a/Operators/IProductWRTBase/IProductWRTBaseCUDAKernels.cuh b/Operators/IProductWRTBase/IProductWRTBaseCUDAKernels.cuh
index e4fd4f06..2e625975 100644
--- a/Operators/IProductWRTBase/IProductWRTBaseCUDAKernels.cuh
+++ b/Operators/IProductWRTBase/IProductWRTBaseCUDAKernels.cuh
@@ -1,662 +1,2015 @@
+#pragma once
+
 namespace Nektar::Operators::detail
 {
+
 template <typename TData, bool SCALE, bool APPEND, bool DEFORMED>
-__global__ void IProductWRTBaseSegKernel(const size_t nm0, const size_t nq0,
-                                         const size_t nelmt,
-                                         const TData *basis0, const TData *w0,
-                                         const TData *jac, const TData *in,
-                                         TData *out, TData scale = 1.0)
+__global__ void IProductWRTBaseSegKernel(
+    const unsigned int nm0, const unsigned int nq0, const unsigned int nelmt,
+    const TData *__restrict basis0, const TData *__restrict w0,
+    const TData *__restrict jac, const TData *__restrict in,
+    TData *__restrict out, TData scale = 1.0)
 {
-    size_t e = blockDim.x * blockIdx.x + threadIdx.x;
-
-    if (e >= nelmt)
-    {
-        return;
-    }
-
-    // Assign pointers.
-    const TData *inptr  = in + (nq0 * e);
-    TData *outptr       = out + (nm0 * e);
-    const TData *jacptr = DEFORMED ? jac + (nq0 * e) : jac + e;
+    extern __shared__ TData shared[];
+    TData *s_basis0 = shared;
+    TData *s_w0     = s_basis0 + nm0 * nq0;
 
-    // Compute inner product.
-    for (size_t p = 0; p < nm0; ++p)
+    // Copy to shared memory.
+    unsigned int sIndex = threadIdx.x;
+    while (sIndex < nm0 * nq0)
     {
-        TData sum = 0.0;
-        for (size_t i = 0; i < nq0; ++i)
-        {
-            TData jac_val = DEFORMED ? jacptr[i] : jacptr[0];
-            sum += inptr[i] * basis0[p * nq0 + i] * jac_val * w0[i];
-        }
-        if (SCALE)
-        {
-            sum *= scale;
-        }
-        outptr[p] = APPEND ? outptr[p] + sum : sum;
+        s_basis0[sIndex] = basis0[sIndex];
+        sIndex += blockDim.x;
     }
-}
-
-template <typename TData, bool SCALE, bool APPEND, bool DEFORMED>
-__global__ void IProductWRTBaseQuadKernel(
-    const size_t nm0, const size_t nm1, const size_t nq0, const size_t nq1,
-    const size_t nelmt, const TData *basis0, const TData *basis1,
-    const TData *w0, const TData *w1, const TData *jac, const TData *in,
-    TData *out, TData scale = 1.0)
-{
-    size_t e = blockDim.x * blockIdx.x + threadIdx.x;
 
-    if (e >= nelmt)
+    sIndex = threadIdx.x;
+    while (sIndex < nq0)
     {
-        return;
+        s_w0[sIndex] = w0[sIndex];
+        sIndex += blockDim.x;
     }
 
-    // Allocate workspace memory.
-    TData *wsp = new TData[nq1];
+    __syncthreads();
 
-    // Assign pointers.
-    const TData *inptr  = in + (nq0 * nq1 * e);
-    TData *outptr       = out + (nm0 * nm1 * e);
-    const TData *jacptr = DEFORMED ? jac + (nq0 * nq1 * e) : jac + e;
+    unsigned int e = blockDim.x * blockIdx.x + threadIdx.x;
 
-    // Compute inner product.
-    for (size_t p = 0; p < nm0; ++p)
+    while (e < nelmt)
     {
-        size_t cnt_ji = 0;
-        for (size_t j = 0; j < nq1; ++j)
+        unsigned int inoffset  = nq0 * e;
+        unsigned int outoffset = nm0 * e;
+
+        for (unsigned int p = 0; p < nm0; ++p)
         {
             TData sum = 0.0;
-            for (size_t i = 0; i < nq0; ++i)
+            for (unsigned int i = 0; i < nq0; ++i)
             {
-                TData jac_val = DEFORMED ? jacptr[j * nq0 + i] : jacptr[0];
-                sum += inptr[cnt_ji++] * basis0[p * nq0 + i] * jac_val * w0[i];
+                unsigned int index    = inoffset + i;
+                unsigned int jacindex = DEFORMED ? index : e;
+                sum +=
+                    in[index] * s_basis0[p * nq0 + i] * jac[jacindex] * s_w0[i];
             }
-            wsp[j] = sum;
-        }
 
-        for (size_t q = 0; q < nm1; ++q)
-        {
-            TData sum = 0.0;
-            for (size_t j = 0; j < nq1; ++j)
+            if constexpr (SCALE)
             {
-                sum += wsp[j] * basis1[q * nq1 + j] * w1[j];
+                sum *= scale;
             }
-            if (SCALE)
+
+            unsigned int index = outoffset + p;
+            if constexpr (APPEND)
             {
-                sum *= scale;
+                out[index] += sum;
+            }
+            else
+            {
+                out[index] = sum;
             }
-            outptr[q * nm0 + p] = APPEND ? outptr[q * nm0 + p] + sum : sum;
         }
+
+        e += blockDim.x * gridDim.x;
     }
-    delete wsp;
 }
 
 template <typename TData, bool SCALE, bool APPEND, bool DEFORMED>
-__global__ void IProductWRTBaseTriKernel(
-    const size_t nm0, const size_t nm1, const size_t nmTot, const size_t nq0,
-    const size_t nq1, const size_t nelmt, const bool correct,
-    const TData *basis0, const TData *basis1, const TData *w0, const TData *w1,
-    const TData *jac, const TData *in, TData *out, TData scale = 1.0)
+__global__ void IProductWRTBaseSegKernel_QP(
+    const unsigned int nm0, const unsigned int nq0, const unsigned int nelmt,
+    const TData *__restrict basis0, const TData *__restrict w0,
+    const TData *__restrict jac, const TData *__restrict in,
+    TData *__restrict out, TData scale = 1.0)
 {
-    size_t e = blockDim.x * blockIdx.x + threadIdx.x;
+    extern __shared__ TData shared[];
+    TData *s_wsp0 = shared;
 
-    if (e >= nelmt)
+    unsigned int e = blockIdx.x;
+
+    while (e < nelmt)
     {
-        return;
-    }
+        unsigned int inoffset  = nq0 * e;
+        unsigned int outoffset = nm0 * e;
 
-    // Allocate workspace memory.
-    TData *wsp = new TData[nq1];
+        // Copy to shared memory.
+        for (unsigned int i = threadIdx.x; i < nq0; i += blockDim.x)
+        {
+            unsigned int index    = inoffset + i;
+            unsigned int jacindex = DEFORMED ? index : e;
+            s_wsp0[i]             = in[index] * jac[jacindex];
+        }
 
-    // Assign pointers.
-    const TData *inptr  = in + (nq0 * nq1 * e);
-    TData *outptr       = out + (nmTot * e);
-    const TData *jacptr = DEFORMED ? jac + (nq0 * nq1 * e) : jac + e;
+        __syncthreads();
 
-    // Compute inner product.
-    size_t mode = 0;
-    for (size_t p = 0; p < nm0; ++p)
-    {
-        size_t eta_idx = 0;
-        for (size_t eta1 = 0; eta1 < nq1; ++eta1)
+        for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
         {
             TData sum = 0.0;
-            for (size_t eta0 = 0; eta0 < nq0; ++eta0)
+            for (unsigned int i = 0; i < nq0; ++i)
             {
-                TData jac_val =
-                    DEFORMED ? jacptr[eta1 * nq0 + eta0] : jacptr[0];
-                sum += inptr[eta_idx++] * basis0[p * nq0 + eta0] * jac_val *
-                       w0[eta0];
+                sum += s_wsp0[i] * basis0[p * nq0 + i] * w0[i];
             }
-            wsp[eta1] = sum;
-        }
 
-        for (size_t q = 0; q < nm1 - p; ++q)
-        {
-            TData sum = 0.0;
-            for (size_t eta1 = 0; eta1 < nq1; ++eta1)
+            if constexpr (SCALE)
+            {
+                sum *= scale;
+            }
+
+            unsigned int index = outoffset + p;
+            if constexpr (APPEND)
             {
-                sum += wsp[eta1] * basis1[mode * nq1 + eta1] * w1[eta1];
+                out[index] += sum;
             }
-            if (SCALE)
+            else
             {
-                sum *= scale;
+                out[index] = sum;
             }
-            outptr[mode++] = APPEND ? outptr[mode] + sum : sum;
         }
+
+        __syncthreads();
+
+        e += gridDim.x;
+    }
+}
+
+template <typename TData, bool SCALE, bool APPEND, bool DEFORMED>
+__global__ void IProductWRTBaseQuadKernel(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nmTot,
+    const unsigned int nq0, const unsigned int nq1, const unsigned int nelmt,
+    const TData *__restrict basis0, const TData *__restrict basis1,
+    const TData *__restrict w0, const TData *__restrict w1,
+    const TData *__restrict jac, const TData *__restrict in,
+    TData *__restrict out, TData scale = 1.0)
+{
+    extern __shared__ TData shared[];
+    TData *s_basis0 = shared;
+    TData *s_basis1 = s_basis0 + nm0 * nq0;
+    TData *s_w0     = s_basis1 + nm1 * nq1;
+    TData *s_w1     = s_w0 + nq0;
+
+    // Copy to shared memory.
+    unsigned int sIndex = threadIdx.x;
+    while (sIndex < nm0 * nq0)
+    {
+        s_basis0[sIndex] = basis0[sIndex];
+        sIndex += blockDim.x;
     }
 
-    // Correction for singular vertex in collpased coordinates.
-    // Basically we add phi_1 * phi_01 * (weighting, etc) to mode 00
-    // With contributions from every quadrature point
-    if (correct)
+    sIndex = threadIdx.x;
+    while (sIndex < nm1 * nq1)
     {
-        size_t eta_idx = 0;
-        TData iprod_01 = 0.0;
-        for (size_t eta1 = 0; eta1 < nq1; ++eta1)
-        {
-            TData tmp = w1[eta1] * basis1[nq1 + eta1];
+        s_basis1[sIndex] = basis1[sIndex];
+        sIndex += blockDim.x;
+    }
+
+    sIndex = threadIdx.x;
+    while (sIndex < nq0)
+    {
+        s_w0[sIndex] = w0[sIndex];
+        sIndex += blockDim.x;
+    }
+
+    sIndex = threadIdx.x;
+    while (sIndex < nq1)
+    {
+        s_w1[sIndex] = w1[sIndex];
+        sIndex += blockDim.x;
+    }
+
+    __syncthreads();
+
+    unsigned int e = blockDim.x * blockIdx.x + threadIdx.x;
+
+    // Allocate workspace memory.
+    TData *wsp = new TData[nq1];
+
+    while (e < nelmt)
+    {
+        unsigned int inoffset  = nq0 * nq1 * e;
+        unsigned int outoffset = nmTot * e;
 
-            if (!DEFORMED)
+        for (unsigned int p = 0; p < nm0; ++p)
+        {
+            for (unsigned int j = 0, cnt_ji = 0; j < nq1; ++j)
             {
-                tmp *= jacptr[0];
+                TData sum = 0.0;
+                for (unsigned int i = 0; i < nq0; ++i, ++cnt_ji)
+                {
+                    unsigned int index    = inoffset + cnt_ji;
+                    unsigned int jacindex = DEFORMED ? index : e;
+                    sum += in[index] * s_basis0[p * nq0 + i] * jac[jacindex] *
+                           s_w0[i];
+                }
+                wsp[j] = sum;
             }
 
-            for (size_t eta0 = 0; eta0 < nq0; ++eta0)
+            for (unsigned int q = 0; q < nm1; ++q)
             {
-                TData prod = inptr[eta_idx++] * tmp * w0[eta0];
-                if (DEFORMED)
+                TData sum = 0.0;
+                for (unsigned int j = 0; j < nq1; ++j)
+                {
+                    sum += wsp[j] * s_basis1[q * nq1 + j] * s_w1[j];
+                }
+
+                if constexpr (SCALE)
+                {
+                    sum *= scale;
+                }
+
+                unsigned int index = outoffset + nm0 * q + p;
+                if constexpr (APPEND)
+                {
+                    out[index] += sum;
+                }
+                else
                 {
-                    prod *= jacptr[eta1 * nq0 + eta0];
+                    out[index] = sum;
                 }
-                iprod_01 += prod * basis0[nq0 + eta0];
             }
         }
-        outptr[1] += SCALE ? iprod_01 * scale : iprod_01;
+
+        e += blockDim.x * gridDim.x;
     }
-    delete wsp;
+
+    // Deallocate workspace memory.
+    delete[] wsp;
 }
 
 template <typename TData, bool SCALE, bool APPEND, bool DEFORMED>
-__global__ void IProductWRTBaseHexKernel(
-    const size_t nm0, const size_t nm1, const size_t nm2, const size_t nq0,
-    const size_t nq1, const size_t nq2, const size_t nelmt, const TData *basis0,
-    const TData *basis1, const TData *basis2, const TData *w0, const TData *w1,
-    const TData *w2, const TData *jac, const TData *in, TData *out,
-    TData scale = 1.0)
+__global__ void IProductWRTBaseQuadKernel_QP(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nmTot,
+    const unsigned int nq0, const unsigned int nq1, const unsigned int nelmt,
+    const TData *__restrict basis0, const TData *__restrict basis1,
+    const TData *__restrict w0, const TData *__restrict w1,
+    const TData *__restrict jac, const TData *__restrict in,
+    TData *__restrict out, TData scale = 1.0)
 {
-    size_t e = blockDim.x * blockIdx.x + threadIdx.x;
-
-    if (e >= nelmt)
-    {
-        return;
-    }
-
-    // Allocate workspace memory.
-    TData *wsp0 = new TData[nq2 * nq1];
-    TData *wsp1 = new TData[nq2];
+    extern __shared__ TData shared[];
+    TData *s_wsp0 = shared;
+    TData *s_wsp1 = s_wsp0 + nq0 * nq1;
 
-    // Assign pointers.
-    const TData *inptr  = in + (nq0 * nq1 * nq2 * e);
-    TData *outptr       = out + (nm0 * nm1 * nm2 * e);
-    const TData *jacptr = DEFORMED ? jac + (nq0 * nq1 * nq2 * e) : jac + e;
+    unsigned int e = blockIdx.x;
 
-    // Compute inner product.
-    for (size_t p = 0; p < nm0; ++p)
+    while (e < nelmt)
     {
-        size_t cnt_kji = 0, cnt_kj = 0;
-        for (size_t k = 0; k < nq2; ++k)
+        unsigned int inoffset  = nq0 * nq1 * e;
+        unsigned int outoffset = nmTot * e;
+
+        // Copy to shared memory.
+        for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
         {
-            for (size_t j = 0; j < nq1; ++j)
+            for (unsigned int i = threadIdx.x; i < nq0; i += blockDim.x)
             {
-                TData sum_kj = 0.0;
-                for (size_t i = 0; i < nq0; ++i)
-                {
-                    TData jac_val = DEFORMED
-                                        ? jacptr[nq0 * nq1 * k + nq0 * j + i]
-                                        : jacptr[0];
-                    sum_kj += inptr[cnt_kji++] * basis0[i + nq0 * p] * jac_val *
-                              w0[i];
-                }
-                wsp0[cnt_kj++] = sum_kj;
+                unsigned int cnt_ji   = nq0 * j + i;
+                unsigned int index    = inoffset + cnt_ji;
+                unsigned int jacindex = DEFORMED ? index : e;
+                s_wsp0[cnt_ji]        = in[index] * jac[jacindex];
             }
         }
 
-        for (size_t q = 0; q < nm1; ++q)
+        __syncthreads();
+
+        for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
         {
-            cnt_kj = 0;
-            for (size_t k = 0; k < nq2; ++k)
+            for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
             {
-                TData sum_k = 0.0;
-                for (size_t j = 0; j < nq1; ++j)
+                unsigned int cnt_ji = nq0 * j;
+                unsigned int cnt_pj = nq1 * p + j;
+
+                TData sum = 0.0;
+                for (unsigned int i = 0; i < nq0; ++i, ++cnt_ji)
                 {
-                    sum_k += wsp0[cnt_kj++] * basis1[q * nq1 + j] * w1[j];
+                    sum += s_wsp0[cnt_ji] * basis0[p * nq0 + i] * w0[i];
                 }
-                wsp1[k] = sum_k;
+                s_wsp1[cnt_pj] = sum;
             }
+        }
 
-            for (size_t r = 0; r < nm2; ++r)
+        __syncthreads();
+
+        for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
+        {
+            for (unsigned int q = threadIdx.y; q < nm1; q += blockDim.y)
             {
+                unsigned int cnt_pj = nq1 * p;
+                unsigned int cnt_pq = nm0 * q + p;
+
                 TData sum = 0.0;
-                for (size_t k = 0; k < nq2; ++k)
+                for (unsigned int j = 0; j < nq1; ++j, ++cnt_pj)
                 {
-                    sum += wsp1[k] * basis2[r * nq2 + k] * w2[k];
+                    sum += s_wsp1[cnt_pj] * basis1[q * nq1 + j] * w1[j];
                 }
-                if (SCALE)
+
+                if constexpr (SCALE)
                 {
                     sum *= scale;
                 }
-                outptr[r * nm0 * nm1 + q * nm0 + p] =
-                    APPEND ? outptr[r * nm0 * nm1 + q * nm0 + p] + sum : sum;
+
+                unsigned int index = outoffset + cnt_pq;
+                if constexpr (APPEND)
+                {
+                    out[index] += sum;
+                }
+                else
+                {
+                    out[index] = sum;
+                }
             }
         }
+
+        __syncthreads();
+
+        e += gridDim.x;
     }
-    delete wsp0;
-    delete wsp1;
 }
 
-// NOTE: Not workign when nm2 > nm1
 template <typename TData, bool SCALE, bool APPEND, bool DEFORMED>
-__global__ void IProductWRTBaseTetKernel(
-    const size_t nm0, const size_t nm1, const size_t nm2, const size_t nmTot,
-    const size_t nq0, const size_t nq1, const size_t nq2, const size_t nelmt,
-    const bool correct, const TData *basis0, const TData *basis1,
-    const TData *basis2, const TData *w0, const TData *w1, const TData *w2,
-    const TData *jac, const TData *in, TData *out, TData scale = 1.0)
+__global__ void IProductWRTBaseTriKernel(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nmTot,
+    const unsigned int nq0, const unsigned int nq1, const unsigned int nelmt,
+    const bool correct, const TData *__restrict basis0,
+    const TData *__restrict basis1, const TData *__restrict w0,
+    const TData *__restrict w1, const TData *__restrict jac,
+    const TData *__restrict in, TData *__restrict out, TData scale = 1.0)
 {
-    size_t e = blockDim.x * blockIdx.x + threadIdx.x;
+    extern __shared__ TData shared[];
+    TData *s_w0 = shared;
+    TData *s_w1 = s_w0 + nq0;
 
-    if (e >= nelmt)
+    // Copy to shared memory.
+    unsigned int sIndex = threadIdx.x;
+    while (sIndex < nq0)
     {
-        return;
+        s_w0[sIndex] = w0[sIndex];
+        sIndex += blockDim.x;
     }
 
-    // Allocate workspace memory.
-    TData *wsp0 = new TData[nq2 * nq1];
-    TData *wsp1 = new TData[nq2];
+    sIndex = threadIdx.x;
+    while (sIndex < nq1)
+    {
+        s_w1[sIndex] = w1[sIndex];
+        sIndex += blockDim.x;
+    }
+
+    __syncthreads();
 
-    // Assign pointers.
-    const TData *inptr  = in + (nq0 * nq1 * nq2 * e);
-    TData *outptr       = out + (nmTot * e);
-    const TData *jacptr = DEFORMED ? jac + (nq0 * nq1 * nq2 * e) : jac + e;
+    unsigned int e = blockDim.x * blockIdx.x + threadIdx.x;
+
+    // Allocate workspace memory.
+    TData *wsp = new TData[nq1];
 
-    // Compute inner product.
-    size_t cnt_pqr = 0;
-    for (size_t p = 0, mode = 0, mode2 = 0; p < nm0; ++p)
+    while (e < nelmt)
     {
-        size_t cnt_kji = 0, cnt_kj = 0;
-        for (size_t k = 0; k < nq2; ++k)
+        unsigned int inoffset  = nq0 * nq1 * e;
+        unsigned int outoffset = nmTot * e;
+
+        for (unsigned int p = 0, mode_pq = 0; p < nm0; ++p)
         {
-            for (size_t j = 0; j < nq1; ++j)
+            for (unsigned int j = 0, cnt_ji = 0; j < nq1; ++j)
             {
-                TData jac_val =
-                    DEFORMED ? jacptr[nq0 * nq1 * k + nq0 * j] : jacptr[0];
-                TData sum_kj =
-                    inptr[cnt_kji++] * basis0[nq0 * p] * jac_val * w0[0];
-                for (size_t i = 1; i < nq0; ++i)
+                TData sum = 0.0;
+                for (unsigned int i = 0; i < nq0; ++i, ++cnt_ji)
                 {
-                    jac_val = DEFORMED ? jacptr[nq0 * nq1 * k + nq0 * j + i]
-                                       : jacptr[0];
-                    sum_kj += inptr[cnt_kji++] * basis0[i + nq0 * p] * jac_val *
-                              w0[i];
+                    unsigned int index    = inoffset + cnt_ji;
+                    unsigned int jacindex = DEFORMED ? index : e;
+                    sum += in[index] * basis0[p * nq0 + i] * jac[jacindex] *
+                           s_w0[i];
                 }
-                wsp0[cnt_kj++] = sum_kj;
+                wsp[j] = sum;
             }
-        }
 
-        for (size_t q = 0; q < nm1 - p; ++q, ++mode)
-        {
-            size_t cnt_kj = 0;
-            for (size_t k = 0; k < nq2; ++k)
+            for (unsigned int q = 0; q < nm1 - p; ++q, ++mode_pq)
             {
-                TData sum_k = basis1[mode * nq1] * wsp0[cnt_kj++] * w1[0];
-                for (size_t j = 1; j < nq1; ++j)
+                TData sum = 0.0;
+                for (unsigned int j = 0; j < nq1; ++j)
                 {
-                    sum_k += basis1[mode * nq1 + j] * wsp0[cnt_kj++] * w1[j];
+                    sum += wsp[j] * basis1[mode_pq * nq1 + j] * s_w1[j];
                 }
-                wsp1[k] = sum_k;
-            }
 
-            for (size_t r = 0; r < nm2 - p - q; ++r, ++mode2)
-            {
-                TData tmp = wsp1[0] * basis2[mode2 * nq2] * w2[0];
-                for (size_t k = 1; k < nq2; ++k)
+                if constexpr (SCALE)
                 {
-                    tmp += wsp1[k] * basis2[mode2 * nq2 + k] * w2[k];
+                    sum *= scale;
+                }
+
+                unsigned int index = outoffset + mode_pq;
+                if constexpr (APPEND)
+                {
+                    out[index] += sum;
                 }
-                if (SCALE)
+                else
                 {
-                    tmp *= scale;
+                    out[index] = sum;
                 }
-                outptr[cnt_pqr++] = APPEND ? outptr[cnt_pqr] + tmp : tmp;
             }
         }
-    }
 
-    // Add correction for collapsed coordinate.
-    if (correct)
-    {
-        size_t cnt = 0;
-        for (size_t k = 0; k < nq2; ++k)
+        // Correction for singular vertex in collpased coordinates.
+        // Basically we add phi_1 * phi_01 * (weighting, etc) to mode 00
+        // With contributions from every quadrature point
+        if (correct)
         {
-            TData tmpQ2 = w2[k];
-            if (!DEFORMED)
+            TData iprod_01 = 0.0;
+            for (unsigned int j = 0, cnt_ji = 0; j < nq1; ++j)
             {
-                tmpQ2 *= jacptr[0];
-            }
+                unsigned int jacindex = DEFORMED ? inoffset : e;
 
-            for (size_t j = 0; j < nq1; ++j)
-            {
-                TData tmpQ1 = tmpQ2 * w1[j];
-                for (size_t i = 0; i < nq0; ++i)
+                TData tmp = s_w1[j] * basis1[nq1 + j];
+                if constexpr (!DEFORMED)
                 {
-                    // Store jac * quadrature weight
-                    TData tmpQ = tmpQ1 * w0[i];
-
-                    if (DEFORMED)
-                    {
-                        tmpQ *= jacptr[k * nq0 * nq1 + j * nq0 + i];
-                    }
-
-                    // top vertex
-                    //
-                    TData tmp = basis0[i] * basis1[nq1 + j];
-                    tmp += basis0[nq0 + i] * basis1[j];
-                    tmp += basis0[nq0 + i] * basis1[nq1 + j];
-                    tmp *= basis2[nq2 + k];
-                    tmp *= inptr[cnt] * tmpQ;
-
-                    // add to existing entry
-                    outptr[1] += SCALE ? tmp * scale : tmp;
+                    tmp *= jac[jacindex];
+                }
 
-                    // bottom vertex
-                    //
-                    tmp = basis0[nq0 + i] * basis1[nq1 + j] * basis2[k] *
-                          inptr[cnt] * tmpQ;
-                    outptr[nm2] += SCALE ? tmp * scale : tmp;
+                for (unsigned int i = 0; i < nq0; ++i, ++cnt_ji)
+                {
+                    unsigned int index    = inoffset + cnt_ji;
+                    unsigned int jacindex = DEFORMED ? index : e;
 
-                    // singular edge
-                    for (size_t r = 1; r < nm2 - 1; ++r)
+                    TData prod = in[index] * tmp * s_w0[i];
+                    if constexpr (DEFORMED)
                     {
-                        tmp = basis2[(r + 1) * nq2 + k] * basis1[nq1 + j] *
-                              basis0[nq0 + i] * inptr[cnt] * tmpQ;
-                        outptr[nm2 + r] += SCALE ? tmp * scale : tmp;
+                        prod *= jac[jacindex];
                     }
-                    cnt++;
+                    iprod_01 += prod * basis0[nq0 + i];
                 }
             }
+
+            unsigned int index = outoffset + 1;
+            if constexpr (SCALE)
+            {
+                out[index] += iprod_01 * scale;
+            }
+            else
+            {
+                out[index] += iprod_01;
+            }
         }
+
+        e += blockDim.x * gridDim.x;
     }
-    delete wsp0;
-    delete wsp1;
+
+    // Deallocate workspace memory.
+    delete[] wsp;
 }
 
 template <typename TData, bool SCALE, bool APPEND, bool DEFORMED>
-__global__ void IProductWRTBasePrismKernel(
-    const size_t nm0, const size_t nm1, const size_t nm2, const size_t nmTot,
-    const size_t nq0, const size_t nq1, const size_t nq2, const size_t nelmt,
-    const bool correct, const TData *basis0, const TData *basis1,
-    const TData *basis2, const TData *w0, const TData *w1, const TData *w2,
-    const TData *jac, const TData *in, TData *out, TData scale = 1.0)
+__global__ void IProductWRTBaseTriKernel_QP(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nmTot,
+    const unsigned int nq0, const unsigned int nq1, const unsigned int nelmt,
+    const bool correct, const TData *__restrict basis0,
+    const TData *__restrict basis1, const TData *__restrict w0,
+    const TData *__restrict w1, const TData *__restrict jac,
+    const TData *__restrict in, TData *__restrict out, TData scale = 1.0)
 {
-    size_t e = blockDim.x * blockIdx.x + threadIdx.x;
+    extern __shared__ TData shared[];
+    TData *s_wsp0     = shared;
+    TData *s_wsp1     = s_wsp0 + nq0 * nq1;
+    TData *s_iprod_01 = s_wsp1 + nm0 * nq1;
+
+    unsigned int e = blockIdx.x;
 
-    if (e >= nelmt)
+    while (e < nelmt)
     {
-        return;
-    }
+        unsigned int inoffset  = nq0 * nq1 * e;
+        unsigned int outoffset = nmTot * e;
 
-    // Allocate workspace memory.
-    TData *wsp0 = new TData[nq2 * nq1];
-    TData *wsp1 = new TData[nq2];
-    TData *wsp2 = new TData[nm1];
+        // Copy to shared memory.
+        for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
+        {
+            for (unsigned int i = threadIdx.x; i < nq0; i += blockDim.x)
+            {
+                unsigned int cnt_ji   = nq0 * j + i;
+                unsigned int index    = inoffset + cnt_ji;
+                unsigned int jacindex = DEFORMED ? index : e;
+                s_wsp0[cnt_ji]        = in[index] * jac[jacindex];
+            }
+        }
 
-    // Assign pointers.
-    const TData *inptr  = in + (nq0 * nq1 * nq2 * e);
-    TData *outptr       = out + (nmTot * e);
-    const TData *jacptr = DEFORMED ? jac + (nq0 * nq1 * nq2 * e) : jac + e;
+        __syncthreads();
 
-    // Compute inner product.
-    size_t mode_pr = 0, mode_pqr = 0;
-    for (size_t p = 0; p < nm0; ++p)
-    {
-        size_t cnt_kji = 0, cnt_kj = 0;
-        for (size_t k = 0; k < nq2; ++k)
+        for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
         {
-            for (size_t j = 0; j < nq1; ++j)
+            for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
             {
-                TData sum_kj = 0.0;
-                for (size_t i = 0; i < nq0; ++i)
+                unsigned int cnt_ji = nq0 * j;
+                unsigned int cnt_pj = nq1 * p + j;
+
+                TData sum = 0.0;
+                for (unsigned int i = 0; i < nq0; ++i, ++cnt_ji)
                 {
-                    TData jac_val = DEFORMED
-                                        ? jacptr[nq0 * nq1 * k + nq0 * j + i]
-                                        : jacptr[0];
-                    sum_kj += inptr[cnt_kji++] * basis0[nq0 * p + i] * jac_val *
-                              w0[i];
+                    sum += s_wsp0[cnt_ji] * basis0[p * nq0 + i] * w0[i];
                 }
-                wsp0[cnt_kj++] = sum_kj;
+                s_wsp1[cnt_pj] = sum;
             }
         }
 
-        for (size_t q = 0; q < nm1; ++q)
+        __syncthreads();
+
+        for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
         {
-            cnt_kj = 0;
-            for (size_t k = 0; k < nq2; ++k)
+            for (unsigned int q = threadIdx.y; q < nm1 - p; q += blockDim.y)
             {
-                TData sum_k = basis1[q * nq1] * w1[0] * wsp0[cnt_kj++];
-                for (size_t j = 1; j < nq1; ++j)
+                unsigned int cnt_pj  = nq1 * p;
+                unsigned int mode_pq = (2 * nm1 - p + 1) * p / 2 + q;
+
+                TData sum = 0.0;
+                for (unsigned int j = 0; j < nq1; ++j, ++cnt_pj)
                 {
-                    sum_k += basis1[q * nq1 + j] * w1[j] * wsp0[cnt_kj++];
+                    sum += s_wsp1[cnt_pj] * basis1[mode_pq * nq1 + j] * w1[j];
                 }
-                wsp1[k] = sum_k;
-            }
 
-            for (int r = 0; r < nm2 - p; ++r)
-            {
-                TData sum_k = basis2[(mode_pr + r) * nq2] * w2[0] * wsp1[0];
-                for (size_t k = 1; k < nq2; ++k)
+                if constexpr (SCALE)
+                {
+                    sum *= scale;
+                }
+
+                unsigned int index = outoffset + mode_pq;
+                if constexpr (APPEND)
                 {
-                    sum_k += basis2[(mode_pr + r) * nq2 + k] * w2[k] * wsp1[k];
+                    out[index] += sum;
                 }
-                if (SCALE)
+                else
                 {
-                    sum_k *= scale;
+                    out[index] = sum;
                 }
-                outptr[mode_pqr++] = APPEND ? outptr[mode_pqr] + sum_k : sum_k;
             }
         }
-        mode_pr += nm2 - p;
-    }
 
-    // Add correction for collapsed coordinate.
-    if (correct)
-    {
-        for (size_t q = 0; q < nm1; ++q)
-        {
-            wsp2[q] = 0.0;
-        }
+        __syncthreads();
 
-        size_t cnt_kji = 0;
-        for (size_t k = 0; k < nq2; ++k)
+        // Correction for singular vertex in collpased coordinates.
+        // Basically we add phi_1 * phi_01 * (weighting, etc) to mode 00
+        // With contributions from every quadrature point
+        if (correct)
         {
-            TData k_weight = w2[k];
-            if (!DEFORMED)
+            if (threadIdx.x == 0 && threadIdx.y == 0)
             {
-                k_weight *= jacptr[0];
+                *s_iprod_01 = 0.0;
             }
 
-            for (size_t j = 0; j < nq1; ++j)
+            __syncthreads();
+
+            for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
             {
-                TData kj_weight = k_weight * w1[j];
-                for (size_t i = 0; i < nq0; ++i)
+                TData tmp = w1[j] * basis1[nq1 + j];
+                for (unsigned int i = threadIdx.x; i < nq0; i += blockDim.x)
                 {
-                    TData prod = kj_weight * w0[i] * inptr[cnt_kji++];
-                    if (DEFORMED)
-                    {
-                        prod *= jacptr[k * nq1 * nq0 + j * nq0 + i];
-                    }
+                    unsigned int cnt_ji = nq0 * j + i;
+                    TData prod          = s_wsp0[cnt_ji] * tmp * w0[i];
+                    atomicAdd(s_iprod_01, prod * basis0[nq0 + i]);
+                }
+            }
 
-                    for (size_t q = 0; q < nm1; ++q)
-                    {
-                        wsp2[q] += prod * basis2[nq2 + k] *
-                                   basis1[q * nq1 + j] * basis0[nq0 + i];
-                    }
+            __syncthreads();
+
+            if (threadIdx.x == 0 && threadIdx.y == 0)
+            {
+                unsigned int index = outoffset + 1;
+                if constexpr (SCALE)
+                {
+                    out[index] += (*s_iprod_01) * scale;
+                }
+                else
+                {
+                    out[index] += (*s_iprod_01);
                 }
             }
         }
 
-        for (size_t q = 0; q < nm1; ++q)
-        {
-            outptr[nm2 * q + 1] += SCALE ? wsp2[q] * scale : wsp2[q];
-        }
+        __syncthreads();
+
+        e += gridDim.x;
     }
-    delete wsp0;
-    delete wsp1;
-    delete wsp2;
 }
 
-// NOTE: Not workign when nm2 > nm1
 template <typename TData, bool SCALE, bool APPEND, bool DEFORMED>
-__global__ void IProductWRTBasePyrKernel(
-    const size_t nm0, const size_t nm1, const size_t nm2, const size_t nmTot,
-    const size_t nq0, const size_t nq1, const size_t nq2, const size_t nelmt,
-    const bool correct, const TData *basis0, const TData *basis1,
-    const TData *basis2, const TData *w0, const TData *w1, const TData *w2,
-    const TData *jac, const TData *in, TData *out, TData scale = 1.0)
+__global__ void IProductWRTBaseHexKernel(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nm2,
+    const unsigned int nmTot, const unsigned int nq0, const unsigned int nq1,
+    const unsigned int nq2, const unsigned int nelmt,
+    const TData *__restrict basis0, const TData *__restrict basis1,
+    const TData *__restrict basis2, const TData *__restrict w0,
+    const TData *__restrict w1, const TData *__restrict w2,
+    const TData *__restrict jac, const TData *__restrict in,
+    TData *__restrict out, TData scale = 1.0)
 {
-    size_t e = blockDim.x * blockIdx.x + threadIdx.x;
+    extern __shared__ TData shared[];
+    TData *s_basis0 = shared;
+    TData *s_basis1 = s_basis0 + nm0 * nq0;
+    TData *s_basis2 = s_basis1 + nm1 * nq1;
+    TData *s_w0     = s_basis2 + nm2 * nq2;
+    TData *s_w1     = s_w0 + nq0;
+    TData *s_w2     = s_w1 + nq1;
+
+    // Copy to shared memory.
+    unsigned int sIndex = threadIdx.x;
+    while (sIndex < nm0 * nq0)
+    {
+        s_basis0[sIndex] = basis0[sIndex];
+        sIndex += blockDim.x;
+    }
+
+    sIndex = threadIdx.x;
+    while (sIndex < nm1 * nq1)
+    {
+        s_basis1[sIndex] = basis1[sIndex];
+        sIndex += blockDim.x;
+    }
+
+    sIndex = threadIdx.x;
+    while (sIndex < nm2 * nq2)
+    {
+        s_basis2[sIndex] = basis2[sIndex];
+        sIndex += blockDim.x;
+    }
+
+    sIndex = threadIdx.x;
+    while (sIndex < nq0)
+    {
+        s_w0[sIndex] = w0[sIndex];
+        sIndex += blockDim.x;
+    }
+
+    sIndex = threadIdx.x;
+    while (sIndex < nq1)
+    {
+        s_w1[sIndex] = w1[sIndex];
+        sIndex += blockDim.x;
+    }
 
-    if (e >= nelmt)
+    sIndex = threadIdx.x;
+    while (sIndex < nq2)
     {
-        return;
+        s_w2[sIndex] = w2[sIndex];
+        sIndex += blockDim.x;
     }
 
+    __syncthreads();
+
+    unsigned int e = blockDim.x * blockIdx.x + threadIdx.x;
+
     // Allocate workspace memory.
     TData *wsp0 = new TData[nq2 * nq1];
     TData *wsp1 = new TData[nq2];
 
-    // Assign pointers.
-    const TData *inptr  = in + (nq0 * nq1 * nq2 * e);
-    TData *outptr       = out + (nmTot * e);
-    const TData *jacptr = DEFORMED ? jac + (nq0 * nq1 * nq2 * e) : jac + e;
-
-    // Compute inner product.
-    size_t mode_pqr = 0;
-    for (size_t p = 0; p < nm0; ++p)
+    while (e < nelmt)
     {
-        size_t cnt_kji = 0, cnt_kj = 0;
-        for (size_t k = 0; k < nq2; ++k)
-        {
-            for (size_t j = 0; j < nq1; ++j)
-            {
-                TData sum_kj = 0.0;
-                for (size_t i = 0; i < nq0; ++i)
-                {
-                    TData jac_val = DEFORMED
-                                        ? jacptr[nq0 * nq1 * k + nq0 * j + i]
-                                        : jacptr[0];
-                    sum_kj += inptr[cnt_kji++] * basis0[nq0 * p + i] * jac_val *
-                              w0[i];
-                }
-                wsp0[cnt_kj++] = sum_kj;
-            }
-        }
+        unsigned int inoffset  = nq0 * nq1 * nq2 * e;
+        unsigned int outoffset = nmTot * e;
 
-        for (size_t q = 0; q < p; ++q)
+        for (unsigned int p = 0; p < nm0; ++p)
         {
-            cnt_kj = 0;
-            for (size_t k = 0; k < nq2; ++k)
+            for (unsigned int k = 0, cnt_kj = 0, cnt_kji = 0; k < nq2; ++k)
             {
-                TData sum_k = basis1[q * nq1] * w1[0] * wsp0[cnt_kj++];
-                for (size_t j = 1; j < nq1; ++j)
+                for (unsigned int j = 0; j < nq1; ++j, ++cnt_kj)
                 {
-                    sum_k += basis1[q * nq1 + j] * w1[j] * wsp0[cnt_kj++];
+                    TData sum_kj = 0.0;
+                    for (unsigned int i = 0; i < nq0; ++i, ++cnt_kji)
+                    {
+                        unsigned int index    = inoffset + cnt_kji;
+                        unsigned int jacindex = DEFORMED ? index : e;
+                        sum_kj += in[index] * s_basis0[i + nq0 * p] *
+                                  jac[jacindex] * s_w0[i];
+                    }
+                    wsp0[cnt_kj] = sum_kj;
                 }
-                wsp1[k] = sum_k;
             }
 
-            for (size_t r = 0; r < nm2 - p; ++r)
+            for (unsigned int q = 0; q < nm1; ++q)
             {
-                TData sum_k = basis2[mode_pqr * nq2] * w2[0] * wsp1[0];
-                for (size_t k = 1; k < nq2; ++k)
+                for (unsigned int k = 0, cnt_kj = 0; k < nq2; ++k)
                 {
-                    sum_k += basis2[mode_pqr * nq2 + k] * w2[k] * wsp1[k];
+                    TData sum_k = 0.0;
+                    for (unsigned int j = 0; j < nq1; ++j, ++cnt_kj)
+                    {
+                        sum_k += wsp0[cnt_kj] * s_basis1[q * nq1 + j] * s_w1[j];
+                    }
+                    wsp1[k] = sum_k;
                 }
-                if (SCALE)
+
+                for (unsigned int r = 0; r < nm2; ++r)
                 {
-                    sum_k *= scale;
+                    unsigned int cnt_rqp = nm0 * nm1 * r + nm0 * q + p;
+
+                    TData sum = 0.0;
+                    for (unsigned int k = 0; k < nq2; ++k)
+                    {
+                        sum += wsp1[k] * s_basis2[r * nq2 + k] * s_w2[k];
+                    }
+
+                    if constexpr (SCALE)
+                    {
+                        sum *= scale;
+                    }
+
+                    unsigned int index = outoffset + cnt_rqp;
+                    if constexpr (APPEND)
+                    {
+                        out[index] += sum;
+                    }
+                    else
+                    {
+                        out[index] = sum;
+                    }
                 }
-                outptr[mode_pqr++] = APPEND ? outptr[mode_pqr] + sum_k : sum_k;
             }
         }
 
-        for (size_t q = p; q < nm1; ++q)
+        e += blockDim.x * gridDim.x;
+    }
+
+    // Deallocate workspace memory.
+    delete[] wsp0;
+    delete[] wsp1;
+}
+
+template <typename TData, bool SCALE, bool APPEND, bool DEFORMED>
+__global__ void IProductWRTBaseHexKernel_QP(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nm2,
+    const unsigned int nmTot, const unsigned int nq0, const unsigned int nq1,
+    const unsigned int nq2, const unsigned int nelmt,
+    const TData *__restrict basis0, const TData *__restrict basis1,
+    const TData *__restrict basis2, const TData *__restrict w0,
+    const TData *__restrict w1, const TData *__restrict w2,
+    const TData *__restrict jac, const TData *__restrict in,
+    TData *__restrict out, TData scale = 1.0)
+{
+    extern __shared__ TData shared[];
+    TData *s_wsp0 = shared;
+    TData *s_wsp1 = s_wsp0 + nq0 * nq1 * nq2;
+    TData *s_wsp2 = s_wsp1 + nm0 * nq1 * nq2;
+
+    unsigned int e = blockIdx.x;
+
+    while (e < nelmt)
+    {
+        unsigned int inoffset  = nq0 * nq1 * nq2 * e;
+        unsigned int outoffset = nmTot * e;
+
+        // Copy to shared memory.
+        for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
         {
-            cnt_kj = 0;
-            for (size_t k = 0; k < nq2; ++k)
+            for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
             {
-                TData sum_k = basis1[q * nq1] * w1[0] * wsp0[cnt_kj++];
-                for (size_t j = 1; j < nq1; ++j)
+                for (unsigned int i = threadIdx.x; i < nq0; i += blockDim.x)
                 {
-                    sum_k += basis1[q * nq1 + j] * w1[j] * wsp0[cnt_kj++];
+                    unsigned int cnt_kji  = nq0 * nq1 * k + nq0 * j + i;
+                    unsigned int index    = inoffset + cnt_kji;
+                    unsigned int jacindex = DEFORMED ? index : e;
+                    s_wsp0[cnt_kji]       = in[index] * jac[jacindex];
                 }
-                wsp1[k] = sum_k;
             }
+        }
+
+        __syncthreads();
 
-            for (size_t r = 0; r < nm2 - q; ++r)
+        for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
+        {
+            for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
             {
-                TData sum_k = basis2[mode_pqr * nq2] * w2[0] * wsp1[0];
-                for (size_t k = 1; k < nq2; ++k)
-                {
-                    sum_k += basis2[mode_pqr * nq2 + k] * w2[k] * wsp1[k];
-                }
-                if (SCALE)
+                for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
                 {
-                    sum_k *= scale;
+                    unsigned int cnt_kji = nq0 * nq1 * k + nq0 * j;
+                    unsigned int cnt_pkj = nq2 * nq1 * p + nq1 * k + j;
+
+                    TData sum_kj = 0.0;
+                    for (unsigned int i = 0; i < nq0; ++i, ++cnt_kji)
+                    {
+                        sum_kj += s_wsp0[cnt_kji] * basis0[i + nq0 * p] * w0[i];
+                    }
+                    s_wsp1[cnt_pkj] = sum_kj;
                 }
-                outptr[mode_pqr++] = APPEND ? outptr[mode_pqr] + sum_k : sum_k;
             }
         }
-    }
 
-    // Add correction for collapsed coordinate.
-    if (correct)
-    {
-        size_t cnt = 0;
-        for (size_t k = 0; k < nq2; ++k)
+        __syncthreads();
+
+        for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
         {
-            TData tmpQ2 = w2[k];
-            if (!DEFORMED)
+            for (unsigned int q = threadIdx.y; q < nm1; q += blockDim.y)
             {
-                tmpQ2 *= jacptr[0];
-            }
+                for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
+                {
+                    unsigned int cnt_pkj = nq2 * nq1 * p + nq1 * k;
+                    unsigned int cnt_pqk = nm1 * nq2 * p + nq2 * q + k;
+
+                    TData sum_k = 0.0;
+                    for (unsigned int j = 0; j < nq1; ++j, ++cnt_pkj)
+                    {
+                        sum_k += s_wsp1[cnt_pkj] * basis1[q * nq1 + j] * w1[j];
+                    }
+                    s_wsp2[cnt_pqk] = sum_k;
+                }
+            }
+        }
+
+        __syncthreads();
+
+        for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
+        {
+            for (unsigned int q = threadIdx.y; q < nm1; q += blockDim.y)
+            {
+                for (unsigned int r = threadIdx.z; r < nm2; r += blockDim.z)
+                {
+                    unsigned int cnt_pqk = nm1 * nq2 * p + nq2 * q;
+                    unsigned int cnt_rqp = nm0 * nm1 * r + nm0 * q + p;
+
+                    TData sum = 0.0;
+                    for (unsigned int k = 0; k < nq2; ++k, ++cnt_pqk)
+                    {
+                        sum += s_wsp2[cnt_pqk] * basis2[r * nq2 + k] * w2[k];
+                    }
+
+                    if constexpr (SCALE)
+                    {
+                        sum *= scale;
+                    }
+
+                    unsigned int index = outoffset + cnt_rqp;
+                    if constexpr (APPEND)
+                    {
+                        out[index] += sum;
+                    }
+                    else
+                    {
+                        out[index] = sum;
+                    }
+                }
+            }
+        }
+
+        __syncthreads();
+
+        e += gridDim.x;
+    }
+}
+
+// NOTE: Not workign when nm2 > nm1
+template <typename TData, bool SCALE, bool APPEND, bool DEFORMED>
+__global__ void IProductWRTBaseTetKernel(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nm2,
+    const unsigned int nmTot, const unsigned int nq0, const unsigned int nq1,
+    const unsigned int nq2, const unsigned int nelmt, const bool correct,
+    const TData *__restrict basis0, const TData *__restrict basis1,
+    const TData *__restrict basis2, const TData *__restrict w0,
+    const TData *__restrict w1, const TData *__restrict w2,
+    const TData *__restrict jac, const TData *__restrict in,
+    TData *__restrict out, TData scale = 1.0)
+{
+    extern __shared__ TData shared[];
+    TData *s_w0 = shared;
+    TData *s_w1 = s_w0 + nq0;
+    TData *s_w2 = s_w1 + nq1;
+
+    // Copy to shared memory.
+    unsigned int sIndex = threadIdx.x;
+    while (sIndex < nq0)
+    {
+        s_w0[sIndex] = w0[sIndex];
+        sIndex += blockDim.x;
+    }
+
+    sIndex = threadIdx.x;
+    while (sIndex < nq1)
+    {
+        s_w1[sIndex] = w1[sIndex];
+        sIndex += blockDim.x;
+    }
+
+    sIndex = threadIdx.x;
+    while (sIndex < nq2)
+    {
+        s_w2[sIndex] = w2[sIndex];
+        sIndex += blockDim.x;
+    }
+
+    __syncthreads();
+
+    unsigned int e = blockDim.x * blockIdx.x + threadIdx.x;
+
+    // Allocate workspace memory.
+    TData *wsp0 = new TData[nq2 * nq1];
+    TData *wsp1 = new TData[nq2];
+    TData *prod = new TData[nm2];
+
+    while (e < nelmt)
+    {
+        unsigned int inoffset  = nq0 * nq1 * nq2 * e;
+        unsigned int outoffset = nmTot * e;
+
+        for (unsigned int p = 0, mode_pq = 0, mode_pqr = 0; p < nm0; ++p)
+        {
+            for (unsigned int k = 0, cnt_kj = 0, cnt_kji = 0; k < nq2; ++k)
+            {
+                for (unsigned int j = 0; j < nq1; ++j, ++cnt_kj)
+                {
+                    TData sum_kj = 0.0;
+                    for (unsigned int i = 0; i < nq0; ++i, ++cnt_kji)
+                    {
+                        unsigned int index    = inoffset + cnt_kji;
+                        unsigned int jacindex = DEFORMED ? index : e;
+                        sum_kj += in[index] * basis0[i + nq0 * p] *
+                                  jac[jacindex] * s_w0[i];
+                    }
+                    wsp0[cnt_kj] = sum_kj;
+                }
+            }
+
+            for (unsigned int q = 0; q < nm1 - p; ++q, ++mode_pq)
+            {
+                for (unsigned int k = 0, cnt_kj = 0; k < nq2; ++k)
+                {
+                    TData sum_k = 0.0;
+                    for (unsigned int j = 0; j < nq1; ++j, ++cnt_kj)
+                    {
+                        sum_k +=
+                            basis1[mode_pq * nq1 + j] * wsp0[cnt_kj] * s_w1[j];
+                    }
+                    wsp1[k] = sum_k;
+                }
+
+                for (unsigned int r = 0; r < nm2 - p - q; ++r, ++mode_pqr)
+                {
+                    TData tmp = 0.0;
+                    for (unsigned int k = 0; k < nq2; ++k)
+                    {
+                        tmp += wsp1[k] * basis2[mode_pqr * nq2 + k] * s_w2[k];
+                    }
+
+                    if constexpr (SCALE)
+                    {
+                        tmp *= scale;
+                    }
+
+                    unsigned int index = outoffset + mode_pqr;
+                    if constexpr (APPEND)
+                    {
+                        out[index] += tmp;
+                    }
+                    else
+                    {
+                        out[index] = tmp;
+                    }
+                }
+            }
+        }
+
+        // Add correction for collapsed coordinate.
+        if (correct)
+        {
+            for (unsigned int r = 0; r < nm2; ++r)
+            {
+                prod[r] = 0.0;
+            }
+
+            for (unsigned int k = 0, cnt_kji = 0; k < nq2; ++k)
+            {
+                TData tmpQ2 = s_w2[k];
+                if constexpr (!DEFORMED)
+                {
+                    tmpQ2 *= jac[e];
+                }
+
+                for (unsigned int j = 0; j < nq1; ++j)
+                {
+                    TData tmpQ1 = tmpQ2 * s_w1[j];
+                    for (unsigned int i = 0; i < nq0; ++i, ++cnt_kji)
+                    {
+                        unsigned int index = inoffset + cnt_kji;
+
+                        // Store jac * quadrature weight
+                        TData tmpQ = tmpQ1 * s_w0[i];
+                        if constexpr (DEFORMED)
+                        {
+                            tmpQ *= jac[inoffset + cnt_kji];
+                        }
+
+                        // top vertex
+                        TData tmp = basis0[i] * basis1[nq1 + j];
+                        tmp += basis0[nq0 + i] * basis1[j];
+                        tmp += basis0[nq0 + i] * basis1[nq1 + j];
+                        tmp *= basis2[nq2 + k];
+                        tmp *= in[index] * tmpQ;
+                        prod[nm2 - 1] += tmp;
+
+                        // bottom vertex
+                        tmp = basis0[nq0 + i] * basis1[nq1 + j] * basis2[k] *
+                              in[index] * tmpQ;
+                        prod[0] += tmp;
+
+                        // singular edge
+                        for (unsigned int r = 1; r < nm2 - 1; ++r)
+                        {
+                            tmp = basis2[(r + 1) * nq2 + k] * basis1[nq1 + j] *
+                                  basis0[nq0 + i] * in[index] * tmpQ;
+                            prod[r] += tmp;
+                        }
+                    }
+                }
+            }
+
+            if constexpr (SCALE)
+            {
+                out[outoffset + 1] += prod[nm2 - 1] * scale;
+                for (unsigned int r = 0; r < nm2 - 1; ++r)
+                {
+                    out[outoffset + nm2 + r] += prod[r] * scale;
+                }
+            }
+            else
+            {
+                out[outoffset + 1] += prod[nm2 - 1];
+                for (unsigned int r = 0; r < nm2 - 1; ++r)
+                {
+                    out[outoffset + nm2 + r] += prod[r];
+                }
+            }
+        }
+
+        e += blockDim.x * gridDim.x;
+    }
+
+    // Deallocate workspace memory.
+    delete[] wsp0;
+    delete[] wsp1;
+    delete[] prod;
+}
+
+// NOTE: Not workign when nm2 > nm1
+template <typename TData, bool SCALE, bool APPEND, bool DEFORMED>
+__global__ void IProductWRTBaseTetKernel_QP(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nm2,
+    const unsigned int nmTot, const unsigned int nq0, const unsigned int nq1,
+    const unsigned int nq2, const unsigned int nelmt, const bool correct,
+    const TData *__restrict basis0, const TData *__restrict basis1,
+    const TData *__restrict basis2, const TData *__restrict w0,
+    const TData *__restrict w1, const TData *__restrict w2,
+    const TData *__restrict jac, const TData *__restrict in,
+    TData *__restrict out, TData scale = 1.0)
+{
+    extern __shared__ TData shared[];
+    TData *s_prod = shared;
+    TData *s_wsp0 = s_prod + nm2;
+    TData *s_wsp1 = s_wsp0 + nq0 * nq1 * nq2;
+    TData *s_wsp2 = s_wsp1 + nm0 * nq1 * nq2;
+
+    unsigned int e = blockIdx.x;
+
+    while (e < nelmt)
+    {
+        unsigned int inoffset  = nq0 * nq1 * nq2 * e;
+        unsigned int outoffset = nmTot * e;
+
+        // Copy to shared memory.
+        for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
+        {
+            for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
+            {
+                for (unsigned int i = threadIdx.x; i < nq0; i += blockDim.x)
+                {
+                    unsigned int cnt_kji  = nq0 * nq1 * k + nq0 * j + i;
+                    unsigned int index    = inoffset + cnt_kji;
+                    unsigned int jacindex = DEFORMED ? index : e;
+                    s_wsp0[cnt_kji]       = in[index] * jac[jacindex];
+                }
+            }
+        }
+
+        __syncthreads();
+
+        for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
+        {
+            for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
+            {
+                for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
+                {
+                    unsigned int cnt_kji = nq0 * nq1 * k + nq0 * j;
+                    unsigned int cnt_pkj = nq1 * nq2 * p + nq1 * k + j;
+
+                    TData sum_kj = 0.0;
+                    for (unsigned int i = 0; i < nq0; ++i, ++cnt_kji)
+                    {
+                        sum_kj += s_wsp0[cnt_kji] * basis0[i + nq0 * p] * w0[i];
+                    }
+                    s_wsp1[cnt_pkj] = sum_kj;
+                }
+            }
+        }
+
+        __syncthreads();
+
+        for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
+        {
+            for (unsigned int q = threadIdx.y; q < nm1 - p; q += blockDim.y)
+            {
+                for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
+                {
+                    unsigned int cnt_pkj = nq1 * nq2 * p + nq1 * k;
+                    unsigned int mode_pq = (2 * nm1 - p + 1) * p / 2 + q;
+
+                    TData sum_k = 0.0;
+                    for (unsigned int j = 0; j < nq1; ++j, ++cnt_pkj)
+                    {
+                        sum_k +=
+                            basis1[mode_pq * nq1 + j] * s_wsp1[cnt_pkj] * w1[j];
+                    }
+                    s_wsp2[mode_pq * nq2 + k] = sum_k;
+                }
+            }
+        }
+
+        __syncthreads();
+
+        for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
+        {
+            for (unsigned int q = threadIdx.y; q < nm1 - p; q += blockDim.y)
+            {
+                for (unsigned int r = threadIdx.z; r < nm2 - p - q;
+                     r += blockDim.z)
+                {
+                    unsigned int mode_pq  = (2 * nm1 - p + 1) * p / 2 + q;
+                    unsigned int mode_pqr = (2 * (nm2 - p) - q + 1) * q;
+                    mode_pqr += nm2 * (nm2 + 1) * p;
+                    mode_pqr -= (2 * nm2 + 1) * (p - 1) * p / 2;
+                    mode_pqr += (p - 1) * p * (2 * p - 1) / 6;
+                    mode_pqr /= 2;
+
+                    TData tmp = 0.0;
+                    for (unsigned int k = 0; k < nq2; ++k)
+                    {
+                        tmp += s_wsp2[mode_pq * nq2 + k] *
+                               basis2[(mode_pqr + r) * nq2 + k] * w2[k];
+                    }
+
+                    if constexpr (SCALE)
+                    {
+                        tmp *= scale;
+                    }
+
+                    unsigned int index = outoffset + mode_pqr + r;
+                    if constexpr (APPEND)
+                    {
+                        out[index] += tmp;
+                    }
+                    else
+                    {
+                        out[index] = tmp;
+                    }
+                }
+            }
+        }
+
+        __syncthreads();
+
+        // Add correction for collapsed coordinate.
+        if (correct)
+        {
+            if (threadIdx.x == 0 && threadIdx.y == 0)
+            {
+                for (unsigned int r = threadIdx.z; r < nm2; r += blockDim.z)
+                {
+                    s_prod[r] = 0.0;
+                }
+            }
+
+            __syncthreads();
+
+            for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
+            {
+                TData tmpQ2 = w2[k];
+                for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
+                {
+                    TData tmpQ1 = tmpQ2 * w1[j];
+                    for (unsigned int i = threadIdx.x; i < nq0; i += blockDim.x)
+                    {
+                        unsigned int cnt_kji = nq1 * nq0 * k + nq0 * j + i;
+
+                        // Store jac * quadrature weight
+                        TData tmpQ = tmpQ1 * w0[i];
+
+                        // top vertex
+                        TData tmp = basis0[i] * basis1[nq1 + j];
+                        tmp += basis0[nq0 + i] * basis1[j];
+                        tmp += basis0[nq0 + i] * basis1[nq1 + j];
+                        tmp *= basis2[nq2 + k];
+                        tmp *= s_wsp0[cnt_kji] * tmpQ;
+                        atomicAdd(s_prod + nm2 - 1, tmp);
+
+                        // bottom vertex
+                        tmp = basis0[nq0 + i] * basis1[nq1 + j] * basis2[k] *
+                              s_wsp0[cnt_kji] * tmpQ;
+                        atomicAdd(s_prod, tmp);
+
+                        // singular edge
+                        for (unsigned int r = 1; r < nm2 - 1; ++r)
+                        {
+                            tmp = basis2[(r + 1) * nq2 + k] * basis1[nq1 + j] *
+                                  basis0[nq0 + i] * s_wsp0[cnt_kji] * tmpQ;
+                            atomicAdd(s_prod + r, tmp);
+                        }
+                    }
+                }
+            }
+
+            __syncthreads();
+
+            if constexpr (SCALE)
+            {
+                if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0)
+                {
+                    out[outoffset + 1] += s_prod[nm2 - 1] * scale;
+                }
+                if (threadIdx.x == 0 && threadIdx.y == 0)
+                {
+                    for (unsigned int r = threadIdx.z; r < nm2 - 1;
+                         r += blockDim.z)
+                    {
+                        out[outoffset + nm2 + r] += s_prod[r] * scale;
+                    }
+                }
+            }
+            else
+            {
+                if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0)
+                {
+                    out[outoffset + 1] += s_prod[nm2 - 1];
+                }
+                if (threadIdx.x == 0 && threadIdx.y == 0)
+                {
+                    for (unsigned int r = threadIdx.z; r < nm2 - 1;
+                         r += blockDim.z)
+                    {
+                        out[outoffset + nm2 + r] += s_prod[r];
+                    }
+                }
+            }
+        }
+
+        __syncthreads();
+
+        e += gridDim.x;
+    }
+}
+
+template <typename TData, bool SCALE, bool APPEND, bool DEFORMED>
+__global__ void IProductWRTBasePrismKernel(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nm2,
+    const unsigned int nmTot, const unsigned int nq0, const unsigned int nq1,
+    const unsigned int nq2, const unsigned int nelmt, const bool correct,
+    const TData *__restrict basis0, const TData *__restrict basis1,
+    const TData *__restrict basis2, const TData *__restrict w0,
+    const TData *__restrict w1, const TData *__restrict w2,
+    const TData *__restrict jac, const TData *__restrict in,
+    TData *__restrict out, TData scale = 1.0)
+{
+    extern __shared__ TData shared[];
+    TData *s_w0 = shared;
+    TData *s_w1 = s_w0 + nq0;
+    TData *s_w2 = s_w1 + nq1;
+
+    // Copy to shared memory.
+    unsigned int sIndex = threadIdx.x;
+    while (sIndex < nq0)
+    {
+        s_w0[sIndex] = w0[sIndex];
+        sIndex += blockDim.x;
+    }
+
+    sIndex = threadIdx.x;
+    while (sIndex < nq1)
+    {
+        s_w1[sIndex] = w1[sIndex];
+        sIndex += blockDim.x;
+    }
+
+    sIndex = threadIdx.x;
+    while (sIndex < nq2)
+    {
+        s_w2[sIndex] = w2[sIndex];
+        sIndex += blockDim.x;
+    }
+
+    __syncthreads();
+
+    unsigned int e = blockDim.x * blockIdx.x + threadIdx.x;
+
+    // Allocate workspace memory.
+    TData *wsp0 = new TData[nq2 * nq1];
+    TData *wsp1 = new TData[nq2];
+    TData *wsp2 = new TData[nm1];
+
+    while (e < nelmt)
+    {
+        unsigned int inoffset  = nq0 * nq1 * nq2 * e;
+        unsigned int outoffset = nmTot * e;
+
+        for (unsigned int p = 0, mode_pqr = 0; p < nm0; ++p)
+        {
+            for (unsigned int k = 0, cnt_kj = 0, cnt_kji = 0; k < nq2; ++k)
+            {
+                for (unsigned int j = 0; j < nq1; ++j, ++cnt_kj)
+                {
+                    TData sum_kj = 0.0;
+                    for (unsigned int i = 0; i < nq0; ++i, ++cnt_kji)
+                    {
+                        unsigned int index    = inoffset + cnt_kji;
+                        unsigned int jacindex = DEFORMED ? index : e;
+                        sum_kj += in[index] * basis0[nq0 * p + i] *
+                                  jac[jacindex] * s_w0[i];
+                    }
+                    wsp0[cnt_kj] = sum_kj;
+                }
+            }
+
+            for (unsigned int q = 0; q < nm1; ++q)
+            {
+                for (unsigned int k = 0, cnt_kj = 0; k < nq2; ++k)
+                {
+                    TData sum_k = 0.0;
+                    for (unsigned int j = 0; j < nq1; ++j, ++cnt_kj)
+                    {
+                        sum_k += basis1[q * nq1 + j] * s_w1[j] * wsp0[cnt_kj];
+                    }
+                    wsp1[k] = sum_k;
+                }
+
+                for (int r = 0; r < nm2 - p; ++r, ++mode_pqr)
+                {
+                    unsigned int mode_pr = (2 * nm2 - p + 1) * p / 2;
+
+                    TData sum_k = 0.0;
+                    for (unsigned int k = 0; k < nq2; ++k)
+                    {
+                        sum_k +=
+                            basis2[(mode_pr + r) * nq2 + k] * s_w2[k] * wsp1[k];
+                    }
+
+                    if (SCALE)
+                    {
+                        sum_k *= scale;
+                    }
+
+                    unsigned int index = outoffset + mode_pqr;
+                    if constexpr (APPEND)
+                    {
+                        out[index] += sum_k;
+                    }
+                    else
+                    {
+                        out[index] = sum_k;
+                    }
+                }
+            }
+        }
+
+        // Add correction for collapsed coordinate.
+        if (correct)
+        {
+            for (unsigned int q = 0; q < nm1; ++q)
+            {
+                wsp2[q] = 0.0;
+            }
+
+            for (unsigned int k = 0, cnt_kji = 0; k < nq2; ++k)
+            {
+                TData k_weight = s_w2[k];
+                if constexpr (!DEFORMED)
+                {
+                    k_weight *= jac[e];
+                }
+
+                for (unsigned int j = 0; j < nq1; ++j)
+                {
+                    TData kj_weight = k_weight * s_w1[j];
+                    for (unsigned int i = 0; i < nq0; ++i, ++cnt_kji)
+                    {
+                        unsigned int index = inoffset + cnt_kji;
+                        TData prod         = kj_weight * s_w0[i] * in[index];
+                        if constexpr (DEFORMED)
+                        {
+                            prod *= jac[index];
+                        }
+
+                        for (unsigned int q = 0; q < nm1; ++q)
+                        {
+                            wsp2[q] += prod * basis2[nq2 + k] *
+                                       basis1[q * nq1 + j] * basis0[nq0 + i];
+                        }
+                    }
+                }
+            }
+
+            for (unsigned int q = 0; q < nm1; ++q)
+            {
+                unsigned int index = outoffset + nm2 * q + 1;
+                if constexpr (SCALE)
+                {
+                    out[index] += wsp2[q] * scale;
+                }
+                else
+                {
+                    out[index] += wsp2[q];
+                }
+            }
+        }
+
+        e += blockDim.x * gridDim.x;
+    }
+
+    // Deallocate workspace memory.
+    delete[] wsp0;
+    delete[] wsp1;
+    delete[] wsp2;
+}
+
+template <typename TData, bool SCALE, bool APPEND, bool DEFORMED>
+__global__ void IProductWRTBasePrismKernel_QP(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nm2,
+    const unsigned int nmTot, const unsigned int nq0, const unsigned int nq1,
+    const unsigned int nq2, const unsigned int nelmt, const bool correct,
+    const TData *__restrict basis0, const TData *__restrict basis1,
+    const TData *__restrict basis2, const TData *__restrict w0,
+    const TData *__restrict w1, const TData *__restrict w2,
+    const TData *__restrict jac, const TData *__restrict in,
+    TData *__restrict out, TData scale = 1.0)
+{
+    extern __shared__ TData shared[];
+    TData *s_wsp0 = shared;
+    TData *s_wsp1 = s_wsp0 + nq0 * nq1 * nq2;
+    TData *s_wsp2 = s_wsp1 + nm0 * nq1 * nq2;
+    TData *s_wsp3 = s_wsp2 + nm0 * nm1 * nq2;
+
+    __syncthreads();
+
+    unsigned int e = blockIdx.x;
+
+    while (e < nelmt)
+    {
+        unsigned int inoffset  = nq0 * nq1 * nq2 * e;
+        unsigned int outoffset = nmTot * e;
+
+        // Copy to shared memory.
+        for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
+        {
+            for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
+            {
+                for (unsigned int i = threadIdx.x; i < nq0; i += blockDim.x)
+                {
+                    unsigned int cnt_kji  = nq1 * nq0 * k + nq0 * j + i;
+                    unsigned int index    = inoffset + cnt_kji;
+                    unsigned int jacindex = DEFORMED ? index : e;
+                    s_wsp0[cnt_kji]       = in[index] * jac[jacindex];
+                }
+            }
+        }
+
+        __syncthreads();
+
+        for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
+        {
+            for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
+            {
+                for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
+                {
+                    unsigned int cnt_kji = nq1 * nq0 * k + nq0 * j;
+                    unsigned int cnt_pkj = nq1 * nq2 * p + nq1 * k + j;
+
+                    TData sum_kj = 0.0;
+                    for (unsigned int i = 0; i < nq0; ++i, ++cnt_kji)
+                    {
+                        sum_kj += s_wsp0[cnt_kji] * basis0[nq0 * p + i] * w0[i];
+                    }
+                    s_wsp1[cnt_pkj] = sum_kj;
+                }
+            }
+        }
+
+        __syncthreads();
+
+        for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
+        {
+            for (unsigned int q = threadIdx.y; q < nm1; q += blockDim.y)
+            {
+                for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
+                {
+                    unsigned int cnt_pkj = nq1 * nq2 * p + nq1 * k;
+                    unsigned int cnt_pqk = nm1 * nq2 * p + nq2 * q + k;
+
+                    TData sum_k = 0.0;
+                    for (unsigned int j = 0; j < nq1; ++j, ++cnt_pkj)
+                    {
+                        sum_k += basis1[q * nq1 + j] * w1[j] * s_wsp1[cnt_pkj];
+                    }
+                    s_wsp2[cnt_pqk] = sum_k;
+                }
+            }
+        }
+
+        __syncthreads();
+
+        for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
+        {
+            for (unsigned int q = threadIdx.y; q < nm1; q += blockDim.y)
+            {
+                for (int r = threadIdx.z; r < nm2 - p; r += blockDim.z)
+                {
+                    unsigned int cnt_pqk  = nm1 * nq2 * p + nq2 * q;
+                    unsigned int mode_pr  = (2 * nm2 - p + 1) * p / 2;
+                    unsigned int mode_pqr = mode_pr * nm1 + (nm2 - p) * q + r;
+
+                    TData sum_k = 0.0;
+                    for (unsigned int k = 0; k < nq2; ++k, ++cnt_pqk)
+                    {
+                        sum_k += basis2[(mode_pr + r) * nq2 + k] * w2[k] *
+                                 s_wsp2[cnt_pqk];
+                    }
+
+                    if (SCALE)
+                    {
+                        sum_k *= scale;
+                    }
+
+                    unsigned int index = outoffset + mode_pqr;
+                    if constexpr (APPEND)
+                    {
+                        out[index] += sum_k;
+                    }
+                    else
+                    {
+                        out[index] = sum_k;
+                    }
+                }
+            }
+        }
+
+        __syncthreads();
+
+        // Add correction for collapsed coordinate.
+        if (correct)
+        {
+            if (threadIdx.x == 0 && threadIdx.z == 0)
+            {
+                for (unsigned int q = threadIdx.y; q < nm1; q += blockDim.y)
+                {
+                    s_wsp2[q] = 0.0;
+                }
+            }
+
+            __syncthreads();
+
+            for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
+            {
+                TData k_weight = w2[k];
+                for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
+                {
+                    TData kj_weight = k_weight * w1[j];
+                    for (unsigned int i = threadIdx.x; i < nq0; i += blockDim.x)
+                    {
+                        unsigned int cnt_kji = nq1 * nq0 * k + nq0 * j + i;
+                        TData prod = kj_weight * w0[i] * s_wsp0[cnt_kji];
+                        for (unsigned int q = 0; q < nm1; ++q)
+                        {
+                            atomicAdd(s_wsp2 + q, prod * basis2[nq2 + k] *
+                                                      basis1[q * nq1 + j] *
+                                                      basis0[nq0 + i]);
+                        }
+                    }
+                }
+            }
+
+            __syncthreads();
+
+            if (threadIdx.x == 0 && threadIdx.z == 0)
+            {
+                for (unsigned int q = threadIdx.y; q < nm1; q += blockDim.y)
+                {
+                    unsigned int index = outoffset + nm2 * q + 1;
+                    if constexpr (SCALE)
+                    {
+                        out[index] += s_wsp2[q] * scale;
+                    }
+                    else
+                    {
+                        out[index] += s_wsp2[q];
+                    }
+                }
+            }
+        }
+
+        __syncthreads();
+
+        e += gridDim.x;
+    }
+}
+
+// NOTE: Not workign when nm2 > nm1
+template <typename TData, bool SCALE, bool APPEND, bool DEFORMED>
+__global__ void IProductWRTBasePyrKernel(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nm2,
+    const unsigned int nmTot, const unsigned int nq0, const unsigned int nq1,
+    const unsigned int nq2, const unsigned int nelmt, const bool correct,
+    const TData *__restrict basis0, const TData *__restrict basis1,
+    const TData *__restrict basis2, const TData *__restrict w0,
+    const TData *__restrict w1, const TData *__restrict w2,
+    const TData *__restrict jac, const TData *__restrict in,
+    TData *__restrict out, TData scale = 1.0)
+{
+    extern __shared__ TData shared[];
+    TData *s_w0 = shared;
+    TData *s_w1 = s_w0 + nq0;
+    TData *s_w2 = s_w1 + nq1;
+
+    // Copy to shared memory.
+    unsigned int sIndex = threadIdx.x;
+    while (sIndex < nq0)
+    {
+        s_w0[sIndex] = w0[sIndex];
+        sIndex += blockDim.x;
+    }
+
+    sIndex = threadIdx.x;
+    while (sIndex < nq1)
+    {
+        s_w1[sIndex] = w1[sIndex];
+        sIndex += blockDim.x;
+    }
+
+    sIndex = threadIdx.x;
+    while (sIndex < nq2)
+    {
+        s_w2[sIndex] = w2[sIndex];
+        sIndex += blockDim.x;
+    }
+
+    __syncthreads();
+
+    unsigned int e = blockDim.x * blockIdx.x + threadIdx.x;
+
+    // Allocate workspace memory.
+    TData *wsp0 = new TData[nq2 * nq1];
+    TData *wsp1 = new TData[nq2];
+
+    while (e < nelmt)
+    {
+        unsigned int inoffset  = nq0 * nq1 * nq2 * e;
+        unsigned int outoffset = nmTot * e;
+
+        for (unsigned int p = 0, mode_pqr = 0; p < nm0; ++p)
+        {
+            for (unsigned int k = 0, cnt_kj = 0, cnt_kji = 0; k < nq2; ++k)
+            {
+                for (unsigned int j = 0; j < nq1; ++j, ++cnt_kj)
+                {
+                    TData sum_kj = 0.0;
+                    for (unsigned int i = 0; i < nq0; ++i, ++cnt_kji)
+                    {
+                        unsigned int index    = inoffset + cnt_kji;
+                        unsigned int jacindex = DEFORMED ? index : e;
+                        sum_kj += in[index] * basis0[nq0 * p + i] *
+                                  jac[jacindex] * s_w0[i];
+                    }
+                    wsp0[cnt_kj] = sum_kj;
+                }
+            }
+
+            for (unsigned int q = 0; q < p; ++q)
+            {
+                for (unsigned int k = 0, cnt_kj = 0; k < nq2; ++k)
+                {
+                    TData sum_k = 0.0;
+                    for (unsigned int j = 0; j < nq1; ++j, ++cnt_kj)
+                    {
+                        sum_k += basis1[q * nq1 + j] * s_w1[j] * wsp0[cnt_kj];
+                    }
+                    wsp1[k] = sum_k;
+                }
+
+                for (unsigned int r = 0; r < nm2 - p; ++r, ++mode_pqr)
+                {
+                    TData sum_k = 0.0;
+                    for (unsigned int k = 0; k < nq2; ++k)
+                    {
+                        sum_k += basis2[mode_pqr * nq2 + k] * s_w2[k] * wsp1[k];
+                    }
+
+                    if constexpr (SCALE)
+                    {
+                        sum_k *= scale;
+                    }
+
+                    unsigned int index = outoffset + mode_pqr;
+                    if constexpr (APPEND)
+                    {
+                        out[index] += sum_k;
+                    }
+                    else
+                    {
+                        out[index] = sum_k;
+                    }
+                }
+            }
+
+            for (unsigned int q = p; q < nm1; ++q)
+            {
+                for (unsigned int k = 0, cnt_kj = 0; k < nq2; ++k)
+                {
+                    TData sum_k = 0.0;
+                    for (unsigned int j = 0; j < nq1; ++j, ++cnt_kj)
+                    {
+                        sum_k += basis1[q * nq1 + j] * s_w1[j] * wsp0[cnt_kj];
+                    }
+                    wsp1[k] = sum_k;
+                }
+
+                for (unsigned int r = 0; r < nm2 - q; ++r, ++mode_pqr)
+                {
+                    TData sum_k = 0.0;
+                    for (unsigned int k = 0; k < nq2; ++k)
+                    {
+                        sum_k += basis2[mode_pqr * nq2 + k] * s_w2[k] * wsp1[k];
+                    }
+
+                    if constexpr (SCALE)
+                    {
+                        sum_k *= scale;
+                    }
+
+                    unsigned int index = outoffset + mode_pqr;
+                    if constexpr (APPEND)
+                    {
+                        out[index] += sum_k;
+                    }
+                    else
+                    {
+                        out[index] = sum_k;
+                    }
+                }
+            }
+        }
+
+        // Add correction for collapsed coordinate.
+        if (correct)
+        {
+            TData prod = 0.0;
+            for (unsigned int k = 0, cnt_kji = 0; k < nq2; ++k)
+            {
+                TData tmpQ2 = s_w2[k];
+                if constexpr (!DEFORMED)
+                {
+                    tmpQ2 *= jac[e];
+                }
+
+                for (unsigned int j = 0; j < nq1; ++j)
+                {
+                    TData tmpQ1 = tmpQ2 * s_w1[j];
+                    for (unsigned int i = 0; i < nq0; ++i, ++cnt_kji)
+                    {
+                        unsigned int index = inoffset + cnt_kji;
+
+                        // Store jac * quadrature weight
+                        TData tmpQ = tmpQ1 * s_w0[i];
+                        if constexpr (DEFORMED)
+                        {
+                            tmpQ *= jac[index];
+                        }
+
+                        // top vertex
+                        TData tmp = basis0[i] * basis1[nq1 + j];
+                        tmp += basis0[nq0 + i] * basis1[j];
+                        tmp += basis0[nq0 + i] * basis1[nq1 + j];
+                        tmp *= basis2[nq2 + k];
+                        tmp *= in[index] * tmpQ;
+                        prod += tmp;
+                    }
+                }
+            }
+
+            // add to existing entry
+            if constexpr (SCALE)
+            {
+                out[outoffset + 1] += prod * scale;
+            }
+            else
+            {
+                out[outoffset + 1] += prod;
+            }
+        }
+
+        e += blockDim.x * gridDim.x;
+    }
+
+    // Deallocate workspace memory.
+    delete[] wsp0;
+    delete[] wsp1;
+}
+
+// NOTE: Not workign when nm2 > nm1
+template <typename TData, bool SCALE, bool APPEND, bool DEFORMED>
+__global__ void IProductWRTBasePyrKernel_QP(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nm2,
+    const unsigned int nmTot, const unsigned int nq0, const unsigned int nq1,
+    const unsigned int nq2, const unsigned int nelmt, const bool correct,
+    const TData *__restrict basis0, const TData *__restrict basis1,
+    const TData *__restrict basis2, const TData *__restrict w0,
+    const TData *__restrict w1, const TData *__restrict w2,
+    const TData *__restrict jac, const TData *__restrict in,
+    TData *__restrict out, TData scale = 1.0)
+{
+    extern __shared__ TData shared[];
+    TData *s_prod = shared;
+    TData *s_wsp0 = s_prod + 1;
+    TData *s_wsp1 = s_wsp0 + nq0 * nq1 * nq2;
+    TData *s_wsp2 = s_wsp1 + nm0 * nq1 * nq2;
+
+    unsigned int e = blockIdx.x;
+
+    while (e < nelmt)
+    {
+        unsigned int inoffset  = nq0 * nq1 * nq2 * e;
+        unsigned int outoffset = nmTot * e;
+
+        // Copy to shared memory.
+        for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
+        {
+            for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
+            {
+                for (unsigned int i = threadIdx.x; i < nq0; i += blockDim.x)
+                {
+                    unsigned int cnt_kji  = k * nq1 * nq0 + j * nq0 + i;
+                    unsigned int index    = inoffset + cnt_kji;
+                    unsigned int jacindex = DEFORMED ? index : e;
+                    s_wsp0[cnt_kji]       = in[index] * jac[jacindex];
+                }
+            }
+        }
+
+        __syncthreads();
+
+        for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
+        {
+            for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
+            {
+                for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
+                {
+                    unsigned int cnt_kji = k * nq1 * nq0 + j * nq0;
+                    unsigned int cnt_pkj = nq1 * nq2 * p + nq1 * k + j;
+
+                    TData sum_kj = 0.0;
+                    for (unsigned int i = 0; i < nq0; ++i, ++cnt_kji)
+                    {
+                        sum_kj += s_wsp0[cnt_kji] * basis0[nq0 * p + i] * w0[i];
+                    }
+                    s_wsp1[cnt_pkj] = sum_kj;
+                }
+            }
+        }
+
+        __syncthreads();
+
+        for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
+        {
+            for (unsigned int q = threadIdx.y; q < nm1; q += blockDim.y)
+            {
+                for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
+                {
+                    unsigned int cnt_pkj = nq1 * nq2 * p + k * nq1;
+                    unsigned int cnt_pqk = nm1 * nq2 * p + nq2 * q + k;
+
+                    TData sum_k = 0.0;
+                    for (unsigned int j = 0; j < nq1; ++j, ++cnt_pkj)
+                    {
+                        sum_k += basis1[q * nq1 + j] * w1[j] * s_wsp1[cnt_pkj];
+                    }
+                    s_wsp2[cnt_pqk] = sum_k;
+                }
+            }
+        }
+
+        __syncthreads();
+
+        for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
+        {
+            for (unsigned int q = threadIdx.y; q < nm1; q += blockDim.y)
+            {
+                unsigned int mode_pq = nm1 * (2 * nm2 + 1 - nm1) * p;
+                mode_pq -= (p - 1) * p / 2;
+                mode_pq -= (p - 1) * p * (2 * p - 1) / 6;
+                mode_pq /= 2;
+
+                if (q < p)
+                {
+                    for (unsigned int r = threadIdx.z; r < nm2 - p;
+                         r += blockDim.z)
+                    {
+                        unsigned int cnt_pqk = nm1 * nq2 * p + nq2 * q;
+                        unsigned int mode_pqr = mode_pq + q * (nm2 - p) + r;
+
+                        TData sum_k = 0.0;
+                        for (unsigned int k = 0; k < nq2; ++k, ++cnt_pqk)
+                        {
+                            sum_k += basis2[mode_pqr * nq2 + k] * w2[k] *
+                                     s_wsp2[cnt_pqk];
+                        }
+
+                        if constexpr (SCALE)
+                        {
+                            sum_k *= scale;
+                        }
+
+                        unsigned int index = outoffset + mode_pqr;
+                        if constexpr (APPEND)
+                        {
+                            out[index] += sum_k;
+                        }
+                        else
+                        {
+                            out[index] = sum_k;
+                        }
+                    }
+                }
+                else
+                {
+                    for (unsigned int r = threadIdx.z; r < nm2 - q;
+                         r += blockDim.z)
+                    {
+                        unsigned int cnt_pqk = nm1 * nq2 * p + nq2 * q;
+                        unsigned int mode_pqr = mode_pq + p * (nm2 - p);
+                        mode_pqr +=
+                            ((2 * (nm2 - p) - (q - p) + 1) * (q - p)) / 2 + r;
+
+                        TData sum_k = 0.0;
+                        for (unsigned int k = 0; k < nq2; ++k, ++cnt_pqk)
+                        {
+                            sum_k += basis2[mode_pqr * nq2 + k] * w2[k] *
+                                     s_wsp2[cnt_pqk];
+                        }
+                        if constexpr (SCALE)
+                        {
+                            sum_k *= scale;
+                        }
+
+                        unsigned int index = outoffset + mode_pqr;
+                        if constexpr (APPEND)
+                        {
+                            out[index] += sum_k;
+                        }
+                        else
+                        {
+                            out[index] = sum_k;
+                        }
+                    }
+                }
+            }
+        }
+
+        __syncthreads();
+
+        // Add correction for collapsed coordinate.
+        if (correct)
+        {
+            if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0)
+            {
+                (*s_prod) = 0.0;
+            }
+
+            __syncthreads();
 
-            for (size_t j = 0; j < nq1; ++j)
+            for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
             {
-                TData tmpQ1 = tmpQ2 * w1[j];
-                for (size_t i = 0; i < nq0; ++i)
+                TData tmpQ2 = w2[k];
+                for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
                 {
-                    // Store jac * quadrature weight
-                    TData tmpQ = tmpQ1 * w0[i];
-                    if (DEFORMED)
+                    TData tmpQ1 = tmpQ2 * w1[j];
+                    for (unsigned int i = threadIdx.x; i < nq0; i += blockDim.x)
                     {
-                        tmpQ *= jacptr[k * nq0 * nq1 + j * nq0 + i];
+                        unsigned int cnt_kji = nq0 * nq1 * k + nq0 * j + i;
+                        // Store jac * quadrature weight
+                        TData tmpQ = tmpQ1 * w0[i];
+
+                        // top vertex
+                        TData tmp = basis0[i] * basis1[nq1 + j];
+                        tmp += basis0[nq0 + i] * basis1[j];
+                        tmp += basis0[nq0 + i] * basis1[nq1 + j];
+                        tmp *= basis2[nq2 + k];
+                        tmp *= s_wsp0[cnt_kji] * tmpQ;
+                        atomicAdd(s_prod, tmp);
                     }
+                }
+            }
 
-                    // top vertex
-                    TData tmp = basis0[i] * basis1[nq1 + j];
-                    tmp += basis0[nq0 + i] * basis1[j];
-                    tmp += basis0[nq0 + i] * basis1[nq1 + j];
-                    tmp *= basis2[nq2 + k];
-                    tmp *= inptr[cnt++] * tmpQ;
+            __syncthreads();
 
-                    // add to existing entry
-                    outptr[1] += SCALE ? tmp * scale : tmp;
+            // add to existing entry
+            if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0)
+            {
+                if constexpr (SCALE)
+                {
+                    out[outoffset + 1] += (*s_prod) * scale;
+                }
+                else
+                {
+                    out[outoffset + 1] += (*s_prod);
                 }
             }
         }
+
+        __syncthreads();
+
+        e += gridDim.x;
     }
-    delete wsp0;
-    delete wsp1;
 }
 
 } // namespace Nektar::Operators::detail
diff --git a/Operators/IProductWRTBase/IProductWRTBaseMatFree.hpp b/Operators/IProductWRTBase/IProductWRTBaseMatFree.hpp
index eeb883d1..77bf7db6 100644
--- a/Operators/IProductWRTBase/IProductWRTBaseMatFree.hpp
+++ b/Operators/IProductWRTBase/IProductWRTBaseMatFree.hpp
@@ -1,11 +1,12 @@
-#include <LibUtilities/Foundations/Basis.h>
+#pragma once
 
 #include <LibUtilities/BasicUtils/ShapeType.hpp>
 #include <LibUtilities/BasicUtils/SharedArray.hpp>
+#include <LibUtilities/Foundations/Basis.h>
+#include <LibUtilities/SimdLib/tinysimd.hpp>
 
 #include "IProductWRTBaseMatFreeKernels.hpp"
 #include "Operators/OperatorIProductWRTBase.hpp"
-#include <LibUtilities/SimdLib/tinysimd.hpp>
 
 namespace Nektar::Operators::detail
 {
diff --git a/Operators/IProductWRTBase/IProductWRTBaseStdMat.hpp b/Operators/IProductWRTBase/IProductWRTBaseStdMat.hpp
index 7b6a8f40..1d003127 100644
--- a/Operators/IProductWRTBase/IProductWRTBaseStdMat.hpp
+++ b/Operators/IProductWRTBase/IProductWRTBaseStdMat.hpp
@@ -1,6 +1,9 @@
-#include "Operators/OperatorIProductWRTBase.hpp"
+#pragma once
+
 #include <StdRegions/StdExpansion.h>
 
+#include "Operators/OperatorIProductWRTBase.hpp"
+
 namespace Nektar::Operators::detail
 {
 
diff --git a/Operators/IProductWRTBase/IProductWRTBaseSumFac.hpp b/Operators/IProductWRTBase/IProductWRTBaseSumFac.hpp
index c2fec766..790d5d45 100644
--- a/Operators/IProductWRTBase/IProductWRTBaseSumFac.hpp
+++ b/Operators/IProductWRTBase/IProductWRTBaseSumFac.hpp
@@ -1,6 +1,9 @@
+#pragma once
+
+#include <StdRegions/StdExpansion.h>
+
 #include "IProductWRTBaseSumFacKernels.hpp"
 #include "Operators/OperatorIProductWRTBase.hpp"
-#include <StdRegions/StdExpansion.h>
 
 namespace Nektar::Operators::detail
 {
@@ -79,7 +82,6 @@ public:
                     break;
                 default:
                     std::cout << "shapetype not implemented" << std::endl;
-
             }
 
             inptr += in.GetBlocks()[block_idx].block_size;
diff --git a/Operators/NeuBndCond/NeuBndCondCUDA.hpp b/Operators/NeuBndCond/NeuBndCondCUDA.hpp
index 2b3d7db2..8fced00a 100644
--- a/Operators/NeuBndCond/NeuBndCondCUDA.hpp
+++ b/Operators/NeuBndCond/NeuBndCondCUDA.hpp
@@ -1,14 +1,14 @@
 #pragma once
 
+#include <MultiRegions/AssemblyMap/AssemblyMapCG.h>
+#include <MultiRegions/ContField.h>
+#include <SpatialDomains/Conditions.h>
+
 #include "MemoryRegionCUDA.hpp"
 #include "Operators/NeuBndCond/NeuBndCondCUDAKernels.cuh"
 #include "Operators/OperatorHelper.cuh"
 #include "Operators/OperatorNeuBndCond.hpp"
 
-#include <MultiRegions/AssemblyMap/AssemblyMapCG.h>
-#include <MultiRegions/ContField.h>
-#include <SpatialDomains/Conditions.h>
-
 using namespace Nektar;
 using namespace Nektar::MultiRegions;
 
diff --git a/Operators/NeuBndCond/NeuBndCondCUDAKernels.cuh b/Operators/NeuBndCond/NeuBndCondCUDAKernels.cuh
index 86019358..a5073dff 100644
--- a/Operators/NeuBndCond/NeuBndCondCUDAKernels.cuh
+++ b/Operators/NeuBndCond/NeuBndCondCUDAKernels.cuh
@@ -1,3 +1,5 @@
+#pragma once
+
 #include <SpatialDomains/Conditions.h>
 
 using namespace Nektar;
@@ -7,20 +9,21 @@ namespace Nektar::Operators::detail
 {
 
 template <typename TData>
-__global__ void NeuBndCondKernel(const size_t nsize, const int *offsetptr,
-                                 const BoundaryConditionType *bctypeptr,
-                                 const int *ncoeffptr, const int *mapptr,
-                                 const TData *inptr, TData *outptr)
+__global__ void NeuBndCondKernel(
+    const unsigned int nsize, const int *__restrict offsetptr,
+    const BoundaryConditionType *__restrict bctypeptr,
+    const int *__restrict ncoeffptr, const int *__restrict mapptr,
+    const TData *__restrict inptr, TData *__restrict outptr)
 {
-    size_t i = blockDim.x * blockIdx.x + threadIdx.x;
+    unsigned int i = blockDim.x * blockIdx.x + threadIdx.x;
 
     while (i < nsize)
     {
         if (bctypeptr[i] == eNeumann || bctypeptr[i] == eRobin)
         {
-            size_t offset = offsetptr[i];
-            size_t ncoeff = ncoeffptr[i];
-            for (size_t j = 0; j < ncoeff; j++)
+            unsigned int offset = offsetptr[i];
+            unsigned int ncoeff = ncoeffptr[i];
+            for (unsigned int j = 0; j < ncoeff; j++)
             {
                 outptr[mapptr[offset + j]] += inptr[offset + j];
             }
@@ -30,21 +33,22 @@ __global__ void NeuBndCondKernel(const size_t nsize, const int *offsetptr,
 }
 
 template <typename TData>
-__global__ void NeuBndCondKernel(const size_t nsize, const int *offsetptr,
-                                 const BoundaryConditionType *bctypeptr,
-                                 const int *ncoeffptr, const TData *signptr,
-                                 const int *mapptr, const TData *inptr,
-                                 TData *outptr)
+__global__ void NeuBndCondKernel(
+    const unsigned int nsize, const int *__restrict offsetptr,
+    const BoundaryConditionType *__restrict bctypeptr,
+    const int *__restrict ncoeffptr, const TData *__restrict signptr,
+    const int *__restrict mapptr, const TData *__restrict inptr,
+    TData *__restrict outptr)
 {
-    size_t i = blockDim.x * blockIdx.x + threadIdx.x;
+    unsigned int i = blockDim.x * blockIdx.x + threadIdx.x;
 
     while (i < nsize)
     {
         if (bctypeptr[i] == eNeumann || bctypeptr[i] == eRobin)
         {
-            size_t offset = offsetptr[i];
-            size_t ncoeff = ncoeffptr[i];
-            for (size_t j = 0; j < ncoeff; j++)
+            unsigned int offset = offsetptr[i];
+            unsigned int ncoeff = ncoeffptr[i];
+            for (unsigned int j = 0; j < ncoeff; j++)
             {
                 outptr[mapptr[offset + j]] +=
                     signptr[offset + j] * inptr[offset + j];
diff --git a/Operators/NeuBndCond/NeuBndCondStdMat.hpp b/Operators/NeuBndCond/NeuBndCondStdMat.hpp
index 792c0af2..ad608f87 100644
--- a/Operators/NeuBndCond/NeuBndCondStdMat.hpp
+++ b/Operators/NeuBndCond/NeuBndCondStdMat.hpp
@@ -1,10 +1,10 @@
 #pragma once
 
-#include "Operators/OperatorNeuBndCond.hpp"
-
 #include <MultiRegions/AssemblyMap/AssemblyMapCG.h>
 #include <MultiRegions/ContField.h>
 
+#include "Operators/OperatorNeuBndCond.hpp"
+
 using namespace Nektar;
 using namespace Nektar::MultiRegions;
 
-- 
GitLab