Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve dtype better when specified. #389

Merged
merged 16 commits into from
Sep 8, 2024
4 changes: 3 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ jobs:
- name: Run Tests
id: status
run: |
python -c "import xarray; xarray.show_versions()"
pytest --durations=20 --durations-min=0.5 -n auto --cov=./ --cov-report=xml --hypothesis-profile ci
- name: Upload code coverage to Codecov
uses: codecov/[email protected]
Expand Down Expand Up @@ -98,7 +99,7 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
repository: "pydata/xarray"
repository: "dcherian/xarray"
fetch-depth: 0 # Fetch all history for all branches and tags.
- name: Set up conda environment
uses: mamba-org/setup-micromamba@v1
Expand All @@ -112,6 +113,7 @@ jobs:
pint>=0.22
- name: Install xarray
run: |
git checkout flox-preserve-dtype
python -m pip install --no-deps .
- name: Install upstream flox
run: |
Expand Down
3 changes: 2 additions & 1 deletion ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ dependencies:
- pytest-pretty
- pytest-xdist
- syrupy
- xarray
- pre-commit
- numpy_groupies>=0.9.19
- pooch
- toolz
- numba
- numbagg>=0.3
- hypothesis
- pip:
- git+https://github.com/dcherian/xarray.git@flox-preserve-dtype
3 changes: 2 additions & 1 deletion ci/no-dask.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ dependencies:
- pytest-pretty
- pytest-xdist
- syrupy
- xarray
- numpydoc
- pre-commit
- numpy_groupies>=0.9.19
- pooch
- toolz
- numba
- numbagg>=0.3
- pip:
- git+https://github.com/dcherian/xarray.git@flox-preserve-dtype
13 changes: 8 additions & 5 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,20 +549,23 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
return (Dim(name="quantile", values=q),)


# if the input contains integers or floats smaller than float64,
# the output data-type is float64. Otherwise, the output data-type is the same as that
# of the input.
quantile = Aggregation(
name="quantile",
fill_value=dtypes.NA,
chunk=None,
combine=None,
final_dtype=np.floating,
final_dtype=np.float64,
new_dims_func=quantile_new_dims_func,
)
nanquantile = Aggregation(
name="nanquantile",
fill_value=dtypes.NA,
chunk=None,
combine=None,
final_dtype=np.floating,
final_dtype=np.float64,
new_dims_func=quantile_new_dims_func,
)
mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True)
Expand Down Expand Up @@ -801,9 +804,9 @@ def _initialize_aggregation(
dtype_: np.dtype | None = (
np.dtype(dtype) if dtype is not None and not isinstance(dtype, np.dtype) else dtype
)
final_dtype = dtypes._normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value)
if not agg.preserves_dtype:
final_dtype = dtypes._maybe_promote_int(final_dtype)
final_dtype = dtypes._normalize_dtype(
dtype_ or agg.dtype_init["final"], array_dtype, agg.preserves_dtype, fill_value
)
agg.dtype = {
"user": dtype, # Save to automatically choose an engine
"final": final_dtype,
Expand Down
9 changes: 7 additions & 2 deletions flox/xrdtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,14 @@ def is_datetime_like(dtype):
return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)


def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype:
def _normalize_dtype(
dtype: DTypeLike, array_dtype: np.dtype, preserves_dtype: bool, fill_value=None
) -> np.dtype:
if dtype is None:
dtype = array_dtype
if not preserves_dtype:
dtype = _maybe_promote_int(array_dtype)
else:
dtype = array_dtype
if dtype is np.floating:
# mean, std, var always result in floating
# but we preserve the array's dtype if it is floating
Expand Down
4 changes: 2 additions & 2 deletions tests/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:


# TODO: stop excluding everything but U
array_dtype_st = supported_dtypes().filter(lambda x: x.kind not in "cmMU")
array_dtypes = supported_dtypes().filter(lambda x: x.kind not in "cmMU")
by_dtype_st = supported_dtypes()

