Commit 4c83a458 authored by Dave Moxey's avatar Dave Moxey

Merge branch 'feature/SumFacPyr' into 'master'

Feature/sum fac pyr

See merge request !750
parents 236cb5f9 9d6a3423
......@@ -7,6 +7,10 @@ v4.5.0
- Add periodic boundary condition meshing in 2D (!733)
- Adjust boundary layer thickness in corners in 2D (!739)
**Library**
- Added in sum factorisation version for pyramid expansions and orthogonal
expansion in pyramids (!750)
**Documentation**:
- Added the developer-guide repository as a submodule (!751)
......
......@@ -757,6 +757,7 @@ INPUT = @CMAKE_SOURCE_DIR@/docs/doxygen/ \
@CMAKE_SOURCE_DIR@/library/LibUtilities/ \
@CMAKE_SOURCE_DIR@/library/StdRegions/ \
@CMAKE_SOURCE_DIR@/library/SpatialDomains/ \
@CMAKE_SOURCE_DIR@/library/Collections/ \
@CMAKE_SOURCE_DIR@/library/LocalRegions/ \
@CMAKE_SOURCE_DIR@/library/MultiRegions/ \
@CMAKE_SOURCE_DIR@/library/GlobalMapping/ \
......
......@@ -918,7 +918,7 @@ class BwdTrans_SumFac_Prism : public Operator
{
Blas::Dgemm('N', 'N', m_nquad2, m_numElmt, m_nmodes2-i,
1.0, m_base2.get()+mode*m_nquad2, m_nquad2,
&input[0]+mode1, totmodes, 0.0,
input.get()+mode1, totmodes, 0.0,
&wsp[j*m_nquad2*m_numElmt*m_nmodes0+ cnt],
m_nquad2);
mode1 += m_nmodes2-i;
......@@ -1018,5 +1018,168 @@ OperatorKey BwdTrans_SumFac_Prism::m_type = GetOperatorFactory().
OperatorKey(ePrism, eBwdTrans, eSumFac,false),
BwdTrans_SumFac_Prism::create, "BwdTrans_SumFac_Prism");
/**
* @brief Backward transform operator using sum-factorisation (Pyr)
*/
class BwdTrans_SumFac_Pyr : public Operator
{
public:
OPERATOR_CREATE(BwdTrans_SumFac_Pyr)
virtual ~BwdTrans_SumFac_Pyr()
{
}
virtual void operator()(
const Array<OneD, const NekDouble> &input,
Array<OneD, NekDouble> &output,
Array<OneD, NekDouble> &output1,
Array<OneD, NekDouble> &output2,
Array<OneD, NekDouble> &wsp)
{
ASSERTL1(wsp.num_elements() == m_wspSize,
"Incorrect workspace size");
// Assign second half of workspace for 2nd DGEMM operation.
int totmodes = m_stdExp->GetNcoeffs();
Array<OneD, NekDouble> wsp2
= wsp + m_nmodes0*m_nmodes1*m_nquad2*m_numElmt;
Vmath::Zero(m_nmodes0*m_nmodes1*m_nquad2*m_numElmt, wsp, 1);
int i = 0;
int j = 0;
int mode = 0;
int mode1 = 0;
int cnt = 0;
for (i = 0; i < m_nmodes0; ++i)
{
for (j = 0; j < m_nmodes1; ++j, ++cnt)
{
int ijmax = max(i,j);
Blas::Dgemm('N', 'N', m_nquad2, m_numElmt, m_nmodes2-ijmax,
1.0, m_base2.get()+mode*m_nquad2, m_nquad2,
input.get()+mode1, totmodes, 0.0,
wsp.get() + cnt*m_nquad2*m_numElmt, m_nquad2);
mode += m_nmodes2-ijmax;
mode1 += m_nmodes2-ijmax;
}
//increment mode in case order1!=order2
for(j = m_nmodes1; j < m_nmodes2-i; ++j)
{
int ijmax = max(i,j);
mode += m_nmodes2-ijmax;
}
}
// vertex mode - currently (1+c)/2 x (1-b)/2 x (1-a)/2
// component is evaluated
if(m_sortTopVertex)
{
for(i = 0; i < m_numElmt; ++i)
{
// top singular vertex
// (1+c)/2 x (1+b)/2 x (1-a)/2 component
Blas::Daxpy(m_nquad2, input[1+i*totmodes],
m_base2.get() + m_nquad2, 1,
&wsp[m_nquad2*m_numElmt] + i*m_nquad2, 1);
// top singular vertex
// (1+c)/2 x (1-b)/2 x (1+a)/2 component
Blas::Daxpy(m_nquad2, input[1+i*totmodes],
m_base2.get() + m_nquad2, 1,
&wsp[m_nmodes1*m_nquad2*m_numElmt]
+ i*m_nquad2, 1);
// top singular vertex
// (1+c)/2 x (1+b)/2 x (1+a)/2 component
Blas::Daxpy(m_nquad2, input[1+i*totmodes],
m_base2.get() + m_nquad2, 1,
&wsp[(m_nmodes1+1)*m_nquad2*m_numElmt]
+ i*m_nquad2, 1);
}
}
// Perform summation over '1' direction
mode = 0;
for(i = 0; i < m_nmodes0; ++i)
{
Blas::Dgemm('N', 'T', m_nquad1, m_nquad2*m_numElmt, m_nmodes1,
1.0, m_base1.get(), m_nquad1,
wsp.get() + mode*m_nquad2*m_numElmt,
m_nquad2*m_numElmt,
0.0, wsp2.get() + i*m_nquad1*m_nquad2*m_numElmt,
m_nquad1);
mode += m_nmodes1;
}
// Perform summation over '0' direction
Blas::Dgemm('N', 'T', m_nquad0, m_nquad1*m_nquad2*m_numElmt,
m_nmodes0, 1.0, m_base0.get(), m_nquad0,
wsp2.get(), m_nquad1*m_nquad2*m_numElmt,
0.0, output.get(), m_nquad0);
}
virtual void operator()(
int dir,
const Array<OneD, const NekDouble> &input,
Array<OneD, NekDouble> &output,
Array<OneD, NekDouble> &wsp)
{
ASSERTL0(false, "Not valid for this operator.");
}
protected:
const int m_nquad0;
const int m_nquad1;
const int m_nquad2;
const int m_nmodes0;
const int m_nmodes1;
const int m_nmodes2;
Array<OneD, const NekDouble> m_base0;
Array<OneD, const NekDouble> m_base1;
Array<OneD, const NekDouble> m_base2;
bool m_sortTopVertex;
private:
BwdTrans_SumFac_Pyr(
vector<StdRegions::StdExpansionSharedPtr> pCollExp,
CoalescedGeomDataSharedPtr pGeomData)
: Operator (pCollExp, pGeomData),
m_nquad0 (m_stdExp->GetNumPoints(0)),
m_nquad1 (m_stdExp->GetNumPoints(1)),
m_nquad2 (m_stdExp->GetNumPoints(2)),
m_nmodes0 (m_stdExp->GetBasisNumModes(0)),
m_nmodes1 (m_stdExp->GetBasisNumModes(1)),
m_nmodes2 (m_stdExp->GetBasisNumModes(2)),
m_base0 (m_stdExp->GetBasis(0)->GetBdata()),
m_base1 (m_stdExp->GetBasis(1)->GetBdata()),
m_base2 (m_stdExp->GetBasis(2)->GetBdata())
{
m_wspSize = m_numElmt*m_nmodes0*m_nquad2*(m_nmodes1 + m_nquad1);
if(m_stdExp->GetBasis(0)->GetBasisType()
== LibUtilities::eModified_A)
{
m_sortTopVertex = true;
}
else
{
m_sortTopVertex = false;
}
}
};
/// Factory initialisation for the BwdTrans_SumFac_Pyr operator
OperatorKey BwdTrans_SumFac_Pyr::m_type = GetOperatorFactory().
RegisterCreatorFunction(
OperatorKey(ePyramid, eBwdTrans, eSumFac,false),
BwdTrans_SumFac_Pyr::create, "BwdTrans_SumFac_Pyr");
}
}
......@@ -422,6 +422,101 @@ void PrismIProduct(bool sortTopVertex, int numElmt,
}
/**
*
*/
void PyrIProduct(bool sortTopVertex, int numElmt,
int nquad0, int nquad1, int nquad2,
int nmodes0, int nmodes1, int nmodes2,
const Array<OneD, const NekDouble> &base0,
const Array<OneD, const NekDouble> &base1,
const Array<OneD, const NekDouble> &base2,
const Array<OneD, const NekDouble> &jac,
const Array<OneD, const NekDouble> &input,
Array<OneD, NekDouble> &output,
Array<OneD, NekDouble> &wsp)
{
int totmodes = LibUtilities::StdPyrData::getNumberOfCoefficients(
nmodes0,nmodes1,nmodes2);
int totpoints = nquad0*nquad1*nquad2;
int cnt;
int mode, mode1;
ASSERTL1(wsp.num_elements() >= numElmt*(nquad1*nquad2*nmodes0 +
nquad2*max(nquad0*nquad1,nmodes0*nmodes1)),
"Insufficient workspace size");
Vmath::Vmul(numElmt*totpoints,jac,1,input,1,wsp,1);
Array<OneD, NekDouble> wsp1 = wsp + numElmt * nquad2
* (max(nquad0*nquad1,
nmodes0*nmodes1));
// Perform iproduct with respect to the '0' direction
Blas::Dgemm('T', 'N', nquad1*nquad2*numElmt, nmodes0, nquad0,
1.0, wsp.get(), nquad0, base0.get(),
nquad0, 0.0, wsp1.get(), nquad1*nquad2*numElmt);
// Inner product with respect to the '1' direction
mode = 0;
for(int i=0; i < nmodes0; ++i)
{
Blas::Dgemm('T', 'N', nquad2*numElmt, nmodes1, nquad1,
1.0, wsp1.get()+ i*nquad1*nquad2*numElmt, nquad1,
base1.get(), nquad1,
0.0, wsp.get() + mode*nquad2*numElmt,nquad2*numElmt);
mode += nmodes1;
}
// Inner product with respect to the '2' direction
mode = mode1 = cnt = 0;
for(int i = 0; i < nmodes0; ++i)
{
for(int j = 0; j < nmodes1; ++j, ++cnt)
{
int ijmax = max(i,j);
Blas::Dgemm('T', 'N', nmodes2-ijmax, numElmt, nquad2,
1.0, base2.get()+mode*nquad2, nquad2,
wsp.get()+cnt*nquad2*numElmt, nquad2,
0.0, output.get()+mode1, totmodes);
mode += nmodes2-ijmax;
mode1 += nmodes2-ijmax;
}
//increment mode in case order1!=order2
for(int j = nmodes1; j < nmodes2; ++j)
{
int ijmax = max(i,j);
mode += nmodes2-ijmax;
}
}
// fix for modified basis for top singular vertex component
// Already have evaluated (1+c)/2 (1-b)/2 (1-a)/2
if(sortTopVertex)
{
for(int n = 0; n < numElmt; ++n)
{
// add in (1+c)/2 (1+b)/2 component
output[1+n*totmodes] += Blas::Ddot(nquad2,
base2.get()+nquad2,1,
&wsp[nquad2*numElmt + n*nquad2],1);
// add in (1+c)/2 (1-b)/2 (1+a)/2 component
output[1+n*totmodes] += Blas::Ddot(nquad2,
base2.get()+nquad2,1,
&wsp[nquad2*nmodes1*numElmt+n*nquad2],1);
// add in (1+c)/2 (1+b)/2 (1+a)/2 component
output[1+n*totmodes] += Blas::Ddot(nquad2,
base2.get()+nquad2,1,
&wsp[nquad2*(nmodes1+1)*numElmt+n*nquad2],1);
}
}
}
/**
*
*/
......
......@@ -82,6 +82,19 @@ void PrismIProduct(bool sortTopVert, int numElmt,
Array<OneD, NekDouble> &output,
Array<OneD, NekDouble> &wsp);
void PyrIProduct(bool sortTopVert, int numElmt,
int nquad0, int nquad1, int nquad2,
int nmodes0, int nmodes1, int nmodes2,
const Array<OneD, const NekDouble> &base0,
const Array<OneD, const NekDouble> &base1,
const Array<OneD, const NekDouble> &base2,
const Array<OneD, const NekDouble> &jac,
const Array<OneD, const NekDouble> &input,
Array<OneD, NekDouble> &output,
Array<OneD, NekDouble> &wsp);
void TetIProduct(bool sortTopEdge, int numElmt,
int nquad0, int nquad1, int nquad2,
int nmodes0, int nmodes1, int nmodes2,
......
......@@ -586,7 +586,7 @@ OperatorKey IProductWRTBase_SumFac_Tri::m_type = GetOperatorFactory().
/**
* @brief Backward transform operator using sum-factorisation (Hex)
* @brief Inner Product operator using sum-factorisation (Hex)
*/
class IProductWRTBase_SumFac_Hex : public Operator
{
......@@ -765,7 +765,7 @@ OperatorKey IProductWRTBase_SumFac_Tet::m_type = GetOperatorFactory().
/**
* @brief Backward transform operator using sum-factorisation (Prism)
* @brief Inner Product operator using sum-factorisation (Prism)
*/
class IProductWRTBase_SumFac_Prism : public Operator
{
......@@ -856,5 +856,99 @@ OperatorKey IProductWRTBase_SumFac_Prism::m_type = GetOperatorFactory().
OperatorKey(ePrism, eIProductWRTBase, eSumFac,false),
IProductWRTBase_SumFac_Prism::create, "IProductWRTBase_SumFac_Prism");
/**
* @brief Inner Product operator using sum-factorisation (Pyr)
*/
class IProductWRTBase_SumFac_Pyr : public Operator
{
public:
OPERATOR_CREATE(IProductWRTBase_SumFac_Pyr)
virtual ~IProductWRTBase_SumFac_Pyr()
{
}
virtual void operator()(
const Array<OneD, const NekDouble> &input,
Array<OneD, NekDouble> &output,
Array<OneD, NekDouble> &output1,
Array<OneD, NekDouble> &output2,
Array<OneD, NekDouble> &wsp)
{
ASSERTL1(wsp.num_elements() == m_wspSize,
"Incorrect workspace size");
PyrIProduct(m_sortTopVertex, m_numElmt,
m_nquad0, m_nquad1, m_nquad2,
m_nmodes0, m_nmodes1, m_nmodes2,
m_base0, m_base1, m_base2,
m_jac,input,output,wsp);
}
virtual void operator()(
int dir,
const Array<OneD, const NekDouble> &input,
Array<OneD, NekDouble> &output,
Array<OneD, NekDouble> &wsp)
{
ASSERTL0(false, "Not valid for this operator.");
}
protected:
const int m_nquad0;
const int m_nquad1;
const int m_nquad2;
const int m_nmodes0;
const int m_nmodes1;
const int m_nmodes2;
Array<OneD, const NekDouble> m_jac;
Array<OneD, const NekDouble> m_base0;
Array<OneD, const NekDouble> m_base1;
Array<OneD, const NekDouble> m_base2;
bool m_sortTopVertex;
private:
IProductWRTBase_SumFac_Pyr(
vector<StdRegions::StdExpansionSharedPtr> pCollExp,
CoalescedGeomDataSharedPtr pGeomData)
: Operator (pCollExp, pGeomData),
m_nquad0 (m_stdExp->GetNumPoints(0)),
m_nquad1 (m_stdExp->GetNumPoints(1)),
m_nquad2 (m_stdExp->GetNumPoints(2)),
m_nmodes0 (m_stdExp->GetBasisNumModes(0)),
m_nmodes1 (m_stdExp->GetBasisNumModes(1)),
m_nmodes2 (m_stdExp->GetBasisNumModes(2)),
m_base0 (m_stdExp->GetBasis(0)->GetBdata()),
m_base1 (m_stdExp->GetBasis(1)->GetBdata()),
m_base2 (m_stdExp->GetBasis(2)->GetBdata())
{
m_jac = pGeomData->GetJacWithStdWeights(pCollExp);
m_wspSize = m_numElmt * m_nquad2
*(max(m_nquad0*m_nquad1,m_nmodes0*m_nmodes1))
+ m_nquad1*m_nquad2*m_numElmt*m_nmodes0;
if(m_stdExp->GetBasis(0)->GetBasisType()
== LibUtilities::eModified_A)
{
m_sortTopVertex = true;
}
else
{
m_sortTopVertex = false;
}
}
};
/// Factory initialisation for the IProductWRTBase_SumFac_Pyr operator
OperatorKey IProductWRTBase_SumFac_Pyr::m_type = GetOperatorFactory().
RegisterCreatorFunction(
OperatorKey(ePyramid, eIProductWRTBase, eSumFac,false),
IProductWRTBase_SumFac_Pyr::create, "IProductWRTBase_SumFac_Pyr");
}
}
This diff is collapsed.
......@@ -31,15 +31,15 @@ int main(int argc, char *argv[])
"dictates the basis as:\n");
fprintf(stderr,"\t Ortho_A = 1\n");
fprintf(stderr,"\t Modified_A = 4\n");
fprintf(stderr,"\t Fourier = 7\n");
fprintf(stderr,"\t Lagrange = 8\n");
fprintf(stderr,"\t Gauss Lagrange = 9\n");
fprintf(stderr,"\t Legendre = 10\n");
fprintf(stderr,"\t Chebyshev = 11\n");
fprintf(stderr,"\t Monomial = 12\n");
fprintf(stderr,"\t FourierSingleMode = 13\n");
fprintf(stderr,"\t Fourier = 9\n");
fprintf(stderr,"\t Lagrange = 10\n");
fprintf(stderr,"\t Gauss Lagrange = 11\n");
fprintf(stderr,"\t Legendre = 12\n");
fprintf(stderr,"\t Chebyshev = 13\n");
fprintf(stderr,"\t Monomial = 14\n");
fprintf(stderr,"\t FourierSingleMode = 15\n");
fprintf(stderr,"Note type = 1,2,4,5 are for higher dimensional basis\n");
fprintf(stderr,"Note type = 1,2,4,5,7,8 are for higher dimensional basis\n");
exit(1);
}
......
......@@ -53,15 +53,16 @@ int main(int argc, char *argv[])
fprintf(stderr,"\t Ortho_B = 2\n");
fprintf(stderr,"\t Modified_A = 4\n");
fprintf(stderr,"\t Modified_B = 5\n");
fprintf(stderr,"\t Fourier = 7\n");
fprintf(stderr,"\t Lagrange = 8\n");
fprintf(stderr,"\t Gauss Lagrange = 9\n");
fprintf(stderr,"\t Legendre = 10\n");
fprintf(stderr,"\t Chebyshev = 11\n");
fprintf(stderr,"\t Nodal tri (Electro) = 13\n");
fprintf(stderr,"\t Nodal tri (Fekete) = 14\n");
fprintf(stderr,"\t Fourier = 9\n");
fprintf(stderr,"\t Lagrange = 10\n");
fprintf(stderr,"\t Gauss Lagrange = 11\n");
fprintf(stderr,"\t Legendre = 12\n");
fprintf(stderr,"\t Chebyshev = 13\n");
fprintf(stderr,"\t Monomial = 14\n");
fprintf(stderr,"\t Nodal tri (Electro) = 15\n");
fprintf(stderr,"\t Nodal tri (Fekete) = 16\n");
fprintf(stderr,"Note type = 3,6 are for three-dimensional basis\n");
fprintf(stderr,"Note type = 3,6,7,8 are for three-dimensional basis\n");
fprintf(stderr,"The last series of values are the coordinates\n");
exit(1);
......@@ -78,17 +79,17 @@ int main(int argc, char *argv[])
int btype1_val = atoi(argv[2]);
int btype2_val = atoi(argv[3]);
if(( btype1_val <= 11)&&( btype2_val <= 11))
if(( btype1_val <= 14)&&( btype2_val <= 14))
{
btype1 = (LibUtilities::BasisType) btype1_val;
btype2 = (LibUtilities::BasisType) btype2_val;
}
else if(( btype1_val >=13)&&(btype2_val <= 14))
else if(( btype1_val >=15)&&(btype2_val <= 16))
{
btype1 = LibUtilities::eOrtho_A;
btype2 = LibUtilities::eOrtho_B;
if(btype1_val == 13)
if(btype1_val == 15)
{
NodalType = LibUtilities::eNodalTriElec;
}
......@@ -204,7 +205,7 @@ int main(int argc, char *argv[])
const LibUtilities::BasisKey Bkey1(btype1,order1,Pkey1);
const LibUtilities::BasisKey Bkey2(btype2,order2,Pkey2);
if(btype1_val >= 11)
if(btype1_val >= 15)
{
E = new LocalRegions::NodalTriExp(Bkey1,Bkey2,NodalType,geom);
}
......
......@@ -88,10 +88,12 @@ int main(int argc, char *argv[]){
fprintf(stderr,"\t Modified_A = 4\n");
fprintf(stderr,"\t Modified_B = 5\n");
fprintf(stderr,"\t Modified_C = 6\n");
fprintf(stderr,"\t Fourier = 7\n");
fprintf(stderr,"\t Lagrange = 8\n");
fprintf(stderr,"\t Legendre = 9\n");
fprintf(stderr,"\t Chebyshev = 10\n");
fprintf(stderr,"\t OrthoPyr_C = 7\n");
fprintf(stderr,"\t ModifiedPyr_C = 8\n");
fprintf(stderr,"\t Fourier = 9\n");
fprintf(stderr,"\t Lagrange = 10\n");
fprintf(stderr,"\t Legendre = 11\n");
fprintf(stderr,"\t Chebyshev = 12\n");
exit(1);
}
......@@ -142,25 +144,28 @@ int main(int argc, char *argv[]){
break;
case LibUtilities::ePyramid:
if((btype1 == eOrtho_B) || (btype1 == eOrtho_C)
|| (btype1 == eModified_B) || (btype1 == eModified_C))
|| (btype1 == eModified_B) || (btype1 == eModified_C)
|| (btype1 == eModifiedPyr_C))
{
NEKERROR(ErrorUtil::efatal,
"Basis 1 cannot be of type Ortho_B, Ortho_C, Modified_B "
"or Modified_C");
"Basis 1 cannot be of type Ortho_B, Ortho_C, Modified_B, "
"Modified_C or ModifiedPyr_C");
}
if((btype2 == eOrtho_B) || (btype2 == eOrtho_C)
|| (btype2 == eModified_B) || (btype2 == eModified_C))
|| (btype2 == eModified_B) || (btype2 == eModified_C)
|| (btype2 == eModifiedPyr_C))
{
NEKERROR(ErrorUtil::efatal,
"Basis 2 cannot be of type Ortho_B, Ortho_C, Modified_B "
"or Modified_C");
"Basis 2 cannot be of type Ortho_B, Ortho_C, Modified_B, "
"Modified_C or ModifiedPyr_C");
}
if((btype3 == eOrtho_A) || (btype3 == eOrtho_B)
|| (btype3 == eModified_A) || (btype3 == eModified_B))
|| (btype3 == eModified_A) || (btype3 == eModified_B)
|| (btype3 == eModified_C))
{
NEKERROR(ErrorUtil::efatal,
"Basis 3 cannot be of type Ortho_A, Ortho_B, Modified_A "
"or Modified_B");
"Basis 3 cannot be of type Ortho_A, Ortho_B, Modified_A, "
"Modified_B or ModifiedPyr_C");
}
break;
case LibUtilities::ePrism:
......
......@@ -34,15 +34,15 @@ int main(int argc, char *argv[])
"dictates the basis as:\n");
fprintf(stderr,"\t Ortho_A = 1\n");
fprintf(stderr,"\t Modified_A = 4\n");
fprintf(stderr,"\t Fourier = 7\n");
fprintf(stderr,"\t Lagrange = 8\n");
fprintf(stderr,"\t Gauss Lagrange = 9\n");
fprintf(stderr,"\t Legendre = 10\n");
fprintf(stderr,"\t Chebyshev = 11\n");