Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ markers = [
"network: marks tests requiring network access",
"slow: marks other tests that cause bottlenecks",
"hypothesis: tests that require hypothesis",
"hypothesis_dbf: hypothesis tests that test dbf functionality",
]
python_files = "test_*.py *_test.py *_tests.py"

Expand Down
1 change: 1 addition & 0 deletions src/shapefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ def from_unchecked(

# Raise Exception or trigger warning early, before user adds more fields
# (fields are only written when first record added, and on close)
# Tests field_type, size and decimal. Name already tested and cached above.
inst.encode_field_descriptor(
encoding=encoding,
encodingErrors=encodingErrors,
Expand Down
63 changes: 52 additions & 11 deletions tests/hypothesis_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import io
import itertools
import string
import warnings

import pytest
from hypothesis import HealthCheck, given, settings, reproduce_failure
Expand All @@ -26,6 +27,16 @@

import shapefile as shp

@contextlib.contextmanager
def ignore_warnings(category=None):
with warnings.catch_warnings():
if category:
warnings.simplefilter("ignore", category)
else:
warnings.simplefilter("ignore")
yield


float_nums = floats(allow_nan=False, allow_infinity=False)
xs = float_nums
ys = float_nums
Expand Down Expand Up @@ -542,7 +553,13 @@ def test_shx_reader_writer_roundtrip(codes_and_shapes)-> None:

ENCODINGS = [
"ascii",
"latin1",
"utf-8",
"utf-16-be",
"utf-16-le",
"utf-16",
"utf-32-be",
"utf-32-le",
]

encodings = sampled_from(ENCODINGS)
Expand All @@ -567,6 +584,12 @@ def _dbf_fields_strategy(draw, encoding: str) -> dict[str, str | int]:

max_length = bounds_dict.get("max_length", 254)
min_length = bounds_dict.get("min_length", 1)
if field_type in {"C", "M"}:
# Make sure field is big enough to store any BOM
# used by non-endianness specified codecs
# (e.g. utf-16 and utf-32)
min_length = max(min_length, len("".encode(encoding)))
max_length = max(max_length, min_length)
max_decimal = bounds_dict.get("max_decimal", 0)
size = draw(integers(min_value=min_length, max_value=max_length))
decimal = draw(integers(min_value=0, max_value=max(0,min(size - 3, max_decimal))))
Expand All @@ -581,7 +604,7 @@ def encodings_and_dbf_fields(draw):
field = draw(fields_strategy)
return encoding, field

def _get_fields_context(fields, codec, strict=False):
def _get_fields_w_context(fields, codec, strict=False):
for field in fields:
if (len(field["name"].encode(codec)) > 10 or
"\x00" in field["name"] or
Expand All @@ -592,32 +615,47 @@ def _get_fields_context(fields, codec, strict=False):
return pytest.warns(shp.PossibleDataLoss), False
return contextlib.nullcontext(), False

def _get_fields_r_context(codec):
# In utf-16-le and utf-32-le, many low code points encode
# to code units ending in null bytes, causing warnings in field
# names (which use trailing null bytes for padding).
normalised = codec.lower().replace("-","").replace("_","")
if (any(normalised.startswith(prefix) for prefix in ["utf16", "utf32"]) and
not codec.lower().endswith("-be")):

return ignore_warnings(shp.PossibleDataLoss)
return contextlib.nullcontext()


@pytest.mark.hypothesis
@pytest.mark.hypothesis_dbf
@settings(suppress_health_check=[HealthCheck.too_slow, HealthCheck.data_too_large])
@given(encoding_and_dbf_field=encodings_and_dbf_fields())
def test_dbf_Field_roundtrips(encoding_and_dbf_field: dict) -> None:

encoding, field_kwargs = encoding_and_dbf_field

w_context, error_expected = _get_fields_context([field_kwargs], encoding, strict=True)
w_context, error_expected = _get_fields_w_context([field_kwargs], encoding, strict=True)

with w_context:
expected = shp.Field.from_unchecked(
encoding=encoding,
strict=True,
**field_kwargs,
)
encoded = expected.encode_field_descriptor(strict=True)
encoded = expected.encode_field_descriptor(encoding=encoding, strict=True)
if error_expected:
return
stream = io.BytesIO()
stream.write(encoded)
stream.seek(0)

actual = shp.Field.from_byte_stream(
stream,
encoding=encoding,
)

with _get_fields_r_context(encoding):
actual = shp.Field.from_byte_stream(
stream,
encoding=encoding,
)

assert isinstance(actual, shp.Field)
assert actual.name == expected.name
Expand Down Expand Up @@ -753,13 +791,15 @@ def _write_fields_and_records_to_strict(w, fields, records):
return written_fields, written_records

@pytest.mark.hypothesis
@pytest.mark.hypothesis_dbf
@settings(suppress_health_check=[HealthCheck.too_slow, HealthCheck.data_too_large])
@given(codec_fields_and_records=dbf_encoding_fields_and_records())
def test_dbf_reader_writer_roundtrip(codec_fields_and_records)-> None:
codec, fields, records = codec_fields_and_records
stream = io.BytesIO()

# pytest.raises and pytest.warns can obscure other
# exceptions inside them
# exceptions inside them, when iterating on the test code
w = shp.DbfWriter(dbf=stream, encoding=codec, strict=True)

written_fields, written_records = _write_fields_and_records_to_strict(w, fields, records)
Expand All @@ -770,7 +810,7 @@ def test_dbf_reader_writer_roundtrip(codec_fields_and_records)-> None:
w.close()


with shp.DbfReader(dbf=stream, encoding=codec) as r:
with _get_fields_r_context(codec), shp.DbfReader(dbf=stream, encoding=codec) as r:
_assert_reader_matches_expected_fields(r, written_fields, True)
_assert_reader_matches_expected_records(r, written_fields, written_records)

Expand All @@ -786,6 +826,7 @@ def codes_codecs_fields_shapes_and_records(draw):


@pytest.mark.hypothesis
@pytest.mark.hypothesis_dbf
@settings(suppress_health_check=[HealthCheck.too_slow, HealthCheck.data_too_large])
@given(codes_codecs_fields_shapes_and_records=codes_codecs_fields_shapes_and_records())
def test_shapefile_reader_writer_roundtrip(codes_codecs_fields_shapes_and_records)-> None:
Expand Down Expand Up @@ -813,7 +854,7 @@ def test_shapefile_reader_writer_roundtrip(codes_codecs_fields_shapes_and_record

w.close()

with shp.Reader(encoding=encoding, **streams) as r:
with _get_fields_r_context(encoding), shp.Reader(encoding=encoding, **streams) as r:
_assert_reader_matches_expected_fields(r, written_fields, True)
_assert_reader_matches_expected_records(r, written_fields, written_records)
_assert_reader_matches_expected_shapes(r, code_ex, expected_shapes)
_assert_reader_matches_expected_shapes(r, code_ex, expected_shapes)
Loading