NON_NUMPY_FUNCS = [
Expand All @@ -43,7 +43,7 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:

func_st = st.sampled_from([f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS])
numeric_arrays = npst.arrays(
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtype_st
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtypes
)
all_arrays = npst.arrays(
elements={"allow_subnormal": False},
Expand Down
15 changes: 14 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _get_array_func(func: str) -> Callable:

def npfunc(x, **kwargs):
x = np.asarray(x)
return (~np.isnan(x)).sum()
return (~xrutils.isnull(x)).sum(**kwargs)

elif func in ["nanfirst", "nanlast"]:
npfunc = getattr(xrutils, func)
Expand Down Expand Up @@ -1984,3 +1984,16 @@ def test_blockwise_nans():
)
assert_equal(expected_groups, actual_groups)
assert_equal(expected, actual)


@pytest.mark.parametrize("func", ["sum", "prod", "count", "nansum"])
@pytest.mark.parametrize("engine", ["flox", "numpy"])
def test_agg_dtypes(func, engine):
# regression test for GH388
counts = np.array([0, 2, 1, 0, 1])
group = np.array([1, 1, 1, 2, 2])
actual, _ = groupby_reduce(
counts, group, expected_groups=(np.array([1, 2]),), func=func, dtype="uint8", engine=engine
)
expected = _get_array_func(func)(counts, dtype="uint8")
assert actual.dtype == np.uint8 == expected.dtype
24 changes: 23 additions & 1 deletion tests/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from flox.xrutils import notnull

from . import assert_equal
from .strategies import by_arrays, chunked_arrays, func_st, numeric_arrays
from .strategies import array_dtypes, by_arrays, chunked_arrays, func_st, numeric_arrays
from .strategies import chunks as chunks_strategy

dask.config.set(scheduler="sync")
Expand Down Expand Up @@ -244,3 +244,25 @@ def test_first_last_useless(data, func):
actual, groups = groupby_reduce(array, by, axis=-1, func=func, engine="numpy")
expected = np.zeros(shape[:-1] + (len(groups),), dtype=array.dtype)
assert_equal(actual, expected)


@given(
func=st.sampled_from(["sum", "prod", "nansum", "nanprod"]),
engine=st.sampled_from(["numpy", "flox"]),
array_dtype=st.none() | array_dtypes,
dtype=st.none() | array_dtypes,
)
def test_agg_dtype_specified(func, array_dtype, dtype, engine):
# regression test for GH388
counts = np.array([0, 2, 1, 0, 1], dtype=array_dtype)
group = np.array([1, 1, 1, 2, 2])
actual, _ = groupby_reduce(
counts,
group,
expected_groups=(np.array([1, 2]),),
func=func,
dtype=dtype,
engine=engine,
)
expected = getattr(np, func)(counts, keepdims=True, dtype=dtype)
assert actual.dtype == expected.dtype
25 changes: 24 additions & 1 deletion tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

# test against legacy xarray implementation
# avoid some compilation overhead
xr.set_options(use_flox=False, use_numbagg=False)
xr.set_options(use_flox=False, use_numbagg=False, use_bottleneck=False)
tolerance64 = {"rtol": 1e-15, "atol": 1e-18}
np.random.seed(123)

Expand Down Expand Up @@ -760,3 +760,26 @@ def test_direct_reduction(func):
with xr.set_options(use_flox=False):
expected = getattr(data.groupby("x", squeeze=False), func)(**kwargs)
xr.testing.assert_identical(expected, actual)


@pytest.mark.parametrize("reduction", ["max", "min", "nanmax", "nanmin", "sum", "nansum", "prod", "nanprod"])
def test_groupby_preserve_dtype(reduction):
# all groups are present, we should follow numpy exactly
ds = xr.Dataset(
{
"test": (
["x", "y"],
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype="int16"),
)
},
coords={"idx": ("x", [1, 2, 1])},
)

kwargs = {"engine": "numpy"}
if "nan" in reduction:
kwargs["skipna"] = True
with xr.set_options(use_flox=True):
actual = getattr(ds.groupby("idx"), reduction.removeprefix("nan"))(**kwargs).test.dtype
expected = getattr(np, reduction)(ds.test.data, axis=0).dtype

assert actual == expected
Loading