From 7421cb152ad5fc4336bb1570165f889d18e0ea9b Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sun, 8 Sep 2024 00:05:46 -0600 Subject: [PATCH] Preserve dtype better when specified. (#389) * Preserve dtype better when specified. * Add one more test * tweak test * more test * [revert] test with Xarray PR branch * tweak * show versions * Drop python 3.9, use ruff * switch to Ruff * fix mypy * remove toctrees * fix * one more --- .github/workflows/ci.yaml | 4 +++- ci/environment.yml | 3 ++- ci/no-dask.yml | 3 ++- flox/aggregations.py | 13 ++++++++----- flox/xrdtypes.py | 9 +++++++-- tests/strategies.py | 4 ++-- tests/test_core.py | 15 ++++++++++++++- tests/test_properties.py | 24 +++++++++++++++++++++++- tests/test_xarray.py | 25 ++++++++++++++++++++++++- 9 files changed, 85 insertions(+), 15 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 3f1416b2..bbee1506 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -70,6 +70,7 @@ jobs: - name: Run Tests id: status run: | + python -c "import xarray; xarray.show_versions()" pytest --durations=20 --durations-min=0.5 -n auto --cov=./ --cov-report=xml --hypothesis-profile ci - name: Upload code coverage to Codecov uses: codecov/codecov-action@v4.5.0 @@ -98,7 +99,7 @@ jobs: steps: - uses: actions/checkout@v4 with: - repository: "pydata/xarray" + repository: "dcherian/xarray" fetch-depth: 0 # Fetch all history for all branches and tags. - name: Set up conda environment uses: mamba-org/setup-micromamba@v1 @@ -112,6 +113,7 @@ jobs: pint>=0.22 - name: Install xarray run: | + git checkout flox-preserve-dtype python -m pip install --no-deps . - name: Install upstream flox run: | diff --git a/ci/environment.yml b/ci/environment.yml index 82995d07..dac6880a 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -19,7 +19,6 @@ dependencies: - pytest-pretty - pytest-xdist - syrupy - - xarray - pre-commit - numpy_groupies>=0.9.19 - pooch @@ -27,3 +26,5 @@ dependencies: - numba - numbagg>=0.3 - hypothesis + - pip: + - git+https://github.com/dcherian/xarray.git@flox-preserve-dtype diff --git a/ci/no-dask.yml b/ci/no-dask.yml index 1f05c63a..fb2bac92 100644 --- a/ci/no-dask.yml +++ b/ci/no-dask.yml @@ -14,7 +14,6 @@ dependencies: - pytest-pretty - pytest-xdist - syrupy - - xarray - numpydoc - pre-commit - numpy_groupies>=0.9.19 @@ -22,3 +21,5 @@ dependencies: - toolz - numba - numbagg>=0.3 + - pip: + - git+https://github.com/dcherian/xarray.git@flox-preserve-dtype diff --git a/flox/aggregations.py b/flox/aggregations.py index 4e031219..0906c8cc 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -549,12 +549,15 @@ def quantile_new_dims_func(q) -> tuple[Dim]: return (Dim(name="quantile", values=q),) +# if the input contains integers or floats smaller than float64, +# the output data-type is float64. Otherwise, the output data-type is the same as that +# of the input. quantile = Aggregation( name="quantile", fill_value=dtypes.NA, chunk=None, combine=None, - final_dtype=np.floating, + final_dtype=np.float64, new_dims_func=quantile_new_dims_func, ) nanquantile = Aggregation( @@ -562,7 +565,7 @@ def quantile_new_dims_func(q) -> tuple[Dim]: fill_value=dtypes.NA, chunk=None, combine=None, - final_dtype=np.floating, + final_dtype=np.float64, new_dims_func=quantile_new_dims_func, ) mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True) @@ -801,9 +804,9 @@ def _initialize_aggregation( dtype_: np.dtype | None = ( np.dtype(dtype) if dtype is not None and not isinstance(dtype, np.dtype) else dtype ) - final_dtype = dtypes._normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value) - if not agg.preserves_dtype: - final_dtype = dtypes._maybe_promote_int(final_dtype) + final_dtype = dtypes._normalize_dtype( + dtype_ or agg.dtype_init["final"], array_dtype, agg.preserves_dtype, fill_value + ) agg.dtype = { "user": dtype, # Save to automatically choose an engine "final": final_dtype, diff --git a/flox/xrdtypes.py b/flox/xrdtypes.py index 3fd0f4fe..34d0d2a5 100644 --- a/flox/xrdtypes.py +++ b/flox/xrdtypes.py @@ -150,9 +150,14 @@ def is_datetime_like(dtype): return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) -def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype: +def _normalize_dtype( + dtype: DTypeLike, array_dtype: np.dtype, preserves_dtype: bool, fill_value=None +) -> np.dtype: if dtype is None: - dtype = array_dtype + if not preserves_dtype: + dtype = _maybe_promote_int(array_dtype) + else: + dtype = array_dtype if dtype is np.floating: # mean, std, var always result in floating # but we preserve the array's dtype if it is floating diff --git a/tests/strategies.py b/tests/strategies.py index 6f28db32..b1dc7ce3 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -27,7 +27,7 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]: # TODO: stop excluding everything but U -array_dtype_st = supported_dtypes().filter(lambda x: x.kind not in "cmMU") +array_dtypes = supported_dtypes().filter(lambda x: x.kind not in "cmMU") by_dtype_st = supported_dtypes() NON_NUMPY_FUNCS = [ @@ -43,7 +43,7 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]: func_st = st.sampled_from([f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS]) numeric_arrays = npst.arrays( - elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtype_st + elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtypes ) all_arrays = npst.arrays( elements={"allow_subnormal": False}, diff --git a/tests/test_core.py b/tests/test_core.py index 2c33ebcf..cef9ad8a 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -81,7 +81,7 @@ def _get_array_func(func: str) -> Callable: def npfunc(x, **kwargs): x = np.asarray(x) - return (~np.isnan(x)).sum() + return (~xrutils.isnull(x)).sum(**kwargs) elif func in ["nanfirst", "nanlast"]: npfunc = getattr(xrutils, func) @@ -1984,3 +1984,16 @@ def test_blockwise_nans(): ) assert_equal(expected_groups, actual_groups) assert_equal(expected, actual) + + +@pytest.mark.parametrize("func", ["sum", "prod", "count", "nansum"]) +@pytest.mark.parametrize("engine", ["flox", "numpy"]) +def test_agg_dtypes(func, engine): + # regression test for GH388 + counts = np.array([0, 2, 1, 0, 1]) + group = np.array([1, 1, 1, 2, 2]) + actual, _ = groupby_reduce( + counts, group, expected_groups=(np.array([1, 2]),), func=func, dtype="uint8", engine=engine + ) + expected = _get_array_func(func)(counts, dtype="uint8") + assert actual.dtype == np.uint8 == expected.dtype diff --git a/tests/test_properties.py b/tests/test_properties.py index 584314ce..0437ef25 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -20,7 +20,7 @@ from flox.xrutils import notnull from . import assert_equal -from .strategies import by_arrays, chunked_arrays, func_st, numeric_arrays +from .strategies import array_dtypes, by_arrays, chunked_arrays, func_st, numeric_arrays from .strategies import chunks as chunks_strategy dask.config.set(scheduler="sync") @@ -244,3 +244,25 @@ def test_first_last_useless(data, func): actual, groups = groupby_reduce(array, by, axis=-1, func=func, engine="numpy") expected = np.zeros(shape[:-1] + (len(groups),), dtype=array.dtype) assert_equal(actual, expected) + + +@given( + func=st.sampled_from(["sum", "prod", "nansum", "nanprod"]), + engine=st.sampled_from(["numpy", "flox"]), + array_dtype=st.none() | array_dtypes, + dtype=st.none() | array_dtypes, +) +def test_agg_dtype_specified(func, array_dtype, dtype, engine): + # regression test for GH388 + counts = np.array([0, 2, 1, 0, 1], dtype=array_dtype) + group = np.array([1, 1, 1, 2, 2]) + actual, _ = groupby_reduce( + counts, + group, + expected_groups=(np.array([1, 2]),), + func=func, + dtype=dtype, + engine=engine, + ) + expected = getattr(np, func)(counts, keepdims=True, dtype=dtype) + assert actual.dtype == expected.dtype diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 2592fa07..9423eb11 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -24,7 +24,7 @@ # test against legacy xarray implementation # avoid some compilation overhead -xr.set_options(use_flox=False, use_numbagg=False) +xr.set_options(use_flox=False, use_numbagg=False, use_bottleneck=False) tolerance64 = {"rtol": 1e-15, "atol": 1e-18} np.random.seed(123) @@ -760,3 +760,26 @@ def test_direct_reduction(func): with xr.set_options(use_flox=False): expected = getattr(data.groupby("x", squeeze=False), func)(**kwargs) xr.testing.assert_identical(expected, actual) + + +@pytest.mark.parametrize("reduction", ["max", "min", "nanmax", "nanmin", "sum", "nansum", "prod", "nanprod"]) +def test_groupby_preserve_dtype(reduction): + # all groups are present, we should follow numpy exactly + ds = xr.Dataset( + { + "test": ( + ["x", "y"], + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype="int16"), + ) + }, + coords={"idx": ("x", [1, 2, 1])}, + ) + + kwargs = {"engine": "numpy"} + if "nan" in reduction: + kwargs["skipna"] = True + with xr.set_options(use_flox=True): + actual = getattr(ds.groupby("idx"), reduction.removeprefix("nan"))(**kwargs).test.dtype + expected = getattr(np, reduction)(ds.test.data, axis=0).dtype + + assert actual == expected