Commit 1b2f2bf0 authored by Michael Bareford's avatar Michael Bareford

Added support for variable gather/scatter.

parent 336552bf
......@@ -126,7 +126,16 @@ public:
void Exscan(T &pData, const enum ReduceOperator pOp, T &ans);
template <class T> T Gather(const int rootProc, T &val);
template <class T> T Scatter(const int rootProc, T &pData);
template <class T> T Scatter(const int rootProc, T &pData,
const unsigned int recvcnt);
template <class T> T Gatherv(const int rootProc, T &val,
std::vector<int> &pRecvDataSizeMap,
std::vector<int> &pRecvDataOffsetMap);
template <class T> T Scatterv(const int rootProc, T &pData,
std::vector<int> &pSendDataSizeMap,
std::vector<int> &pSendDataOffsetMap,
const unsigned int recvcnt);
LIB_UTILITIES_EXPORT inline CommSharedPtr CommCreateIf(int flag);
......@@ -189,6 +198,13 @@ protected:
void *recvbuf, int recvcount, CommDataType recvtype,
int root) = 0;
virtual void v_Gatherv(void *sendbuf, int sendcount, CommDataType sendtype,
void *recvbuf, int recvcounts[], int rdispls[], CommDataType recvtype,
int root) = 0;
virtual void v_Scatterv(void *sendbuf, int sendcounts[], int sdispls[], CommDataType sendtype,
void *recvbuf, int recvcount, CommDataType recvtype,
int root) = 0;
virtual CommSharedPtr v_CommCreateIf(int flag) = 0;
virtual void v_SplitComm(int pRows, int pColumns) = 0;
virtual bool v_TreatAsRankZero(void) = 0;
......@@ -445,24 +461,98 @@ template <class T> T Comm::Gather(const int rootProc, T &val)
/**
* Scatter pData across ranks in chunks of len(pData)/num_ranks
*/
template <class T> T Comm::Scatter(const int rootProc, T &pData)
template <class T> T Comm::Scatter(const int rootProc, T &pData,
const unsigned int recvcnt)
{
static_assert(
CommDataTypeTraits<T>::IsVector,
"Scatter only valid with Array or vector arguments.");
bool amRoot = (GetRank() == rootProc);
unsigned nEl = CommDataTypeTraits<T>::GetCount(pData) / GetSize();
void *sendbuf = amRoot ? CommDataTypeTraits<T>::GetPointer(pData) : NULL;
T ans(nEl);
v_Scatter(sendbuf, nEl, CommDataTypeTraits<T>::GetDataType(),
CommDataTypeTraits<T>::GetPointer(ans), nEl,
T ans(recvcnt);
v_Scatter(sendbuf, recvcnt, CommDataTypeTraits<T>::GetDataType(),
CommDataTypeTraits<T>::GetPointer(ans), recvcnt,
CommDataTypeTraits<T>::GetDataType(), rootProc);
return ans;
}
template <class T> T Comm::Gatherv(const int rootProc, T &val,
std::vector<int> &recvDataSizes,
std::vector<int> &recvDataOffsets)
{
static_assert(
CommDataTypeTraits<T>::IsVector,
"Gatherv only valid with Array or vector arguments.");
bool amRoot = (GetRank() == rootProc);
bool mapExists = false;
unsigned nOut = 0;
void *recvbuf = NULL;
T ans;
if (amRoot)
{
mapExists = recvDataSizes.size() == recvDataOffsets.size() &&
recvDataSizes.size() > 0;
if (mapExists)
{
nOut = recvDataOffsets.back() + recvDataSizes.back();
}
if (nOut > 0)
{
ans.resize(nOut);
recvbuf = CommDataTypeTraits<T>::GetPointer(ans);
}
}
v_Gatherv(CommDataTypeTraits<T>::GetPointer(val),
CommDataTypeTraits<T>::GetCount(val),
CommDataTypeTraits<T>::GetDataType(),
recvbuf,
amRoot && mapExists ? &recvDataSizes[0] : NULL,
amRoot && mapExists ? &recvDataOffsets[0] : NULL,
CommDataTypeTraits<T>::GetDataType(),
rootProc);
return ans;
}
template <class T> T Comm::Scatterv(const int rootProc, T &pData,
std::vector<int> &sendDataSizes,
std::vector<int> &sendDataOffsets,
const unsigned int recvcnt)
{
static_assert(
CommDataTypeTraits<T>::IsVector,
"Scatterv only valid with Array or vector arguments.");
bool amRoot = (GetRank() == rootProc);
bool mapExists = sendDataSizes.size() == sendDataOffsets.size() &&
sendDataSizes.size() > 0;
T ans(recvcnt);
void *recvbuf = CommDataTypeTraits<T>::GetPointer(ans);
v_Scatterv(amRoot ? CommDataTypeTraits<T>::GetPointer(pData) : NULL,
amRoot && mapExists ? &sendDataSizes[0] : NULL,
amRoot && mapExists ? &sendDataOffsets[0] : NULL,
CommDataTypeTraits<T>::GetDataType(),
recvbuf, recvcnt,
CommDataTypeTraits<T>::GetDataType(),
rootProc);
return ans;
}
/**
* @brief If the flag is non-zero create a new communicator.
*/
......
......@@ -385,6 +385,28 @@ void CommMpi::v_Scatter(void *sendbuf, int sendcount, CommDataType sendtype,
ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Scatter.");
}
void CommMpi::v_Gatherv(void *sendbuf, int sendcount, CommDataType sendtype,
void *recvbuf, int recvcounts[], int rdispls[], CommDataType recvtype,
int root)
{
int retval = MPI_Gatherv(sendbuf, sendcount, sendtype,
recvbuf, recvcounts, rdispls, recvtype,
root, m_comm);
ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Gatherv.");
}
void CommMpi::v_Scatterv(void *sendbuf, int sendcounts[], int sdispls[], CommDataType sendtype,
void *recvbuf, int recvcount, CommDataType recvtype,
int root)
{
int retval = MPI_Scatterv(sendbuf, sendcounts, sdispls, sendtype,
recvbuf, recvcount, recvtype,
root, m_comm);
ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Scatterv.");
}
/**
* Processes are considered as a grid of size pRows*pColumns. Comm
* objects are created corresponding to the rows and columns of this
......
......@@ -126,6 +126,12 @@ protected:
void *recvbuf, int recvcount, CommDataType recvtype,
int root);
virtual void v_Gatherv(void *sendbuf, int sendcount, CommDataType sendtype,
void *recvbuf, int recvcounts[], int rdispls[], CommDataType recvtype,
int root);
virtual void v_Scatterv(void *sendbuf, int sendcounts[], int sdispls[], CommDataType sendtype,
void *recvbuf, int recvcount, CommDataType recvtype,
int root);
virtual void v_SplitComm(int pRows, int pColumns);
virtual CommSharedPtr v_CommCreateIf(int flag);
......
......@@ -207,6 +207,21 @@ void CommSerial::v_Scatter(void *sendbuf, int sendcount, CommDataType sendtype,
{
std::memcpy(recvbuf, sendbuf, sendcount * CommDataTypeGetSize(sendtype));
}
void CommSerial::v_Gatherv(void *sendbuf, int sendcount, CommDataType sendtype,
void *recvbuf, int recvcounts[], int rdispls[], CommDataType recvtype,
int root)
{
std::memcpy(recvbuf, sendbuf, sendcount * CommDataTypeGetSize(sendtype));
}
void CommSerial::v_Scatterv(void *sendbuf, int sendcounts[], int sdispls[], CommDataType sendtype,
void *recvbuf, int recvcount, CommDataType recvtype,
int root)
{
std::memcpy(recvbuf, sendbuf, sendcounts[0] * CommDataTypeGetSize(sendtype));
}
/**
*
*/
......
......@@ -117,6 +117,7 @@ protected:
LIB_UTILITIES_EXPORT virtual void v_Exscan(
Array<OneD, unsigned long long> &pData, const enum ReduceOperator pOp,
Array<OneD, unsigned long long> &ans);
LIB_UTILITIES_EXPORT virtual void v_Gather(void *sendbuf, int sendcount,
CommDataType sendtype,
void *recvbuf, int recvcount,
......@@ -127,6 +128,17 @@ protected:
CommDataType recvtype,
int root);
LIB_UTILITIES_EXPORT virtual void v_Gatherv(void *sendbuf, int sendcount,
CommDataType sendtype,
void *recvbuf, int recvcounts[],
int rdispls[], CommDataType recvtype,
int root);
LIB_UTILITIES_EXPORT virtual void v_Scatterv(void *sendbuf, int sendcounts[],
int sdispls[], CommDataType sendtype,
void *recvbuf, int recvcount,
CommDataType recvtype,
int root);
LIB_UTILITIES_EXPORT virtual void v_SplitComm(int pRows, int pColumns);
LIB_UTILITIES_EXPORT virtual CommSharedPtr v_CommCreateIf(int flag);
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment