diff --git a/Operators/BwdTrans/BwdTransCUDA.cu b/Operators/BwdTrans/BwdTransCUDA.cu
index 2f9ac718a4d1940b8c61a2d08db26dad51093145..2de93b2396de5b2442d85d7410cf816ed7a7fdd3 100644
--- a/Operators/BwdTrans/BwdTransCUDA.cu
+++ b/Operators/BwdTrans/BwdTransCUDA.cu
@@ -2,9 +2,11 @@
 
 namespace Nektar::Operators::detail
 {
+
 template <>
 std::string OperatorBwdTransImpl<double, ImplCUDA>::className =
     GetOperatorFactory<double>().RegisterCreatorFunction(
         "BwdTransCUDA", OperatorBwdTransImpl<double, ImplCUDA>::instantiate,
         "...");
+
 } // namespace Nektar::Operators::detail
diff --git a/Operators/BwdTrans/BwdTransCUDA.hpp b/Operators/BwdTrans/BwdTransCUDA.hpp
index 4cf7c9d69f0229abb7f8011695ed8804ed7c4b10..a8b2a19522da33bb5247b249151b63a358b6f05c 100644
--- a/Operators/BwdTrans/BwdTransCUDA.hpp
+++ b/Operators/BwdTrans/BwdTransCUDA.hpp
@@ -5,51 +5,35 @@
 #include "Operators/OperatorBwdTrans.hpp"
 #include "Operators/OperatorHelper.cuh"
 
+#define FLAG_QP false
+
 namespace Nektar::Operators::detail
 {
 
 template <typename TData>
-void BwdTrans1DKernel(const size_t gridSize, const size_t blockSize,
-                      const size_t nm0, const size_t nq0, const size_t nElmts,
-                      const TData *basis0, const TData *in, TData *out);
-
-template <typename TData>
-void BwdTrans1DKernel_QP(const size_t nm0, const size_t nq0,
-                         const size_t nElmts, const TData *basis0,
-                         const TData *in, TData *out);
-
-template <typename TData>
-void BwdTrans2DKernel(const size_t gridSize, const size_t blockSize,
-                      LibUtilities::ShapeType shapetype, const size_t nm0,
-                      const size_t nm1, const size_t nq0, const size_t nq1,
-                      const size_t nElmts, const bool correct,
-                      const TData *basis0, const TData *basis1, const TData *in,
-                      TData *out);
+void BwdTrans1DKernel(const unsigned int gridSize, const unsigned int blockSize,
+                      const unsigned int nm0, const unsigned int nq0,
+                      const unsigned int nElmts, const TData *basis0,
+                      const TData *in, TData *out);
 
 template <typename TData>
-void BwdTrans2DKernel_QP(LibUtilities::ShapeType shapetype, const size_t nm0,
-                         const size_t nm1, const size_t nq0, const size_t nq1,
-                         const size_t nElmts, const bool correct,
-                         const TData *basis0, const TData *basis1,
-                         const TData *in, TData *out);
+void BwdTrans2DKernel(const unsigned int gridSize, const unsigned int blockSize,
+                      LibUtilities::ShapeType shapetype, const unsigned int nm0,
+                      const unsigned int nm1, const unsigned int nq0,
+                      const unsigned int nq1, const unsigned int nElmts,
+                      const bool correct, const TData *basis0,
+                      const TData *basis1, const TData *in, TData *out);
 
 template <typename TData>
-void BwdTrans3DKernel(const size_t gridSize, const size_t blockSize,
-                      LibUtilities::ShapeType shapetype, const size_t nm0,
-                      const size_t nm1, const size_t nm2, const size_t nq0,
-                      const size_t nq1, const size_t nq2, const size_t nElmts,
+void BwdTrans3DKernel(const unsigned int gridSize, const unsigned int blockSize,
+                      LibUtilities::ShapeType shapetype, const unsigned int nm0,
+                      const unsigned int nm1, const unsigned int nm2,
+                      const unsigned int nq0, const unsigned int nq1,
+                      const unsigned int nq2, const unsigned int nElmts,
                       const bool correct, const TData *basis0,
                       const TData *basis1, const TData *basis2, const TData *in,
                       TData *out);
 
-template <typename TData>
-void BwdTrans3DKernel_QP(LibUtilities::ShapeType shapetype, const size_t nm0,
-                         const size_t nm1, const size_t nm2, const size_t nq0,
-                         const size_t nq1, const size_t nq2,
-                         const size_t nElmts, const bool correct,
-                         const TData *basis0, const TData *basis1,
-                         const TData *basis2, const TData *in, TData *out);
-
 // BwdTrans implementation
 template <typename TData>
 class OperatorBwdTransImpl<TData, ImplCUDA> : public OperatorBwdTrans<TData>
@@ -160,180 +144,163 @@ public:
 
 private:
     std::map<std::vector<LibUtilities::BasisKey>, std::vector<TData *>> m_basis;
+    size_t m_gridSize  = 32;
     size_t m_blockSize = 32;
-    size_t m_gridSize;
 };
 
 template <typename TData>
-void BwdTrans1DKernel(const size_t gridSize, const size_t blockSize,
-                      const size_t nm0, const size_t nq0, const size_t nElmts,
-                      const TData *basis0, const TData *in, TData *out)
-{
-    BwdTransSegKernel<<<gridSize, blockSize>>>(nm0, nq0, nElmts, basis0, in,
-                                               out);
-}
-
-template <typename TData>
-void BwdTrans1DKernel_QP(const size_t nm0, const size_t nq0,
-                         const size_t nElmts, const TData *basis0,
-                         const TData *in, TData *out)
-{
-    BwdTransSegKernel_QP<<<nElmts, nq0>>>(nm0, nq0, basis0, in, out);
-}
-
-template <typename TData>
-void BwdTrans2DKernel(const size_t gridSize, const size_t blockSize,
-                      LibUtilities::ShapeType shapetype, const size_t nm0,
-                      const size_t nm1, const size_t nq0, const size_t nq1,
-                      const size_t nElmts, const bool correct,
-                      const TData *basis0, const TData *basis1, const TData *in,
-                      TData *out)
+void BwdTrans1DKernel(const unsigned int gridSize, const unsigned int blockSize,
+                      const unsigned int nm0, const unsigned int nq0,
+                      const unsigned int nElmts, const TData *basis0,
+                      const TData *in, TData *out)
 {
-    if (shapetype == LibUtilities::Quad)
+    if (!FLAG_QP)
     {
-        BwdTransQuadKernel<<<gridSize, blockSize>>>(nm0, nm1, nq0, nq1, nElmts,
-                                                    basis0, basis1, in, out);
+        unsigned int nshared = sizeof(TData) * (nq0 * nm0);
+        BwdTransSegKernel<TData><<<gridSize, blockSize, nshared>>>(
+            nm0, nq0, nElmts, basis0, in, out);
     }
-    else if (shapetype == LibUtilities::Tri)
+    else
     {
-        size_t nmTot =
-            LibUtilities::StdTriData::getNumberOfCoefficients(nm0, nm1);
-        BwdTransTriKernel<<<gridSize, blockSize>>>(nm0, nm1, nmTot, nq0, nq1,
-                                                   nElmts, correct, basis0,
-                                                   basis1, in, out);
+        unsigned int nshared = sizeof(TData) * (nm0);
+        BwdTransSegKernel_QP<TData><<<gridSize, dim3(32), nshared>>>(
+            nm0, nq0, nElmts, basis0, in, out);
     }
 }
 
 template <typename TData>
-void BwdTrans2DKernel_QP(LibUtilities::ShapeType shapetype, const size_t nm0,
-                         const size_t nm1, const size_t nq0, const size_t nq1,
-                         const size_t nElmts, const bool correct,
-                         const TData *basis0, const TData *basis1,
-                         const TData *in, TData *out)
+void BwdTrans2DKernel(const unsigned int gridSize, const unsigned int blockSize,
+                      LibUtilities::ShapeType shapetype, const unsigned int nm0,
+                      const unsigned int nm1, const unsigned int nq0,
+                      const unsigned int nq1, const unsigned int nElmts,
+                      const bool correct, const TData *basis0,
+                      const TData *basis1, const TData *in, TData *out)
 {
     if (shapetype == LibUtilities::Quad)
     {
-        TData *wsp = 0;
-        cudaMalloc((void **)&wsp, sizeof(TData) * nq0 * nm1);
-        BwdTransQuadKernel_QP<<<nElmts, dim3(nq0, nq1)>>>(
-            nm0, nm1, nq0, nq1, basis0, basis1, in, wsp, out);
-        cudaFree(wsp);
+        unsigned int nmTot =
+            LibUtilities::StdQuadData::getNumberOfCoefficients(nm0, nm1);
+        if (!FLAG_QP)
+        {
+            unsigned int nshared = sizeof(TData) * (nq0 * nm0 + nq1 * nm1);
+            BwdTransQuadKernel<TData><<<gridSize, blockSize, nshared>>>(
+                nm0, nm1, nmTot, nq0, nq1, nElmts, basis0, basis1, in, out);
+        }
+        else
+        {
+            unsigned int nshared = sizeof(TData) * (nmTot + nq0 * nm1);
+            BwdTransQuadKernel_QP<TData><<<gridSize, dim3(8, 8), nshared>>>(
+                nm0, nm1, nmTot, nq0, nq1, nElmts, basis0, basis1, in, out);
+        }
     }
     else if (shapetype == LibUtilities::Tri)
     {
-        size_t nmTot =
+        unsigned int nmTot =
             LibUtilities::StdTriData::getNumberOfCoefficients(nm0, nm1);
-        TData *wsp = 0;
-        cudaMalloc((void **)&wsp, sizeof(TData) * nm0 * nq1);
-        BwdTransTriKernel_QP<<<nElmts, dim3(nq0, nq1)>>>(
-            nm0, nm1, nmTot, nq0, nq1, correct, basis0, basis1, in, wsp, out);
-        cudaFree(wsp);
+        if (!FLAG_QP)
+        {
+            BwdTransTriKernel<TData>
+                <<<gridSize, blockSize>>>(nm0, nm1, nmTot, nq0, nq1, nElmts,
+                                          correct, basis0, basis1, in, out);
+        }
+        else
+        {
+            unsigned int nshared = sizeof(TData) * (nmTot + nm0 * nq1);
+            BwdTransTriKernel_QP<TData><<<nElmts, dim3(8, 8), nshared>>>(
+                nm0, nm1, nmTot, nq0, nq1, nElmts, correct, basis0, basis1, in,
+                out);
+        }
     }
 }
 
 template <typename TData>
-void BwdTrans3DKernel(const size_t gridSize, const size_t blockSize,
-                      LibUtilities::ShapeType shapetype, const size_t nm0,
-                      const size_t nm1, const size_t nm2, const size_t nq0,
-                      const size_t nq1, const size_t nq2, const size_t nElmts,
+void BwdTrans3DKernel(const unsigned int gridSize, const unsigned int blockSize,
+                      LibUtilities::ShapeType shapetype, const unsigned int nm0,
+                      const unsigned int nm1, const unsigned int nm2,
+                      const unsigned int nq0, const unsigned int nq1,
+                      const unsigned int nq2, const unsigned int nElmts,
                       const bool correct, const TData *basis0,
                       const TData *basis1, const TData *basis2, const TData *in,
                       TData *out)
 {
     if (shapetype == LibUtilities::Hex)
     {
-        BwdTransHexKernel<<<gridSize, blockSize>>>(nm0, nm1, nm2, nq0, nq1, nq2,
-                                                   nElmts, basis0, basis1,
-                                                   basis2, in, out);
+        unsigned int nmTot =
+            LibUtilities::StdHexData::getNumberOfCoefficients(nm0, nm1, nm2);
+        if (!FLAG_QP)
+        {
+            unsigned int nshared =
+                sizeof(TData) * (nq0 * nm0 + nq1 * nm1 + nq2 * nm2);
+            BwdTransHexKernel<TData><<<gridSize, blockSize, nshared>>>(
+                nm0, nm1, nm2, nmTot, nq0, nq1, nq2, nElmts, basis0, basis1,
+                basis2, in, out);
+        }
+        else
+        {
+            unsigned int nshared =
+                sizeof(TData) * (nmTot + (nq0 * nm1 * nm2) + (nq0 * nq1 * nm2));
+            BwdTransHexKernel_QP<TData><<<gridSize, dim3(4, 4, 4), nshared>>>(
+                nm0, nm1, nm2, nmTot, nq0, nq1, nq2, nElmts, basis0, basis1,
+                basis2, in, out);
+        }
     }
     else if (shapetype == LibUtilities::Tet)
     {
-        size_t nmTot =
+        unsigned int nmTot =
             LibUtilities::StdTetData::getNumberOfCoefficients(nm0, nm1, nm2);
-        BwdTransTetKernel<<<gridSize, blockSize>>>(
-            nm0, nm1, nm2, nmTot, nq0, nq1, nq2, nElmts, correct, basis0,
-            basis1, basis2, in, out);
-    }
-    else if (shapetype == LibUtilities::Pyr)
-    {
-        size_t nmTot =
-            LibUtilities::StdPyrData::getNumberOfCoefficients(nm0, nm1, nm2);
-        BwdTransPyrKernel<<<gridSize, blockSize>>>(
-            nm0, nm1, nm2, nmTot, nq0, nq1, nq2, nElmts, correct, basis0,
-            basis1, basis2, in, out);
+        if (!FLAG_QP)
+        {
+            BwdTransTetKernel<TData><<<gridSize, blockSize>>>(
+                nm0, nm1, nm2, nmTot, nq0, nq1, nq2, nElmts, correct, basis0,
+                basis1, basis2, in, out);
+        }
+        else
+        {
+            unsigned int nshared =
+                sizeof(TData) * (nmTot + ((2 * nm1 - nm0 + 1) * nm0 / 2 * nq2) +
+                                 (nm0 * nq1 * nq2));
+            BwdTransTetKernel_QP<TData><<<gridSize, dim3(4, 4, 4), nshared>>>(
+                nm0, nm1, nm2, nmTot, nq0, nq1, nq2, nElmts, correct, basis0,
+                basis1, basis2, in, out);
+        }
     }
     else if (shapetype == LibUtilities::Prism)
     {
-        size_t nmTot =
+        unsigned int nmTot =
             LibUtilities::StdPrismData::getNumberOfCoefficients(nm0, nm1, nm2);
-        BwdTransPrismKernel<<<gridSize, blockSize>>>(
-            nm0, nm1, nm2, nmTot, nq0, nq1, nq2, nElmts, correct, basis0,
-            basis1, basis2, in, out);
-    }
-}
-
-template <typename TData>
-void BwdTrans3DKernel_QP(LibUtilities::ShapeType shapetype, const size_t nm0,
-                         const size_t nm1, const size_t nm2, const size_t nq0,
-                         const size_t nq1, const size_t nq2,
-                         const size_t nElmts, const bool correct,
-                         const TData *basis0, const TData *basis1,
-                         const TData *basis2, const TData *in, TData *out)
-{
-    if (shapetype == LibUtilities::Hex)
-    {
-        TData *wsp0 = 0;
-        TData *wsp1 = 0;
-        cudaMalloc((void **)&wsp0, sizeof(TData) * nq0 * nm1 * nm2);
-        cudaMalloc((void **)&wsp1, sizeof(TData) * nm0 * nq1 * nq2);
-        BwdTransHexKernel_QP<<<nElmts, dim3(nq0, nq1, nq2)>>>(
-            nm0, nm1, nm2, nq0, nq1, nq2, basis0, basis1, basis2, in, wsp0,
-            wsp1, out);
-        cudaFree(wsp0);
-        cudaFree(wsp1);
-    }
-    if (shapetype == LibUtilities::Tet)
-    {
-        size_t nmTot =
-            LibUtilities::StdTetData::getNumberOfCoefficients(nm0, nm1, nm2);
-        TData *fpq = 0;
-        TData *fp  = 0;
-        cudaMalloc((void **)&fpq,
-                   sizeof(TData) * (2 * nm1 - nm0 + 1) * nm0 / 2 * nq2);
-        cudaMalloc((void **)&fp, sizeof(TData) * nm0 * nq1 * nq2);
-        BwdTransTetKernel_QP<<<nElmts, dim3(nq0, nq1, nq2)>>>(
-            nm0, nm1, nm2, nmTot, nq0, nq1, nq2, correct, basis0, basis1,
-            basis2, in, fpq, fp, out);
-        cudaFree(fpq);
-        cudaFree(fp);
+        if (!FLAG_QP)
+        {
+            BwdTransPrismKernel<TData><<<gridSize, blockSize>>>(
+                nm0, nm1, nm2, nmTot, nq0, nq1, nq2, nElmts, correct, basis0,
+                basis1, basis2, in, out);
+        }
+        else
+        {
+            unsigned int nshared =
+                sizeof(TData) * (nmTot + (nm0 * nm1 * nq2) + (nm0 * nq1 * nq2));
+            BwdTransPrismKernel_QP<TData><<<gridSize, dim3(4, 4, 4), nshared>>>(
+                nm0, nm1, nm2, nmTot, nq0, nq1, nq2, nElmts, correct, basis0,
+                basis1, basis2, in, out);
+        }
     }
-    if (shapetype == LibUtilities::Pyr)
+    else if (shapetype == LibUtilities::Pyr)
     {
-        size_t nmTot =
+        unsigned int nmTot =
             LibUtilities::StdPyrData::getNumberOfCoefficients(nm0, nm1, nm2);
-        TData *fpq = 0;
-        TData *fp  = 0;
-        cudaMalloc((void **)&fpq, sizeof(TData) * nm0 * nm1 * nq2);
-        cudaMalloc((void **)&fp, sizeof(TData) * nm0 * nq1 * nq2);
-        BwdTransPyrKernel_QP<<<nElmts, dim3(nq0, nq1, nq2)>>>(
-            nm0, nm1, nm2, nmTot, nq0, nq1, nq2, correct, basis0, basis1,
-            basis2, in, fpq, fp, out);
-        cudaFree(fpq);
-        cudaFree(fp);
-    }
-    if (shapetype == LibUtilities::Prism)
-    {
-        size_t nmTot =
-            LibUtilities::StdPrismData::getNumberOfCoefficients(nm0, nm1, nm2);
-        TData *fpq = 0;
-        TData *fp  = 0;
-        cudaMalloc((void **)&fpq, sizeof(TData) * nm0 * nm1 * nq2);
-        cudaMalloc((void **)&fp, sizeof(TData) * nm0 * nq1 * nq2);
-        BwdTransPrismKernel_QP<<<nElmts, dim3(nq0, nq1, nq2)>>>(
-            nm0, nm1, nm2, nmTot, nq0, nq1, nq2, correct, basis0, basis1,
-            basis2, in, fpq, fp, out);
-        cudaFree(fpq);
-        cudaFree(fp);
+        if (!FLAG_QP)
+        {
+            BwdTransPyrKernel<TData><<<gridSize, blockSize>>>(
+                nm0, nm1, nm2, nmTot, nq0, nq1, nq2, nElmts, correct, basis0,
+                basis1, basis2, in, out);
+        }
+        else
+        {
+            unsigned int nshared =
+                sizeof(TData) * (nmTot + (nm0 * nm1 * nq2) + (nm0 * nq1 * nq2));
+            BwdTransPyrKernel_QP<TData><<<gridSize, dim3(4, 4, 4), nshared>>>(
+                nm0, nm1, nm2, nmTot, nq0, nq1, nq2, nElmts, correct, basis0,
+                basis1, basis2, in, out);
+        }
     }
 }
 
diff --git a/Operators/BwdTrans/BwdTransCUDAKernels.cuh b/Operators/BwdTrans/BwdTransCUDAKernels.cuh
index ca3ae323e69084dc22d46c27d0c22b5d31e57ebd..db7c8b17a866e960549a5db6ede00df53a7ee73b 100644
--- a/Operators/BwdTrans/BwdTransCUDAKernels.cuh
+++ b/Operators/BwdTrans/BwdTransCUDAKernels.cuh
@@ -1,942 +1,1239 @@
+#pragma once
+
+#include <cstdio>
+
 namespace Nektar::Operators::detail
 {
+
 template <typename TData>
-__global__ void BwdTransSegKernel(const size_t nm0, const size_t nq0,
-                                  const size_t nelmt, const TData *basis0,
-                                  const TData *in, TData *out)
+__global__ void BwdTransSegKernel(const unsigned int nm0,
+                                  const unsigned int nq0,
+                                  const unsigned int nelmt,
+                                  const TData *__restrict basis0,
+                                  const TData *__restrict in,
+                                  TData *__restrict out)
 {
-    size_t e = blockDim.x * blockIdx.x + threadIdx.x;
+    extern __shared__ TData shared[];
+    TData *s_basis0 = shared;
 
-    if (e >= nelmt)
+    // Copy to shared memory.
+    unsigned int sIndex = threadIdx.x;
+    while (sIndex < nm0 * nq0)
     {
-        return;
+        s_basis0[sIndex] = basis0[sIndex];
+        sIndex += blockDim.x;
     }
 
-    const TData *inptr = in + (nm0 * e);
-    TData *outptr      = out + (nq0 * e);
+    __syncthreads();
 
-    for (size_t i = 0; i < nq0; ++i)
+    unsigned int e = blockDim.x * blockIdx.x + threadIdx.x;
+
+    while (e < nelmt)
     {
-        TData tmp = inptr[0] * basis0[i];
-        for (size_t p = 1; p < nm0; ++p)
+        unsigned int inoffset  = nm0 * e;
+        unsigned int outoffset = nq0 * e;
+
+        for (unsigned int i = 0; i < nq0; ++i)
         {
-            tmp += inptr[p] * basis0[p * nq0 + i];
+            TData tmp = 0.0;
+            for (unsigned int p = 0; p < nm0; ++p)
+            {
+                tmp += in[inoffset + p] * s_basis0[p * nq0 + i];
+            }
+            out[outoffset + i] = tmp;
         }
-        outptr[i] = tmp;
+
+        e += blockDim.x * gridDim.x;
     }
 }
 
 template <typename TData>
-__global__ void BwdTransQuadKernel(const size_t nm0, const size_t nm1,
-                                   const size_t nq0, const size_t nq1,
-                                   const size_t nelmt, const TData *basis0,
-                                   const TData *basis1, const TData *in,
-                                   TData *out)
+__global__ void BwdTransSegKernel_QP(const unsigned int nm0,
+                                     const unsigned int nq0,
+                                     const unsigned int nelmt,
+                                     const TData *__restrict basis0,
+                                     const TData *__restrict in,
+                                     TData *__restrict out)
 {
-    size_t e = blockDim.x * blockIdx.x + threadIdx.x;
-
-    if (e >= nelmt)
-    {
-        return;
-    }
+    extern __shared__ TData shared[];
+    TData *s_wsp0 = shared;
 
-    TData *wsp = new TData[nq0 * nm1];
+    unsigned int e = blockIdx.x;
 
-    const TData *inptr = in + (nm0 * nm1 * e);
-    TData *outptr      = out + (nq0 * nq1 * e);
-
-    size_t cnt_iq = 0;
-    for (size_t i = 0; i < nq0; ++i)
+    while (e < nelmt)
     {
-        size_t cnt_pq = 0;
-        for (size_t q = 0; q < nm1; ++q)
+        unsigned int inoffset  = nm0 * e;
+        unsigned int outoffset = nq0 * e;
+
+        // Copy to shared memory.
+        for (unsigned int p = 0; p < nm0; p++)
         {
-            TData tmp = inptr[cnt_pq++] * basis0[i];
-            for (size_t p = 1; p < nm0; ++p)
-            {
-                tmp += inptr[cnt_pq++] * basis0[p * nq0 + i];
-            }
-            wsp[cnt_iq++] = tmp;
+            s_wsp0[p] = in[inoffset + p];
         }
-    }
 
-    size_t cnt_ij = 0;
-    for (size_t j = 0; j < nq1; ++j)
-    {
-        size_t cnt_iq = 0;
-        for (size_t i = 0; i < nq0; ++i)
+        __syncthreads();
+
+        for (unsigned int i = threadIdx.x; i < nq0; i += blockDim.x)
         {
-            TData tmp = wsp[cnt_iq++] * basis1[j];
-            for (size_t q = 1; q < nm1; ++q)
+            TData tmp = 0.0;
+            for (unsigned int p = 0; p < nm0; p++)
             {
-                tmp += wsp[cnt_iq++] * basis1[q * nq1 + j];
+                tmp += s_wsp0[p] * basis0[p * nq0 + i];
             }
-            outptr[cnt_ij++] = tmp;
+            out[outoffset + i] = tmp;
         }
-    }
 
-    delete wsp;
+        __syncthreads();
+
+        e += gridDim.x;
+    }
 }
 
 template <typename TData>
-__global__ void BwdTransTriKernel(const size_t nm0, const size_t nm1,
-                                  const size_t nmTot, const size_t nq0,
-                                  const size_t nq1, const size_t nelmt,
-                                  const bool correct, const TData *basis0,
-                                  const TData *basis1, const TData *in,
-                                  TData *out)
+__global__ void BwdTransQuadKernel(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nmTot,
+    const unsigned int nq0, const unsigned int nq1, const unsigned int nelmt,
+    const TData *__restrict basis0, const TData *__restrict basis1,
+    const TData *__restrict in, TData *__restrict out)
 {
-    size_t e = blockDim.x * blockIdx.x + threadIdx.x;
+    extern __shared__ TData shared[];
+    TData *s_basis0 = shared;
+    TData *s_basis1 = s_basis0 + nm0 * nq0;
 
-    if (e >= nelmt)
+    // Copy to shared memory.
+    unsigned int sIndex = threadIdx.x;
+    while (sIndex < nm0 * nq0)
     {
-        return;
+        s_basis0[sIndex] = basis0[sIndex];
+        sIndex += blockDim.x;
     }
 
-    TData *wsp = new TData[nm0];
+    sIndex = threadIdx.x;
+    while (sIndex < nm1 * nq1)
+    {
+        s_basis1[sIndex] = basis1[sIndex];
+        sIndex += blockDim.x;
+    }
+
+    __syncthreads();
 
-    const TData *inptr = in + (nmTot * e);
-    TData *outptr      = out + (nq0 * nq1 * e);
+    unsigned int e = blockDim.x * blockIdx.x + threadIdx.x;
 
-    size_t cnt_ij = 0;
-    for (size_t j = 0; j < nq1; ++j)
+    TData *wsp = new TData[nm1];
+
+    while (e < nelmt)
     {
-        size_t mode = 0;
-        for (size_t p = 0; p < nm0; ++p)
-        {
-            TData tmp = 0.0;
-            for (size_t q = 0; q < (nm1 - p); ++q)
-            {
-                tmp += basis1[mode * nq1 + j] * inptr[mode];
-                mode++;
-            }
-            wsp[p] = tmp;
-        }
+        unsigned int inoffset  = nmTot * e;
+        unsigned int outoffset = nq0 * nq1 * e;
 
-        for (size_t i = 0; i < nq0; ++i)
+        for (unsigned int i = 0; i < nq0; ++i)
         {
-            TData tmp = wsp[0] * basis0[i];
-            for (size_t p = 1; p < nm0; ++p)
+            for (unsigned int q = 0, cnt_qp = 0; q < nm1; ++q)
             {
-                tmp += wsp[p] * basis0[p * nq0 + i];
+                TData tmp = 0.0;
+                for (unsigned int p = 0; p < nm0; ++p, ++cnt_qp)
+                {
+                    tmp += in[inoffset + cnt_qp] * s_basis0[p * nq0 + i];
+                }
+                wsp[q] = tmp;
             }
 
-            if (correct)
+            for (unsigned int j = 0; j < nq1; ++j)
             {
-                tmp += inptr[1] * basis0[nq0 + i] * basis1[nq1 + j];
+                TData tmp = 0.0;
+                for (unsigned int q = 0; q < nm1; ++q)
+                {
+                    tmp += wsp[q] * s_basis1[q * nq1 + j];
+                }
+                out[outoffset + nq0 * j + i] = tmp;
             }
-            outptr[cnt_ij++] = tmp;
         }
+
+        e += blockDim.x * gridDim.x;
     }
 
-    delete wsp;
+    delete[] wsp;
 }
 
 template <typename TData>
-__global__ void BwdTransHexKernel(const size_t nm0, const size_t nm1,
-                                  const size_t nm2, const size_t nq0,
-                                  const size_t nq1, const size_t nq2,
-                                  const size_t nelmt, const TData *basis0,
-                                  const TData *basis1, const TData *basis2,
-                                  const TData *in, TData *out)
+__global__ void BwdTransQuadKernel_QP(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nmTot,
+    const unsigned int nq0, const unsigned int nq1, const unsigned int nelmt,
+    const TData *__restrict basis0, const TData *__restrict basis1,
+    const TData *__restrict in, TData *__restrict out)
 {
-    size_t e = blockDim.x * blockIdx.x + threadIdx.x;
+    extern __shared__ TData shared[];
+    TData *s_wsp0 = shared;
+    TData *s_wsp1 = s_wsp0 + nmTot;
+
+    unsigned int e = blockIdx.x;
 
-    if (e >= nelmt)
+    while (e < nelmt)
     {
-        return;
-    }
+        unsigned int inoffset  = nmTot * e;
+        unsigned int outoffset = nq0 * nq1 * e;
 
-    TData *wsp0 = new TData[nq0 * nm1 * nm2];
-    TData *wsp1 = new TData[nq0 * nq1 * nm2];
+        // Copy to shared memory.
+        for (unsigned int q = threadIdx.y; q < nm1; q += blockDim.y)
+        {
+            unsigned int cnt_qp = nm0 * q;
+
+            for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
+            {
+                s_wsp0[cnt_qp + p] = in[inoffset + cnt_qp + p];
+            }
+        }
 
-    const TData *inptr = in + (nm0 * nm1 * nm2 * e);
-    TData *outptr      = out + (nq0 * nq1 * nq2 * e);
+        __syncthreads();
 
-    size_t cnt_irq = 0;
-    for (size_t i = 0; i < nq0; ++i)
-    {
-        size_t cnt_rqp = 0;
-        for (size_t r = 0; r < nm2; ++r)
+        for (unsigned int i = threadIdx.x; i < nq0; i += blockDim.x)
         {
-            for (size_t q = 0; q < nm1; ++q)
+            for (unsigned int q = threadIdx.y; q < nm1; q += blockDim.y)
             {
-                TData tmp = inptr[cnt_rqp++] * basis0[i];
-                for (size_t p = 1; p < nm0; ++p)
+                unsigned int cnt_iq = nm1 * i + q;
+                unsigned int cnt_qp = nm0 * q;
+
+                TData tmp = 0.0;
+                for (unsigned int p = 0; p < nm0; ++p, ++cnt_qp)
                 {
-                    tmp += inptr[cnt_rqp++] * basis0[p * nq0 + i];
+                    tmp += s_wsp0[cnt_qp] * basis0[p * nq0 + i];
                 }
-                wsp0[cnt_irq++] = tmp;
+                s_wsp1[cnt_iq] = tmp;
             }
         }
-    }
 
-    size_t cnt_jir = 0;
-    for (size_t j = 0; j < nq1; ++j)
-    {
-        size_t cnt_irq = 0;
-        for (size_t i = 0; i < nq0; ++i)
+        __syncthreads();
+
+        for (unsigned int i = threadIdx.x; i < nq0; i += blockDim.x)
         {
-            for (size_t r = 0; r < nm2; ++r)
+            for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
             {
-                TData tmp = wsp0[cnt_irq++] * basis1[j];
-                for (size_t q = 1; q < nm1; ++q)
+                unsigned int cnt_iq = nm1 * i;
+                unsigned int cnt_ji = nq0 * j + i;
+
+                TData tmp = 0.0;
+                for (unsigned int q = 0; q < nm1; ++q, ++cnt_iq)
                 {
-                    tmp += wsp0[cnt_irq++] * basis1[q * nq1 + j];
+                    tmp += s_wsp1[cnt_iq] * basis1[q * nq1 + j];
                 }
-                wsp1[cnt_jir++] = tmp;
+                out[outoffset + cnt_ji] = tmp;
             }
         }
+
+        __syncthreads();
+
+        e += gridDim.x;
     }
+}
+
+template <typename TData>
+__global__ void BwdTransTriKernel(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nmTot,
+    const unsigned int nq0, const unsigned int nq1, const unsigned int nelmt,
+    const bool correct, const TData *__restrict basis0,
+    const TData *__restrict basis1, const TData *__restrict in,
+    TData *__restrict out)
+{
+    unsigned int e = blockDim.x * blockIdx.x + threadIdx.x;
 
-    size_t cnt_kji = 0;
-    for (size_t k = 0; k < nq2; ++k)
+    TData *wsp = new TData[nm0];
+
+    while (e < nelmt)
     {
-        size_t cnt_jir = 0;
-        for (size_t j = 0; j < nq1; ++j)
+        unsigned int inoffset  = nmTot * e;
+        unsigned int outoffset = nq0 * nq1 * e;
+
+        for (unsigned int j = 0, cnt_ji = 0; j < nq1; ++j)
         {
-            for (size_t i = 0; i < nq0; ++i)
+            for (unsigned int p = 0, mode_pq = 0; p < nm0; ++p)
+            {
+                TData tmp = 0.0;
+                for (unsigned int q = 0; q < (nm1 - p); ++q, ++mode_pq)
+                {
+                    tmp += basis1[mode_pq * nq1 + j] * in[inoffset + mode_pq];
+                }
+                wsp[p] = tmp;
+            }
+
+            for (unsigned int i = 0; i < nq0; ++i, ++cnt_ji)
             {
-                TData tmp = wsp1[cnt_jir++] * basis2[k];
-                for (size_t r = 1; r < nm2; ++r)
+                TData tmp = 0.0;
+                for (unsigned int p = 0; p < nm0; ++p)
+                {
+                    tmp += wsp[p] * basis0[p * nq0 + i];
+                }
+
+                if (correct)
                 {
-                    tmp += wsp1[cnt_jir++] * basis2[r * nq2 + k];
+                    tmp += in[inoffset + 1] * basis0[nq0 + i] * basis1[nq1 + j];
                 }
-                outptr[cnt_kji++] = tmp;
+
+                out[outoffset + cnt_ji] = tmp;
             }
         }
+
+        e += blockDim.x * gridDim.x;
     }
 
-    delete wsp0;
-    delete wsp1;
+    delete[] wsp;
 }
 
-template <typename TData> // not working for nm2 > nm1
-__global__ void BwdTransTetKernel(const size_t nm0, const size_t nm1,
-                                  const size_t nm2, const size_t nmTot,
-                                  const size_t nq0, const size_t nq1,
-                                  const size_t nq2, const size_t nelmt,
-                                  const bool correct, const TData *basis0,
-                                  const TData *basis1, const TData *basis2,
-                                  const TData *in, TData *out)
+template <typename TData>
+__global__ void BwdTransTriKernel_QP(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nmTot,
+    const unsigned int nq0, const unsigned int nq1, const unsigned int nelmt,
+    const bool correct, const TData *__restrict basis0,
+    const TData *__restrict basis1, const TData *__restrict in,
+    TData *__restrict out)
 {
-    size_t e = blockDim.x * blockIdx.x + threadIdx.x;
+    extern __shared__ TData shared[];
+    TData *s_wsp0 = shared;
+    TData *s_wsp1 = s_wsp0 + nmTot;
 
-    if (e >= nelmt)
-    {
-        return;
-    }
-
-    size_t nm01 = (2 * nm1 - nm0 + 1) * nm0 / 2;
-    TData *fpq  = new TData[nm01];
-    TData *fp   = new TData[nm0];
-
-    const TData *inptr = in + (nmTot * e);
-    TData *outptr      = out + (nq0 * nq1 * nq2 * e);
+    unsigned int e = blockIdx.x;
 
-    size_t cnt_kji = 0;
-    for (size_t k = 0; k < nq2; ++k)
+    while (e < nelmt)
     {
-        size_t cnt_pq = 0;
-        size_t mode   = 0;
-        for (size_t p = 0; p < nm0; ++p)
+        unsigned int inoffset  = nmTot * e;
+        unsigned int outoffset = nq0 * nq1 * e;
+
+        // Copy to shared memory.
+        for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
         {
-            for (size_t q = 0; q < nm1 - p; ++q)
-            {
-                TData tmp = basis2[k + nq2 * mode] * inptr[mode++];
-                for (size_t r = 1; r < nm2 - p - q; ++r)
-                {
-                    tmp += basis2[k + nq2 * mode] * inptr[mode++];
-                }
-                fpq[cnt_pq++] = tmp;
-            }
+            unsigned int mode_pq = (2 * nm1 - p + 1) * p / 2;
 
-            // increment mode in case order1!=order2
-            for (size_t q = nm1 - p; q < nm2 - p; ++q)
+            for (unsigned int q = threadIdx.y; q < nm1 - p; q += blockDim.y)
             {
-                mode += nm2 - p - q;
+                s_wsp0[mode_pq + q] = in[inoffset + mode_pq + q];
             }
         }
 
-        for (size_t j = 0; j < nq1; ++j)
+        __syncthreads();
+
+        for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
         {
-            mode   = 0;
-            cnt_pq = 0;
-            for (size_t p = 0; p < nm0; ++p)
+            for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
             {
-                TData tmp = fpq[cnt_pq++] * basis1[mode * nq1 + j];
-                for (size_t q = 1; q < nm1 - p; ++q)
+                unsigned int mode_pq = (2 * nm1 - p + 1) * p / 2;
+                unsigned int cnt_jp  = nm0 * j + p;
+
+                TData tmp = 0.0;
+                for (unsigned int q = 0; q < (nm1 - p); ++q, ++mode_pq)
                 {
-                    tmp += fpq[cnt_pq++] * basis1[(mode + q) * nq1 + j];
+                    tmp += basis1[mode_pq * nq1 + j] * s_wsp0[mode_pq];
                 }
-
-                fp[p] = tmp;
-                mode += nm1 - p;
+                s_wsp1[cnt_jp] = tmp;
             }
+        }
+
+        __syncthreads();
 
-            for (size_t i = 0; i < nq0; ++i)
+        for (unsigned int i = threadIdx.x; i < nq0; i += blockDim.x)
+        {
+            for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
             {
-                TData tmp = basis0[i] * fp[0];
-                for (size_t p = 1; p < nm0; ++p)
+                unsigned int cnt_jp = nm0 * j;
+                unsigned int cnt_ij = nq0 * j + i;
+
+                TData tmp = 0.0;
+                for (unsigned int p = 0; p < nm0; ++p, ++cnt_jp)
                 {
-                    tmp += basis0[p * nq0 + i] * fp[p];
+                    tmp += s_wsp1[cnt_jp] * basis0[p * nq0 + i];
                 }
 
                 if (correct)
                 {
-                    // top vertex
-                    TData tmp1 = basis0[i] * basis1[nq1 + j];
-                    tmp1 += basis0[nq0 + i] * basis1[j];
-                    tmp1 += basis0[nq0 + i] * basis1[nq1 + j];
-                    tmp1 *= basis2[nq2 + k];
-                    tmp += tmp1 * inptr[1];
-
-                    // bottom vertex
-                    tmp1 = basis0[nq0 + i] * basis1[nq1 + j];
-                    tmp1 *= basis2[k];
-                    tmp += tmp1 * inptr[nm2];
-
-                    // singular edge
-                    for (size_t r = 1; r < nm2 - 1; ++r)
-                    {
-                        tmp1 = basis1[nq1 + j] * basis0[nq0 + i];
-                        tmp1 *= basis2[(r + 1) * nq2 + k];
-                        tmp += tmp1 * inptr[nm2 + r];
-                    }
+                    tmp += s_wsp0[1] * basis0[nq0 + i] * basis1[nq1 + j];
                 }
-                outptr[cnt_kji++] = tmp;
+
+                out[outoffset + cnt_ij] = tmp;
             }
         }
-    }
 
-    delete fpq;
-    delete fp;
+        __syncthreads();
+
+        e += gridDim.x;
+    }
 }
 
 template <typename TData>
-__global__ void BwdTransPrismKernel(const size_t nm0, const size_t nm1,
-                                    const size_t nm2, const size_t nmTot,
-                                    const size_t nq0, const size_t nq1,
-                                    const size_t nq2, const size_t nelmt,
-                                    const bool correct, const TData *basis0,
-                                    const TData *basis1, const TData *basis2,
-                                    const TData *in, TData *out)
+__global__ void BwdTransHexKernel(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nm2,
+    const unsigned int nmTot, const unsigned int nq0, const unsigned int nq1,
+    const unsigned int nq2, const unsigned int nelmt,
+    const TData *__restrict basis0, const TData *__restrict basis1,
+    const TData *__restrict basis2, const TData *__restrict in,
+    TData *__restrict out)
 {
-    size_t e = blockDim.x * blockIdx.x + threadIdx.x;
+    extern __shared__ TData shared[];
+    TData *s_basis0 = shared;
+    TData *s_basis1 = s_basis0 + nm0 * nq0;
+    TData *s_basis2 = s_basis1 + nm1 * nq1;
+
+    // Copy to shared memory.
+    unsigned int sIndex = threadIdx.x;
+    while (sIndex < nm0 * nq0)
+    {
+        s_basis0[sIndex] = basis0[sIndex];
+        sIndex += blockDim.x;
+    }
 
-    if (e >= nelmt)
+    sIndex = threadIdx.x;
+    while (sIndex < nm1 * nq1)
     {
-        return;
+        s_basis1[sIndex] = basis1[sIndex];
+        sIndex += blockDim.x;
     }
 
-    TData *fpq = new TData[nm0 * nm1];
-    TData *fp  = new TData[nm0];
+    sIndex = threadIdx.x;
+    while (sIndex < nm2 * nq2)
+    {
+        s_basis2[sIndex] = basis2[sIndex];
+        sIndex += blockDim.x;
+    }
 
-    const TData *inptr = in + (nmTot * e);
-    TData *outptr      = out + (nq0 * nq1 * nq2 * e);
+    __syncthreads();
+
+    unsigned int e = blockDim.x * blockIdx.x + threadIdx.x;
+
+    TData *wsp1 = new TData[nm1 * nm2];
+    TData *wsp2 = new TData[nm2];
 
-    size_t cnt_kji = 0;
-    for (size_t k = 0; k < nq2; ++k)
+    while (e < nelmt)
     {
-        size_t mode_pqr = 0;
-        size_t mode_pr  = 0;
-        size_t mode_pq  = 0;
-        for (size_t p = 0; p < nm0; ++p)
-        {
-            for (size_t q = 0; q < nm1; ++q)
-            {
-                TData tmp = 0.0;
-                for (size_t r = 0; r < nm2 - p; ++r)
-                {
-                    tmp += inptr[mode_pqr++] * basis2[(mode_pr + r) * nq2 + k];
-                }
-                fpq[mode_pq++] = tmp;
-            }
-            mode_pr += nm2 - p;
-        }
+        unsigned int inoffset  = nmTot * e;
+        unsigned int outoffset = nq0 * nq1 * nq2 * e;
 
-        for (size_t j = 0; j < nq1; ++j)
+        for (unsigned int i = 0; i < nq0; ++i)
         {
-            size_t mode_pq = 0;
-            for (size_t p = 0; p < nm0; ++p)
+            for (unsigned int r = 0, cnt_rqp = 0, cnt_rq = 0; r < nm2; ++r)
             {
-                TData tmp = fpq[mode_pq++] * basis1[j];
-                for (size_t q = 1; q < nm1; ++q)
+                for (unsigned int q = 0; q < nm1; ++q, ++cnt_rq)
                 {
-                    tmp += fpq[mode_pq++] * basis1[q * nq1 + j];
+                    TData tmp = 0.0;
+                    for (unsigned int p = 0; p < nm0; ++p, ++cnt_rqp)
+                    {
+                        tmp += in[inoffset + cnt_rqp] * s_basis0[p * nq0 + i];
+                    }
+                    wsp1[cnt_rq] = tmp;
                 }
-                fp[p] = tmp;
             }
 
-            for (size_t i = 0; i < nq0; ++i)
+            for (unsigned int j = 0; j < nq1; ++j)
             {
-                TData tmp = fp[0] * basis0[i];
-                for (size_t p = 1; p < nm0; ++p)
+                for (unsigned int r = 0, cnt_rq = 0; r < nm2; ++r)
                 {
-                    tmp += fp[p] * basis0[p * nq0 + i];
+                    TData tmp = 0.0;
+                    for (unsigned int q = 0; q < nm1; ++q, ++cnt_rq)
+                    {
+                        tmp += wsp1[cnt_rq] * s_basis1[q * nq1 + j];
+                    }
+                    wsp2[r] = tmp;
                 }
 
-                if (correct)
+                for (unsigned int k = 0; k < nq2; ++k)
                 {
-                    for (size_t q = 0; q < nm1; ++q)
+                    TData tmp = 0.0;
+                    for (unsigned int r = 0; r < nm2; ++r)
                     {
-                        tmp += basis2[nq2 + k] * basis1[q * nq1 + j] *
-                               basis0[nq0 + i] * inptr[q * nm2 + 1];
+                        tmp += wsp2[r] * s_basis2[r * nq2 + k];
                     }
+                    out[outoffset + k * nq1 * nq0 + j * nq0 + i] = tmp;
                 }
-                outptr[cnt_kji++] = tmp;
             }
         }
+
+        e += blockDim.x * gridDim.x;
     }
 
-    delete fpq;
-    delete fp;
+    delete[] wsp1;
+    delete[] wsp2;
 }
 
-template <typename TData> // not working for nm2 > nm1
-__global__ void BwdTransPyrKernel(const size_t nm0, const size_t nm1,
-                                  const size_t nm2, const size_t nmTot,
-                                  const size_t nq0, const size_t nq1,
-                                  const size_t nq2, const size_t nelmt,
-                                  const bool correct, const TData *basis0,
-                                  const TData *basis1, const TData *basis2,
-                                  const TData *in, TData *out)
+template <typename TData>
+__global__ void BwdTransHexKernel_QP(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nm2,
+    const unsigned int nmTot, const unsigned int nq0, const unsigned int nq1,
+    const unsigned int nq2, const unsigned int nelmt,
+    const TData *__restrict basis0, const TData *__restrict basis1,
+    const TData *__restrict basis2, const TData *__restrict in,
+    TData *__restrict out)
 {
-    size_t e = blockDim.x * blockIdx.x + threadIdx.x;
+    extern __shared__ TData shared[];
+    TData *s_wsp0 = shared;
+    TData *s_wsp1 = s_wsp0 + nmTot;
+    TData *s_wsp2 = s_wsp1 + (nq0 * nm1 * nm2);
+
+    unsigned int e = blockIdx.x;
 
-    if (e >= nelmt)
+    while (e < nelmt)
     {
-        return;
-    }
+        unsigned int inoffset  = nmTot * e;
+        unsigned int outoffset = nq0 * nq1 * nq2 * e;
 
-    TData *fpq = new TData[nm0 * nm1];
-    TData *fp  = new TData[nm0];
+        // Copy to shared memory.
+        for (unsigned int r = threadIdx.z; r < nm2; r += blockDim.z)
+        {
+            for (unsigned int q = threadIdx.y; q < nm1; q += blockDim.y)
+            {
+                unsigned int cnt_rqp = nm1 * nm0 * r + nm0 * q;
 
-    const TData *inptr = in + (nmTot * e);
-    TData *outptr      = out + (nq0 * nq1 * nq2 * e);
+                for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
+                {
+                    s_wsp0[cnt_rqp + p] = in[inoffset + cnt_rqp + p];
+                }
+            }
+        }
 
-    size_t cnt_kji = 0;
-    for (int k = 0; k < nq2; ++k)
-    {
-        size_t mode_pqr = 0;
-        size_t mode_pq  = 0;
-        for (size_t p = 0; p < nm0; ++p)
+        __syncthreads();
+
+        for (unsigned int i = threadIdx.x; i < nq0; i += blockDim.x)
         {
-            for (size_t q = 0; q < p; ++q)
+            for (unsigned int r = threadIdx.z; r < nm2; r += blockDim.z)
             {
-                TData tmp = 0.0;
-                for (size_t r = 0; r < nm2 - p; ++r)
+                for (unsigned int q = threadIdx.y; q < nm1; q += blockDim.y)
                 {
-                    tmp += basis2[mode_pqr * nq2 + k] * inptr[mode_pqr++];
+                    unsigned int cnt_rqp = nm1 * nm0 * r + nm0 * q;
+                    unsigned int cnt_irq = nm1 * nm2 * i + nm1 * r + q;
+
+                    TData tmp = 0.0;
+                    for (unsigned int p = 0; p < nm0; ++p, ++cnt_rqp)
+                    {
+                        tmp += s_wsp0[cnt_rqp] * basis0[p * nq0 + i];
+                    }
+                    s_wsp1[cnt_irq] = tmp;
                 }
-                fpq[mode_pq++] = tmp;
             }
+        }
+
+        __syncthreads();
 
-            for (size_t q = p; q < nm1; ++q)
+        for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
+        {
+            for (unsigned int i = threadIdx.x; i < nq0; i += blockDim.x)
             {
-                TData tmp = 0.0;
-                for (size_t r = 0; r < nm2 - q; ++r)
+                for (unsigned int r = threadIdx.z; r < nm2; r += blockDim.z)
                 {
-                    tmp += basis2[mode_pqr * nq2 + k] * inptr[mode_pqr++];
+                    unsigned int cnt_irq = nm1 * nm2 * i + nm1 * r;
+                    unsigned int cnt_jir = nq0 * nm2 * j + nm2 * i + r;
+
+                    TData tmp = 0.0;
+                    for (unsigned int q = 0; q < nm1; ++q, ++cnt_irq)
+                    {
+                        tmp += s_wsp1[cnt_irq] * basis1[q * nq1 + j];
+                    }
+                    s_wsp2[cnt_jir] = tmp;
                 }
-                fpq[mode_pq++] = tmp;
             }
+        }
+
+        __syncthreads();
 
-            // increment mode in case nm2>nm1
-            for (size_t q = nm1; q < nm2 - p; ++q)
+        for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
+        {
+            for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
             {
-                mode_pqr += nm2 - q;
+                for (unsigned int i = threadIdx.x; i < nq0; i += blockDim.x)
+                {
+                    unsigned int cnt_jir = nq0 * nm2 * j + nm2 * i;
+                    unsigned int cnt_kji = nq0 * nq1 * k + nq0 * j + i;
+
+                    TData tmp = 0.0;
+                    for (unsigned int r = 0; r < nm2; ++r, ++cnt_jir)
+                    {
+                        tmp += s_wsp2[cnt_jir] * basis2[r * nq2 + k];
+                    }
+                    out[outoffset + cnt_kji] = tmp;
+                }
             }
         }
 
-        for (size_t j = 0; j < nq1; ++j)
+        __syncthreads();
+
+        e += gridDim.x;
+    }
+}
+
+template <typename TData> // not working for nm2 > nm1
+__global__ void BwdTransTetKernel(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nm2,
+    const unsigned int nmTot, const unsigned int nq0, const unsigned int nq1,
+    const unsigned int nq2, const unsigned int nelmt, const bool correct,
+    const TData *__restrict basis0, const TData *__restrict basis1,
+    const TData *__restrict basis2, const TData *__restrict in,
+    TData *__restrict out)
+{
+    unsigned int e = blockDim.x * blockIdx.x + threadIdx.x;
+
+    unsigned int nm01 = (2 * nm1 - nm0 + 1) * nm0 / 2;
+    TData *fpq        = new TData[nm01];
+    TData *fp         = new TData[nm0];
+
+    while (e < nelmt)
+    {
+        unsigned int inoffset  = nmTot * e;
+        unsigned int outoffset = nq0 * nq1 * nq2 * e;
+
+        for (unsigned int k = 0, cnt_kji = 0; k < nq2; ++k)
         {
-            size_t mode_pq = 0;
-            for (size_t p = 0; p < nm0; ++p)
+            for (unsigned int p = 0, cnt_pq = 0, mode_pqr = 0; p < nm0; ++p)
             {
-                TData tmp = fpq[mode_pq++] * basis1[j];
-                for (int q = 1; q < nm1; ++q)
+                for (unsigned int q = 0; q < nm1 - p; ++q, ++cnt_pq)
+                {
+                    TData tmp = 0.0;
+                    for (unsigned int r = 0; r < nm2 - p - q; ++r, ++mode_pqr)
+                    {
+                        tmp += basis2[k + nq2 * mode_pqr] *
+                               in[inoffset + mode_pqr];
+                    }
+                    fpq[cnt_pq] = tmp;
+                }
+
+                // increment mode in case order1!=order2
+                for (unsigned int q = nm1 - p; q < nm2 - p; ++q)
                 {
-                    tmp += fpq[mode_pq++] * basis1[q * nq1 + j];
+                    mode_pqr += nm2 - p - q;
                 }
-                fp[p] = tmp;
             }
 
-            for (size_t i = 0; i < nq0; ++i)
+            for (unsigned int j = 0; j < nq1; ++j)
             {
-                TData tmp = fp[0] * basis0[i];
-                for (size_t p = 1; p < nm0; ++p)
+                for (unsigned int p = 0, mode_pq = 0; p < nm0; ++p)
                 {
-                    tmp += fp[p] * basis0[p * nq0 + i];
+                    TData tmp = 0.0;
+                    for (unsigned int q = 0; q < nm1 - p; ++q, ++mode_pq)
+                    {
+                        tmp += fpq[mode_pq] * basis1[mode_pq * nq1 + j];
+                    }
+                    fp[p] = tmp;
                 }
 
-                if (correct)
+                for (unsigned int i = 0; i < nq0; ++i, ++cnt_kji)
                 {
-                    // top vertex
-                    TData tmp1 = basis0[i] * basis1[nq1 + j];
-                    tmp1 += basis0[nq0 + i] * basis1[j];
-                    tmp1 += basis0[nq0 + i] * basis1[nq1 + j];
-                    tmp1 *= basis2[nq2 + k];
-                    tmp += tmp1 * inptr[1];
+                    TData tmp = 0.0;
+                    for (unsigned int p = 0; p < nm0; ++p)
+                    {
+                        tmp += basis0[p * nq0 + i] * fp[p];
+                    }
+
+                    if (correct)
+                    {
+                        // top vertex
+                        TData tmp1 = basis0[i] * basis1[nq1 + j];
+                        tmp1 += basis0[nq0 + i] * basis1[j];
+                        tmp1 += basis0[nq0 + i] * basis1[nq1 + j];
+                        tmp1 *= basis2[nq2 + k];
+                        tmp += tmp1 * in[inoffset + 1];
+
+                        // bottom vertex
+                        tmp1 = basis0[nq0 + i] * basis1[nq1 + j];
+                        tmp1 *= basis2[k];
+                        tmp += tmp1 * in[inoffset + nm2];
+
+                        // singular edge
+                        for (unsigned int r = 1; r < nm2 - 1; ++r)
+                        {
+                            tmp1 = basis1[nq1 + j] * basis0[nq0 + i];
+                            tmp1 *= basis2[(r + 1) * nq2 + k];
+                            tmp += tmp1 * in[inoffset + nm2 + r];
+                        }
+                    }
+
+                    out[outoffset + cnt_kji] = tmp;
                 }
-                outptr[cnt_kji++] = tmp;
             }
         }
+
+        e += blockDim.x * gridDim.x;
     }
 
-    delete fpq;
-    delete fp;
+    delete[] fpq;
+    delete[] fp;
 }
 
-template <typename TData>
-__global__ void BwdTransSegKernel_QP(const size_t nm0, const size_t nq0,
-                                     const TData *basis0, const TData *in,
-                                     TData *out)
+template <typename TData> // not working for nm2 > nm1
+__global__ void BwdTransTetKernel_QP(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nm2,
+    const unsigned int nmTot, const unsigned int nq0, const unsigned int nq1,
+    const unsigned int nq2, const unsigned int nelmt, const bool correct,
+    const TData *__restrict basis0, const TData *__restrict basis1,
+    const TData *__restrict basis2, const TData *__restrict in,
+    TData *__restrict out)
 {
-    size_t e = blockIdx.x;
+    extern __shared__ TData shared[];
+    TData *s_wsp0 = shared;
+    TData *s_wsp1 = s_wsp0 + nmTot;
+    TData *s_wsp2 = s_wsp1 + ((2 * nm1 - nm0 + 1) * nm0 / 2 * nq2);
+
+    unsigned int nm01 = (2 * nm1 - nm0 + 1) * nm0 / 2;
 
-    const TData *inptr = in + (nm0 * e);
-    TData *outptr      = out + (nq0 * e);
+    unsigned int e = blockIdx.x;
 
-    size_t i = threadIdx.x;
-    if (i < nq0)
+    while (e < nelmt)
     {
-        TData tmp = inptr[0] * basis0[i];
-        for (size_t p = 1; p < nm0; p++)
+        unsigned int inoffset  = nmTot * e;
+        unsigned int outoffset = nq0 * nq1 * nq2 * e;
+
+        // Copy to shared memory.
+        for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
         {
-            tmp += inptr[p] * basis0[p * nq0 + i];
+            for (unsigned int q = threadIdx.y; q < nm1 - p; q += blockDim.y)
+            {
+                unsigned int mode_pqr = (2 * (nm2 - p) - q + 1) * q;
+                mode_pqr += nm2 * (nm2 + 1) * p;
+                mode_pqr -= (2 * nm2 + 1) * (p - 1) * p / 2;
+                mode_pqr += (p - 1) * p * (2 * p - 1) / 6;
+                mode_pqr /= 2;
+
+                for (unsigned int r = threadIdx.z; r < nm2 - p - q;
+                     r += blockDim.z)
+                {
+                    s_wsp0[mode_pqr + r] = in[inoffset + mode_pqr + r];
+                }
+            }
         }
-        outptr[i] = tmp;
-    }
-}
-
-template <typename TData>
-__global__ void BwdTransQuadKernel_QP(const size_t nm0, const size_t nm1,
-                                      const size_t nq0, const size_t nq1,
-                                      const TData *basis0, const TData *basis1,
-                                      const TData *in, TData *wsp, TData *out)
-{
-    size_t e = blockIdx.x;
 
-    const TData *inptr = in + (nm0 * nm1 * e);
-    TData *outptr      = out + (nq0 * nq1 * e);
+        __syncthreads();
 
-    size_t i, j, q;
-    i = threadIdx.x;
-    q = threadIdx.y;
-    if (i < nq0 && q < nm1)
-    {
-        size_t cnt_iq = nm1 * i + q;
-        size_t cnt_pq = nm0 * q;
-        TData tmp     = inptr[cnt_pq++] * basis0[i];
-        for (size_t p = 1; p < nm0; ++p)
+        for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
         {
-            tmp += inptr[cnt_pq++] * basis0[p * nq0 + i];
+            for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
+            {
+                for (unsigned int q = threadIdx.y; q < nm1 - p; q += blockDim.y)
+                {
+                    unsigned int cnt_kpq =
+                        nm01 * k + (2 * nm1 - p + 1) * p / 2 + q;
+                    unsigned int mode_pqr = (2 * (nm2 - p) - q + 1) * q;
+                    mode_pqr += nm2 * (nm2 + 1) * p;
+                    mode_pqr -= (2 * nm2 + 1) * (p - 1) * p / 2;
+                    mode_pqr += (p - 1) * p * (2 * p - 1) / 6;
+                    mode_pqr /= 2;
+
+                    TData tmp = 0.0;
+                    for (unsigned int r = 0; r < nm2 - p - q; ++r, ++mode_pqr)
+                    {
+                        tmp += basis2[k + nq2 * mode_pqr] * s_wsp0[mode_pqr];
+                    }
+                    s_wsp1[cnt_kpq] = tmp;
+                }
+            }
         }
-        wsp[cnt_iq] = tmp;
-    }
 
-    __syncthreads();
+        __syncthreads();
 
-    i = threadIdx.x;
-    j = threadIdx.y;
-    if (i < nq0 && j < nq1)
-    {
-        size_t cnt_iq = nm1 * i;
-        size_t cnt_ij = nq0 * j + i;
-        TData tmp     = wsp[cnt_iq++] * basis1[j];
-        for (size_t q = 1; q < nm1; ++q)
+        for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
         {
-            tmp += wsp[cnt_iq++] * basis1[q * nq1 + j];
-        }
-        outptr[cnt_ij] = tmp;
-    }
-}
+            for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
+            {
+                for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
+                {
+                    unsigned int mode_pq  = (2 * nm1 - p + 1) * p / 2;
+                    unsigned int cnt_kpq  = nm01 * k + mode_pq;
+                    unsigned int mode_kjp = nm0 * nq1 * k + nm0 * j + p;
 
-template <typename TData>
-__global__ void BwdTransTriKernel_QP(const size_t nm0, const size_t nm1,
-                                     const size_t nmTot, const size_t nq0,
-                                     const size_t nq1, const bool correct,
-                                     const TData *basis0, const TData *basis1,
-                                     const TData *in, TData *wsp, TData *out)
-{
-    size_t e = blockIdx.x;
+                    TData tmp = 0.0;
+                    for (unsigned int q = 0; q < nm1 - p; ++q, ++cnt_kpq)
+                    {
+                        tmp +=
+                            s_wsp1[cnt_kpq] * basis1[(mode_pq + q) * nq1 + j];
+                    }
+                    s_wsp2[mode_kjp] = tmp;
+                }
+            }
+        }
 
-    const TData *inptr = in + (nmTot * e);
-    TData *outptr      = out + (nq0 * nq1 * e);
+        __syncthreads();
 
-    size_t i, j, p;
-    p = threadIdx.x;
-    j = threadIdx.y;
-    if (p < nm0 && j < nq1)
-    {
-        size_t mode = (2 * nm1 - p + 1) * p / 2;
-        TData tmp   = 0.0;
-        for (size_t q = 0; q < (nm1 - p); ++q)
+        for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
         {
-            tmp += basis1[mode * nq1 + j] * inptr[mode];
-            mode++;
-        }
-        wsp[nm0 * j + p] = tmp;
-    }
+            for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
+            {
+                for (unsigned int i = threadIdx.x; i < nq0; i += blockDim.x)
+                {
+                    unsigned int cnt_kji  = nq0 * nq1 * k + nq0 * j + i;
+                    unsigned int mode_kjp = nm0 * nq1 * k + nm0 * j;
 
-    __syncthreads();
+                    TData tmp = 0.0;
+                    for (unsigned int p = 0; p < nm0; ++p, ++mode_kjp)
+                    {
+                        tmp += basis0[p * nq0 + i] * s_wsp2[mode_kjp];
+                    }
 
-    i = threadIdx.x;
-    j = threadIdx.y;
-    if (i < nq0 && j < nq1)
-    {
-        size_t cnt_ij = nq0 * j + i;
-        TData tmp     = wsp[nm0 * j] * basis0[i];
-        for (size_t p = 1; p < nm0; ++p)
-        {
-            tmp += wsp[nm0 * j + p] * basis0[p * nq0 + i];
-        }
+                    if (correct)
+                    {
+                        // top vertex
+                        TData tmp1 = basis0[i] * basis1[nq1 + j];
+                        tmp1 += basis0[nq0 + i] * basis1[j];
+                        tmp1 += basis0[nq0 + i] * basis1[nq1 + j];
+                        tmp1 *= basis2[nq2 + k];
+                        tmp += tmp1 * s_wsp0[1];
+
+                        // bottom vertex
+                        tmp1 = basis0[nq0 + i] * basis1[nq1 + j];
+                        tmp1 *= basis2[k];
+                        tmp += tmp1 * s_wsp0[nm2];
+
+                        // singular edge
+                        for (unsigned int r = 1; r < nm2 - 1; ++r)
+                        {
+                            tmp1 = basis1[nq1 + j] * basis0[nq0 + i];
+                            tmp1 *= basis2[(r + 1) * nq2 + k];
+                            tmp += tmp1 * s_wsp0[nm2 + r];
+                        }
+                    }
 
-        if (correct)
-        {
-            tmp += inptr[1] * basis0[nq0 + i] * basis1[nq1 + j];
+                    out[outoffset + cnt_kji] = tmp;
+                }
+            }
         }
 
-        outptr[cnt_ij] = tmp;
+        __syncthreads();
+
+        e += gridDim.x;
     }
 }
 
 template <typename TData>
-__global__ void BwdTransHexKernel_QP(const size_t nm0, const size_t nm1,
-                                     const size_t nm2, const size_t nq0,
-                                     const size_t nq1, const size_t nq2,
-                                     const TData *basis0, const TData *basis1,
-                                     const TData *basis2, const TData *in,
-                                     TData *wsp0, TData *wsp1, TData *out)
+__global__ void BwdTransPrismKernel(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nm2,
+    const unsigned int nmTot, const unsigned int nq0, const unsigned int nq1,
+    const unsigned int nq2, const unsigned int nelmt, const bool correct,
+    const TData *__restrict basis0, const TData *__restrict basis1,
+    const TData *__restrict basis2, const TData *__restrict in,
+    TData *__restrict out)
 {
-    size_t e = blockIdx.x;
+    unsigned int e = blockDim.x * blockIdx.x + threadIdx.x;
 
-    const TData *inptr = in + (nm0 * nm1 * nm2 * e);
-    TData *outptr      = out + (nq0 * nq1 * nq2 * e);
+    TData *fpq = new TData[nm0 * nm1];
+    TData *fp  = new TData[nm0];
 
-    size_t i, j, k, q, r;
-    i = threadIdx.x;
-    q = threadIdx.y;
-    r = threadIdx.z;
-    if (i < nq0 && q < nm1 && r < nm2)
+    while (e < nelmt)
     {
-        size_t cnt_rqp = nm1 * nm0 * r + nm0 * q;
-        size_t cnt_irq = nm1 * nm2 * i + nm1 * r + q;
-        TData tmp      = inptr[cnt_rqp++] * basis0[i];
-        for (size_t p = 1; p < nm0; ++p)
+        unsigned int inoffset  = nmTot * e;
+        unsigned int outoffset = nq0 * nq1 * nq2 * e;
+
+        for (unsigned int k = 0, cnt_kji = 0; k < nq2; ++k)
         {
-            tmp += inptr[cnt_rqp++] * basis0[p * nq0 + i];
-        }
-        wsp0[cnt_irq] = tmp;
-    }
+            for (unsigned int p = 0, mode_pr = 0, mode_pq = 0, mode_pqr = 0;
+                 p < nm0; ++p)
+            {
+                for (unsigned int q = 0; q < nm1; ++q, ++mode_pq)
+                {
+                    TData tmp = 0.0;
+                    for (unsigned int r = 0; r < nm2 - p; ++r, ++mode_pqr)
+                    {
+                        tmp += in[inoffset + mode_pqr] *
+                               basis2[(mode_pr + r) * nq2 + k];
+                    }
+                    fpq[mode_pq] = tmp;
+                }
+                mode_pr += nm2 - p;
+            }
 
-    __syncthreads();
+            for (unsigned int j = 0; j < nq1; ++j)
+            {
+                for (unsigned int p = 0, mode_pq = 0; p < nm0; ++p)
+                {
+                    TData tmp = 0.0;
+                    for (unsigned int q = 0; q < nm1; ++q, ++mode_pq)
+                    {
+                        tmp += fpq[mode_pq] * basis1[q * nq1 + j];
+                    }
+                    fp[p] = tmp;
+                }
 
-    i = threadIdx.x;
-    j = threadIdx.y;
-    r = threadIdx.z;
-    if (i < nq0 && j < nq1 && r < nm2)
-    {
-        size_t cnt_irq = nm1 * nm2 * i + nm1 * r;
-        size_t cnt_jir = nq0 * nm2 * j + nm2 * i + r;
-        TData tmp      = wsp0[cnt_irq++] * basis1[j];
-        for (size_t q = 1; q < nm1; ++q)
-        {
-            tmp += wsp0[cnt_irq++] * basis1[q * nq1 + j];
-        }
-        wsp1[cnt_jir] = tmp;
-    }
+                for (unsigned int i = 0; i < nq0; ++i, ++cnt_kji)
+                {
+                    TData tmp = 0.0;
+                    for (unsigned int p = 0; p < nm0; ++p)
+                    {
+                        tmp += fp[p] * basis0[p * nq0 + i];
+                    }
 
-    __syncthreads();
+                    if (correct)
+                    {
+                        for (unsigned int q = 0; q < nm1; ++q)
+                        {
+                            tmp += basis2[nq2 + k] * basis1[q * nq1 + j] *
+                                   basis0[nq0 + i] * in[inoffset + nm2 * q + 1];
+                        }
+                    }
 
-    i = threadIdx.x;
-    j = threadIdx.y;
-    k = threadIdx.z;
-    if (i < nq0 && j < nq1 && k < nq2)
-    {
-        size_t cnt_jir = nq0 * nm2 * j + nm2 * i;
-        size_t cnt_kji = nq0 * nq1 * k + nq0 * j + i;
-        TData tmp      = wsp1[cnt_jir++] * basis2[k];
-        for (size_t r = 1; r < nm2; ++r)
-        {
-            tmp += wsp1[cnt_jir++] * basis2[r * nq2 + k];
+                    out[outoffset + cnt_kji] = tmp;
+                }
+            }
         }
-        outptr[cnt_kji] = tmp;
+
+        e += blockDim.x * gridDim.x;
     }
+
+    delete[] fpq;
+    delete[] fp;
 }
 
-template <typename TData> // not working for nm2 > nm1
-__global__ void BwdTransTetKernel_QP(const size_t nm0, const size_t nm1,
-                                     const size_t nm2, const size_t nmTot,
-                                     const size_t nq0, const size_t nq1,
-                                     const size_t nq2, const bool correct,
-                                     const TData *basis0, const TData *basis1,
-                                     const TData *basis2, const TData *in,
-                                     TData *fpq, TData *fp, TData *out)
+template <typename TData>
+__global__ void BwdTransPrismKernel_QP(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nm2,
+    const unsigned int nmTot, const unsigned int nq0, const unsigned int nq1,
+    const unsigned int nq2, const unsigned int nelmt, const bool correct,
+    const TData *__restrict basis0, const TData *__restrict basis1,
+    const TData *__restrict basis2, const TData *__restrict in,
+    TData *__restrict out)
 {
-    size_t nm01 = (2 * nm1 - nm0 + 1) * nm0 / 2;
-    size_t e    = blockIdx.x;
+    extern __shared__ TData shared[];
+    TData *s_wsp0 = shared;
+    TData *s_wsp1 = s_wsp0 + nmTot;
+    TData *s_wsp2 = s_wsp1 + (nm0 * nm1 * nq2);
 
-    const TData *inptr = in + (nmTot * e);
-    TData *outptr      = out + (nq0 * nq1 * nq2 * e);
+    unsigned int e = blockIdx.x;
 
-    size_t i, j, k, p, q;
-    p = threadIdx.x;
-    q = threadIdx.y;
-    k = threadIdx.z;
-    if (p < nm0 && q < nm1 && k < nq2)
+    while (e < nelmt)
     {
-        size_t cnt_pq = nm01 * k + (2 * nm1 - p + 1) * p / 2 + q;
-        size_t mode   = (2 * (nm2 - p) - q + 1) * q / 2;
-        for (size_t n = 0; n < p; ++n)
-        {
-            mode += (nm2 - n + 1) * (nm2 - n) / 2;
-        }
+        unsigned int inoffset  = nmTot * e;
+        unsigned int outoffset = nq0 * nq1 * nq2 * e;
 
-        if (q < nm1 - p)
+        // Copy to shared memory.
+        for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
         {
-            TData tmp = basis2[k + nq2 * mode] * inptr[mode++];
-            for (size_t r = 1; r < nm2 - p - q; ++r)
+            for (unsigned int q = threadIdx.y; q < nm1; q += blockDim.y)
             {
-                tmp += basis2[k + nq2 * mode] * inptr[mode++];
+                unsigned int mode_pr  = (2 * nm2 - p + 1) * p / 2;
+                unsigned int mode_pqr = mode_pr * nm1 + (nm2 - p) * q;
+
+                for (unsigned int r = threadIdx.z; r < nm2 - p; r += blockDim.z)
+                {
+                    s_wsp0[mode_pqr + r] = in[inoffset + mode_pqr + r];
+                }
             }
-            fpq[cnt_pq] = tmp;
         }
-    }
 
-    __syncthreads();
+        __syncthreads();
 
-    p = threadIdx.x;
-    j = threadIdx.y;
-    k = threadIdx.z;
-    if (p < nm0 && j < nq1 && k < nq2)
-    {
-        size_t cnt_pq = nm01 * k + (2 * nm1 - p + 1) * p / 2;
-        size_t mode   = (2 * nm1 - p + 1) * p / 2;
-        size_t mode_p = nm0 * nq1 * k + nm0 * j + p;
-        TData tmp     = fpq[cnt_pq++] * basis1[mode * nq1 + j];
-        for (size_t q = 1; q < nm1 - p; ++q)
+        for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
         {
-            tmp += fpq[cnt_pq++] * basis1[(mode + q) * nq1 + j];
+            for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
+            {
+                for (unsigned int q = threadIdx.y; q < nm1; q += blockDim.y)
+                {
+                    unsigned int mode_pr  = (2 * nm2 - p + 1) * p / 2;
+                    unsigned int mode_pqr = mode_pr * nm1 + (nm2 - p) * q;
+                    unsigned int mode_kpq = nm0 * nm1 * k + nm1 * p + q;
+
+                    TData tmp = 0.0;
+                    for (unsigned int r = 0; r < nm2 - p; ++r, ++mode_pqr)
+                    {
+                        tmp +=
+                            s_wsp0[mode_pqr] * basis2[(mode_pr + r) * nq2 + k];
+                    }
+                    s_wsp1[mode_kpq] = tmp;
+                }
+            }
         }
-        fp[mode_p] = tmp;
-    }
 
-    __syncthreads();
+        __syncthreads();
 
-    i = threadIdx.x;
-    j = threadIdx.y;
-    k = threadIdx.z;
-    if (i < nq0 && j < nq1 && k < nq2)
-    {
-        size_t cnt_kji = nq0 * nq1 * k + nq0 * j + i;
-        size_t mode_p  = nm0 * nq1 * k + nm0 * j;
-        TData tmp      = basis0[i] * fp[mode_p++];
-        for (size_t p = 1; p < nm0; ++p)
+        for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
         {
-            tmp += basis0[p * nq0 + i] * fp[mode_p++];
+            for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
+            {
+                for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
+                {
+                    unsigned int mode_kpq = nm0 * nm1 * k + nm1 * p;
+                    unsigned int mode_kjp = nm0 * nq1 * k + nm0 * j + p;
+
+                    TData tmp = 0.0;
+                    for (int q = 0; q < nm1; ++q, ++mode_kpq)
+                    {
+                        tmp += s_wsp1[mode_kpq] * basis1[q * nq1 + j];
+                    }
+                    s_wsp2[mode_kjp] = tmp;
+                }
+            }
         }
 
-        if (correct)
+        __syncthreads();
+
+        for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
         {
-            // top vertex
-            TData tmp1 = basis0[i] * basis1[nq1 + j];
-            tmp1 += basis0[nq0 + i] * basis1[j];
-            tmp1 += basis0[nq0 + i] * basis1[nq1 + j];
-            tmp1 *= basis2[nq2 + k];
-            tmp += tmp1 * inptr[1];
+            for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
+            {
+                for (unsigned int i = threadIdx.x; i < nq0; i += blockDim.x)
+                {
+                    unsigned int cnt_kji  = nq0 * nq1 * k + nq0 * j + i;
+                    unsigned int mode_kjp = nm0 * nq1 * k + nm0 * j;
 
-            // bottom vertex
-            tmp1 = basis0[nq0 + i] * basis1[nq1 + j];
-            tmp1 *= basis2[k];
-            tmp += tmp1 * inptr[nm2];
+                    TData tmp = 0.0;
+                    for (int p = 0; p < nm0; ++p, ++mode_kjp)
+                    {
+                        tmp += s_wsp2[mode_kjp] * basis0[p * nq0 + i];
+                    }
 
-            // singular edge
-            for (size_t r = 1; r < nm2 - 1; ++r)
-            {
-                tmp1 = basis1[nq1 + j] * basis0[nq0 + i];
-                tmp1 *= basis2[(r + 1) * nq2 + k];
-                tmp += tmp1 * inptr[nm2 + r];
+                    if (correct)
+                    {
+                        for (int q = 0; q < nm1; ++q)
+                        {
+                            tmp += basis2[nq2 + k] * basis1[q * nq1 + j] *
+                                   basis0[nq0 + i] * s_wsp0[q * nm2 + 1];
+                        }
+                    }
+
+                    out[outoffset + cnt_kji] = tmp;
+                }
             }
         }
 
-        outptr[cnt_kji] = tmp;
+        __syncthreads();
+
+        e += gridDim.x;
     }
 }
 
-template <typename TData>
-__global__ void BwdTransPrismKernel_QP(const size_t nm0, const size_t nm1,
-                                       const size_t nm2, const size_t nmTot,
-                                       const size_t nq0, const size_t nq1,
-                                       const size_t nq2, const bool correct,
-                                       const TData *basis0, const TData *basis1,
-                                       const TData *basis2, const TData *in,
-                                       TData *fpq, TData *fp, TData *out)
+template <typename TData> // not working for nm2 > nm1
+__global__ void BwdTransPyrKernel(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nm2,
+    const unsigned int nmTot, const unsigned int nq0, const unsigned int nq1,
+    const unsigned int nq2, const unsigned int nelmt, const bool correct,
+    const TData *__restrict basis0, const TData *__restrict basis1,
+    const TData *__restrict basis2, const TData *__restrict in,
+    TData *__restrict out)
 {
-    size_t e = blockIdx.x;
+    unsigned int e = blockDim.x * blockIdx.x + threadIdx.x;
 
-    const TData *inptr = in + (nmTot * e);
-    TData *outptr      = out + (nq0 * nq1 * nq2 * e);
+    TData *fpq = new TData[nm0 * nm1];
+    TData *fp  = new TData[nm0];
 
-    size_t i, j, k, p, q;
-    p = threadIdx.x;
-    q = threadIdx.y;
-    k = threadIdx.z;
-    if (p < nm0 && q < nm1 && k < nq2)
+    while (e < nelmt)
     {
-        size_t mode_pr  = (2 * nm2 - p + 1) * p / 2;
-        size_t mode_pqr = mode_pr * nm1 + (nm2 - p) * q;
-        size_t mode_pq  = nm0 * nm1 * k + nm1 * p + q;
-        TData tmp       = 0.0;
-        for (size_t r = 0; r < nm2 - p; ++r)
-        {
-            tmp += inptr[mode_pqr++] * basis2[(mode_pr + r) * nq2 + k];
-        }
-        fpq[mode_pq] = tmp;
-    }
-
-    __syncthreads();
+        unsigned int inoffset  = nmTot * e;
+        unsigned int outoffset = nq0 * nq1 * nq2 * e;
 
-    p = threadIdx.x;
-    j = threadIdx.y;
-    k = threadIdx.z;
-    if (p < nm0 && j < nq1 && k < nq2)
-    {
-        size_t mode_pq = nm0 * nm1 * k + nm1 * p;
-        size_t mode_p  = nm0 * nq1 * k + nm0 * j + p;
-        TData tmp      = fpq[mode_pq++] * basis1[j];
-        for (int q = 1; q < nm1; ++q)
+        for (int k = 0, cnt_kji = 0; k < nq2; ++k)
         {
-            tmp += fpq[mode_pq++] * basis1[q * nq1 + j];
-        }
-        fp[mode_p] = tmp;
-    }
+            for (unsigned int p = 0, mode_pq = 0, mode_pqr = 0; p < nm0; ++p)
+            {
+                for (unsigned int q = 0; q < p; ++q, ++mode_pq)
+                {
+                    TData tmp = 0.0;
+                    for (unsigned int r = 0; r < nm2 - p; ++r, ++mode_pqr)
+                    {
+                        tmp += basis2[mode_pqr * nq2 + k] *
+                               in[inoffset + mode_pqr];
+                    }
+                    fpq[mode_pq] = tmp;
+                }
 
-    __syncthreads();
+                for (unsigned int q = p; q < nm1; ++q, ++mode_pq)
+                {
+                    TData tmp = 0.0;
+                    for (unsigned int r = 0; r < nm2 - q; ++r, ++mode_pqr)
+                    {
+                        tmp += basis2[mode_pqr * nq2 + k] *
+                               in[inoffset + mode_pqr];
+                    }
+                    fpq[mode_pq] = tmp;
+                }
 
-    i = threadIdx.x;
-    j = threadIdx.y;
-    k = threadIdx.z;
-    if (i < nq0 && j < nq1 && k < nq2)
-    {
-        size_t cnt_kji = nq0 * nq1 * k + nq0 * j + i;
-        size_t mode_p  = nm0 * nq1 * k + nm0 * j;
-        TData tmp      = fp[mode_p++] * basis0[i];
-        for (int p = 1; p < nm0; ++p)
-        {
-            tmp += fp[mode_p++] * basis0[p * nq0 + i];
-        }
+                // increment mode in case nm2>nm1
+                for (unsigned int q = nm1; q < nm2 - p; ++q)
+                {
+                    mode_pqr += nm2 - q;
+                }
+            }
 
-        if (correct)
-        {
-            for (int q = 0; q < nm1; ++q)
+            for (unsigned int j = 0; j < nq1; ++j)
             {
-                tmp += basis2[nq2 + k] * basis1[q * nq1 + j] * basis0[nq0 + i] *
-                       inptr[q * nm2 + 1];
+                for (unsigned int p = 0, mode_pq = 0; p < nm0; ++p)
+                {
+                    TData tmp = 0.0;
+                    for (unsigned int q = 0; q < nm1; ++q, ++mode_pq)
+                    {
+                        tmp += fpq[mode_pq] * basis1[q * nq1 + j];
+                    }
+                    fp[p] = tmp;
+                }
+
+                for (unsigned int i = 0; i < nq0; ++i, ++cnt_kji)
+                {
+                    TData tmp = 0.0;
+                    for (unsigned int p = 0; p < nm0; ++p)
+                    {
+                        tmp += fp[p] * basis0[p * nq0 + i];
+                    }
+
+                    if (correct)
+                    {
+                        // top vertex
+                        TData tmp1 = basis0[i] * basis1[nq1 + j];
+                        tmp1 += basis0[nq0 + i] * basis1[j];
+                        tmp1 += basis0[nq0 + i] * basis1[nq1 + j];
+                        tmp1 *= basis2[nq2 + k];
+                        tmp += tmp1 * in[inoffset + 1];
+                    }
+
+                    out[outoffset + cnt_kji] = tmp;
+                }
             }
         }
-        outptr[cnt_kji] = tmp;
+
+        e += blockDim.x * gridDim.x;
     }
+
+    delete[] fpq;
+    delete[] fp;
 }
 
 template <typename TData> // not working for nm2 > nm1
-__global__ void BwdTransPyrKernel_QP(const size_t nm0, const size_t nm1,
-                                     const size_t nm2, const size_t nmTot,
-                                     const size_t nq0, const size_t nq1,
-                                     const size_t nq2, const bool correct,
-                                     const TData *basis0, const TData *basis1,
-                                     const TData *basis2, const TData *in,
-                                     TData *fpq, TData *fp, TData *out)
+__global__ void BwdTransPyrKernel_QP(
+    const unsigned int nm0, const unsigned int nm1, const unsigned int nm2,
+    const unsigned int nmTot, const unsigned int nq0, const unsigned int nq1,
+    const unsigned int nq2, const unsigned int nelmt, const bool correct,
+    const TData *__restrict basis0, const TData *__restrict basis1,
+    const TData *__restrict basis2, const TData *__restrict in,
+    TData *__restrict out)
 {
-    size_t e = blockIdx.x;
+    extern __shared__ TData shared[];
+    TData *s_wsp0 = shared;
+    TData *s_wsp1 = s_wsp0 + nmTot;
+    TData *s_wsp2 = s_wsp1 + (nm0 * nm1 * nq2);
 
-    const TData *inptr = in + (nmTot * e);
-    TData *outptr      = out + (nq0 * nq1 * nq2 * e);
+    unsigned int e = blockIdx.x;
 
-    size_t i, j, k, p, q;
-    p = threadIdx.x;
-    q = threadIdx.y;
-    k = threadIdx.z;
-    if (p < nm0 && q < nm1 && k < nq2)
+    while (e < nelmt)
     {
-        size_t mode_pq  = nm0 * nm1 * k + nm1 * p + q;
-        size_t mode_tmp = 0;
-        for (size_t n = 0; n < p; ++n)
+        unsigned int inoffset  = nmTot * e;
+        unsigned int outoffset = nq0 * nq1 * nq2 * e;
+
+        // Copy to shared memory.
+        for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
         {
-            mode_tmp += n * (nm2 - n);
-            mode_tmp += ((2 * nm2 - nm1 - n + 1) * (nm1 - n)) / 2;
-            if (nm2 > nm1 && nm2 - nm1 > n)
+            for (unsigned int q = threadIdx.y; q < nm1; q += blockDim.y)
             {
-                mode_tmp += (((nm2 - nm1) + (n + 1)) * (nm2 - nm1 - n)) / 2;
+                unsigned int mode_pqr = nm1 * (2 * nm2 + 1 - nm1) * p;
+                mode_pqr -= (p - 1) * p / 2;
+                mode_pqr -= (p - 1) * p * (2 * p - 1) / 6;
+                mode_pqr /= 2;
+
+                if (q < p)
+                {
+                    mode_pqr += q * (nm2 - p);
+                    for (unsigned int r = threadIdx.z; r < nm2 - p;
+                         r += blockDim.z)
+                    {
+                        s_wsp0[mode_pqr + r] = in[inoffset + mode_pqr + r];
+                    }
+                }
+                else
+                {
+                    mode_pqr += p * (nm2 - p);
+                    mode_pqr += ((2 * (nm2 - p) - (q - p) + 1) * (q - p)) / 2;
+                    for (unsigned int r = threadIdx.z; r < nm2 - q;
+                         r += blockDim.z)
+                    {
+                        s_wsp0[mode_pqr + r] = in[inoffset + mode_pqr + r];
+                    }
+                }
             }
         }
 
-        if (q < p)
+        __syncthreads();
+
+        for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
         {
-            size_t mode_pqr = mode_tmp + q * (nm2 - p);
-            TData tmp       = 0.0;
-            for (size_t r = 0; r < nm2 - p; ++r)
+            for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
             {
-                tmp += basis2[mode_pqr * nq2 + k] * inptr[mode_pqr++];
+                for (unsigned int q = threadIdx.y; q < nm1; q += blockDim.y)
+                {
+                    unsigned int mode_kpq = nm0 * nm1 * k + nm1 * p + q;
+                    unsigned int mode_pqr = nm1 * (2 * nm2 + 1 - nm1) * p;
+                    mode_pqr -= (p - 1) * p / 2;
+                    mode_pqr -= (p - 1) * p * (2 * p - 1) / 6;
+                    mode_pqr /= 2;
+
+                    if (q < p)
+                    {
+                        mode_pqr += q * (nm2 - p);
+                        TData tmp = 0.0;
+                        for (unsigned int r = 0; r < nm2 - p; ++r, ++mode_pqr)
+                        {
+                            tmp +=
+                                basis2[mode_pqr * nq2 + k] * s_wsp0[mode_pqr];
+                        }
+                        s_wsp1[mode_kpq] = tmp;
+                    }
+                    else
+                    {
+                        mode_pqr += p * (nm2 - p);
+                        mode_pqr +=
+                            ((2 * (nm2 - p) - (q - p) + 1) * (q - p)) / 2;
+
+                        TData tmp = 0.0;
+                        for (unsigned int r = 0; r < nm2 - q; ++r, ++mode_pqr)
+                        {
+                            tmp +=
+                                basis2[mode_pqr * nq2 + k] * s_wsp0[mode_pqr];
+                        }
+                        s_wsp1[mode_kpq] = tmp;
+                    }
+                }
             }
-            fpq[mode_pq] = tmp;
         }
-        else if (q < nm1)
+
+        __syncthreads();
+
+        for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
         {
-            size_t mode_pqr = mode_tmp + p * (nm2 - p);
-            mode_pqr += ((2 * (nm2 - p) - (q - p) + 1) * (q - p)) / 2;
-            TData tmp = 0.0;
-            for (size_t r = 0; r < nm2 - q; ++r)
+            for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
             {
-                tmp += basis2[mode_pqr * nq2 + k] * inptr[mode_pqr++];
+                for (unsigned int p = threadIdx.x; p < nm0; p += blockDim.x)
+                {
+                    unsigned int mode_kjp = nm0 * nq1 * k + nm0 * j + p;
+                    unsigned int mode_kpq = nm0 * nm1 * k + nm1 * p;
+
+                    TData tmp = 0.0;
+                    for (unsigned int q = 0; q < nm1; ++q, ++mode_kpq)
+                    {
+                        tmp += s_wsp1[mode_kpq] * basis1[q * nq1 + j];
+                    }
+                    s_wsp2[mode_kjp] = tmp;
+                }
             }
-            fpq[mode_pq] = tmp;
         }
-    }
 
-    __syncthreads();
+        __syncthreads();
 
-    p = threadIdx.x;
-    j = threadIdx.y;
-    k = threadIdx.z;
-    if (p < nm0 && j < nq1 && k < nq2)
-    {
-        size_t mode_p  = nm0 * nq1 * k + nm0 * j + p;
-        size_t mode_pq = nm0 * nm1 * k + nm1 * p;
-        TData tmp      = fpq[mode_pq++] * basis1[j];
-        for (size_t q = 1; q < nm1; ++q)
+        for (unsigned int k = threadIdx.z; k < nq2; k += blockDim.z)
         {
-            tmp += fpq[mode_pq++] * basis1[q * nq1 + j];
-        }
-        fp[mode_p] = tmp;
-    }
+            for (unsigned int j = threadIdx.y; j < nq1; j += blockDim.y)
+            {
+                for (unsigned int i = threadIdx.x; i < nq0; i += blockDim.x)
+                {
+                    unsigned int cnt_kji  = nq0 * nq1 * k + nq0 * j + i;
+                    unsigned int mode_kjp = nm0 * nq1 * k + nm0 * j;
 
-    __syncthreads();
+                    TData tmp = 0.0;
+                    for (unsigned int p = 0; p < nm0; ++p, ++mode_kjp)
+                    {
+                        tmp += s_wsp2[mode_kjp] * basis0[p * nq0 + i];
+                    }
 
-    i = threadIdx.x;
-    j = threadIdx.y;
-    k = threadIdx.z;
-    if (i < nq0 && j < nq1 && k < nq2)
-    {
-        size_t cnt_kji = nq0 * nq1 * k + nq0 * j + i;
-        size_t mode_p  = nm0 * nq1 * k + nm0 * j;
-        TData tmp      = fp[mode_p++] * basis0[i];
-        for (size_t p = 1; p < nm0; ++p)
-        {
-            tmp += fp[mode_p++] * basis0[p * nq0 + i];
-        }
+                    if (correct)
+                    {
+                        // top vertex
+                        TData tmp1 = basis0[i] * basis1[nq1 + j];
+                        tmp1 += basis0[nq0 + i] * basis1[j];
+                        tmp1 += basis0[nq0 + i] * basis1[nq1 + j];
+                        tmp1 *= basis2[nq2 + k];
+                        tmp += tmp1 * s_wsp0[1];
+                    }
 
-        if (correct)
-        {
-            // top vertex
-            TData tmp1 = basis0[i] * basis1[nq1 + j];
-            tmp1 += basis0[nq0 + i] * basis1[j];
-            tmp1 += basis0[nq0 + i] * basis1[nq1 + j];
-            tmp1 *= basis2[nq2 + k];
-            tmp += tmp1 * inptr[1];
+                    out[outoffset + cnt_kji] = tmp;
+                }
+            }
         }
-        outptr[cnt_kji] = tmp;
+
+        __syncthreads();
+
+        e += gridDim.x;
     }
 }
 
diff --git a/Operators/BwdTrans/BwdTransMatFree.hpp b/Operators/BwdTrans/BwdTransMatFree.hpp
index 418617c8202dfea4f45c7f6dd7de9d2d7a19590b..0f228182eb43ee31cfc0723189cfb743fad88900 100644
--- a/Operators/BwdTrans/BwdTransMatFree.hpp
+++ b/Operators/BwdTrans/BwdTransMatFree.hpp
@@ -1,11 +1,12 @@
-#include <LibUtilities/Foundations/Basis.h>
+#pragma once
 
 #include <LibUtilities/BasicUtils/ShapeType.hpp>
 #include <LibUtilities/BasicUtils/SharedArray.hpp>
+#include <LibUtilities/Foundations/Basis.h>
+#include <LibUtilities/SimdLib/tinysimd.hpp>
 
 #include "BwdTransMatFreeKernels.hpp"
 #include "Operators/OperatorBwdTrans.hpp"
-#include <LibUtilities/SimdLib/tinysimd.hpp>
 
 namespace Nektar::Operators::detail
 {
diff --git a/Operators/BwdTrans/BwdTransMatFreeKernels.hpp b/Operators/BwdTrans/BwdTransMatFreeKernels.hpp
index 7b5cde9cafa75e08c038589d2f88e943ac6002e3..8c0d0cba2f657675f2b2e1ca1aeb45980264d5cd 100644
--- a/Operators/BwdTrans/BwdTransMatFreeKernels.hpp
+++ b/Operators/BwdTrans/BwdTransMatFreeKernels.hpp
@@ -1,4 +1,5 @@
 #pragma once
+
 #include <LibUtilities/BasicUtils/NekInline.hpp>
 
 namespace Nektar::Operators::detail
diff --git a/Operators/BwdTrans/BwdTransStdMat.hpp b/Operators/BwdTrans/BwdTransStdMat.hpp
index 8b9a0dd844ca18dcc642e1eb02efaeecca3589d0..4c8d1604ded0e9d6492fa40095163589a41ac972 100644
--- a/Operators/BwdTrans/BwdTransStdMat.hpp
+++ b/Operators/BwdTrans/BwdTransStdMat.hpp
@@ -1,6 +1,9 @@
-#include "Operators/OperatorBwdTrans.hpp"
+#pragma once
+
 #include <StdRegions/StdExpansion.h>
 
+#include "Operators/OperatorBwdTrans.hpp"
+
 namespace Nektar::Operators::detail
 {
 
diff --git a/Operators/BwdTrans/BwdTransSumFac.hpp b/Operators/BwdTrans/BwdTransSumFac.hpp
index e39568d6c346ab8a28f6fff954e5f260e5054d99..50c2c370d3344e22ebd70a3d14abd37c2615a70d 100644
--- a/Operators/BwdTrans/BwdTransSumFac.hpp
+++ b/Operators/BwdTrans/BwdTransSumFac.hpp
@@ -1,3 +1,5 @@
+#pragma once
+
 #include "Operators/OperatorBwdTrans.hpp"
 
 namespace Nektar::Operators::detail