Skip to content
Snippets Groups Projects
Commit 6d02d9bb authored by Dave Moxey's avatar Dave Moxey
Browse files

Various changes to allow Operator to take Field as arguments in apply()

parent 53202173
No related branches found
No related tags found
No related merge requests found
......@@ -7,7 +7,8 @@ set(SRC main.cpp Operator.cpp)
if (NEKTAR_USE_CUDA)
enable_language(CUDA)
add_definitions(-DNEKTAR_USE_CUDA)
set(SRC ${SRC} Operator.cu)
endif()
add_executable(main ${SRC})
\ No newline at end of file
add_executable(main ${SRC})
......@@ -2,12 +2,15 @@
#include <string>
#include "MemRef.hpp"
#include "MemRefCPU.hpp"
// Device options
struct DeviceCPU;
struct DeviceCUDA;
#if NEKTAR_USE_CUDA
using DefaultDevice = DeviceCUDA;
#include "MemRefCUDA.hpp"
#else
using DefaultDevice = DeviceCPU;
#endif
......@@ -30,9 +33,11 @@ template<typename TType = double, typename TState = DefaultState, typename tBack
class Field
{
public:
Field();
Field() : m_storage(10)
{
}
Field(const Field&) = default;
~Field();
~Field() = default;
MemRef<TType, tBackend> &GetStorage()
{
......
#pragma once
struct BackendCPU;
struct BackendCUDA;
......
#pragma once
#include "MemRef.hpp"
template<typename tData, typename tBackend = DefaultMemRef>
class MemRef;
template<typename tData>
class MemRef<tData, BackendCPU>;
class MemRef<tData, BackendCPU>
{
public:
MemRef(size_t n)
{
m_host = new tData[n];
......
#pragma once
#include "MemRef.hpp"
#include <cuda.h>
template<typename tData>
class MemRef<tData, BackendCUDA>;
class MemRef<tData, BackendCUDA>
{
MemRef(size_t n)
{
m_host = new tData[n];
m_device = cudaMalloc(sizeof(tData) * n);
}
MemRef(size_t n);
double *m_host = nullptr;
double *m_device = nullptr;
......
#include "OperatorBwdTrans.hpp"
/*
template<typename TData>
OperatorFactory<TData> &GetOperatorFactory() {
static OperatorFactory<TData> instance;
......@@ -9,4 +10,5 @@ OperatorFactory<TData> &GetOperatorFactory() {
std::string OpBwdTransLocMatCPU =
GetOperatorFactory<double>().RegisterCreatorFunction(
"BwdTransLocMatCPU",
Operator<double, OpBwdTrans, MethodLocMat, DeviceCPU>::create);
\ No newline at end of file
Operator<double, OpBwdTrans, MethodLocMat, DeviceCPU>::create);
*/
......@@ -16,20 +16,35 @@ struct MethodMatFree;
using DefaultMethod = MethodLocMat;
// Base class
template<typename TData>
template<typename TDerived, typename TData, typename TStateIn, typename TStateOut>
class OperatorBase
{
public:
virtual void apply(TData in, TData out) = 0;
void apply(Field<TData, TStateIn> &in, Field<TData, TStateOut> &out)
{
static_cast<TDerived *>(this)->apply_impl(in, out);
}
};
#if 0
template<typename TData, typename TStateIn, typename TStateOut>
class OperatorBase
{
public:
virtual void apply(Field<TData, TStateIn> &in, Field<TData, TStateOut> &out) = 0;
};
#endif
using namespace Nektar::LibUtilities;
/*
template<typename TData>
using OperatorFactory = NekFactory<std::string, OperatorBase<TData>>;
using OperatorFactory = NekFactory<std::string, OperatorBase<TData, TState>>;
template<typename TData>
OperatorFactory<TData> &GetOperatorFactory();
*/
// Templated Operator
template<typename TData,
......@@ -37,5 +52,3 @@ template<typename TData,
typename TMethod = DefaultMethod,
typename TDevice = DefaultDevice>
class Operator;
......@@ -4,18 +4,20 @@
// BwdTrans local matrix operator
template<typename TData>
class Operator<TData, OpBwdTrans, MethodLocMat, DeviceCPU>
: public OperatorBase<TData>
class Operator<TData, OpBwdTrans, MethodLocMat, DeviceCPU>
: public OperatorBase<Operator<TData, OpBwdTrans, MethodLocMat, DeviceCPU>, TData, StateCoeff, StatePhys>
{
public:
using ClassType = Operator<TData, OpBwdTrans, MethodLocMat, DeviceCPU>;
/*
static std::unique_ptr<ClassType> create()
{
return std::unique_ptr<ClassType>(new ClassType());
}
*/
void apply(TData in, TData out) override
void apply_impl(Field<TData, StateCoeff> &in, Field<TData, StatePhys> &out)
{
std::cout << "Perform BwdTrans op with LocMat" << std::endl;
}
......@@ -23,18 +25,20 @@ public:
// BwdTrans mat-free operator
template<typename TData>
class Operator<TData, OpBwdTrans, MethodMatFree, DeviceCPU>
: public OperatorBase<TData>
class Operator<TData, OpBwdTrans, MethodMatFree, DeviceCPU>
: public OperatorBase<Operator<TData, OpBwdTrans, MethodMatFree, DeviceCPU>, TData, StateCoeff, StatePhys>
{
public:
using ClassType = Operator<TData, OpBwdTrans, MethodMatFree, DeviceCPU>;
/*
static std::unique_ptr<ClassType> create()
{
return std::unique_ptr<ClassType>(new ClassType());
}
*/
void apply(TData in, TData out)
void apply_impl(Field<TData, StateCoeff> &in, Field<TData, StatePhys> &out)
{
std::cout << "Perform BwdTrans op with MatFree" << std::endl;
......
......@@ -10,9 +10,18 @@ int main() {
using DefaultMethod = MethodLocMat;
using TData = double;
Operator<TData, OpBwdTrans>::create()->apply(5.0, 6.0);
Field<double, StateCoeff> in;
Field<double, StatePhys> out;
auto test2 = std::make_shared<Operator<double, OpBwdTrans>>();
auto test = std::dynamic_pointer_cast<OperatorBase<Operator<double, OpBwdTrans>, double, StateCoeff, StatePhys>>(test2);
test->apply(in, out);
/*
Operator<TData, OpBwdTrans>::create()->apply(in, out);
auto o = Operator<TData, OpBwdTrans>::create();
o->apply(5.0, 6.0);
o->apply(in, out);
cout << boost::core::demangle(typeid(o).name()) << endl;
*/
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment