Commit 1bc367cb authored by Chris Cantwell's avatar Chris Cantwell

Initial working version with dynamic state recovery.

parent 24452ecd
......@@ -146,6 +146,14 @@ public:
LIB_UTILITIES_EXPORT inline int EnrolSpare();
LIB_UTILITIES_EXPORT inline void BeginTransactionLog();
LIB_UTILITIES_EXPORT inline void EndTransactionLog();
LIB_UTILITIES_EXPORT inline bool IsRecovering();
LIB_UTILITIES_EXPORT inline void StateAdd(const std::string& name, const int& data);
LIB_UTILITIES_EXPORT inline void StateAdd(const std::string& name, const NekDouble* data, const int n);
LIB_UTILITIES_EXPORT inline void StateGet(const std::string& name, int& data);
LIB_UTILITIES_EXPORT inline void StateGet(const std::string& name, NekDouble* data, const int n);
LIB_UTILITIES_EXPORT inline void StateCommit();
LIB_UTILITIES_EXPORT inline void StateRestore();
protected:
int m_size; ///< Number of processes
......@@ -208,6 +216,14 @@ protected:
virtual int v_EnrolSpare() = 0;
virtual void v_BeginTransactionLog() {}
virtual void v_EndTransactionLog() {}
virtual bool v_IsRecovering() { return false; }
virtual void v_StateAdd(const std::string& name, const int& data) {}
virtual void v_StateAdd(const std::string& name, const NekDouble* data, const int n) {}
virtual void v_StateGet(const std::string& name, int& data) {}
virtual void v_StateGet(const std::string& name, NekDouble* data, const int n) {}
virtual void v_StateCommit() {}
virtual void v_StateRestore() {}
};
......@@ -536,6 +552,41 @@ inline void Comm::EndTransactionLog()
v_EndTransactionLog();
}
inline bool Comm::IsRecovering()
{
return v_IsRecovering();
}
inline void Comm::StateAdd(const std::string& name, const int& data)
{
v_StateAdd(name, data);
}
inline void Comm::StateAdd(const std::string& name, const NekDouble* data, const int n)
{
v_StateAdd(name, data, n);
}
inline void Comm::StateGet(const std::string& name, int& data)
{
v_StateGet(name, data);
}
inline void Comm::StateGet(const std::string& name, NekDouble* data, const int n)
{
v_StateGet(name, data, n);
}
inline void Comm::StateCommit()
{
v_StateCommit();
}
inline void Comm::StateRestore()
{
v_StateRestore();
}
}
}
......
......@@ -47,6 +47,7 @@ using namespace std;
#include <boost/serialization/queue.hpp>
#include <boost/serialization/deque.hpp>
#include <boost/serialization/vector.hpp>
#include <boost/serialization/map.hpp>
namespace mpi = boost::mpi;
#include <LibUtilities/BasicUtils/SharedArray.hpp>
......@@ -869,7 +870,7 @@ void CommMpi::RestoreState()
// a) backup of the recovering process's data for its recovery
// b) replacement copy of this processes backup data
// c) queue of flags for recovering derived communicators on recovering process
const int nReq = 6;
const int nReq = 8;
if (sendRecoveryData)
{
cout << "Restore: Sending " << m_dataBackup.size() << " backup items." << endl;
......@@ -881,6 +882,8 @@ void CommMpi::RestoreState()
reqs[3] = c.isend(send_rank, 3, m_derivedCommFlag);
reqs[4] = c.isend(send_rank, 4, m_gsInitDataBackup);
reqs[5] = c.isend(send_rank, 5, m_gsInitData);
reqs[6] = c.isend(send_rank, 6, m_stateDataBackup);
reqs[7] = c.isend(send_rank, 7, m_stateData);
cout << "Restore: Waiting for data to be sent to " << send_rank << endl;
mpi::wait_all(reqs, reqs + nReq);
cout << "Restore: Complete" << endl;
......@@ -898,6 +901,8 @@ void CommMpi::RestoreState()
reqs[3] = c.irecv(recv_rank, 3, m_derivedCommFlagBackup);
reqs[4] = c.irecv(recv_rank, 4, m_gsInitData);
reqs[5] = c.irecv(recv_rank, 5, m_gsInitDataBackup);
reqs[6] = c.irecv(recv_rank, 6, m_stateData);
reqs[7] = c.irecv(recv_rank, 7, m_stateDataBackup);
cout << "Restore: Waiting for data from " << recv_rank << endl;
mpi::wait_all(reqs, reqs + nReq);
cout << "Restore: Complete" << endl;
......@@ -1055,6 +1060,95 @@ void CommMpi::v_EndTransactionLog()
}
}
bool CommMpi::v_IsRecovering()
{
return m_isRecovering;
}
void CommMpi::v_StateAdd(const std::string& name, const int& data)
{
std::vector<char> x;
int dtsize = sizeof(int);
x.assign((char*)(&data), (char*)(&data)+dtsize);
m_stateData[name] = x;
}
void CommMpi::v_StateAdd(const std::string& name, const NekDouble* data, const int n)
{
std::vector<char> x;
int dtsize = sizeof(NekDouble);
x.assign((char*)data, (char*)data+n*dtsize);
m_stateData[name] = x;
}
void CommMpi::v_StateGet(const std::string& name, int& data)
{
ASSERTL0(m_stateData.count(name), "STATE ITEM DOES NOT EXIST!!");
std::vector<char> x = m_stateData[name];
memcpy(&data, &x[0], sizeof(int));
}
void CommMpi::v_StateGet(const std::string& name, NekDouble* data, const int n)
{
ASSERTL0(m_stateData.count(name), "STATE ITEM DOES NOT EXIST!!");
std::vector<char> x = m_stateData[name];
memcpy(data, &x[0], n*sizeof(NekDouble));
}
void CommMpi::v_StateCommit()
{
mpi::communicator c(m_comm, mpi::comm_attach);
mpi::request reqs[2];
int rank = c.rank();
int size = c.size();
if (size > 1)
{
int recv_rank = (rank + size - 1) % size;
int send_rank = (rank + 1) % size;
cout << "StateCommit: Sending " << m_stateData.size() << " items in queue." << endl;
reqs[0] = c.isend(send_rank, 0, m_stateData);
reqs[1] = c.irecv(recv_rank, 0, m_stateDataBackup);
cout << "StateCommit: Sent to " << send_rank << endl;
cout << "StateCommit: Waiting for data from " << recv_rank << endl;
mpi::wait_all(reqs, reqs + 2);
cout << "StateCommit: Received " << m_stateDataBackup.size() << " items." << endl;
}
else
{
cout << "StateCommit: Not backing up as comm of size 1" << endl;
}
}
void CommMpi::v_StateRestore()
{
mpi::communicator c(m_comm, mpi::comm_attach);
mpi::request reqs[2];
int rank = c.rank();
int size = c.size();
if (size > 1)
{
int send_rank = (rank + size - 1) % size;
int recv_rank = (rank + 1) % size;
cout << "StateRestore: Sending " << m_stateDataBackup.size() << " items in queue." << endl;
reqs[0] = c.isend(send_rank, 0, m_stateDataBackup);
reqs[1] = c.irecv(recv_rank, 0, m_stateData);
cout << "StateRestore: Sent to " << send_rank << endl;
cout << "StateRestore: Waiting for data from " << recv_rank << endl;
mpi::wait_all(reqs, reqs + 2);
cout << "StateRestore: Received " << m_stateData.size() << " items." << endl;
}
else
{
cout << "StateRestore: Not restoring up as comm of size 1" << endl;
}
}
void CommMpi::ReplaceComm(MPI_Comm commptr)
{
m_comm = commptr;
......
......@@ -150,12 +150,21 @@ protected:
virtual int v_EnrolSpare();
virtual void v_BeginTransactionLog();
virtual void v_EndTransactionLog();
virtual bool v_IsRecovering();
virtual void v_StateAdd(const std::string& name, const int& data);
virtual void v_StateAdd(const std::string& name, const NekDouble* data, const int n);
virtual void v_StateGet(const std::string& name, int& data);
virtual void v_StateGet(const std::string& name, NekDouble* data, const int n);
virtual void v_StateCommit();
virtual void v_StateRestore();
private:
typedef std::queue<std::vector<char>> StorageType;
typedef std::list<CommMpiSharedPtr> DerivedCommType;
typedef std::queue<int> DerivedCommFlagType;
typedef std::vector<Gs::gs_data*> GsHandlesType;
typedef std::map<std::string, std::vector<char>> StateStorageType;
MPI_Comm m_comm;
MPI_Comm m_agreecomm;
......@@ -173,7 +182,8 @@ private:
StorageType m_gsInitDataBackup;
GsHandlesType::iterator m_gsHandlesRestoreIt;
GsHandlesType m_gsHandles; ///< Handles to Gs library
StateStorageType m_stateData;
StateStorageType m_stateDataBackup;
static void HandleMpiError(MPI_Comm* pcomm, int* perr, ...);
......
......@@ -40,6 +40,15 @@
#include <SpatialDomains/MeshGraph.h>
#include <MultiRegions/ContField2D.h>
#include <boost/mpi/environment.hpp>
#include <boost/mpi/communicator.hpp>
#include <boost/mpi/nonblocking.hpp>
#include <boost/serialization/string.hpp>
#include <boost/serialization/queue.hpp>
#include <boost/serialization/deque.hpp>
#include <boost/serialization/vector.hpp>
namespace mpi = boost::mpi;
using namespace std;
using namespace Nektar;
......@@ -89,10 +98,17 @@ int main(int argc, char *argv[])
// Zero field coefficients for initial guess for linear solver.
Vmath::Zero(field->GetNcoeffs(), field->UpdateCoeffs(), 1);
session->GetComm()->EndTransactionLog();
// int n = 0;
// if (session->GetComm()->IsRecovering())
// {
//// session->GetComm()->StateRestore();
//// session->GetComm()->StateGet("n", n);
//// session->GetComm()->StateGet("Field", &field->UpdatePhys()[0], nq);
// n++;
// }
// Time integrate using backward Euler
for (unsigned int n = 0; n < nSteps; ++n)
for (int n = 0 ; n < nSteps; ++n)
{
cout << "Time step: " << n << endl;
try {
......@@ -103,15 +119,46 @@ int main(int argc, char *argv[])
NullFlagList, factors);
field->BwdTrans(field->GetCoeffs(), field->UpdatePhys());
if (session->GetComm()->IsRecovering())
{
cout << "Restoring field data back to last step" << endl;
session->GetComm()->StateGet("n", n);
session->GetComm()->StateGet("Field", &field->UpdatePhys()[0], nq);
}
else
{
cout << "Add field state data" << endl;
session->GetComm()->StateAdd("Field", &field->GetPhys()[0], nq);
cout << "Add time step data" << endl;
session->GetComm()->StateAdd("n", n);
cout << "Commit state" << endl;
session->GetComm()->StateCommit();
}
if (n == 0 || session->GetComm()->IsRecovering()) {
cout << "Ending transaction log" << endl;
session->GetComm()->EndTransactionLog();
}
} catch (...) {
try {
cout << "Caught an error - trying to invoke a spare." << endl;
int x = session->GetComm()->EnrolSpare();
cout << "Enroled spare, result: " << x << endl;
--n; // need to roll back to previous time step here...
Vmath::Smul(nq, -delta_t*epsilon, field->GetPhys(), 1,
field->UpdatePhys(), 1);
// --n; // need to roll back to previous time step here...
// Vmath::Smul(nq, -delta_t*epsilon, field->GetPhys(), 1,
// field->UpdatePhys(), 1);
// cout << "Recover state" << endl;
// session->GetComm()->StateGet("Field",
// &field->UpdatePhys()[0],
// field->GetTotPoints());
// session->GetComm()->StateGet("n", n);
// cout << "Finished recovering state" << endl;
cout << "Restoring last state" << endl;
// session->GetComm()->StateRestore();
session->GetComm()->StateGet("n", n);
session->GetComm()->StateGet("Field", &field->UpdatePhys()[0], nq);
cout << "Completed restoring last state" << endl;
} catch (...) {
cout << "ERROR WHEN PERFORMING ENROLSPARE!!!" << endl;
exit(-1);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment