Skip to content

Commit

Permalink
use engine flox for ordered groups
Browse files Browse the repository at this point in the history
  • Loading branch information
mathause committed Sep 29, 2023
1 parent 5ea713e commit ef6bbc4
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
22 changes: 16 additions & 6 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1755,7 +1755,7 @@ def groupby_reduce(
dtype: np.typing.DTypeLike = None,
min_count: int | None = None,
method: T_Method = "map-reduce",
engine: T_Engine = "numpy",
engine: T_Engine = None,
reindex: bool | None = None,
finalize_kwargs: dict[Any, Any] | None = None,
) -> tuple[DaskArray, Unpack[tuple[np.ndarray | DaskArray, ...]]]: # type: ignore[misc] # Unpack not in mypy yet
Expand Down Expand Up @@ -1851,17 +1851,27 @@ def groupby_reduce(
xarray.xarray_reduce
"""

bys: T_Bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
nby = len(bys)
by_is_dask = tuple(is_duck_dask_array(b) for b in bys)
any_by_dask = any(by_is_dask)

if engine is None:
# choose numpy per default
engine = "numpy"

if nby == 1 and not any_by_dask and bys[0].ndim == 1:
# maybe move to helper function
issorted = lambda arr: (arr[:-1] <= arr[1:]).all()
if not _is_arg_reduction(func) and issorted(bys[0]):
engine = "flox"

if engine == "flox" and _is_arg_reduction(func):
raise NotImplementedError(
"argreductions not supported for engine='flox' yet."
"Try engine='numpy' or engine='numba' instead."
)

bys: T_Bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
nby = len(bys)
by_is_dask = tuple(is_duck_dask_array(b) for b in bys)
any_by_dask = any(by_is_dask)

if method in ["split-reduce", "cohorts"] and any_by_dask:
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")

Expand Down
2 changes: 1 addition & 1 deletion flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def xarray_reduce(
fill_value=None,
dtype: np.typing.DTypeLike = None,
method: str = "map-reduce",
engine: str = "numpy",
engine: str = None,
keep_attrs: bool | None = True,
skipna: bool | None = None,
min_count: int | None = None,
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest


@pytest.fixture(scope="module", params=["flox", "numpy", "numba"])
@pytest.fixture(scope="module", params=[None, "flox", "numpy", "numba"])
def engine(request):
if request.param == "numba":
try:
Expand Down

0 comments on commit ef6bbc4

Please sign in to comment.