diff --git a/benchmarks/mpas_ocean.py b/benchmarks/mpas_ocean.py index e8f783aa3..53a67a1b6 100644 --- a/benchmarks/mpas_ocean.py +++ b/benchmarks/mpas_ocean.py @@ -4,6 +4,8 @@ import uxarray as ux +import numpy as np + current_path = Path(os.path.dirname(os.path.realpath(__file__))) data_var = 'bottomDepth' @@ -164,3 +166,11 @@ def teardown(self, resolution): def time_check_norm(self, resolution): from uxarray.grid.validation import _check_normalization _check_normalization(self.uxgrid) + + +class CrossSections(DatasetBenchmark): + param_names = DatasetBenchmark.param_names + ['n_lat'] + params = DatasetBenchmark.params + [[1, 2, 4, 8]] + def time_constant_lat_fast(self, resolution, n_lat): + for lat in np.linspace(-89, 89, n_lat): + self.uxds.uxgrid.constant_latitude_cross_section(lat, method='fast') diff --git a/docs/api.rst b/docs/api.rst index 4c6469279..31f162f80 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -291,6 +291,30 @@ UxDataArray UxDataArray.subset.bounding_circle +Cross Sections +-------------- + + +Grid +~~~~ + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_method.rst + + Grid.cross_section + Grid.cross_section.constant_latitude + +UxDataArray +~~~~~~~~~~~ + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_method.rst + + UxDataArray.cross_section + UxDataArray.cross_section.constant_latitude + Remapping --------- diff --git a/docs/user-guide/cross-sections.ipynb b/docs/user-guide/cross-sections.ipynb new file mode 100644 index 000000000..9c0ba703c --- /dev/null +++ b/docs/user-guide/cross-sections.ipynb @@ -0,0 +1,208 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "4a432a8bf95d9cdb", + "metadata": {}, + "source": [ + "# Cross-Sections\n", + "\n", + "This section demonstrates how to extract cross-sections from an unstructured grid using UXarray, which allows the analysis and visualization across slices of grids.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b35ba4a2c30750e4", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-09T17:50:50.244285Z", + "start_time": "2024-10-09T17:50:50.239653Z" + } + }, + "outputs": [], + "source": [ + "import uxarray as ux\n", + "import geoviews.feature as gf\n", + "\n", + "import cartopy.crs as ccrs\n", + "import geoviews as gv\n", + "\n", + "projection = ccrs.Robinson()" + ] + }, + { + "cell_type": "markdown", + "id": "395a3db7-495c-4cff-b733-06bbe522a604", + "metadata": {}, + "source": [ + "## Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4160275c09fe6b0", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-09T17:50:51.217211Z", + "start_time": "2024-10-09T17:50:50.540946Z" + } + }, + "outputs": [], + "source": [ + "base_path = \"../../test/meshfiles/ugrid/outCSne30/\"\n", + "grid_path = base_path + \"outCSne30.ug\"\n", + "data_path = base_path + \"outCSne30_vortex.nc\"\n", + "\n", + "uxds = ux.open_dataset(grid_path, data_path)\n", + "uxds[\"psi\"].plot(\n", + " cmap=\"inferno\",\n", + " periodic_elements=\"split\",\n", + " projection=projection,\n", + " title=\"Global Plot\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a7a40958-0a4d-47e4-9e38-31925261a892", + "metadata": {}, + "source": [ + "## Constant Latitude\n", + "\n", + "Cross-sections along constant latitude lines can be obtained using the ``.cross_section.constant_latitude`` method, available for both ``ux.Grid`` and ``ux.DataArray`` objects. This functionality allows users to extract and analyze slices of data at specified latitudes, providing insights into variations along horizontal sections of the grid.\n" + ] + }, + { + "cell_type": "markdown", + "id": "2fbe9f6e5bb59a17", + "metadata": {}, + "source": [ + "For example, we can obtain a cross-section at 30 degrees latitude by doing the following:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3775daa1-2f1d-4738-bab5-2b69ebd689d9", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-09T17:50:53.093314Z", + "start_time": "2024-10-09T17:50:53.077719Z" + } + }, + "outputs": [], + "source": [ + "lat = 30\n", + "\n", + "uxda_constant_lat = uxds[\"psi\"].cross_section.constant_latitude(lat)" + ] + }, + { + "cell_type": "markdown", + "id": "dcec0b96b92e7f4", + "metadata": {}, + "source": [ + "Since the result is a new ``UxDataArray``, we can directly plot the result to see the cross-section." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "484b77a6-86da-4395-9e63-f5ac56e37deb", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " uxda_constant_lat.plot(\n", + " rasterize=False,\n", + " backend=\"bokeh\",\n", + " cmap=\"inferno\",\n", + " projection=projection,\n", + " global_extent=True,\n", + " coastline=True,\n", + " title=f\"Cross Section at {lat} degrees latitude\",\n", + " )\n", + " * gf.grid(projection=projection)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "c7cca7de4722c121", + "metadata": {}, + "source": [ + "You can also perform operations on the cross-section, such as taking the mean." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1cbee722-34a4-4e67-8e22-f393d7d36c99", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"Global Mean: {uxds['psi'].data.mean()}\")\n", + "print(f\"Mean at {lat} degrees lat: {uxda_constant_lat.data.mean()}\")" + ] + }, + { + "cell_type": "markdown", + "id": "c4a7ee25-0b60-470f-bab7-92ff70563076", + "metadata": {}, + "source": [ + "## Constant Longitude" + ] + }, + { + "cell_type": "markdown", + "id": "9fcc8ec5-c6a8-4bde-a33d-7f37f9116ee2", + "metadata": {}, + "source": [ + "```{warning}\n", + "Constant longitude cross sections are not yet supported.\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "54d9eff1-67f1-4691-a3b0-1ee0c874c98f", + "metadata": {}, + "source": [ + "## Arbitrary Great Circle Arc (GCA)" + ] + }, + { + "cell_type": "markdown", + "id": "ea94ff9f-fe86-470d-813b-45f32a633ffc", + "metadata": {}, + "source": [ + "```{warning}\n", + "Arbitrary great circle arc cross sections are not yet supported.\n", + "```" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/userguide.rst b/docs/userguide.rst index 007079780..c2bd7670d 100644 --- a/docs/userguide.rst +++ b/docs/userguide.rst @@ -43,6 +43,9 @@ These user guides provide detailed explanations of the core functionality in UXa `Subsetting `_ Select specific regions of a grid +`Cross-Sections `_ + Select cross-sections of a grid + `Remapping `_ Remap (a.k.a Regrid) between unstructured grids @@ -82,6 +85,7 @@ These user guides provide additional detail about specific features in UXarray. user-guide/mpl.ipynb user-guide/advanced-plotting.ipynb user-guide/subset.ipynb + user-guide/cross-sections.ipynb user-guide/remapping.ipynb user-guide/topological-aggregations.ipynb user-guide/calculus.ipynb diff --git a/test/meshfiles/ugrid/quad-hexagon/grid.nc b/test/meshfiles/ugrid/quad-hexagon/grid.nc index 6c0327dbb..aa38e3d46 100644 Binary files a/test/meshfiles/ugrid/quad-hexagon/grid.nc and b/test/meshfiles/ugrid/quad-hexagon/grid.nc differ diff --git a/test/test_cross_sections.py b/test/test_cross_sections.py new file mode 100644 index 000000000..26bf8777c --- /dev/null +++ b/test/test_cross_sections.py @@ -0,0 +1,94 @@ +import uxarray as ux +import pytest +from pathlib import Path +import os + +import numpy.testing as nt + +# Define the current path and file paths for grid and data +current_path = Path(os.path.dirname(os.path.realpath(__file__))) +quad_hex_grid_path = current_path / 'meshfiles' / "ugrid" / "quad-hexagon" / 'grid.nc' +quad_hex_data_path = current_path / 'meshfiles' / "ugrid" / "quad-hexagon" / 'data.nc' + +cube_sphere_grid = current_path / "meshfiles" / "geos-cs" / "c12" / "test-c12.native.nc4" + + + +class TestQuadHex: + """The quad hexagon grid contains four faces. + + Top Left Face: Index 1 + + Top Right Face: Index 2 + + Bottom Left Face: Index 0 + + Bottom Right Face: Index 3 + + The top two faces intersect a constant latitude of 0.1 + + The bottom two faces intersect a constant latitude of -0.1 + + All four faces intersect a constant latitude of 0.0 + """ + + def test_constant_lat_cross_section_grid(self): + uxgrid = ux.open_grid(quad_hex_grid_path) + + grid_top_two = uxgrid.cross_section.constant_latitude(lat=0.1) + + assert grid_top_two.n_face == 2 + + grid_bottom_two = uxgrid.cross_section.constant_latitude(lat=-0.1) + + assert grid_bottom_two.n_face == 2 + + grid_all_four = uxgrid.cross_section.constant_latitude(lat=0.0) + + assert grid_all_four.n_face == 4 + + with pytest.raises(ValueError): + # no intersections found at this line + uxgrid.cross_section.constant_latitude(lat=10.0) + + + def test_constant_lat_cross_section_uxds(self): + uxds = ux.open_dataset(quad_hex_grid_path, quad_hex_data_path) + + da_top_two = uxds['t2m'].cross_section.constant_latitude(lat=0.1) + + nt.assert_array_equal(da_top_two.data, uxds['t2m'].isel(n_face=[1, 2]).data) + + da_bottom_two = uxds['t2m'].cross_section.constant_latitude(lat=-0.1) + + nt.assert_array_equal(da_bottom_two.data, uxds['t2m'].isel(n_face=[0, 3]).data) + + da_all_four = uxds['t2m'].cross_section.constant_latitude(lat=0.0) + + nt.assert_array_equal(da_all_four.data , uxds['t2m'].data) + + with pytest.raises(ValueError): + # no intersections found at this line + uxds['t2m'].cross_section.constant_latitude(lat=10.0) + + +class TestGeosCubeSphere: + def test_north_pole(self): + uxgrid = ux.open_grid(cube_sphere_grid) + + lats = [89.85, 89.9, 89.95, 89.99] + + for lat in lats: + cross_grid = uxgrid.cross_section.constant_latitude(lat=lat) + # Cube sphere grid should have 4 faces centered around the pole + assert cross_grid.n_face == 4 + + def test_south_pole(self): + uxgrid = ux.open_grid(cube_sphere_grid) + + lats = [-89.85, -89.9, -89.95, -89.99] + + for lat in lats: + cross_grid = uxgrid.cross_section.constant_latitude(lat=lat) + # Cube sphere grid should have 4 faces centered around the pole + assert cross_grid.n_face == 4 diff --git a/uxarray/core/dataarray.py b/uxarray/core/dataarray.py index ece6387cb..a038ca5a0 100644 --- a/uxarray/core/dataarray.py +++ b/uxarray/core/dataarray.py @@ -33,6 +33,7 @@ from uxarray.plot.accessor import UxDataArrayPlotAccessor from uxarray.subset import DataArraySubsetAccessor from uxarray.remap import UxDataArrayRemapAccessor +from uxarray.cross_sections import UxDataArrayCrossSectionAccessor from uxarray.core.aggregation import _uxda_grid_aggregate import warnings @@ -85,6 +86,7 @@ def __init__(self, *args, uxgrid: Grid = None, **kwargs): plot = UncachedAccessor(UxDataArrayPlotAccessor) subset = UncachedAccessor(DataArraySubsetAccessor) remap = UncachedAccessor(UxDataArrayRemapAccessor) + cross_section = UncachedAccessor(UxDataArrayCrossSectionAccessor) def _repr_html_(self) -> str: if OPTIONS["display_style"] == "text": diff --git a/uxarray/cross_sections/__init__.py b/uxarray/cross_sections/__init__.py new file mode 100644 index 000000000..8694cabca --- /dev/null +++ b/uxarray/cross_sections/__init__.py @@ -0,0 +1,7 @@ +from .dataarray_accessor import UxDataArrayCrossSectionAccessor +from .grid_accessor import GridCrossSectionAccessor + +__all__ = ( + "GridCrossSectionAccessor", + "UxDataArrayCrossSectionAccessor", +) diff --git a/uxarray/cross_sections/dataarray_accessor.py b/uxarray/cross_sections/dataarray_accessor.py new file mode 100644 index 000000000..d840892b6 --- /dev/null +++ b/uxarray/cross_sections/dataarray_accessor.py @@ -0,0 +1,70 @@ +from __future__ import annotations + + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + pass + + +class UxDataArrayCrossSectionAccessor: + """Accessor for cross-section operations on a ``UxDataArray``""" + + def __init__(self, uxda) -> None: + self.uxda = uxda + + def __repr__(self): + prefix = "\n" + methods_heading = "Supported Methods:\n" + + methods_heading += " * constant_latitude(center_coord, k, element, **kwargs)\n" + + return prefix + methods_heading + + def constant_latitude(self, lat: float, method="fast"): + """Extracts a cross-section of the data array at a specified constant + latitude. + + Parameters + ---------- + lat : float + The latitude at which to extract the cross-section, in degrees. + method : str, optional + The internal method to use when identifying faces at the constant latitude. + Options are: + - 'fast': Uses a faster but potentially less accurate method for face identification. + - 'accurate': Uses a slower but more accurate method. + Default is 'fast'. + + Raises + ------ + ValueError + If no intersections are found at the specified latitude, a ValueError is raised. + + Examples + -------- + >>> uxda.constant_latitude_cross_section(lat=-15.5) + + Notes + ----- + The accuracy and performance of the function can be controlled using the `method` parameter. + For higher precision requreiments, consider using method='acurate'. + """ + faces = self.uxda.uxgrid.get_faces_at_constant_latitude(lat, method) + + return self.uxda.isel(n_face=faces) + + def constant_longitude(self, *args, **kwargs): + raise NotImplementedError + + def gca(self, *args, **kwargs): + raise NotImplementedError + + def bounded_latitude(self, *args, **kwargs): + raise NotImplementedError + + def bounded_longitude(self, *args, **kwargs): + raise NotImplementedError + + def gca_gca(self, *args, **kwargs): + raise NotImplementedError diff --git a/uxarray/cross_sections/grid_accessor.py b/uxarray/cross_sections/grid_accessor.py new file mode 100644 index 000000000..067e8f5fb --- /dev/null +++ b/uxarray/cross_sections/grid_accessor.py @@ -0,0 +1,95 @@ +from __future__ import annotations + + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from uxarray.grid import Grid + + +class GridCrossSectionAccessor: + """Accessor for cross-section operations on a ``Grid``""" + + def __init__(self, uxgrid: Grid) -> None: + self.uxgrid = uxgrid + + def __repr__(self): + prefix = "\n" + methods_heading = "Supported Methods:\n" + + methods_heading += " * constant_latitude(lat, )\n" + return prefix + methods_heading + + def constant_latitude(self, lat: float, return_face_indices=False, method="fast"): + """Extracts a cross-section of the grid at a specified constant + latitude. + + This method identifies and returns all faces (or grid elements) that intersect + with a given latitude. The returned cross-section can include either just the grid + or both the grid elements and the corresponding face indices, depending + on the `return_face_indices` parameter. + + Parameters + ---------- + lat : float + The latitude at which to extract the cross-section, in degrees. + return_face_indices : bool, optional + If True, returns both the grid at the specified latitude and the indices + of the intersecting faces. If False, only the grid is returned. + Default is False. + method : str, optional + The internal method to use when identifying faces at the constant latitude. + Options are: + - 'fast': Uses a faster but potentially less accurate method for face identification. + - 'accurate': Uses a slower but more accurate method. + Default is 'fast'. + + Returns + ------- + grid_at_constant_lat : Grid + The grid with faces that interesected at a given lattitude + faces : array, optional + The indices of the faces that intersect with the specified latitude. This is only + returned if `return_face_indices` is set to True. + + Raises + ------ + ValueError + If no intersections are found at the specified latitude, a ValueError is raised. + + Examples + -------- + >>> grid, indices = grid.cross_section.constant_latitude(lat=30.0, return_face_indices=True) + >>> grid = grid.cross_section.constant_latitude(lat=-15.5) + + Notes + ----- + The accuracy and performance of the function can be controlled using the `method` parameter. + For higher precision requreiments, consider using method='acurate'. + """ + faces = self.uxgrid.get_faces_at_constant_latitude(lat, method) + + if len(faces) == 0: + raise ValueError(f"No intersections found at lat={lat}.") + + grid_at_constant_lat = self.uxgrid.isel(n_face=faces) + + if return_face_indices: + return grid_at_constant_lat, faces + else: + return grid_at_constant_lat + + def constant_longitude(self, *args, **kwargs): + raise NotImplementedError + + def gca(self, *args, **kwargs): + raise NotImplementedError + + def bounded_latitude(self, *args, **kwargs): + raise NotImplementedError + + def bounded_longitude(self, *args, **kwargs): + raise NotImplementedError + + def gca_gca(self, *args, **kwargs): + raise NotImplementedError diff --git a/uxarray/grid/grid.py b/uxarray/grid/grid.py index 1d9e52811..211aa7949 100644 --- a/uxarray/grid/grid.py +++ b/uxarray/grid/grid.py @@ -66,12 +66,18 @@ _populate_edge_node_distances, ) +from uxarray.grid.intersections import ( + fast_constant_lat_intersections, +) + from spatialpandas import GeoDataFrame from uxarray.plot.accessor import GridPlotAccessor from uxarray.subset import GridSubsetAccessor +from uxarray.cross_sections import GridCrossSectionAccessor + from uxarray.grid.validation import ( _check_connectivity, _check_duplicate_nodes, @@ -83,7 +89,6 @@ from uxarray.conventions import ugrid - from xarray.core.utils import UncachedAccessor from warnings import warn @@ -92,6 +97,8 @@ import copy + +from uxarray.constants import INT_FILL_VALUE from uxarray.grid.dual import construct_dual @@ -218,6 +225,9 @@ def __init__( # declare subset accessor subset = UncachedAccessor(GridSubsetAccessor) + # declare cross section accessor + cross_section = UncachedAccessor(GridCrossSectionAccessor) + @classmethod def from_dataset( cls, dataset: xr.Dataset, use_dual: Optional[bool] = False, **kwargs @@ -983,7 +993,7 @@ def face_node_connectivity(self, value): def edge_node_connectivity(self) -> xr.DataArray: """Indices of the two nodes that make up each edge. - Dimensions: ``(n_edge, n_max_edge_nodes)`` + Dimensions: ``(n_edge, two)`` Nodes are in arbitrary order. """ @@ -998,6 +1008,23 @@ def edge_node_connectivity(self, value): assert isinstance(value, xr.DataArray) self._ds["edge_node_connectivity"] = value + @property + def edge_node_z(self) -> xr.DataArray: + """Cartesian z location for the two nodes that make up every edge. + + Dimensions: ``(n_edge, two)`` + """ + + if "edge_node_z" not in self._ds: + _edge_node_z = self.node_z.values[self.edge_node_connectivity.values] + + self._ds["edge_node_z"] = xr.DataArray( + data=_edge_node_z, + dims=["n_edge", "two"], + ) + + return self._ds["edge_node_z"] + @property def node_node_connectivity(self) -> xr.DataArray: """Indices of the nodes that surround each node.""" @@ -1880,6 +1907,31 @@ def to_linecollection( return line_collection + def get_dual(self): + """Compute the dual for a grid, which constructs a new grid centered + around the nodes, where the nodes of the primal become the face centers + of the dual, and the face centers of the primal become the nodes of the + dual. Returns a new `Grid` object. + + Returns + -------- + dual : Grid + Dual Mesh Grid constructed + """ + + if _check_duplicate_nodes_indices(self): + raise RuntimeError("Duplicate nodes found, cannot construct dual") + + # Get dual mesh node face connectivity + dual_node_face_conn = construct_dual(grid=self) + + # Construct dual mesh + dual = self.from_topology( + self.face_lon.values, self.face_lat.values, dual_node_face_conn + ) + + return dual + def isel(self, **dim_kwargs): """Indexes an unstructured grid along a given dimension (``n_node``, ``n_edge``, or ``n_face``) and returns a new grid. @@ -1918,27 +1970,67 @@ def isel(self, **dim_kwargs): "Indexing must be along a grid dimension: ('n_node', 'n_edge', 'n_face')" ) - def get_dual(self): - """Compute the dual for a grid, which constructs a new grid centered - around the nodes, where the nodes of the primal become the face centers - of the dual, and the face centers of the primal become the nodes of the - dual. Returns a new `Grid` object. + def get_edges_at_constant_latitude(self, lat, method="fast"): + """Identifies the edges of the grid that intersect with a specified + constant latitude. + + This method computes the intersection of grid edges with a given latitude and + returns a collection of edges that cross or are aligned with that latitude. + The method used for identifying these edges can be controlled by the `method` + parameter. + + Parameters + ---------- + lat : float + The latitude at which to identify intersecting edges, in degrees. + method : str, optional + The computational method used to determine edge intersections. Options are: + - 'fast': Uses a faster but potentially less accurate method for determining intersections. + - 'accurate': Uses a slower but more precise method. + Default is 'fast'. Returns - -------- - dual : Grid - Dual Mesh Grid constructed + ------- + edges : array + A squeezed array of edges that intersect the specified constant latitude. """ + if method == "fast": + edges = fast_constant_lat_intersections( + lat, self.edge_node_z.values, self.n_edge + ) + elif method == "accurate": + raise NotImplementedError("Accurate method not yet implemented.") + else: + raise ValueError(f"Invalid method: {method}.") + return edges.squeeze() - if _check_duplicate_nodes_indices(self): - raise RuntimeError("Duplicate nodes found, cannot construct dual") + def get_faces_at_constant_latitude(self, lat, method="fast"): + """Identifies the faces of the grid that intersect with a specified + constant latitude. - # Get dual mesh node face connectivity - dual_node_face_conn = construct_dual(grid=self) + This method finds the faces (or cells) of the grid that intersect a given latitude + by first identifying the intersecting edges and then determining the faces connected + to these edges. The method used for identifying edges can be adjusted with the `method` + parameter. - # Construct dual mesh - dual = self.from_topology( - self.face_lon.values, self.face_lat.values, dual_node_face_conn - ) + Parameters + ---------- + lat : float + The latitude at which to identify intersecting faces, in degrees. + method : str, optional + The computational method used to determine intersecting edges. Options are: + - 'fast': Uses a faster but potentially less accurate method for determining intersections. + - 'accurate': Uses a slower but more precise method. + Default is 'fast'. - return dual + Returns + ------- + faces : array + An array of unique face indices that intersect the specified latitude. + Faces that are invalid or missing (e.g., with a fill value) are excluded + from the result. + """ + edges = self.get_edges_at_constant_latitude(lat, method) + faces = np.unique(self.edge_face_connectivity[edges].data.ravel()) + + return faces[faces != INT_FILL_VALUE] diff --git a/uxarray/grid/intersections.py b/uxarray/grid/intersections.py index ccf767df5..7e74622d4 100644 --- a/uxarray/grid/intersections.py +++ b/uxarray/grid/intersections.py @@ -7,6 +7,48 @@ from uxarray.utils.computing import cross_fma, allclose, dot, cross, norm +from numba import njit, prange + + +@njit(parallel=True, nogil=True, cache=True) +def fast_constant_lat_intersections(lat, edge_node_z, n_edge): + """Determine which edges intersect a constant line of latitude on a sphere. + + Parameters + ---------- + lat: + Constant latitude value in degrees. + edge_node_z: + Array of shape (n_edge, 2) containing z-coordinates of the edge nodes. + n_edge: + Total number of edges to check. + + Returns + ------- + intersecting_edges: + array of indices of edges that intersect the constant latitude. + """ + lat = np.deg2rad(lat) + + intersecting_edges_mask = np.zeros(n_edge, dtype=np.int32) + + # Calculate the constant z-value for the given latitude + z_constant = np.sin(lat) + + # Iterate through each edge and check for intersections + for i in prange(n_edge): + # Get the z-coordinates of the edge's nodes + z0 = edge_node_z[i, 0] + z1 = edge_node_z[i, 1] + + if (z0 - z_constant) * (z1 - z_constant) < 0.0: + intersecting_edges_mask[i] = 1 + + intersecting_edges = np.argwhere(intersecting_edges_mask) + + return np.unique(intersecting_edges) + + def gca_gca_intersection(gca1_cart, gca2_cart, fma_disabled=True): """Calculate the intersection point(s) of two Great Circle Arcs (GCAs) in a Cartesian coordinate system.