From 64548062bc9d97a000fda4571860a12faf32337a Mon Sep 17 00:00:00 2001 From: C1-BA-B1-F3 Date: Fri, 26 Jun 2026 13:01:53 +0800 Subject: [PATCH] Fix np.linalg.pinv dimension inference for DataArrays When np.linalg.pinv is called on a DataArray or Variable, the result has swapped last two dimensions. Previously, __array_wrap__ blindly preserved the original dimension names, causing a mismatch between dimension names and their actual sizes. This fix adds dimension inference in __array_wrap__ when the result shape differs from the original. It finds the permutation that maps old dimensions to new ones based on their sizes. Fixes #11396 --- xarray/core/variable.py | 60 +++++++++++++++++++++++++++++++++- xarray/tests/test_dataarray.py | 13 ++++++++ xarray/tests/test_variable.py | 14 ++++++++ 3 files changed, 86 insertions(+), 1 deletion(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 1b5e0d4ff69..9d0ca05257a 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -95,6 +95,60 @@ class MissingDimensionsError(ValueError): # TODO: move this to an xarray.exceptions module? +def _infer_new_dims(old_dims, old_shape, new_shape): + """Infer new dimension names when shape changes in __array_wrap__. + + Handles the case where numpy operations rearrange dimensions (e.g., + np.linalg.pinv swaps the last two axes). When the new shape is a + permutation of the old shape, we map dimension names accordingly. + + Parameters + ---------- + old_dims : tuple + Original dimension names. + old_shape : tuple + Original shape. + new_shape : tuple + New shape after the numpy operation. + + Returns + ------- + tuple + New dimension names that match the new shape. + + Raises + ------ + ValueError + If the shape change cannot be resolved as a dimension permutation. + """ + if len(old_shape) != len(new_shape): + raise ValueError( + f"Cannot infer dimensions: shape changed from {old_shape} to {new_shape} " + f"(different number of dimensions)" + ) + + # Try to find a permutation that maps old_shape to new_shape + # This handles cases like np.linalg.pinv which swaps last two dims + used = [False] * len(old_shape) + new_dims = list(old_dims) + + for i, new_size in enumerate(new_shape): + found = False + for j, old_size in enumerate(old_shape): + if not used[j] and old_size == new_size: + new_dims[i] = old_dims[j] + used[j] = True + found = True + break + if not found: + raise ValueError( + f"Cannot infer dimensions: shape changed from {old_shape} to {new_shape} " + f"(not a permutation of existing dimensions)" + ) + + return tuple(new_dims) + + def as_variable( obj: T_DuckArray | Any, name=None, auto_convert: bool = True ) -> Variable | IndexVariable: @@ -2432,7 +2486,11 @@ def real(self) -> Variable: return self._new(data=self.data.real) def __array_wrap__(self, obj, context=None, return_scalar=False): - return Variable(self.dims, obj) + if obj.shape != self.shape: + dims = _infer_new_dims(self.dims, self.shape, obj.shape) + else: + dims = self.dims + return Variable(dims, obj) def _unary_op(self, f, *args, **kwargs): keep_attrs = kwargs.pop("keep_attrs", None) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 377fd2f8a8b..cd4b45a40b0 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -7673,3 +7673,16 @@ def test_unstack_index_var() -> None: name="x", ) assert_identical(actual, expected) + + +def test_pinv_infers_dims_on_shape_change(): + # Regression test for GH#11396 + da = xr.DataArray( + np.arange(12).reshape(3, 4), + coords={"foo": ["x", "y", "z"], "bar": ["a", "b", "c", "d"]}, + ) + result = np.linalg.pinv(da) + assert result.shape == (4, 3) + assert result.dims == ("bar", "foo") + assert list(result.coords["foo"].values) == ["x", "y", "z"] + assert list(result.coords["bar"].values) == ["a", "b", "c", "d"] diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 8e240352436..be89e11c5ab 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -2947,6 +2947,20 @@ def __init__(self, array): orig = Variable(dims=(), data=array2) assert isinstance(orig._data.item(), CustomWithValuesAttr) # type: ignore[union-attr] + def test_array_wrap_infers_dims_on_shape_change(self): + # Regression test for https://github.com/pydata/xarray/issues/11396 + # np.linalg.pinv swaps last two dims; __array_wrap__ should infer this + var = Variable(dims=("x", "y"), data=np.arange(12).reshape(3, 4)) + result = np.linalg.pinv(var) + assert result.shape == (4, 3) + assert result.dims == ("y", "x") + + def test_array_wrap_preserves_dims_when_shape_unchanged(self): + var = Variable(dims=("x", "y"), data=np.random.rand(3, 3)) + result = np.linalg.pinv(var) + assert result.shape == (3, 3) + assert result.dims == ("x", "y") + def test_raise_no_warning_for_nan_in_binary_ops(): with assert_no_warnings():