Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 41 additions & 28 deletions s3fs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,11 @@ def __init__(
self.use_ssl = use_ssl
self.cache_regions = cache_regions
self._s3 = None
self._set_session_lock = asyncio.Lock()
self.session = session
self._session_is_owned = (
session is None
) # False when the caller injected a session
self.fixed_upload_size = fixed_upload_size
self.local_expiry_check = local_expiry_check
if max_concurrency < 1:
Expand Down Expand Up @@ -599,7 +603,9 @@ async def set_session(self, refresh=False, kwargs={}):
if self._s3 is not None and not refresh:
hsess = getattr(getattr(self._s3, "_endpoint", None), "http_session", None)
if hsess is not None:
if all(_.closed for _ in hsess._sessions.values()):
if hsess._sessions is None or (
hsess._sessions and all(_.closed for _ in hsess._sessions.values())
):
refresh = True
if not refresh:
return self._s3
Expand Down Expand Up @@ -638,36 +644,43 @@ async def set_session(self, refresh=False, kwargs={}):
}
config_kwargs["signature_version"] = UNSIGNED

conf = AioConfig(**config_kwargs)
if self.session is None or refresh:
self.session = aiobotocore.session.AioSession(**self.kwargs)

for parameters in (config_kwargs, self.kwargs, init_kwargs, client_kwargs):
for option in ("region_name", "endpoint_url"):
if parameters.get(option):
self.cache_regions = False
break
else:
cache_regions = self.cache_regions
async with self._set_session_lock:
# Re-check under the lock: a concurrent task may have set up the
# session while we were waiting.
if self._s3 is not None and not refresh:
return self._s3
conf = AioConfig(**config_kwargs)
if self.session is None or (refresh and self._session_is_owned):
# Only (re)create the AioSession when s3fs owns it
self.session = aiobotocore.session.AioSession(**self.kwargs)
self._session_is_owned = True

for parameters in (config_kwargs, self.kwargs, init_kwargs, client_kwargs):
for option in ("region_name", "endpoint_url"):
if parameters.get(option):
self.cache_regions = False
break
else:
cache_regions = self.cache_regions

