Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4882,7 +4882,7 @@ def identical(self, other: Self) -> bool:

def __array_wrap__(self, obj, context=None, return_scalar=False) -> Self:
new_var = self.variable.__array_wrap__(obj, context, return_scalar)
return self._replace(new_var)
return self._replace_maybe_drop_dims(new_var)

def __matmul__(self, obj: T_Xarray) -> T_Xarray:
return self.dot(obj)
Expand Down
13 changes: 13 additions & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2432,6 +2432,19 @@ def real(self) -> Variable:
return self._new(data=self.data.real)

def __array_wrap__(self, obj, context=None, return_scalar=False):
if obj.shape != self.shape:
# Shape changed during the numpy operation.
# Check if only the last two dims are transposed (e.g. np.linalg.pinv).
if (
obj.ndim == self.ndim
and obj.ndim >= 2
and obj.shape[:-2] == self.shape[:-2]
and obj.shape[-2:] == self.shape[-2:][::-1]
):
new_dims = self.dims[:-2] + (self.dims[-1], self.dims[-2])
return Variable(new_dims, obj)
# Fallback: use generic dim names since we can't reliably map dims.
return Variable(tuple(f"dim_{i}" for i in range(obj.ndim)), obj)
return Variable(self.dims, obj)

def _unary_op(self, f, *args, **kwargs):
Expand Down
28 changes: 28 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2478,6 +2478,34 @@ def test_array_interface(self) -> None:
bar = Variable(["x", "y"], np.zeros((10, 20)))
assert_equal(self.dv, np.maximum(self.dv, bar))

def test_array_wrap_drops_mismatched_coords(self) -> None:
array = DataArray(
np.arange(12).reshape(3, 4),
dims=("foo", "bar"),
coords={
"foo": ["x", "y", "z"],
"bar": ["a", "b", "c", "d"],
"aux": (("foo", "bar"), np.ones((3, 4))),
"scalar": 1,
},
)

actual = np.linalg.pinv(array)
# pinv transposes the last two dims, so dims are swapped from
# (foo: 3, bar: 4) to (bar: 4, foo: 3). All coords whose dims
# still exist in the result are preserved.
expected = DataArray(
np.linalg.pinv(array.data),
dims=("bar", "foo"),
coords={
"foo": ["x", "y", "z"],
"bar": ["a", "b", "c", "d"],
"aux": (("foo", "bar"), np.ones((3, 4))),
"scalar": 1,
},
)
assert_identical(expected, actual)

def test_astype_attrs(self) -> None:
# Split into two loops for mypy - Variable, DataArray, and Dataset
# don't share a common base class, so mypy infers type object for v,
Expand Down
Loading