Skip to content

Commit

Permalink
Fix quite broken bfill
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jul 28, 2024
1 parent 63d700d commit ed73a5d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
4 changes: 2 additions & 2 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,8 +654,8 @@ def __post_init__(self):


def reverse(a: AlignedArrays) -> AlignedArrays:
a.group_idx = a.group_idx[::-1]
a.array = a.array[::-1]
a.group_idx = a.group_idx[..., ::-1]
a.array = a.array[..., ::-1]
return a


Expand Down
31 changes: 20 additions & 11 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2788,13 +2788,24 @@ def groupby_scan(
if by_.shape[-1] == 1 or by_.shape == grp_shape:
return array.astype(agg.dtype)

# Made a design choice here to have `preprocess` handle both array and group_idx
# Example: for reversing, we need to reverse the whole array, not just reverse
# each block independently
inp = AlignedArrays(array=array, group_idx=by_)
if agg.preprocess:
inp = agg.preprocess(inp)

if not has_dask:
final_state = chunk_scan(
AlignedArrays(array=array, group_idx=by_), axis=single_axis, agg=agg, dtype=agg.dtype
)
return extract_array(final_state)
final_state = chunk_scan(inp, axis=single_axis, agg=agg, dtype=agg.dtype)
result = _finalize_scan(final_state)
else:
return dask_groupby_scan(array, by_, axes=axis_, agg=agg)
result = dask_groupby_scan(inp.array, inp.group_idx, axes=axis_, agg=agg)

# Made a design choice here to have `postprocess` handle both array and group_idx
out = AlignedArrays(array=result, group_idx=by_)
if agg.finalize:
out = agg.finalize(out)
return out.array


def chunk_scan(inp: AlignedArrays, *, axis: int, agg: Scan, dtype=None, keepdims=None) -> ScanState:
Expand Down Expand Up @@ -2836,10 +2847,9 @@ def _zip(group_idx: np.ndarray, array: np.ndarray) -> AlignedArrays:
return AlignedArrays(group_idx=group_idx, array=array)


def extract_array(block: ScanState, finalize: Callable | None = None) -> np.ndarray:
def _finalize_scan(block: ScanState) -> np.ndarray:
assert block.result is not None
result = finalize(block.result) if finalize is not None else block.result
return result.array
return block.result.array


def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan) -> DaskArray:
Expand All @@ -2855,9 +2865,8 @@ def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan) -> DaskArray:
array, by = _unify_chunks(array, by)

# 1. zip together group indices & array
to_map = _zip if agg.preprocess is None else tlz.compose(agg.preprocess, _zip)
zipped = map_blocks(
to_map, by, array, dtype=array.dtype, meta=array._meta, name="groupby-scan-preprocess"
_zip, by, array, dtype=array.dtype, meta=array._meta, name="groupby-scan-preprocess"
)

scan_ = partial(chunk_scan, agg=agg)
Expand All @@ -2878,7 +2887,7 @@ def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan) -> DaskArray:
)

# 3. Unzip and extract the final result array, discard groups
result = map_blocks(extract_array, accumulated, dtype=agg.dtype, finalize=agg.finalize)
result = map_blocks(_finalize_scan, accumulated, dtype=agg.dtype)

assert result.chunks == array.chunks

Expand Down

0 comments on commit ed73a5d

Please sign in to comment.