From f6fff33a19a3517f619e2938e2e33defd56c8eaa Mon Sep 17 00:00:00 2001 From: C1-BA-B1-F3 Date: Fri, 26 Jun 2026 10:28:22 +0800 Subject: [PATCH] fix: cumulative argmax/argmin returns absolute indices (GH#11336) When calling argmax/argmin on a cumulative rolling window, the result returned window-local indices (positions within the NaN-padded window) instead of absolute indices in the original array. The fix detects cumulative rolling (window size == dimension size, not centered) and adjusts the returned indices by subtracting the NaN padding offset: absolute = local - (window_size - 1 - position). Regression tests included for 1D and 2D cases. --- xarray/computation/rolling.py | 34 +++++++++++++++++++++++++++++++++- xarray/tests/test_rolling.py | 25 +++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/xarray/computation/rolling.py b/xarray/computation/rolling.py index 0f914691211..5104a3b31cf 100644 --- a/xarray/computation/rolling.py +++ b/xarray/computation/rolling.py @@ -177,7 +177,7 @@ def _reduce_method( # type: ignore[misc] def method(self, keep_attrs=None, **kwargs): keep_attrs = self._get_keep_attrs(keep_attrs) - return self._array_reduce( + result = self._array_reduce( array_agg_func=array_agg_func, bottleneck_move_func=bottleneck_move_func, numbagg_move_func=numbagg_move_func, @@ -188,6 +188,11 @@ def method(self, keep_attrs=None, **kwargs): **kwargs, ) + if name in ("argmax", "argmin"): + result = self._adjust_argminmax_result(result) + + return result + method.__name__ = name method.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name=name) return method @@ -580,6 +585,33 @@ def reduce( counts = self._counts(keep_attrs=False) return result.where(counts >= self.min_periods) + def _is_cumulative(self) -> bool: + for i, d in enumerate(self.dim): + if self.center[i]: + return False + if self.window[i] != self.obj.sizes[d]: + return False + return True + + def _adjust_argminmax_result(self, result: DataArray) -> DataArray: + if not self._is_cumulative(): + return result + + for i, d in enumerate(self.dim): + window_size = self.window[i] + if window_size <= 1: + continue + + n = result.sizes[d] + position = np.arange(n) + # GH#11336: cumulative windows have NaN padding on the left, + # so window-local indices need adjustment to become absolute: + # absolute = local - (window_size - 1 - position) + offset = window_size - 1 - position + result = result - offset + + return result + def _counts(self, keep_attrs: bool | None) -> DataArray: """Number of non-nan entries in each rolling window.""" diff --git a/xarray/tests/test_rolling.py b/xarray/tests/test_rolling.py index b4f9f45ea52..720179fdbbc 100644 --- a/xarray/tests/test_rolling.py +++ b/xarray/tests/test_rolling.py @@ -49,6 +49,31 @@ def test_cumulative_vs_cum(d) -> None: assert_identical(result, expected) +@pytest.mark.parametrize("func", ["argmax", "argmin"]) +def test_cumulative_argminmax(func) -> None: + # Regression test for GH#11336: cumulative argmax/argmin should return + # absolute indices, not window-local indices + da = DataArray([3, 1, 4, 2, 5], dims=["time"]) + result = getattr(da.cumulative("time"), func)() + + if func == "argmax": + expected = DataArray([0, 0, 2, 2, 4], dims=["time"]) + else: + expected = DataArray([0, 1, 1, 1, 1], dims=["time"]) + + assert_identical(result, expected) + + +def test_cumulative_argmax_2d() -> None: + da = DataArray( + [[1, 3, 2], [4, 1, 5]], + dims=("x", "time"), + ) + result = da.cumulative("time").argmax() + expected = DataArray([[0, 1, 1], [0, 0, 2]], dims=("x", "time")) + assert_identical(result, expected) + + class TestDataArrayRolling: @pytest.mark.parametrize("da", (1, 2), indirect=True) @pytest.mark.parametrize("center", [True, False])