Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,60 @@ class MissingDimensionsError(ValueError):
# TODO: move this to an xarray.exceptions module?


def _infer_new_dims(old_dims, old_shape, new_shape):
"""Infer new dimension names when shape changes in __array_wrap__.

Handles the case where numpy operations rearrange dimensions (e.g.,
np.linalg.pinv swaps the last two axes). When the new shape is a
permutation of the old shape, we map dimension names accordingly.

Parameters
----------
old_dims : tuple
Original dimension names.
old_shape : tuple
Original shape.
new_shape : tuple
New shape after the numpy operation.

Returns
-------
tuple
New dimension names that match the new shape.

Raises
------
ValueError
If the shape change cannot be resolved as a dimension permutation.
"""
if len(old_shape) != len(new_shape):
raise ValueError(
f"Cannot infer dimensions: shape changed from {old_shape} to {new_shape} "
f"(different number of dimensions)"
)

# Try to find a permutation that maps old_shape to new_shape
# This handles cases like np.linalg.pinv which swaps last two dims
used = [False] * len(old_shape)
new_dims = list(old_dims)

for i, new_size in enumerate(new_shape):
found = False
for j, old_size in enumerate(old_shape):
if not used[j] and old_size == new_size:
new_dims[i] = old_dims[j]
used[j] = True
found = True
break
if not found:
raise ValueError(
f"Cannot infer dimensions: shape changed from {old_shape} to {new_shape} "
f"(not a permutation of existing dimensions)"
)

return tuple(new_dims)


def as_variable(
obj: T_DuckArray | Any, name=None, auto_convert: bool = True
) -> Variable | IndexVariable:
Expand Down Expand Up @@ -2432,7 +2486,11 @@ def real(self) -> Variable:
return self._new(data=self.data.real)

def __array_wrap__(self, obj, context=None, return_scalar=False):
return Variable(self.dims, obj)
if obj.shape != self.shape:
dims = _infer_new_dims(self.dims, self.shape, obj.shape)
else:
dims = self.dims
return Variable(dims, obj)

def _unary_op(self, f, *args, **kwargs):
keep_attrs = kwargs.pop("keep_attrs", None)
Expand Down
13 changes: 13 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -7673,3 +7673,16 @@ def test_unstack_index_var() -> None:
name="x",
)
assert_identical(actual, expected)


def test_pinv_infers_dims_on_shape_change():
# Regression test for GH#11396
da = xr.DataArray(
np.arange(12).reshape(3, 4),
coords={"foo": ["x", "y", "z"], "bar": ["a", "b", "c", "d"]},
)
result = np.linalg.pinv(da)
assert result.shape == (4, 3)
assert result.dims == ("bar", "foo")
assert list(result.coords["foo"].values) == ["x", "y", "z"]
assert list(result.coords["bar"].values) == ["a", "b", "c", "d"]
14 changes: 14 additions & 0 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2947,6 +2947,20 @@ def __init__(self, array):
orig = Variable(dims=(), data=array2)
assert isinstance(orig._data.item(), CustomWithValuesAttr) # type: ignore[union-attr]

def test_array_wrap_infers_dims_on_shape_change(self):
# Regression test for https://github.com/pydata/xarray/issues/11396
# np.linalg.pinv swaps last two dims; __array_wrap__ should infer this
var = Variable(dims=("x", "y"), data=np.arange(12).reshape(3, 4))
result = np.linalg.pinv(var)
assert result.shape == (4, 3)
assert result.dims == ("y", "x")

def test_array_wrap_preserves_dims_when_shape_unchanged(self):
var = Variable(dims=("x", "y"), data=np.random.rand(3, 3))
result = np.linalg.pinv(var)
assert result.shape == (3, 3)
assert result.dims == ("x", "y")


def test_raise_no_warning_for_nan_in_binary_ops():
with assert_no_warnings():
Expand Down
Loading