logger.debug(
"RC: caching enabled? %r (explicit option is %r)",
cache_regions,
self.cache_regions,
)
self.cache_regions = cache_regions
if self.cache_regions:
s3creator = S3BucketRegionCache(
self.session, config=conf, **init_kwargs, **client_kwargs
)
self._s3 = await s3creator.get_client()
else:
s3creator = self.session.create_client(
"s3", config=conf, **init_kwargs, **client_kwargs
logger.debug(
"RC: caching enabled? %r (explicit option is %r)",
cache_regions,
self.cache_regions,
)
self._s3 = await s3creator.__aenter__()
self.cache_regions = cache_regions
if self.cache_regions:
s3creator = S3BucketRegionCache(
self.session, config=conf, **init_kwargs, **client_kwargs
)
self._s3 = await s3creator.get_client()
else:
s3creator = self.session.create_client(
"s3", config=conf, **init_kwargs, **client_kwargs
)
self._s3 = await s3creator.__aenter__()

self._s3creator = s3creator
self._s3creator = s3creator
# the following actually closes the aiohttp connection; use of privates
# might break in the future, would cause exception at gc time
if not self.asynchronous:
Expand Down
155 changes: 149 additions & 6 deletions s3fs/tests/test_s3fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3172,14 +3172,14 @@ def test_find_missing_ls(s3):
assert set(listed_cached) == set(listed_no_cache)


def test_session_close():
def test_session_close(s3):
s3.pipe(f"{test_bucket_name}/dir/afile", b"small")

async def run_program(run):
s3 = s3fs.S3FileSystem(anon=True, asynchronous=True)
s3 = s3fs.S3FileSystem(anon=True, asynchronous=True, endpoint_url=endpoint_uri)
s3.invalidate_cache()
session = await s3.set_session()
files = await s3._ls(
"s3://noaa-hrrr-bdp-pds/hrrr.20140730/conus/"
) # Random open data store
print(f"Number of files {len(files)}")
files = await s3._ls(f"{test_bucket_name}/dir")
await session.close()

import aiobotocore.httpsession
Expand All @@ -3189,6 +3189,97 @@ async def run_program(run):
asyncio.run(run_program(False))


def test_set_session_sessions_none(s3):
"""After the HTTP session is closed and aiobotocore sets _sessions=None
(aiobotocore 3.x behaviour), set_session must rebuild the client rather
than returning the dead one."""
s3.pipe(f"{test_bucket_name}/dir/afile", b"small")

async def run():
fs = S3FileSystem(
anon=False,
asynchronous=True,
client_kwargs={"endpoint_url": endpoint_uri},
skip_instance_cache=True,
)
await fs.set_session()
original_client = fs._s3

# Simulate aiobotocore 3.x behaviour: __aexit__ sets _sessions to None.
hsess = fs._s3._endpoint.http_session
hsess._sessions = None

# set_session must detect the dead client and create a new one.
await fs.set_session()
assert (
fs._s3 is not original_client
), "set_session should have rebuilt the client when _sessions is None"
await fs.set_session(refresh=True) # clean up

asyncio.run(run())


def test_set_session_concurrent_no_leak(s3):
"""Concurrent calls to set_session on a fresh instance must not leak
clients (i.e. only one client should be created, not N)."""
s3.pipe(f"{test_bucket_name}/dir/afile", b"small")

async def run():
S3FileSystem.clear_instance_cache()
fs = S3FileSystem(
anon=False,
asynchronous=True,
client_kwargs={"endpoint_url": endpoint_uri},
skip_instance_cache=True,
)
# All coroutines start with _s3 == None; fire them simultaneously.
results = await asyncio.gather(*[fs.set_session() for _ in range(8)])
# Every coroutine must get back the same client object.
assert (
len(set(id(r) for r in results)) == 1
), "Concurrent set_session calls returned different client objects"
# Only one client should be alive (no leaked extras).
assert fs._s3 is results[0]
await fs.set_session(refresh=True) # clean up

asyncio.run(run())


def test_set_session_preserves_injected_session(s3):
"""A session= injected at construction time must not be replaced when
set_session triggers a refresh due to closed HTTP connections."""
s3.pipe(f"{test_bucket_name}/dir/afile", b"small")

async def run():
import aiobotocore.session as aio_session

custom_session = aio_session.AioSession()
fs = S3FileSystem(
anon=False,
asynchronous=True,
session=custom_session,
client_kwargs={"endpoint_url": endpoint_uri},
skip_instance_cache=True,
)
await fs.set_session()
assert (
fs.session is custom_session
), "session should not be replaced on first connect"

# Simulate all HTTP connections being closed so set_session will refresh.
hsess = fs._s3._endpoint.http_session
if hsess._sessions:
for sess in hsess._sessions.values():
await sess.close()

await fs.set_session()
assert (
fs.session is custom_session
), "set_session must not replace a user-injected session on refresh"

asyncio.run(run())


def test_rm_recursive_prfix(s3):
prefix = "logs/" # must end with "/"

Expand All @@ -3198,3 +3289,55 @@ def test_rm_recursive_prfix(s3):
logs_path = f"s3://{test_bucket_name}/{prefix}"
s3.rm(logs_path, recursive=True)
assert not s3.isdir(logs_path)


def test_set_session_closed_sessions_rebuilds_once(s3):
"""Regression test for the #1019 perf regression: a populated-but-all-closed
_sessions dict must rebuild the client exactly once, then reuse it. The
rebuilt client has an empty _sessions dict, which must not count as
"all closed" (vacuous all([]) == True) and force a refresh on every call.
"""
import aiobotocore.session as aio_session
from unittest import mock

s3.pipe(f"{test_bucket_name}/dir/afile", b"small")

create_client_calls = 0
original_create_client = aio_session.AioSession._create_client

async def counting_create_client(self, *args, **kwargs):
nonlocal create_client_calls
create_client_calls += 1
return await original_create_client(self, *args, **kwargs)

async def run():
fs = S3FileSystem(
anon=False,
asynchronous=True,
client_kwargs={"endpoint_url": endpoint_uri},
skip_instance_cache=True,
)
await fs._ls(f"{test_bucket_name}/dir") # populates _sessions
sessions = fs._s3._endpoint.http_session._sessions
assert sessions, "expected a populated _sessions dict after an op"

for sess in sessions.values():
await sess.close()

baseline = create_client_calls # ignore the warmup build above
iterations = 10
for _ in range(iterations):
await fs.set_session()

rebuilds = create_client_calls - baseline
assert rebuilds == 1, (
f"set_session rebuilt the client {rebuilds} times across {iterations} "
f"calls; expected 1. >1 means the empty-dict case forces a refresh on "
f"every call (#1019 regression)."
)
await fs.set_session(refresh=True) # clean up

with mock.patch.object(
aio_session.AioSession, "_create_client", counting_create_client
):
asyncio.run(run())
Loading