Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jul 26, 2024
1 parent 5e0b065 commit 64c7be9
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 18 deletions.
3 changes: 3 additions & 0 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def generic_aggregate(
if func == "identity":
return array

if func in ["nanfirst", "nanlast"] and array.dtype.kind in "US":
func = func[3:]

if engine == "flox":
try:
method = getattr(aggregate_flox, func)
Expand Down
117 changes: 105 additions & 12 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,27 +101,32 @@ def assert_equal(a, b, tolerance=None):
else:
tolerance = {}

# Always run the numpy comparison first, so that we get nice error messages with dask.
# sometimes it's nice to see values and shapes
# rather than being dropped into some file in dask
if a.dtype != b.dtype:
raise AssertionError(f"a and b have different dtypes: (a: {a.dtype}, b: {b.dtype})")

if has_dask:
a_eager = a.compute() if isinstance(a, dask_array_type) else a
b_eager = b.compute() if isinstance(b, dask_array_type) else b

if a.dtype.kind in "SUMm":
np.testing.assert_equal(a_eager, b_eager)
else:
np.testing.assert_allclose(a_eager, b_eager, equal_nan=True, **tolerance)

if has_dask and isinstance(a, dask_array_type) or isinstance(b, dask_array_type):
# sometimes it's nice to see values and shapes
# rather than being dropped into some file in dask
np.testing.assert_allclose(a, b, **tolerance)
# does some validation of the dask graph
da.utils.assert_eq(a, b, equal_nan=True)
else:
if a.dtype != b.dtype:
raise AssertionError(f"a and b have different dtypes: (a: {a.dtype}, b: {b.dtype})")
if a.dtype.kind in "SU":
np.testing.assert_equal(a, b)
else:
np.testing.assert_allclose(a, b, equal_nan=True, **tolerance)
dask_assert_eq(a, b, equal_nan=True)


def assert_equal_tuple(a, b):
"""assert_equal for .blocks indexing tuples"""
assert len(a) == len(b)

for a_, b_ in zip(a, b):
assert type(a_) == type(b_)
assert type(a_) is type(b_)
if isinstance(a_, np.ndarray):
np.testing.assert_array_equal(a_, b_)
else:
Expand Down Expand Up @@ -158,3 +163,91 @@ def assert_equal_tuple(a, b):
"quantile",
"nanquantile",
) + tuple(SCIPY_STATS_FUNCS)


def dask_assert_eq(
a,
b,
check_shape=True,
check_graph=True,
check_meta=True,
check_chunks=True,
check_ndim=True,
check_type=True,
check_dtype=True,
equal_nan=True,
scheduler="sync",
**kwargs,
):
"""dask.array.utils.assert_eq modified to skip value checks. Their code is buggy for some dtypes.
We just check values through numpy and care about validating the graph in this function."""
from dask.array.utils import _get_dt_meta_computed

a_original = a
b_original = b

if isinstance(a, (list, int, float)):
a = np.array(a)
if isinstance(b, (list, int, float)):
b = np.array(b)

a, adt, a_meta, a_computed = _get_dt_meta_computed(
a,
check_shape=check_shape,
check_graph=check_graph,
check_chunks=check_chunks,
check_ndim=check_ndim,
scheduler=scheduler,
)
b, bdt, b_meta, b_computed = _get_dt_meta_computed(
b,
check_shape=check_shape,
check_graph=check_graph,
check_chunks=check_chunks,
check_ndim=check_ndim,
scheduler=scheduler,
)

if check_type:
_a = a if a.shape else a.item()
_b = b if b.shape else b.item()
assert type(_a) is type(_b), f"a and b have different types (a: {type(_a)}, b: {type(_b)})"
if check_meta:
if hasattr(a, "_meta") and hasattr(b, "_meta"):
dask_assert_eq(a._meta, b._meta)
if hasattr(a_original, "_meta"):
msg = (
f"compute()-ing 'a' changes its number of dimensions "
f"(before: {a_original._meta.ndim}, after: {a.ndim})"
)
assert a_original._meta.ndim == a.ndim, msg
if a_meta is not None:
msg = (
f"compute()-ing 'a' changes its type "
f"(before: {type(a_original._meta)}, after: {type(a_meta)})"
)
assert type(a_original._meta) is type(a_meta), msg
if not (np.isscalar(a_meta) or np.isscalar(a_computed)):
msg = (
f"compute()-ing 'a' results in a different type than implied by its metadata "
f"(meta: {type(a_meta)}, computed: {type(a_computed)})"
)
assert type(a_meta) is type(a_computed), msg
if hasattr(b_original, "_meta"):
msg = (
f"compute()-ing 'b' changes its number of dimensions "
f"(before: {b_original._meta.ndim}, after: {b.ndim})"
)
assert b_original._meta.ndim == b.ndim, msg
if b_meta is not None:
msg = (
f"compute()-ing 'b' changes its type "
f"(before: {type(b_original._meta)}, after: {type(b_meta)})"
)
assert type(b_original._meta) is type(b_meta), msg
if not (np.isscalar(b_meta) or np.isscalar(b_computed)):
msg = (
f"compute()-ing 'b' results in a different type than implied by its metadata "
f"(meta: {type(b_meta)}, computed: {type(b_computed)})"
)
assert type(b_meta) is type(b_computed), msg
9 changes: 3 additions & 6 deletions tests/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def test_scans(data, array, func):

@given(data=st.data(), array=chunked_arrays())
def test_ffill_bfill_reverse(data, array):
# TODO: test NaT and timedelta, datetime
assume(not_overflowing_array(np.asarray(array)))
by = data.draw(by_arrays(shape=(array.shape[-1],)))

Expand Down Expand Up @@ -240,9 +241,7 @@ def test_first_last(data, array, func):
mate = MATES[func]

if func in ["first", "last"]:
newchunks = list(array.chunks)
newchunks[-1] = -1
array = array.rechunk(newchunks)
array = array.rechunk((*array.chunks[:-1], -1))

for arr in [array, array.compute()]:
forward, fg = groupby_reduce(arr, by, func=func, engine="flox")
Expand All @@ -256,9 +255,7 @@ def test_first_last(data, array, func):

if arr.dtype.kind == "f" and not np.isnan(array.compute()).any():
if mate in ["first", "last"]:
newchunks = list(array.chunks)
newchunks[-1] = -1
array = array.rechunk(newchunks)
array = array.rechunk((*array.chunks[:-1], -1))

first, _ = groupby_reduce(array, by, func=func, engine="flox")
second, _ = groupby_reduce(array, by, func=mate, engine="flox")
Expand Down

0 comments on commit 64c7be9

Please sign in to comment.