Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,39 @@
"""Configuration for pytest."""

import os

import pytest


def _use_dask_array(config: pytest.Config) -> bool:
env_value = os.environ.get("XR_USE_DASK_ARRAY_WITH_EXPR", "")
return config.getoption("--use-dask-array-with-expr") or env_value.lower() in {
"1",
"true",
"yes",
"on",
}


def _register_dask_array() -> None:
try:
import dask_array.xarray
except ImportError as err:
raise pytest.UsageError(
"--use-dask-array-with-expr requires dask-array to be importable"
) from err

dask_array.xarray.register()
if not dask_array.xarray.isactive():
raise pytest.UsageError(
"--use-dask-array-with-expr registered dask-array, but it is not the active dask chunk manager"
)

from xarray.tests import refresh_dask_chunkmanager_helpers

refresh_dask_chunkmanager_helpers()


def pytest_addoption(parser: pytest.Parser):
"""Add command-line flags for pytest."""
parser.addoption("--run-flaky", action="store_true", help="runs flaky tests")
Expand All @@ -12,9 +43,31 @@ def pytest_addoption(parser: pytest.Parser):
help="runs tests requiring a network connection",
)
parser.addoption("--run-mypy", action="store_true", help="runs mypy tests")
parser.addoption(
"--use-dask-array-with-expr",
action="store_true",
help="register dask-array as xarray's dask chunk manager",
)


def pytest_configure(config: pytest.Config):
config.addinivalue_line(
"markers",
"skip_with_dask_array: skip when dask-array is registered as xarray's dask chunk manager",
)
config.addinivalue_line(
"markers",
"xfail_with_dask_array: xfail when dask-array is registered as xarray's dask chunk manager",
)
if not _use_dask_array(config):
return

_register_dask_array()


def pytest_runtest_setup(item):
if _use_dask_array(item.config):
_register_dask_array()
# based on https://stackoverflow.com/questions/47559524
if "flaky" in item.keywords and not item.config.getoption("--run-flaky"):
pytest.skip("set --run-flaky option to run flaky tests")
Expand All @@ -39,6 +92,18 @@ def pytest_collection_modifyitems(items):
# marking approach, meaning that each test case must contain "mypy" in the
# name.
item.add_marker(pytest.mark.mypy)
if _use_dask_array(item.config) and "skip_with_dask_array" in item.keywords:
item.add_marker(
pytest.mark.skip(reason="skipped with dask-array chunk manager")
)
if _use_dask_array(item.config) and "xfail_with_dask_array" in item.keywords:
mark = item.get_closest_marker("xfail_with_dask_array")
kwargs = dict(mark.kwargs) if mark is not None else {}
kwargs.setdefault(
"reason", "expected failure with dask-array chunk manager"
)
kwargs.setdefault("strict", True)
item.add_marker(pytest.mark.xfail(**kwargs))


@pytest.fixture(autouse=True)
Expand Down
8 changes: 6 additions & 2 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from xarray.core.types import ReadBuffer
from xarray.core.utils import emit_user_level_warning, is_remote_uri
from xarray.namedarray.daskmanager import DaskManager
from xarray.namedarray.parallelcompat import guess_chunkmanager
from xarray.namedarray.parallelcompat import guess_chunkmanager, list_chunkmanagers
from xarray.namedarray.utils import _get_chunk
from xarray.structure.chunks import _maybe_chunk
from xarray.structure.combine import (
Expand Down Expand Up @@ -232,7 +232,11 @@ def _chunk_ds(
chunkmanager = guess_chunkmanager(chunked_array_type)

# TODO refactor to move this dask-specific logic inside the DaskManager class
if isinstance(chunkmanager, DaskManager):
is_dask_chunkmanager = isinstance(chunkmanager, DaskManager) or any(
name == "dask" and manager is chunkmanager
for name, manager in list_chunkmanagers().items()
)
if is_dask_chunkmanager:
from dask.base import tokenize

mtime = _get_mtime(filename_or_obj)
Expand Down
20 changes: 20 additions & 0 deletions xarray/compat/dask_array_compat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any

from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array
from xarray.namedarray.utils import module_available


Expand All @@ -12,6 +13,8 @@ def reshape_blockwise(
try:
array_api = get_chunked_array_type(x).array_api
except TypeError:
if is_chunked_array(x):
raise
array_api = None

if array_api is not None and hasattr(array_api, "reshape_blockwise"):
Expand All @@ -29,6 +32,23 @@ def sliding_window_view(
):
# Backcompat for handling `automatic_rechunk`, delete when dask>=2024.11.0
# Note that subok, writeable are unsupported by dask, so we ignore those in kwargs
try:
array_api = get_chunked_array_type(x).array_api
except TypeError:
if is_chunked_array(x):
raise
array_api = None

if array_api is not None:
array_sliding_window_view = getattr(array_api, "sliding_window_view", None)
if array_sliding_window_view is not None:
return array_sliding_window_view(
x,
window_shape=window_shape,
axis=axis,
automatic_rechunk=automatic_rechunk,
)

from dask.array.lib.stride_tricks import sliding_window_view

if module_available("dask", "2024.11.0"):
Expand Down
26 changes: 14 additions & 12 deletions xarray/core/accessor_dt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Generic
from typing import TYPE_CHECKING, Generic, cast

import numpy as np
import pandas as pd
Expand All @@ -17,6 +17,7 @@
)
from xarray.core.types import T_DataArray
from xarray.core.variable import IndexVariable, Variable
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.utils import is_duck_dask_array

if TYPE_CHECKING:
Expand Down Expand Up @@ -129,15 +130,14 @@ def _get_date_field(values, name, dtype):
access_method = _access_through_cftimeindex

if is_duck_dask_array(values):
from dask.array import map_blocks

chunkmanager = get_chunked_array_type(values)
new_axis = chunks = None
# isocalendar adds an axis
if name == "isocalendar":
chunks = (3,) + values.chunksize
chunks = cast(tuple[int, ...], (3,) + values.chunksize)
new_axis = 0

return map_blocks(
return chunkmanager.map_blocks(
access_method, values, name, dtype=dtype, new_axis=new_axis, chunks=chunks
)
else:
Expand Down Expand Up @@ -187,10 +187,13 @@ def _round_field(values, name, freq):

"""
if is_duck_dask_array(values):
from dask.array import map_blocks

dtype = np.datetime64 if is_np_datetime_like(values.dtype) else np.dtype("O")
return map_blocks(
chunkmanager = get_chunked_array_type(values)
dtype = (
np.dtype(values.dtype)
if is_np_datetime_like(values.dtype)
else np.dtype("O")
)
return chunkmanager.map_blocks(
_round_through_series_or_index, values, name, freq=freq, dtype=dtype
)
else:
Expand Down Expand Up @@ -224,9 +227,8 @@ def _strftime(values, date_format):
else:
access_method = _strftime_through_cftimeindex
if is_duck_dask_array(values):
from dask.array import map_blocks

return map_blocks(access_method, values, date_format)
chunkmanager = get_chunked_array_type(values)
return chunkmanager.map_blocks(access_method, values, date_format)
else:
return access_method(values, date_format)

Expand Down
Loading
Loading