Skip to content
Snippets Groups Projects
Commit 2d43d5f9 authored by Stephen Liu's avatar Stephen Liu
Browse files

Add filter data models for AeroForces,HistoryPoint and CHeckpoint

parent bfa3e2c1
No related branches found
No related tags found
No related merge requests found
from __future__ import annotations
from dataclasses import dataclass
from types import MappingProxyType
from typing import Any,Callable
import os
"""
Currently only implemented filters of interest
"""
class NekFilter:
pass
@dataclass(frozen=True)
class AerodynamicForcesFilter(NekFilter):
boundary: str #should be of form B[...]
output_file: str
output_frequency: int=1
pivot_point: tuple[float,float,float]=(0.0,0.0,0.0)
@dataclass(frozen=True)
class CheckpointFilter(NekFilter):
output_frequency: int
output_file: str
output_start_time: float = 0.0
@dataclass(frozen=True)
class HistoryPointsFilter(NekFilter):
output_file: str
output_frequency: int=1
output_one_file: bool=True
#homogeneous simulation params
#need way to check with __post_init at some point
output_plane: int=-1
wave_space: bool=False
points: tuple[float,...]=()
line: tuple[float,...]=()
plane: tuple[float,...]=()
box: tuple[float,...]=()
class NekFilterFactory:
#allows lazy loading in of filters
_filter_registry: MappingProxyType[str, Callable[[dict[str, Any], str], NekFilter]] = MappingProxyType({
"AeroForces": lambda params, session_name: AerodynamicForcesFilter(
output_file=params.get("OutputFile", session_name),
output_frequency=int(params.get("OutputFrequency",1)),
boundary=params.get("Boundary",""),
pivot_point=tuple(params.get("PivotPoint",(0,0,0)))
),
"Checkpoint": lambda params, session_name: CheckpointFilter(
output_frequency=int(params.get("OutputFrequency",1)),
output_file=params.get("OutputFile", session_name),
output_start_time=float(params.get("OutputStartTime",0.0))
),
#TODO History Points
})
@classmethod
def get_filter(cls, name: str, params: dict[str, Any], session_file: str) -> NekFilter | None:
session_name = os.path.basename(session_file)
filter_constructor = cls._filter_registry.get(name, None)
return filter_constructor(params, session_name) if filter_constructor else None
......@@ -168,6 +168,9 @@ class NekSessionFile:
params: dict[str,int | float] = parsing.evaluate_parameters(raw_params)
return params
def get_filters(self):
pass
def ensure_composite_format(text: str) -> str:
"""Ensure any [...] are converted to C[...]. This is to accouint for the fact
......
import pytest
from NekUpload.NekData.filters import *
from typing import Any
@pytest.mark.unit
@pytest.mark.parametrize("session_file,params,expected_params",[
("test.xml",
{"OutputFile":"IntermediateFields","OutputFrequency": 100,"OutputStartTime":50.0},
{"OutputFile":"IntermediateFields","OutputFrequency": 100,"OutputStartTime":50.0}
),
("test.xml",
{"OutputFrequency": 100,"OutputStartTime":50.0},
{"OutputFile":"test.xml","OutputFrequency": 100,"OutputStartTime":50.0}
),
("path/to/test.xml",
{"OutputFrequency": 100,"OutputStartTime":50.0},
{"OutputFile":"test.xml","OutputFrequency": 100,"OutputStartTime":50.0}
),
("test.xml",
{"OutputFile":"IntermediateFields","OutputFrequency": 100},
{"OutputFile":"IntermediateFields","OutputFrequency": 100,"OutputStartTime":0.0}
),
("test.xml",
{"OutputFile":"IntermediateFields","OutputStartTime":50.0},
{"OutputFile":"IntermediateFields","OutputFrequency": 1,"OutputStartTime":50.0}
),
("test.xml",
{"OutputFile":"IntermediateFields","OutputFrequency": "100","OutputStartTime":50.0},
{"OutputFile":"IntermediateFields","OutputFrequency": 100,"OutputStartTime":50.0}
),
("test.xml",
{"OutputFile":"IntermediateFields","OutputFrequency": 100,"OutputStartTime":"50.0"},
{"OutputFile":"IntermediateFields","OutputFrequency": 100,"OutputStartTime":50.0}
),
])
def test_nek_filter_checkpoint_factory(session_file:str,params:dict[str,Any],
expected_params:dict[str,Any]):
filter: CheckpointFilter = NekFilterFactory.get_filter("Checkpoint",params,session_file)
assert filter.output_file == expected_params["OutputFile"], f"Expected OutputFile: {expected_params['OutputFile']}, but got: {filter.output_file}"
assert filter.output_frequency == expected_params["OutputFrequency"], f"Expected OutputFrequency: {expected_params['OutputFrequency']}, but got: {filter.output_frequency}"
assert filter.output_start_time == expected_params["OutputStartTime"], f"Expected OutputStartTime: {expected_params['OutputStartTime']}, but got: {filter.output_start_time}"
@pytest.mark.unit
@pytest.mark.parametrize("session_file,params,expected_params",[
("test.xml",
{"OutputFile":"DragLift","OutputFrequency":10,"Boundary":"B[1,2]","PivotPoint":[1.0,.1,0.0]},
{"OutputFile":"DragLift","OutputFrequency":10,"Boundary":"B[1,2]","PivotPoint":(1.0,.1,0.0)},
),
("test.xml",
{"OutputFrequency":10,"Boundary":"B[1,2]","PivotPoint":[1.0,.1,0.0]},
{"OutputFile":"test.xml","OutputFrequency":10,"Boundary":"B[1,2]","PivotPoint":(1.0,.1,0.0)},
),
("path/to/test.xml",
{"OutputFrequency":10,"Boundary":"B[1,2]","PivotPoint":[1.0,.1,0.0]},
{"OutputFile":"test.xml","OutputFrequency":10,"Boundary":"B[1,2]","PivotPoint":(1.0,.1,0.0)},
),
("test.xml",
{"OutputFile":"DragLift","Boundary":"B[1,2]","PivotPoint":[1.0,.1,0.0]},
{"OutputFile":"DragLift","OutputFrequency":1,"Boundary":"B[1,2]","PivotPoint":(1.0,.1,0.0)},
),
("test.xml",
{"OutputFile":"DragLift","OutputFrequency":10,"PivotPoint":[1.0,.1,0.0]},
{"OutputFile":"DragLift","OutputFrequency":10,"Boundary":"","PivotPoint":(1.0,.1,0.0)},
),
("test.xml",
{"OutputFile":"DragLift","OutputFrequency":10,"Boundary":"B[1,2]"},
{"OutputFile":"DragLift","OutputFrequency":10,"Boundary":"B[1,2]","PivotPoint":(0.0,0.0,0.0)},
),
("test.xml",
{"OutputFile":"DragLift","OutputFrequency":"10","Boundary":"B[1,2]"},
{"OutputFile":"DragLift","OutputFrequency":10,"Boundary":"B[1,2]","PivotPoint":(0.0,0.0,0.0)},
),
])
def test_nek_filter_aeroforces_factory(session_file:str,params:dict[str,Any],
expected_params:dict[str,Any]):
filter: AerodynamicForcesFilter = NekFilterFactory.get_filter("AeroForces",params,session_file)
assert filter.output_file == expected_params["OutputFile"], f"Expected OutputFile: {expected_params['OutputFile']}, but got: {filter.output_file}"
assert filter.output_frequency == expected_params["OutputFrequency"], f"Expected OutputFrequency: {expected_params['OutputFrequency']}, but got: {filter.output_frequency}"
assert filter.boundary == expected_params["Boundary"]
assert filter.pivot_point == expected_params["PivotPoint"]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment