diff --git a/example/observations_station_gaussian_model.py b/example/observations_station_gaussian_model.py index dc2143664..a55cf789c 100644 --- a/example/observations_station_gaussian_model.py +++ b/example/observations_station_gaussian_model.py @@ -15,6 +15,7 @@ """ # Noqa:D205,D400 import logging import os +from pathlib import Path from typing import Tuple import matplotlib.pyplot as plt @@ -30,6 +31,8 @@ DwdObservationResolution, ) +HERE = Path(__file__).parent + log = logging.getLogger() try: @@ -63,7 +66,7 @@ class ModelYearlyGaussians: """ - def __init__(self, station_data: StationsResult): + def __init__(self, station_data: StationsResult, plot_path: Path): self._station_data = station_data result_values = station_data.values.all().df.drop_nulls() @@ -81,7 +84,7 @@ def __init__(self, station_data: StationsResult): log.info(f"Fit Result message: {out.result.message}") - self.plot_data_and_model(valid_data, out, savefig_to_file=True) + self.plot_data_and_model(valid_data, out, savefig_to_file=True, plot_path=plot_path) def get_valid_data(self, result_values: pl.DataFrame) -> pl.DataFrame: valid_data_lst = [] @@ -137,7 +140,7 @@ def model_pars_update( return pars - def plot_data_and_model(self, valid_data: pl.DataFrame, out: ModelResult, savefig_to_file=True) -> None: + def plot_data_and_model(self, valid_data: pl.DataFrame, out: ModelResult, savefig_to_file, plot_path: Path) -> None: """plots the data and the model""" if savefig_to_file: _ = plt.subplots(figsize=(12, 12)) @@ -153,21 +156,21 @@ def plot_data_and_model(self, valid_data: pl.DataFrame, out: ModelResult, savefi if savefig_to_file: number_of_years = valid_data.get_column("date").dt.year().n_unique() filename = f"{self.__class__.__qualname__}_wetter_model_{number_of_years}" - plt.savefig(filename, dpi=300, bbox_inches="tight") + plt.savefig(plot_path / filename, dpi=300, bbox_inches="tight") log.info("saved fig to file: " + filename) if "PYTEST_CURRENT_TEST" not in os.environ: plt.show() -def main(): +def main(plot_path=HERE): """Run example.""" logging.basicConfig(level=logging.INFO) station_data_one_year = station_example(start_date="2020-12-25", end_date="2022-01-01") - _ = ModelYearlyGaussians(station_data_one_year) + _ = ModelYearlyGaussians(station_data_one_year, plot_path=plot_path) station_data_many_years = station_example(start_date="1995-12-25", end_date="2022-12-31") - _ = ModelYearlyGaussians(station_data_many_years) + _ = ModelYearlyGaussians(station_data_many_years, plot_path=plot_path) if __name__ == "__main__": diff --git a/tests/example/test_regular_examples.py b/tests/example/test_regular_examples.py index effc47a6d..edf69b370 100644 --- a/tests/example/test_regular_examples.py +++ b/tests/example/test_regular_examples.py @@ -35,10 +35,10 @@ def test_pdbufr_examples(): @pytest.mark.skipif(IS_CI and IS_LINUX, reason="stalls on Mac/Windows in CI") @pytest.mark.cflake -def test_gaussian_example(): +def test_gaussian_example(tmp_path): from example import observations_station_gaussian_model - assert observations_station_gaussian_model.main() is None + assert observations_station_gaussian_model.main(tmp_path) is None # @pytest.mark.skipif(IS_CI, reason="radar examples not working in CI")