Skip to content
Snippets Groups Projects

Implement CUDA IProductWRTDerivBase sum-factorization kernels

1 unresolved thread
Files
14
#pragma once
#include "MemoryRegionCUDA.hpp"
#include "Operators/IProductWRTBase/IProductWRTBaseCUDAKernels.cuh"
#include "Operators/OperatorHelper.cuh"
@@ -6,14 +8,16 @@
namespace Nektar::Operators::detail
{
template <typename TData, bool APPEND = false, bool DEFORMED = false>
template <typename TData, bool SCALE = false, bool APPEND = false,
bool DEFORMED = false>
void IProductWRTBase1DKernel(const size_t gridSize, const size_t blockSize,
const size_t nm0, const size_t nq0,
const size_t nElmts, const TData *basis0,
const TData *w0, const TData *jac, const TData *in,
TData *out);
TData *out, TData scale = 1.0);
template <typename TData, bool APPEND = false, bool DEFORMED = false>
template <typename TData, bool SCALE = false, bool APPEND = false,
bool DEFORMED = false>
void IProductWRTBase2DKernel(const size_t gridSize, const size_t blockSize,
LibUtilities::ShapeType shapetype,
const size_t nm0, const size_t nm1,
@@ -21,16 +25,20 @@ void IProductWRTBase2DKernel(const size_t gridSize, const size_t blockSize,
const size_t nElmts, const bool correct,
const TData *basis0, const TData *basis1,
const TData *w0, const TData *w1, const TData *jac,
const TData *in, TData *out);
const TData *in, TData *out, TData scale = 1.0);
template <typename TData, bool APPEND = false, bool DEFORMED = false>
void IProductWRTBase3DKernel(
const size_t gridSize, const size_t blockSize,
LibUtilities::ShapeType shapetype, const size_t nm0, const size_t nm1,
const size_t nm2, const size_t nq0, const size_t nq1, const size_t nq2,
const size_t nElmts, const bool correct, const TData *basis0,
const TData *basis1, const TData *basis2, const TData *w0, const TData *w1,
const TData *w2, const TData *jac, const TData *in, TData *out);
template <typename TData, bool SCALE = false, bool APPEND = false,
bool DEFORMED = false>
void IProductWRTBase3DKernel(const size_t gridSize, const size_t blockSize,
LibUtilities::ShapeType shapetype,
const size_t nm0, const size_t nm1,
const size_t nm2, const size_t nq0,
const size_t nq1, const size_t nq2,
const size_t nElmts, const bool correct,
const TData *basis0, const TData *basis1,
const TData *basis2, const TData *w0,
const TData *w1, const TData *w2, const TData *jac,
const TData *in, TData *out, TData scale = 1.0);
// IProductWRTBase implementation
template <typename TData>
@@ -49,7 +57,7 @@ public:
cudaMemcpy(m_jac, jac.get(), sizeof(TData) * jacSize,
cudaMemcpyHostToDevice);
// Initialize basiskey.
// Initialize basis.
m_basis = GetBasisDataCUDA<TData>(expansionList);
// Initialize weight.
@@ -58,25 +66,14 @@ public:
~OperatorIProductWRTBaseImpl(void)
{
for (auto &basis : m_basis)
{
for (size_t i = 0; i < basis.second.size(); i++)
{
cudaFree(basis.second[i]);
}
}
for (auto &weight : m_weight)
{
for (size_t i = 0; i < weight.second.size(); i++)
{
cudaFree(weight.second[i]);
}
}
DeallocateDataCUDA<TData>(m_basis);
DeallocateDataCUDA<TData>(m_weight);
cudaFree(m_jac);
}
void apply(Field<TData, FieldState::Phys> &in,
Field<TData, FieldState::Coeff> &out) override
Field<TData, FieldState::Coeff> &out,
const TData lambda = 1.0) override
{
// Copy memory to GPU, if necessary and get raw pointers.
auto const *inptr =
@@ -126,15 +123,33 @@ public:
auto nq0 = expPtr->GetNumPoints(0);
if (deformed)
{
IProductWRTBase1DKernel<TData, false, true>(
m_gridSize, m_blockSize, nm0, nq0, nElmts, basis0, w0,
jacptr, inptr, outptr);
if (lambda == 1.0)
{
IProductWRTBase1DKernel<TData, false, false, true>(
m_gridSize, m_blockSize, nm0, nq0, nElmts, basis0,
w0, jacptr, inptr, outptr);
}
else
{
IProductWRTBase1DKernel<TData, true, false, true>(
m_gridSize, m_blockSize, nm0, nq0, nElmts, basis0,
w0, jacptr, inptr, outptr, lambda);
}
}
else
{
IProductWRTBase1DKernel<TData, false, false>(
m_gridSize, m_blockSize, nm0, nq0, nElmts, basis0, w0,
jacptr, inptr, outptr);
if (lambda == 1.0)
{
IProductWRTBase1DKernel<TData, false, false, false>(
m_gridSize, m_blockSize, nm0, nq0, nElmts, basis0,
w0, jacptr, inptr, outptr);
}
else
{
IProductWRTBase1DKernel<TData, true, false, false>(
m_gridSize, m_blockSize, nm0, nq0, nElmts, basis0,
w0, jacptr, inptr, outptr, lambda);
}
}
}
else if (expPtr->GetShapeDimension() == 2)
@@ -150,17 +165,37 @@ public:
auto nq1 = expPtr->GetNumPoints(1);
if (deformed)
{
IProductWRTBase2DKernel<TData, false, true>(
m_gridSize, m_blockSize, shape, nm0, nm1, nq0, nq1,
nElmts, correct, basis0, basis1, w0, w1, jacptr, inptr,
outptr);
if (lambda == 1.0)
{
IProductWRTBase2DKernel<TData, false, false, true>(
m_gridSize, m_blockSize, shape, nm0, nm1, nq0, nq1,
nElmts, correct, basis0, basis1, w0, w1, jacptr,
inptr, outptr);
}
else
{
IProductWRTBase2DKernel<TData, true, false, true>(
m_gridSize, m_blockSize, shape, nm0, nm1, nq0, nq1,
nElmts, correct, basis0, basis1, w0, w1, jacptr,
inptr, outptr, lambda);
}
}
else
{
IProductWRTBase2DKernel<TData, false, false>(
m_gridSize, m_blockSize, shape, nm0, nm1, nq0, nq1,
nElmts, correct, basis0, basis1, w0, w1, jacptr, inptr,
outptr);
if (lambda == 1.0)
{
IProductWRTBase2DKernel<TData, false, false, false>(
m_gridSize, m_blockSize, shape, nm0, nm1, nq0, nq1,
nElmts, correct, basis0, basis1, w0, w1, jacptr,
inptr, outptr);
}
else
{
IProductWRTBase2DKernel<TData, true, false, false>(
m_gridSize, m_blockSize, shape, nm0, nm1, nq0, nq1,
nElmts, correct, basis0, basis1, w0, w1, jacptr,
inptr, outptr, lambda);
}
}
}
else if (expPtr->GetShapeDimension() == 3)
@@ -180,17 +215,37 @@ public:
auto nq2 = expPtr->GetNumPoints(2);
if (deformed)
{
IProductWRTBase3DKernel<TData, false, true>(
m_gridSize, m_blockSize, shape, nm0, nm1, nm2, nq0, nq1,
nq2, nElmts, correct, basis0, basis1, basis2, w0, w1,
w2, jacptr, inptr, outptr);
if (lambda == 1.0)
{
IProductWRTBase3DKernel<TData, false, false, true>(
m_gridSize, m_blockSize, shape, nm0, nm1, nm2, nq0,
nq1, nq2, nElmts, correct, basis0, basis1, basis2,
w0, w1, w2, jacptr, inptr, outptr);
}
else
{
IProductWRTBase3DKernel<TData, true, false, true>(
m_gridSize, m_blockSize, shape, nm0, nm1, nm2, nq0,
nq1, nq2, nElmts, correct, basis0, basis1, basis2,
w0, w1, w2, jacptr, inptr, outptr, lambda);
}
}
else
{
IProductWRTBase3DKernel<TData, false, false>(
m_gridSize, m_blockSize, shape, nm0, nm1, nm2, nq0, nq1,
nq2, nElmts, correct, basis0, basis1, basis2, w0, w1,
w2, jacptr, inptr, outptr);
if (lambda == 1.0)
{
IProductWRTBase3DKernel<TData, false, false, false>(
m_gridSize, m_blockSize, shape, nm0, nm1, nm2, nq0,
nq1, nq2, nElmts, correct, basis0, basis1, basis2,
w0, w1, w2, jacptr, inptr, outptr);
}
else
{
IProductWRTBase3DKernel<TData, true, false, false>(
m_gridSize, m_blockSize, shape, nm0, nm1, nm2, nq0,
nq1, nq2, nElmts, correct, basis0, basis1, basis2,
w0, w1, w2, jacptr, inptr, outptr, lambda);
}
}
}
@@ -220,18 +275,19 @@ private:
size_t m_gridSize;
};
template <typename TData, bool APPEND, bool DEFORMED>
template <typename TData, bool SCALE, bool APPEND, bool DEFORMED>
void IProductWRTBase1DKernel(const size_t gridSize, const size_t blockSize,
const size_t nm0, const size_t nq0,
const size_t nElmts, const TData *basis0,
const TData *w0, const TData *jac, const TData *in,
TData *out)
TData *out, TData scale)
{
IProductWRTBaseSegKernel<TData, false, APPEND, DEFORMED>
<<<gridSize, blockSize>>>(nm0, nq0, nElmts, basis0, w0, jac, in, out);
IProductWRTBaseSegKernel<TData, SCALE, APPEND, DEFORMED>
<<<gridSize, blockSize>>>(nm0, nq0, nElmts, basis0, w0, jac, in, out,
scale);
}
template <typename TData, bool APPEND, bool DEFORMED>
template <typename TData, bool SCALE, bool APPEND, bool DEFORMED>
void IProductWRTBase2DKernel(const size_t gridSize, const size_t blockSize,
LibUtilities::ShapeType shapetype,
const size_t nm0, const size_t nm1,
@@ -239,67 +295,67 @@ void IProductWRTBase2DKernel(const size_t gridSize, const size_t blockSize,
const size_t nElmts, const bool correct,
const TData *basis0, const TData *basis1,
const TData *w0, const TData *w1, const TData *jac,
const TData *in, TData *out)
const TData *in, TData *out, TData scale)
{
if (shapetype == LibUtilities::Quad)
{
IProductWRTBaseQuadKernel<TData, false, APPEND, DEFORMED>
IProductWRTBaseQuadKernel<TData, SCALE, APPEND, DEFORMED>
<<<gridSize, blockSize>>>(nm0, nm1, nq0, nq1, nElmts, basis0,
basis1, w0, w1, jac, in, out);
basis1, w0, w1, jac, in, out, scale);
}
else if (shapetype == LibUtilities::Tri)
{
size_t nmTot =
LibUtilities::StdTriData::getNumberOfCoefficients(nm0, nm1);
IProductWRTBaseTriKernel<TData, false, APPEND, DEFORMED>
IProductWRTBaseTriKernel<TData, SCALE, APPEND, DEFORMED>
<<<gridSize, blockSize>>>(nm0, nm1, nmTot, nq0, nq1, nElmts,
correct, basis0, basis1, w0, w1, jac, in,
out);
out, scale);
}
}
template <typename TData, bool APPEND, bool DEFORMED>
template <typename TData, bool SCALE, bool APPEND, bool DEFORMED>
void IProductWRTBase3DKernel(
const size_t gridSize, const size_t blockSize,
LibUtilities::ShapeType shapetype, const size_t nm0, const size_t nm1,
const size_t nm2, const size_t nq0, const size_t nq1, const size_t nq2,
const size_t nElmts, const bool correct, const TData *basis0,
const TData *basis1, const TData *basis2, const TData *w0, const TData *w1,
const TData *w2, const TData *jac, const TData *in, TData *out)
const TData *w2, const TData *jac, const TData *in, TData *out, TData scale)
{
if (shapetype == LibUtilities::Hex)
{
IProductWRTBaseHexKernel<TData, false, APPEND, DEFORMED>
IProductWRTBaseHexKernel<TData, SCALE, APPEND, DEFORMED>
<<<gridSize, blockSize>>>(nm0, nm1, nm2, nq0, nq1, nq2, nElmts,
basis0, basis1, basis2, w0, w1, w2, jac,
in, out);
in, out, scale);
}
else if (shapetype == LibUtilities::Tet)
{
size_t nmTot =
LibUtilities::StdTetData::getNumberOfCoefficients(nm0, nm1, nm2);
IProductWRTBaseTetKernel<TData, false, APPEND, DEFORMED>
IProductWRTBaseTetKernel<TData, SCALE, APPEND, DEFORMED>
<<<gridSize, blockSize>>>(nm0, nm1, nm2, nmTot, nq0, nq1, nq2,
nElmts, correct, basis0, basis1, basis2,
w0, w1, w2, jac, in, out);
w0, w1, w2, jac, in, out, scale);
}
else if (shapetype == LibUtilities::Pyr)
{
size_t nmTot =
LibUtilities::StdPyrData::getNumberOfCoefficients(nm0, nm1, nm2);
IProductWRTBasePyrKernel<TData, false, APPEND, DEFORMED>
IProductWRTBasePyrKernel<TData, SCALE, APPEND, DEFORMED>
<<<gridSize, blockSize>>>(nm0, nm1, nm2, nmTot, nq0, nq1, nq2,
nElmts, correct, basis0, basis1, basis2,
w0, w1, w2, jac, in, out);
w0, w1, w2, jac, in, out, scale);
}
else if (shapetype == LibUtilities::Prism)
{
size_t nmTot =
LibUtilities::StdPrismData::getNumberOfCoefficients(nm0, nm1, nm2);
IProductWRTBasePrismKernel<TData, false, APPEND, DEFORMED>
IProductWRTBasePrismKernel<TData, SCALE, APPEND, DEFORMED>
<<<gridSize, blockSize>>>(nm0, nm1, nm2, nmTot, nq0, nq1, nq2,
nElmts, correct, basis0, basis1, basis2,
w0, w1, w2, jac, in, out);
w0, w1, w2, jac, in, out, scale);
}
}
Loading