diff --git a/docs/getting-started/benchmark.md b/docs/getting-started/benchmark.md index 698dfa880..498930225 100644 --- a/docs/getting-started/benchmark.md +++ b/docs/getting-started/benchmark.md @@ -195,7 +195,7 @@ guidellm run --profile kind=sweep,sweep_size=10,rampup_duration=10,strategy_type #### Replay Profile -Replays trace events using timestamps from a `trace_synthetic` dataset. See [Trace Replay Benchmarking](#trace-replay-benchmarking-beta) below for data setup. +Replays trace events using timestamps from a trace file dataset. See [Trace Replay Benchmarking](#trace-replay-benchmarking) below for data setup. ```bash guidellm run --profile kind=replay,time_scale=1.0 @@ -225,9 +225,9 @@ guidellm run \ You can customize synthetic data generation with additional parameters such as standard deviation, minimum, and maximum values. See the [Datasets Synthetic data documentation](../guides/datasets.md#synthetic-data) for more details. -### Trace Replay Benchmarking (beta) +### Trace Replay Benchmarking -For realistic load testing, replay trace events using each row's timestamp and token lengths. Trace files must be JSONL and are loaded with the `trace_synthetic` data type. By default, each row uses `timestamp`, `input_length`, and `output_length` fields. Timestamps may be absolute or monotonic values; GuideLLM sorts them and converts them to offsets from the first event before scheduling: +For realistic load testing, replay trace events using each row's timestamp and token lengths. Trace files must be JSONL, JSON, CSV, or Parquet and are loaded with a supported [trace file format](../guides/trace_replay.md#supported-formats). Timestamps may be absolute or monotonic values; GuideLLM sorts them and converts them to offsets from the first event before scheduling: ```json {"timestamp": 1234500.0, "input_length": 256, "output_length": 128} @@ -249,7 +249,7 @@ The replay profile parameter `time_scale` acts as a scaling factor for the inter GuideLLM orders trace rows by timestamp before scheduling and payload generation, so each scheduled event uses the token lengths from the same sorted row. Use `--data-loader kind=pytorch,samples=1000` to limit how many trace rows are loaded and replayed. `--constraint kind=max_requests,count=1000` remains a runtime completion constraint; it does not truncate the trace dataset. -If your trace uses different column names, include `timestamp_column`, `prompt_tokens_column`, and `output_tokens_column` in the data config: +Every format by default looks for the columns "timestamp", "input_length", and "output_length". If your trace uses different column names, include `timestamp_column`, `prompt_tokens_column`, and `output_tokens_column` in the data config: ```bash guidellm run \ @@ -258,7 +258,7 @@ guidellm run \ --profile kind=replay,time_scale=1.0 ``` -For very small prompts (roughly under 15 tokens, depending on the tokenizer), GuideLLM may not have enough room to include the full per-row unique prefix. Different rows can then produce similar or identical prompts, which reduces cache resistance in replay benchmarks. +This functionality extends to columns required by specific formats. These additional columns and other format-specific arguments are described in the [Trace File Formats documentation](../guides/trace_replay.md) ### Working with Real Data diff --git a/docs/guides/datasets.md b/docs/guides/datasets.md index 23be265c7..e7fccce0e 100644 --- a/docs/guides/datasets.md +++ b/docs/guides/datasets.md @@ -20,11 +20,11 @@ The following arguments configure datasets and their processing: - `synthetic_text` — generates synthetic prompts on the fly. Required field: `prompt_tokens`. Optional: `output_tokens`, `turns`, `prefix_tokens`, `prefix_count`, `prefix_buckets`, and distribution controls (`prompt_tokens_stdev`, `output_tokens_stdev`, etc.). - `huggingface` (alias `hf`) — loads from HuggingFace Hub or a local directory/file. Required field: `source` (dataset ID or path). Pass dataset loading arguments (for example `split`, `name`) via `load_kwargs`. - `json_file`, `csv_file`, `text_file`, `parquet_file`, `arrow_file`, `hdf5_file`, `db_file`, `tar_file` — loads from a local file. Required field: `path`. - - `trace_synthetic` — loads a JSONL trace file for replay benchmarking. Required field: `path`. Optional: `timestamp_column` (default: `timestamp`), `prompt_tokens_column` (default: `input_length`), `output_tokens_column` (default: `output_length`). + - `trace_synthetic`, `mooncake` — loads a JSONL, JSON, CSV, or Parquet trace file for replay benchmarking. Required field: `path`. Optional: `timestamp_column` (default: `timestamp`), `prompt_tokens_column` (default: `input_length`), `output_tokens_column` (default: `output_length`). -In addition, you can specify additional arguments to the dataset loading with the data argument `loader_kwargs`: +In addition, you can specify additional arguments to the dataset loading with the data argument `load_kwargs`: -- loader_kwargs: Additional arguments to the dataset loading. For example, dataset splits can be specified with `--data '{"kind":"huggingface","source":"my/dataset","loader_kwargs":{"split":"test"}}'`. +- load_kwargs: Additional arguments to the dataset loading. For example, dataset splits can be specified with `--data '{"kind":"huggingface","source":"my/dataset","load_kwargs":{"split":"test"}}'`. ### Data Loader @@ -188,7 +188,7 @@ GuideLLM supports various file formats for datasets, including text, CSV, JSON, {"prompt": "What is your name?", "output_tokens_count": 3, "additional_column": "baz", "additional_column2": "qux"} ``` -- **Trace files (`.jsonl` with `trace_synthetic` type)**: Specialized JSONL files for replay benchmarking with `timestamp`, `input_length`, and `output_length` fields. Used with `--profile kind=replay` to replay trace events using each row's timestamp and token lengths. Timestamps must be numbers expressed in seconds on a shared timeline with any consistent zero point; GuideLLM sorts them and converts them to offsets from the first event before scheduling. Date strings are not parsed yet, so provide timestamps as numbers. See [Trace Replay Benchmarking](../getting-started/benchmark.md#trace-replay-benchmarking-beta). +- **Trace files (`.jsonl`, `.json`, `.csv` or `.parquet` with a supported trace file format)**: Specialized files for replay. Used with `--profile kind=replay` to replay trace events using each row's timestamp and token lengths. Timestamps must be numbers expressed in seconds on a shared timeline with any consistent zero point; GuideLLM sorts them and converts them to offsets from the first event before scheduling. Date strings are not parsed yet, so provide timestamps as numbers. See [Trace Replay Benchmarking](../getting-started/benchmark.md#trace-replay-benchmarking). ```json {"timestamp": 1234500.0, "input_length": 256, "output_length": 128} @@ -197,7 +197,7 @@ GuideLLM supports various file formats for datasets, including text, CSV, JSON, In this example, the second request is scheduled 0.5 seconds after the first request. Trace rows are ordered by timestamp before GuideLLM schedules requests and generates synthetic payloads. This keeps each scheduled event aligned with the prompt and output token lengths from the same row. - Use `trace_synthetic` to enable trace loading: + Use a supported [trace file format](./trace_replay.md#supported-formats) to enable trace loading: ```bash guidellm run \ @@ -206,7 +206,7 @@ GuideLLM supports various file formats for datasets, including text, CSV, JSON, --data kind=trace_synthetic,path=path/to/trace.jsonl ``` - If your trace uses different column names, include `timestamp_column`, `prompt_tokens_column`, and `output_tokens_column` in the data config: + All trace formats by default look for the columns "timestamp", "input_length", and "output_length". If your trace uses different column names, include `timestamp_column`, `prompt_tokens_column`, and `output_tokens_column` in the data config: ```bash guidellm run \ @@ -217,8 +217,6 @@ GuideLLM supports various file formats for datasets, including text, CSV, JSON, For replay, `time_scale` on the profile is a time scale for the intervals between trace events rather than requests per second. Use `--data-loader kind=pytorch,samples=1000` to limit how many trace rows are loaded and replayed. Use `--constraint kind=max_requests,count=` only as a runtime completion constraint; it does not limit the trace rows loaded from the file. - Very small `input_length` values (roughly under 15 tokens, depending on the tokenizer) may not leave enough room for the full per-row unique prefix in the synthetic prompt. This can make prompts more similar across rows and weaken cache resistance. See [Trace Replay Benchmarking](../getting-started/benchmark.md#trace-replay-benchmarking) for details. - - **JSON files (`.json`)**: Where the entire dataset is represented as a JSON array of objects nested under a specific key. To surface the correct key to use, a `--data-column-mapper` argument must be passed in of `"field": "NAME"` for where the array exists. The objects should include `prompt` or other common names for the prompt which will be used as the prompt column. Additional fields can be included based on the previously mentioned aliases for the `--data-column-mapper` argument. ```json diff --git a/docs/guides/trace_replay.md b/docs/guides/trace_replay.md new file mode 100644 index 000000000..c9c07cd58 --- /dev/null +++ b/docs/guides/trace_replay.md @@ -0,0 +1,44 @@ +# Trace File Formats + +Many trace files are formatted in ways that need to be specially handled to create an accurate replay. This guide covers all trace file formats currently supported by GuideLLM, along with the format-agnostic and format-specific data arguments. + +Detailed use of the replay profile and file-based datasets as a whole is explained in [Trace Replay Benchmarking](../getting-started/benchmark.md#trace-replay-benchmarking). + +## Supported Formats + +These are passed to the `--data` argument as `kind=format`: + +- `trace_synthetic`: A trace format that does the bare minimum needed to complete a fully functioning trace replay benchmark with synthetic prompt generation +- `mooncake`: The trace format used by the serving platform Mooncake, as defined in [https://doi.org/10.48550/arXiv.2407.00079](https://doi.org/10.48550/arXiv.2407.00079) + +## Format-Agnostic Data Arguments + +All trace formats can accept the following optional data arguments: + +| Argument | Default | Description | +| ---------------------- | --------------- | ----------------------------------------------------- | +| `timestamp_column` | "timestamp" | Column name for timestamps in the trace file | +| `prompt_tokens_column` | "input_length" | Column name for prompt token counts in the trace file | +| `output_tokens_column` | "output_length" | Column name for output token counts in the trace file | + +These are passed through the `--data` argument like below: + +```bash +guidellm benchmark \ + --target http://localhost:8000 \ + --profile kind=replay \ + --data "kind=trace_synthetic,path=replay.jsonl,timestamp_column=ts,prompt_tokens_column=input_tokens,output_tokens_column=generated_tokens" +``` + +`trace_synthetic` can be thought of as the format-agnostic option, only looking for the timestamp, prompt token count and output token count columns and ignoring all other features contained in a dataset. While primarily used for testing, `trace_synthetic` may be used as a fallback for trace formats not currently supported by GuideLLM. + +## Format-Specific Data Arguments + +### `mooncake` + +The Mooncake format expects an additional column for hash IDs. During prompt generation, hash IDs sharing the same previous ID are required to represent dinstinct blocks of token ids. + +| Argument | Default | Description | +| -------------------- | ---------- | --------------------------------------------------- | +| `hash_ids_column` | "hash_ids" | Column name for lists of hash IDs in the trace file | +| `hash_id_block_size` | 512 | Amount of tokens represented by one hash ID | diff --git a/src/guidellm/data/deserializers/__init__.py b/src/guidellm/data/deserializers/__init__.py index ebaca6fc0..caf1ab579 100644 --- a/src/guidellm/data/deserializers/__init__.py +++ b/src/guidellm/data/deserializers/__init__.py @@ -28,8 +28,16 @@ SyntheticTextDataset, SyntheticTextDatasetDeserializer, ) -from .trace_mooncake import TraceMooncakeDataArgs, TraceMooncakeDatasetDeserializer -from .trace_synthetic import TraceSyntheticDataArgs, TraceSyntheticDatasetDeserializer +from .trace_common import ( + TraceDataArgs, + TraceDatasetDeserializer, + TraceFormatBase, + TraceFormatRegistry, + decode_prompt, + generate_token_ids, +) +from .trace_minimal import MinimalTraceFormatArgs +from .trace_mooncake import MooncakeTraceFormatArgs __all__ = [ "ArrowFileDatasetDeserializer", @@ -49,14 +57,18 @@ "InMemoryItemListDataArgs", "InMemoryItemListDatasetDeserializer", "JSONFileDatasetDeserializer", + "MinimalTraceFormatArgs", + "MooncakeTraceFormatArgs", "ParquetFileDatasetDeserializer", "SyntheticTextDataArgs", "SyntheticTextDataset", "SyntheticTextDatasetDeserializer", "TarFileDatasetDeserializer", "TextFileDatasetDeserializer", - "TraceMooncakeDataArgs", - "TraceMooncakeDatasetDeserializer", - "TraceSyntheticDataArgs", - "TraceSyntheticDatasetDeserializer", + "TraceDataArgs", + "TraceDatasetDeserializer", + "TraceFormatBase", + "TraceFormatRegistry", + "decode_prompt", + "generate_token_ids", ] diff --git a/src/guidellm/data/deserializers/trace_common.py b/src/guidellm/data/deserializers/trace_common.py new file mode 100644 index 000000000..f665def3d --- /dev/null +++ b/src/guidellm/data/deserializers/trace_common.py @@ -0,0 +1,347 @@ +"""Trace file deserializer that generates synthetic prompts per row. + +Reads a trace file (consisting of at least the columns timestamp, input_length, +output_length) and yields one row per line with a synthetic prompt matching the +requested input_length for replay benchmarks.""" + +from __future__ import annotations + +from collections.abc import Callable, Iterable +from pathlib import Path +from typing import Any, Protocol + +import numpy as np +from datasets import ( + Dataset, + DatasetInfo, + Features, + IterableDataset, + Value, +) +from datasets.exceptions import DatasetGenerationError +from datasets.iterable_dataset import _BaseExamplesIterable +from faker import Faker +from pydantic import Field +from transformers import PreTrainedTokenizerBase + +from guidellm.data.deserializers.deserializer import ( + DataNotSupportedError, + DatasetDeserializer, + DatasetDeserializerFactory, +) +from guidellm.data.schemas import DataArgs +from guidellm.utils.hf_datasets import load_dataset_from_file +from guidellm.utils.registry import RegistryMixin + +__all__ = [ + "TraceDataArgs", + "TraceDatasetDeserializer", + "TraceFormatBase", + "TraceFormatRegistry", + "decode_prompt", + "generate_token_ids", +] + + +def decode_prompt( + processor: PreTrainedTokenizerBase, + token_ids: list[int], +) -> str: + """Decode token ids into a prompt string.""" + decoded = processor.decode(token_ids, skip_special_tokens=True) + if isinstance(decoded, list): + return decoded[0] if decoded else "" + return decoded + + +def generate_token_ids( + token_count: int, + processor: PreTrainedTokenizerBase, + faker: Faker, + margin_of_safety: int = 8, +) -> list[int]: + """Generate `token_count` synthetic token ids for trace prompt construction. + + Ideally, `margin_of_safety` should be set to slighty more than + the average number of characters used by tokenizers to form one token.""" + attempt = 0 + while True: + attempt += 1 + # The Faker.text() can only generate text of at least 5 characters. + num_chars = max(token_count * margin_of_safety * attempt, 5) + text = faker.text(max_nb_chars=num_chars) + token_ids = processor.encode(text) + if len(token_ids) >= token_count: + return token_ids[:token_count] + + +def validate_trace_path(path: Path | str) -> Path: + path = Path(path) + if path.stat().st_size == 0: + raise ValueError(f"Trace file is empty: {path}") + return path + + +def check_and_raise_missing_columns( + required_columns: list[str], actual_columns: list[str] +) -> None: + missing = [c for c in required_columns if c not in actual_columns] + if missing: + raise KeyError(f"Trace row missing required columns: {missing}") + + +def load_trace_rows( + path: Path | str, + timestamp_column_name: str, + required_columns: Features, + **data_kwargs: Any, +) -> Dataset: + """ + Load trace file rows as a HuggingFace Dataset. + + Every column in required_columns must exist in the dataset; + otherwise KeyError is raised with a descriptive message. + Rows are sorted by column timestamp_column_name. + + :param path: Path to the trace file. + :param timestamp_column_name: Name of the timestamp column used to sort trace rows. + :param required_columns: List of column/fields that each row must have. Must contain + the timestamp column. + :param data_kwargs: Additional keyword arguments forwarded to load_dataset. + :return: HuggingFace Dataset (iterable as dicts, column-accessible). + :raises DataNotSupportedError: For any of the following reasons: + - The dataset is empty or has no valid rows + - A required column contains a NoneType + - A required column failed during cast to feature type + + :raises KeyError: If a required column is missing in the dataset. + :raises ValueError: If the file format is not .jsonl, .json, .csv or .parquet. + """ + path = validate_trace_path(path) + trace_dataset = load_dataset_from_file(path, **data_kwargs) + if required_columns: + check_and_raise_missing_columns( + required_columns.keys(), trace_dataset.column_names + ) + + if not trace_dataset: + raise DataNotSupportedError(f"Trace file has no valid rows: {path}") + for name, val in required_columns.items(): + if trace_dataset.data[name].null_count != 0: + raise DataNotSupportedError(f"Missing column values in {name}") + try: + trace_dataset.cast_column(name, val) + except ValueError as e: + raise DataNotSupportedError(str(e)) from e + + return trace_dataset.sort(timestamp_column_name) + + +class TraceFormatBase(Protocol): + def __init__(self) -> None: ... + + def required_columns(self, config) -> Features: ... + + def validate_row(self, config, row: dict) -> None: + """Called within `trace_common.TraceExamplesIterable` on initialization, + immediately after doing its own checks on the row.""" + + def create_prompt( + self, + config, + row: dict, + processor: PreTrainedTokenizerBase, + faker: Faker, + ) -> str: + """Called within `trace_common.TraceExamplesIterable` on each iteration. + Returns a generated synthetic prompt.""" + + +class TraceFormatRegistry(RegistryMixin[type[TraceFormatBase]]): + @classmethod + def dispatch(cls, config: TraceDataArgs) -> TraceFormatBase: + format_from_type = cls.get_registered_object(config.kind) + if format_from_type is None: + raise DataNotSupportedError( + f"Format type '{config.kind}' is not registered." + ) + return format_from_type() + + +class TraceDataArgs(DataArgs): + """Abstract class meant to be inherited by a trace format. + For testing, use `trace_minimal.MinimalTraceFormatArgs` instead.""" + + kind: str = Field( + description="Type identifier for the trace dataset deserializer.", + ) + path: Path = Field(description="Path to the trace file.") + timestamp_column: str = Field( + default="timestamp", + description="Column name for timestamps in the trace file.", + ) + prompt_tokens_column: str = Field( + default="input_length", + description="Column name for prompt token counts in the trace file.", + ) + output_tokens_column: str = Field( + default="output_length", + description="Column name for output token counts in the trace file.", + ) + + +def validate_row(row: dict, config: TraceDataArgs) -> None: + n_in = row[config.prompt_tokens_column] + n_out = row[config.output_tokens_column] + if n_in < 0 or n_out < 0: + raise DataNotSupportedError( + f"Trace token counts must be non-negative, got " + f"input_length={n_in}, output_length={n_out}" + ) + + +class TraceExamplesIterable(_BaseExamplesIterable): + """Custom examples iterable for synthetic prompt generation. Used to avoid + pre-generating a prompt for every row in the dataset on load.""" + + def __init__( + self, + config: TraceDataArgs, + processor: PreTrainedTokenizerBase, + random_seed: int, + ): + super().__init__() + self.config = config + self.format = TraceFormatRegistry.dispatch(self.config) + self.processor = processor + self.faker = Faker() + self.faker.seed_instance(random_seed) + try: + self.trace_rows = load_trace_rows( + config.path, + config.timestamp_column, + required_columns=Features( + { + config.timestamp_column: Value("float"), + config.prompt_tokens_column: Value("int32"), + config.output_tokens_column: Value("int32"), + **dict(self.format.required_columns(self.config)), + } + ), + **config.load_kwargs, + ) + except (DatasetGenerationError, KeyError, ValueError) as e: + raise DataNotSupportedError(str(e)) from e + + for row in self.trace_rows: + validate_row(row, self.config) + self.format.validate_row(self.config, row) + self.iteration_count = 0 + + def __iter__(self) -> Iterable[tuple[int, dict[str, Any]]]: + self.iteration_count += 1 + row_idx = 0 + timestamps = self.trace_rows[self.config.timestamp_column] + while True: + try: + row = self.trace_rows[row_idx] + except IndexError: + break + + prompt = self.format.create_prompt( + self.config, row, self.processor, self.faker + ) + relative_timestamp = timestamps[row_idx] - timestamps[0] + yield ( + row_idx, + { + "prompt": prompt, + "prompt_tokens_count": row[self.config.prompt_tokens_column], + "output_tokens_count": row[self.config.output_tokens_column], + "relative_timestamp": relative_timestamp, + }, + ) + row_idx += 1 + + @property + def is_typed(self) -> bool: + return True + + @property + def features(self) -> Features: + return Features( + { + "prompt": Value("string"), + "prompt_tokens_count": Value("int32"), + "output_tokens_count": Value("int32"), + "relative_timestamp": Value("float"), + } + ) + + @property + def num_shards(self) -> int: + return 1 + + def shuffle_data_sources( + self, + generator: np.random.Generator, # noqa: ARG002 + ) -> TraceExamplesIterable: + """Returns self as sharding is not implemented yet.""" + return self + + def shard_data_sources( + self, + num_shards: int, # noqa: ARG002 + index: int, # noqa: ARG002 + contiguous: bool = True, # noqa: ARG002 + ) -> TraceExamplesIterable: + """Returns self as sharding is not implemented yet.""" + return self + + def load_state_dict(self, state_dict: dict) -> None: + """Load the state from a state dict.""" + self.iteration_count = state_dict.get("iteration_count", 0) + + def _init_state_dict(self): + """Initialize the state dict for the iterable.""" + self._state_dict = {"iteration_count": self.iteration_count} + return self._state_dict + + +class TraceDataset(IterableDataset): + def __init__( + self, + config: TraceDataArgs, + processor: PreTrainedTokenizerBase, + random_seed: int, + ): + ex_iterable = TraceExamplesIterable(config, processor, random_seed) + super().__init__( + ex_iterable=ex_iterable, + info=DatasetInfo( + description="Synthetic trace dataset generator", + features=ex_iterable.features, + ), + ) + + def set_epoch(self, epoch: int): + """Set the epoch for the dataset iteration.""" + if hasattr(self._ex_iterable, "iteration_count"): + self._ex_iterable.iteration_count = epoch + + +@DatasetDeserializerFactory.register(["trace_synthetic"]) +class TraceDatasetDeserializer(DatasetDeserializer): + """Dataset deserializer for all trace formats.""" + + def __call__( + self, + config: TraceDataArgs, + processor_factory: Callable[[], PreTrainedTokenizerBase], + random_seed: int = 42, + ) -> IterableDataset: + if not config.path.exists(): + raise DataNotSupportedError(f"Trace file not found: {config.path}") + if not config.path.is_file(): + raise DataNotSupportedError(f"Trace path is not a file: {config.path}") + return TraceDataset(config, processor_factory(), random_seed) diff --git a/src/guidellm/data/deserializers/trace_minimal.py b/src/guidellm/data/deserializers/trace_minimal.py new file mode 100644 index 000000000..8b9b5df91 --- /dev/null +++ b/src/guidellm/data/deserializers/trace_minimal.py @@ -0,0 +1,67 @@ +""" +A minimal trace file format primarily used for testing. Designed to do the bare minimum +needed to complete a fully functioning trace deserializer with synthetic prompt +generation. + +Reads a trace file (timestamp, input_length, output_length) and yields one row per +line with a synthetic prompt matching the requested input_length for replay benchmarks. +""" + +from __future__ import annotations + +from typing import Literal + +from datasets import Features +from faker import Faker +from pydantic import Field +from transformers import PreTrainedTokenizerBase + +from guidellm.data.deserializers.trace_common import ( + TraceDataArgs, + TraceFormatBase, + TraceFormatRegistry, + decode_prompt, + generate_token_ids, +) +from guidellm.data.schemas import DataArgs + +__all__ = ["MinimalTraceFormatArgs"] + + +@DataArgs.register("trace_synthetic") +class MinimalTraceFormatArgs(TraceDataArgs): + kind: Literal["trace_synthetic"] = Field( + default="trace_synthetic", + description="Type identifier for the minimal trace format.", + ) + + +@TraceFormatRegistry.register("trace_synthetic") +class MinimalTraceFormat(TraceFormatBase): + def __init__(self) -> None: + pass + + def required_columns( + self, + config: MinimalTraceFormatArgs, # noqa: ARG002 + ) -> Features: + return [] + + def validate_row( + self, + config: MinimalTraceFormatArgs, # noqa: ARG002 + row: dict, # noqa: ARG002 + ) -> None: + return + + def create_prompt( + self, + config: MinimalTraceFormatArgs, + row: dict, + processor: PreTrainedTokenizerBase, + faker: Faker, + ) -> str: + token_ids = generate_token_ids( + row[config.prompt_tokens_column], processor, faker + ) + return decode_prompt(processor, token_ids) diff --git a/src/guidellm/data/deserializers/trace_mooncake.py b/src/guidellm/data/deserializers/trace_mooncake.py index 34273b627..e2ea2d5bb 100644 --- a/src/guidellm/data/deserializers/trace_mooncake.py +++ b/src/guidellm/data/deserializers/trace_mooncake.py @@ -1,58 +1,37 @@ """ -Trace deserializer for Mooncake formatted files that generates synthetic prompts per -row. +The Mooncake trace format and data arguments. Reads a trace file (timestamp, input_length, output_length, hash_ids) and yields one row per line with a synthetic prompt matching the requested input_length for replay -benchmarks. +benchmarks. Checks for distinctness between hash IDs that share the +same previous hash ID. """ from __future__ import annotations -import dataclasses import math -from collections.abc import Callable, Iterable -from pathlib import Path from typing import Any, Literal -import numpy as np -from datasets import Dataset, DatasetInfo, Features, IterableDataset, List, Value -from datasets.exceptions import DatasetGenerationError -from datasets.iterable_dataset import _BaseExamplesIterable +from datasets import Features, List, Value from faker import Faker from pydantic import Field from transformers import PreTrainedTokenizerBase from guidellm.data.deserializers.deserializer import ( DataNotSupportedError, - DatasetDeserializer, DatasetDeserializerFactory, ) -from guidellm.data.deserializers.trace_synthetic import TraceSyntheticDataArgs +from guidellm.data.deserializers.trace_common import ( + TraceDataArgs, + TraceDatasetDeserializer, + TraceFormatBase, + TraceFormatRegistry, + decode_prompt, + generate_token_ids, +) from guidellm.data.schemas import DataArgs -from guidellm.utils.trace_io import load_trace_rows - -__all__ = ["TraceMooncakeDataArgs", "TraceMooncakeDatasetDeserializer"] - - -@DataArgs.register("mooncake") -class TraceMooncakeDataArgs(TraceSyntheticDataArgs): - """Model for Mooncake trace dataset deserializer arguments.""" - kind: Literal["mooncake"] = Field( # type: ignore[assignment] - default="mooncake", - description="Type identifier for the trace Mooncake dataset deserializer.", - ) - hash_ids_column: str = Field( - default="hash_ids", - description="Column name for lists of hash IDs in the trace file.", - ) - hash_id_block_size: int = Field( - gt=0, - # Default used in Mooncake's paper https://arxiv.org/pdf/2407.00079 - default=512, - description="Amount of tokens represented by one hash ID.", - ) +__all__ = ["MooncakeTraceFormatArgs"] def _is_in_table(hash_id_table: list[Any], hash_id: int) -> bool: @@ -69,7 +48,7 @@ def _resize_to_hold_id(hash_id_table: list[Any], hash_id: int) -> None: def _calculate_required_prompt_tokens( - row: dict, config: TraceMooncakeDataArgs, hash_id: int + config: MooncakeTraceFormatArgs, row: dict, hash_id: int ) -> int: """Returns the number of prompt tokens needed to satisfy the row input length. This will be less than the block_size if the input length is not divisible by it @@ -80,26 +59,6 @@ def _calculate_required_prompt_tokens( return config.hash_id_block_size -def _generate_token_ids( - token_count: int, - processor: PreTrainedTokenizerBase, - faker: Faker, -) -> list[int]: - """Generate `token_count` synthetic token ids for trace prompt construction.""" - # Ideally, `margin_of_safety` should be set to slighty more than - # the average number of characters used by tokenizers to form one token. - margin_of_safety = 8 - attempt = 0 - while True: - attempt += 1 - # The Faker.text() can only generate text of at least 5 characters. - num_chars = max(token_count * margin_of_safety * attempt, 5) - text = faker.text(max_nb_chars=num_chars) - token_ids = processor.encode(text) - if len(token_ids) >= token_count: - return token_ids[:token_count] - - def _create_distinct_token_block( block_size: int, sibling_token_blocks: list[list[int]], @@ -111,7 +70,7 @@ def _create_distinct_token_block( `sibling_token_blocks`.""" attempt = 0 while attempt < max_attempts: - token_ids = _generate_token_ids(block_size, processor, faker) + token_ids = generate_token_ids(block_size, processor, faker) if token_ids not in sibling_token_blocks: return token_ids attempt += 1 @@ -131,209 +90,33 @@ def _create_prompt_from_hash_ids( prompt_token_ids = [ token for hash_id in hash_ids for token in hash_id_table[hash_id] ] - prompt = processor.decode(prompt_token_ids, skip_special_tokens=True) - if isinstance(prompt, list): - return prompt[0] if prompt else "" - return prompt - - -@dataclasses.dataclass -class Column: - name: str - feature_type: Value + return decode_prompt(processor, prompt_token_ids) -def _load_formatted_trace_rows( - path: Path, - timestamp_column: Column, - required_columns: list[Column], -) -> Dataset: - """Load a trace file and format the columns.""" - try: - rows = load_trace_rows( - path, - [col.name for col in required_columns], - timestamp_column.name, - ) - except (DatasetGenerationError, KeyError, ValueError) as e: - raise DataNotSupportedError(str(e)) from e - if not rows: - raise DataNotSupportedError("Trace file is empty") - for col in [timestamp_column] + required_columns: - if rows.data[col.name].null_count != 0: - raise DataNotSupportedError(f"NoneType found in {col}") - try: - rows.cast_column(col.name, col.feature_type) - except ValueError as e: - raise DataNotSupportedError(str(e)) from e - return rows - - -def _validate_row(row: dict, config: TraceMooncakeDataArgs) -> None: - n_in = row[config.prompt_tokens_column] - n_out = row[config.output_tokens_column] - n_blocks = len(row[config.hash_ids_column]) - if n_in < 0 or n_out < 0: - raise DataNotSupportedError( - f"Trace token counts must be non-negative, got " - f"input_length={n_in}, output_length={n_out}" - ) - for hash_id in row[config.hash_ids_column]: - if hash_id < 0: - raise DataNotSupportedError(f"Hash ID must be non-negative, got {hash_id}") - if math.ceil(n_in / config.hash_id_block_size) != n_blocks: - raise DataNotSupportedError( - f"Input token count of {n_in} split into blocks of size " - f"{config.hash_id_block_size} does not match given {n_blocks} blocks" - ) - - -class _TraceMooncakeExamplesIterable(_BaseExamplesIterable): - """Custom examples iterable for synthetic prompt generation. - - Used to avoid pre-generating a prompt for every row in the dataset on load. - """ - - def __init__( - self, - config: TraceMooncakeDataArgs, - processor: PreTrainedTokenizerBase, - random_seed: int = 42, - ): - super().__init__() - self.config = config - self.processor = processor - self.faker = Faker() - self.faker.seed_instance(random_seed) - self.trace_rows = _load_formatted_trace_rows( - config.path, - Column(config.timestamp_column, Value("float")), - required_columns=[ - Column(config.prompt_tokens_column, Value("int32")), - Column(config.output_tokens_column, Value("int32")), - Column(config.hash_ids_column, List(Value("int32"))), - ], - ) - for row in self.trace_rows: - _validate_row(row, self.config) - self.iteration_count = 0 +DatasetDeserializerFactory.register_decorator(TraceDatasetDeserializer, "mooncake") - def __iter__(self) -> Iterable[tuple[int, dict[str, Any]]]: - self.iteration_count += 1 - row_idx = 0 - hash_id_table: list[Any] = [] - sibling_token_blocks: dict[Any, list[list[int]]] = {} - timestamps = self.trace_rows[self.config.timestamp_column] - while True: - try: - row = self.trace_rows[row_idx] - except IndexError: - break - - ids = row[self.config.hash_ids_column] - for idx, hash_id in enumerate(ids): - if not _is_in_table(hash_id_table, hash_id): - _resize_to_hold_id(hash_id_table, hash_id) - prev_id = None if idx == 0 else ids[idx - 1] - num_tokens = _calculate_required_prompt_tokens( - row, self.config, hash_id - ) - sibling_token_blocks.setdefault(prev_id, []) - hash_id_table[hash_id] = _create_distinct_token_block( - num_tokens, - sibling_token_blocks[prev_id], - self.processor, - self.faker, - ) - sibling_token_blocks[prev_id].append(hash_id_table[hash_id]) - prompt = _create_prompt_from_hash_ids(ids, hash_id_table, self.processor) - relative_timestamp = timestamps[row_idx] - timestamps[0] - yield ( - row_idx, - { - "prompt_tokens_count": row[self.config.prompt_tokens_column], - "output_tokens_count": row[self.config.output_tokens_column], - "prompt": prompt, - "relative_timestamp": relative_timestamp, - }, - ) - row_idx += 1 - - @property - def is_typed(self) -> bool: - return True - - @property - def features(self) -> Features: - return Features( - { - "prompt": Value("string"), - "prompt_tokens_count": Value("int32"), - "output_tokens_count": Value("int32"), - "relative_timestamp": Value("float"), - } - ) - - @property - def num_shards(self) -> int: - return 1 - - def shuffle_data_sources( - self, - generator: np.random.Generator, # noqa: ARG002 - ) -> _TraceMooncakeExamplesIterable: - """Returns self as sharding is not implemented yet.""" - return self - - def shard_data_sources( - self, - num_shards: int, # noqa: ARG002 - index: int, # noqa: ARG002 - contiguous: bool = True, # noqa: ARG002 - ) -> _TraceMooncakeExamplesIterable: - """Returns self as sharding is not implemented yet.""" - return self - - def load_state_dict(self, state_dict: dict) -> None: - """Load the state from a state dict.""" - self.iteration_count = state_dict.get("iteration_count", 0) - - def _init_state_dict(self): - """Initialize the state dict for the iterable.""" - self._state_dict = {"iteration_count": self.iteration_count} - return self._state_dict - - -class _TraceMooncakeDataset(IterableDataset): - def __init__( - self, - config: TraceMooncakeDataArgs, - processor: PreTrainedTokenizerBase, - random_seed: int = 42, - ): - self.config = config - self.processor = processor - self.random_seed = random_seed - ex_iterable = _TraceMooncakeExamplesIterable(config, processor, random_seed) - super().__init__( - ex_iterable=ex_iterable, - info=DatasetInfo( - description="Mooncake trace dataset generator", - features=ex_iterable.features, - ), - ) - - def set_epoch(self, epoch: int): - """Set the epoch for the dataset iteration.""" - if isinstance(self._ex_iterable, _TraceMooncakeExamplesIterable): - self._ex_iterable.iteration_count = epoch +@DataArgs.register("mooncake") +class MooncakeTraceFormatArgs(TraceDataArgs): + kind: Literal["mooncake"] = Field( + default="mooncake", + description="Type identifier for the trace Mooncake dataset deserializer.", + ) + hash_ids_column: str = Field( + default="hash_ids", + description="Column name for lists of hash IDs in the trace file.", + ) + hash_id_block_size: int = Field( + gt=0, + # Default used in Mooncake's paper https://arxiv.org/pdf/2407.00079 + default=512, + description="Amount of tokens represented by one hash ID.", + ) -@DatasetDeserializerFactory.register("mooncake") -class TraceMooncakeDatasetDeserializer(DatasetDeserializer): - """Mooncake trace format deserializer - The Mooncake trace format requires a column for timestamps, prompt token counts, +@TraceFormatRegistry.register("mooncake") +class MooncakeTraceFormat(TraceFormatBase): + """Mooncake trace format requires a column for timestamps, prompt token counts, ouput token counts and lists of hash IDs. Hash IDs are globally unique identifiers based on the current and previous token @@ -345,15 +128,48 @@ class TraceMooncakeDatasetDeserializer(DatasetDeserializer): Generated prompts match the prompt token count of the row.""" - def __call__( - self, - config: TraceMooncakeDataArgs, - processor_factory: Callable[[], PreTrainedTokenizerBase], - random_seed: int, - ) -> IterableDataset: - if not config.path.is_file(): + def __init__(self) -> None: + self.hash_id_table: list[Any] = [] + self.sibling_token_blocks: dict[Any, list[list[int]]] = {} + + def required_columns(self, config: MooncakeTraceFormatArgs) -> Features: + return Features({config.hash_ids_column: List(Value("int32"))}) + + def validate_row(self, config: MooncakeTraceFormatArgs, row: dict) -> None: + n_in = row[config.prompt_tokens_column] + n_blocks = len(row[config.hash_ids_column]) + for hash_id in row[config.hash_ids_column]: + if hash_id < 0: + raise DataNotSupportedError( + f"Hash ID must be non-negative, got {hash_id}" + ) + if math.ceil(n_in / config.hash_id_block_size) != n_blocks: raise DataNotSupportedError( - f"{type(self).__name__} expects a path to a trace file, " - f"got {config.path}" + f"Input token count of {n_in} split into blocks of size " + f"{config.hash_id_block_size} does not match given {n_blocks} blocks" ) - return _TraceMooncakeDataset(config, processor_factory(), random_seed) + + def create_prompt( + self, + config: MooncakeTraceFormatArgs, + row: dict, + processor: PreTrainedTokenizerBase, + faker: Faker, + ) -> str: + """Before generating the prompt, this first generates a block of tokens for + each hash ID that has not already been seen.""" + ids = row[config.hash_ids_column] + for idx, hash_id in enumerate(ids): + if not _is_in_table(self.hash_id_table, hash_id): + _resize_to_hold_id(self.hash_id_table, hash_id) + prev_id = None if idx == 0 else ids[idx - 1] + num_tokens = _calculate_required_prompt_tokens(config, row, hash_id) + self.sibling_token_blocks.setdefault(prev_id, []) + self.hash_id_table[hash_id] = _create_distinct_token_block( + num_tokens, + self.sibling_token_blocks[prev_id], + processor, + faker, + ) + self.sibling_token_blocks[prev_id].append(self.hash_id_table[hash_id]) + return _create_prompt_from_hash_ids(ids, self.hash_id_table, processor) diff --git a/src/guidellm/data/deserializers/trace_synthetic.py b/src/guidellm/data/deserializers/trace_synthetic.py deleted file mode 100644 index 68f083038..000000000 --- a/src/guidellm/data/deserializers/trace_synthetic.py +++ /dev/null @@ -1,238 +0,0 @@ -""" -Trace file deserializer that generates synthetic prompts per row. - -Reads a trace file (timestamp, input_length, output_length) and yields one row per -line with a synthetic prompt matching the requested input_length for replay benchmarks. -""" - -from __future__ import annotations - -from collections.abc import Callable -from pathlib import Path -from typing import Any, Literal - -from datasets import Dataset -from datasets.exceptions import DatasetGenerationError -from faker import Faker -from pydantic import Field, field_serializer -from transformers import PreTrainedTokenizerBase - -from guidellm.data.deserializers.deserializer import ( - DataNotSupportedError, - DatasetDeserializer, - DatasetDeserializerFactory, -) -from guidellm.data.schemas import DataArgs -from guidellm.utils.trace_io import load_trace_rows - -__all__ = ["TraceSyntheticDataArgs", "TraceSyntheticDatasetDeserializer"] - - -def _encode_prompt( - processor: PreTrainedTokenizerBase, - text: str, -) -> list[int]: - """Encode text with the configured tokenizer defaults.""" - return processor.encode(text) - - -def _decode_prompt( - processor: PreTrainedTokenizerBase, - token_ids: list[int], -) -> str: - """Decode token ids into a prompt string.""" - decoded = processor.decode(token_ids, skip_special_tokens=True) - if isinstance(decoded, list): - return decoded[0] if decoded else "" - return decoded - - -def _create_base_prompt_token_ids( - processor: PreTrainedTokenizerBase, - faker: Faker, - token_count: int, -) -> list[int]: - """Generate reusable synthetic token ids for trace prompt construction.""" - if token_count <= 0: - return [] - - token_text = (faker.word() or "x")[0] - text = token_text - token_ids = _encode_prompt(processor, text) - max_attempts = 8 - attempts = 0 - - while len(token_ids) < token_count and attempts < max_attempts: - attempts += 1 - missing_tokens = token_count - len(token_ids) - text = f"{text} {' '.join([token_text] * missing_tokens)}" - token_ids = _encode_prompt(processor, text) - - if len(token_ids) < token_count: - raise DataNotSupportedError( - "Could not generate enough synthetic prompt tokens for " - f"{token_count} tokens after {max_attempts} attempts" - ) - - return token_ids - - -def _create_prompt( - processor: PreTrainedTokenizerBase, - prompt_tokens_count: int, - base_prompt_token_ids: list[int], - request_index: int, -) -> str: - """ - Build a prompt from unique prefix tokens and reusable base prompt tokens. - - For very small prompt lengths (roughly under 15 tokens, depending on the - tokenizer), the target slice can truncate the per-row unique prefix before - it includes the request index, so prompts may become similar across rows and - less cache-resistant. - """ - if prompt_tokens_count <= 0: - return "" - - unique_prefix = f"guidellm-trace-request-{request_index}: " - prefix_token_ids = _encode_prompt(processor, unique_prefix) - prompt_token_ids = (prefix_token_ids + base_prompt_token_ids)[:prompt_tokens_count] - if len(prompt_token_ids) < prompt_tokens_count: - raise DataNotSupportedError( - "Could not build a synthetic prompt with " - f"{prompt_tokens_count} tokens from generated base tokens" - ) - - return _decode_prompt(processor, prompt_token_ids) - - -def _load_trace_rows( - path: Path, - timestamp_column: str, - prompt_tokens_column: str, - output_tokens_column: str, -) -> list[dict[str, Any]]: - """Load trace file into list of dicts with timestamp, prompt_tokens, - output_tokens.""" - try: - raw = load_trace_rows( - path, - required_columns=[ - prompt_tokens_column, - output_tokens_column, - ], - timestamp_column=timestamp_column, - ) - except (DatasetGenerationError, KeyError, ValueError) as e: - raise DataNotSupportedError(str(e)) from e - try: - return [ - { - "timestamp": float(row[timestamp_column]), - "prompt_tokens": int(row[prompt_tokens_column]), - "output_tokens": int(row[output_tokens_column]), - } - for row in raw - ] - except (TypeError, ValueError) as e: - raise DataNotSupportedError(str(e)) from e - - -@DataArgs.register("trace_synthetic") -class TraceSyntheticDataArgs(DataArgs): - """Model for synthetic trace dataset deserializer arguments.""" - - kind: Literal["trace_synthetic"] = Field( - default="trace_synthetic", - description="Type identifier for the trace synthetic dataset deserializer.", - ) - path: Path = Field(description="Path to the trace file.") - timestamp_column: str = Field( - default="timestamp", - description="Column name for timestamps in the trace file.", - ) - prompt_tokens_column: str = Field( - default="input_length", - description="Column name for prompt token counts in the trace file.", - ) - output_tokens_column: str = Field( - default="output_length", - description="Column name for output token counts in the trace file.", - ) - - @field_serializer("path") - @classmethod - def serialize_path(cls, path: Path) -> str: - """Serialize path as a string because Path is not JSON serializable.""" - return str(path) - - -@DatasetDeserializerFactory.register("trace_synthetic") -class TraceSyntheticDatasetDeserializer(DatasetDeserializer): - """ - Load a trace file and generate a synthetic prompt per row. - - Trace file must have timestamp, and columns for prompt and output token counts - (default: input_length, output_length). Each row becomes one request with - a synthetic prompt of the requested input length. - """ - - def __call__( - self, - config: TraceSyntheticDataArgs, - processor_factory: Callable[[], PreTrainedTokenizerBase], - random_seed: int, - ) -> Dataset: - if not (path := config.path).exists() or not path.is_file(): - raise DataNotSupportedError( - "TraceSyntheticDatasetDeserializer expects a path to a trace file, " - f"got {path}" - ) - rows = _load_trace_rows( - path, - config.timestamp_column, - config.prompt_tokens_column, - config.output_tokens_column, - ) - if not rows: - raise DataNotSupportedError("Trace file is empty") - - processor = processor_factory() - faker = Faker() - faker.seed_instance(random_seed) - max_prompt_tokens = max(row["prompt_tokens"] for row in rows) - base_prompt_token_ids = _create_base_prompt_token_ids( - processor, faker, max_prompt_tokens - ) - - timestamps = [row["timestamp"] for row in rows] - t0 = timestamps[0] - relative_timestamps = [t - t0 for t in timestamps] - - prompts: list[str] = [] - prompt_tokens_counts: list[int] = [] - output_tokens_counts: list[int] = [] - for i, row in enumerate(rows): - n_in = row["prompt_tokens"] - n_out = row["output_tokens"] - if n_in < 0 or n_out < 0: - raise DataNotSupportedError( - "Trace token counts must be non-negative, got " - f"input_length={n_in}, output_length={n_out}" - ) - prompt = _create_prompt( - processor, n_in, base_prompt_token_ids, request_index=i - ) - prompts.append(prompt) - prompt_tokens_counts.append(n_in) - output_tokens_counts.append(n_out) - - return Dataset.from_dict( - { - "prompt": prompts, - "prompt_tokens_count": prompt_tokens_counts, - "output_tokens_count": output_tokens_counts, - "relative_timestamp": relative_timestamps, - }, - **config.load_kwargs, - ) diff --git a/src/guidellm/utils/hf_datasets.py b/src/guidellm/utils/hf_datasets.py index 86f04485d..fd65c63ac 100644 --- a/src/guidellm/utils/hf_datasets.py +++ b/src/guidellm/utils/hf_datasets.py @@ -1,6 +1,7 @@ from pathlib import Path +from typing import Any -from datasets import Dataset +from datasets import Dataset, load_dataset SUPPORTED_TYPES = { ".json", @@ -10,6 +11,22 @@ } +def load_dataset_from_file( + path: str | Path, split: str = "train", **data_kwargs: Any +) -> Dataset: + path = Path(path) + suffix = path.suffix.lower() + if suffix in SUPPORTED_TYPES: + suffix = suffix.replace(".jsonl", ".json") + return load_dataset( + suffix.replace(".", ""), data_files=str(path), split=split, **data_kwargs + ) + raise ValueError( + f"Unsupported file suffix '{suffix}' in path '{path}'." + f" Only {SUPPORTED_TYPES} are supported." + ) + + def save_dataset_to_file(dataset: Dataset, output_path: str | Path) -> None: """ Saves a HuggingFace Dataset to file in a supported format. @@ -30,6 +47,6 @@ def save_dataset_to_file(dataset: Dataset, output_path: str | Path) -> None: dataset.to_parquet(output_path) else: raise ValueError( - f"Unsupported file suffix '{suffix}' in output_path'{output_path}'." + f"Unsupported file suffix '{suffix}' in output_path '{output_path}'." f" Only {SUPPORTED_TYPES} are supported." ) diff --git a/src/guidellm/utils/trace_io.py b/src/guidellm/utils/trace_io.py deleted file mode 100644 index a3f1962a9..000000000 --- a/src/guidellm/utils/trace_io.py +++ /dev/null @@ -1,87 +0,0 @@ -""" -Shared trace file I/O for replay benchmarks. - -Reads trace files (.jsonl only for now) and exposes rows or relative timestamps. -Used by replay profiles and the trace_synthetic deserializer. -""" - -from __future__ import annotations - -from pathlib import Path -from typing import Any - -from datasets import Dataset, load_dataset - -__all__ = ["load_relative_timestamps", "load_trace_rows"] - - -def load_trace_rows( - path: Path | str, - required_columns: list[str] | None = None, - timestamp_column: str | None = None, - **data_kwargs: Any, -) -> Dataset: - """ - Load trace file rows as a HuggingFace Dataset. - - Supports .jsonl only (one JSON object per line). - If required_columns is set, every column must exist in the dataset; - otherwise KeyError is raised with a descriptive message. - If timestamp_column is set, rows are sorted by that column. - - :param path: Path to the trace file. - :param required_columns: Optional list of column/field names that each row - must have. - :param timestamp_column: Optional timestamp column used to sort trace rows. - :param data_kwargs: Additional keyword arguments forwarded to load_dataset. - :return: HuggingFace Dataset (iterable as dicts, column-accessible). - :raises KeyError: If a required column is missing in the dataset. - :raises ValueError: If the file format is not .jsonl. - """ - path = Path(path) - suffix = path.suffix.lower() - if suffix != ".jsonl": - raise ValueError(f"Unsupported trace file format: {suffix}") - if path.stat().st_size == 0: - raise ValueError(f"Trace file is empty or has no valid rows: {path}") - - trace_dataset = load_dataset( - "json", data_files=str(path), split="train", **data_kwargs - ) - - required_columns = required_columns or [] - if timestamp_column and timestamp_column not in required_columns: - required_columns = [*required_columns, timestamp_column] - - if required_columns: - missing = [c for c in required_columns if c not in trace_dataset.column_names] - if missing: - raise KeyError(f"Trace row missing required columns: {missing}") - - if timestamp_column: - trace_dataset = trace_dataset.sort(timestamp_column) - - return trace_dataset - - -def load_relative_timestamps( - path: Path | str, - timestamp_column: str = "timestamp", -) -> list[float]: - """ - Load timestamps from a trace file and return times relative to the first event. - - Trace file must be JSONL (one JSON object per line). The first timestamp - becomes 0.0, and all others are relative to it (always >= 0). - - :param path: Path to the trace file. - :param timestamp_column: Name of the column/field containing the timestamp. - :return: List of relative timestamps in seconds (first is 0.0). - :raises ValueError: If the trace file is empty or has no valid rows. - """ - trace_dataset = load_trace_rows(path, timestamp_column=timestamp_column) - if len(trace_dataset) == 0: - raise ValueError(f"Trace file is empty or has no valid rows: {path}") - timestamps = [float(t) for t in trace_dataset[timestamp_column]] - t0 = timestamps[0] - return [t - t0 for t in timestamps] diff --git a/tests/integration/scheduler/test_trace_replay_multiprocess.py b/tests/integration/scheduler/test_trace_replay_multiprocess.py index 0096c4f2d..8b4a8c0ba 100644 --- a/tests/integration/scheduler/test_trace_replay_multiprocess.py +++ b/tests/integration/scheduler/test_trace_replay_multiprocess.py @@ -16,9 +16,9 @@ import pytest -from guidellm.data.deserializers.trace_synthetic import ( - TraceSyntheticDataArgs, - TraceSyntheticDatasetDeserializer, +from guidellm.data.deserializers import ( + MinimalTraceFormatArgs, + TraceDatasetDeserializer, ) from guidellm.data.finalizers.generative import ( GenerativeRequestFinalizer, @@ -62,9 +62,9 @@ def _write_trace(path: Path, lines: list[str]) -> Path: def _requests_from_trace( trace_path: Path, ) -> tuple[list[list[GenerationRequest]], list[float]]: - deserializer = TraceSyntheticDatasetDeserializer() + deserializer = TraceDatasetDeserializer() dataset = deserializer( - config=TraceSyntheticDataArgs(path=trace_path), + config=MinimalTraceFormatArgs(path=trace_path), processor_factory=_mock_processor, random_seed=42, ) @@ -75,13 +75,12 @@ def _requests_from_trace( conversations: list[list[GenerationRequest]] = [] relative_timestamps: list[float] = [] - for index in range(len(dataset)): - row = {column: dataset[column][index] for column in dataset.column_names} + for idx, row in enumerate(dataset): mapped = mapper([{"dataset": row}]) requests = finalizer(mapped) assert len(requests) == 1 request = requests[0] - request.request_id = f"req_{index}" + request.request_id = f"req_{idx}" conversations.append([request]) offset = request.settings.relative_timestamp assert offset is not None diff --git a/tests/unit/data/deserializers/test_trace_common.py b/tests/unit/data/deserializers/test_trace_common.py new file mode 100644 index 000000000..babd2f389 --- /dev/null +++ b/tests/unit/data/deserializers/test_trace_common.py @@ -0,0 +1,268 @@ +import dataclasses +from collections.abc import Callable +from pathlib import Path +from typing import Any +from unittest.mock import Mock + +import pytest +from datasets import IterableDataset +from faker import Faker +from pydantic import ValidationError + +from guidellm.data.deserializers import DataNotSupportedError +from guidellm.data.deserializers.trace_common import ( + TraceDataArgs, + TraceDatasetDeserializer, + TraceFormatRegistry, + decode_prompt, + generate_token_ids, +) +from guidellm.data.deserializers.trace_minimal import MinimalTraceFormatArgs + + +def _mock_processor() -> Mock: + """Tokenizer where each whitespace-delimited word is one token.""" + proc = Mock() + proc.encode.side_effect = lambda text: list(range(len(text.split()))) + proc.decode.side_effect = lambda tokens, skip_special_tokens=False: " ".join( + f"tok{t}" for t in tokens + ) + return proc + + +@pytest.mark.parametrize( + ("token_ids", "expected"), + [ + ([], ""), + ([0], "tok0"), + ([1, 1], "tok1 tok1"), + ([0, 2, 3, 2], "tok0 tok2 tok3 tok2"), + ], +) +def test_decode_prompt(token_ids, expected): + proc = _mock_processor() + assert decode_prompt(proc, token_ids) == expected + + +@pytest.mark.parametrize( + ("token_count", "expected"), + [ + (0, []), + (1, [0]), + (10, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), + (1000, list(range(1000))), + ], +) +def test_generate_token_ids(token_count, expected): + proc = _mock_processor() + faker = Faker() + res = generate_token_ids(token_count, proc, faker) + assert len(res) == len(expected) + assert res == expected + + +class TestTraceFormatRegistry: + def test_unknown_kind_raises(self, tmp_path: Path): + config = TraceDataArgs(kind="unknown_kind", path=tmp_path) + with pytest.raises(DataNotSupportedError, match="not registered"): + TraceFormatRegistry.dispatch(config) + + +@dataclasses.dataclass +class TraceColumnGenerator: + name: str + # Function with row index as the one argument + data_generator: Callable[[int], Any] + + +def _write_trace(tmp_path: Path, content: str, suffix: str = ".jsonl") -> Path: + path = tmp_path / f"trace{suffix}" + path.write_text(content) + return path + + +def _generate_trace(num_rows: int, columns: list[TraceColumnGenerator]) -> str: + return "\n".join( + "{" + + ", ".join(f'"{col.name}": {col.data_generator(idx)}' for col in columns) + + "}" + for idx in range(num_rows) + ) + + +def _get_from_kwargs(keys, kwargs) -> dict: + return {k: v for k, v in kwargs.items() if k in keys} + + +class TestTraceDatasetDeserializer: + @pytest.fixture + def deserializer(self) -> TraceDatasetDeserializer: + return TraceDatasetDeserializer() + + def _deserialize(self, deserializer, data, **kwargs): + col_kwargs = _get_from_kwargs( + ( + "timestamp_column", + "prompt_tokens_column", + "output_tokens_column", + ), + kwargs, + ) + config = MinimalTraceFormatArgs(path=data, **col_kwargs) + return deserializer( + config=config, + processor_factory=_mock_processor, + random_seed=42, + ) + + @pytest.mark.sanity + @pytest.mark.parametrize( + "suffix", + [".json", ".jsonl"], + ) + def test_loads_json(self, tmp_path: Path, deserializer, suffix): + trace = _write_trace( + tmp_path, + '{"timestamp": 1, "input_length": 10, "output_length": 1}\n' + '{"timestamp": 2, "input_length": 20, "output_length": 2}\n', + suffix=suffix, + ) + ds = self._deserialize(deserializer, trace) + for i, row in enumerate(ds): + assert row["relative_timestamp"] == i + assert row["prompt_tokens_count"] == (i + 1) * 10 + assert row["output_tokens_count"] == i + 1 + + @pytest.mark.sanity + def test_loads_csv(self, tmp_path: Path, deserializer): + trace = _write_trace( + tmp_path, + "timestamp,input_length,output_length\n1,10,1\n2,20,2\n", + suffix=".csv", + ) + ds = self._deserialize(deserializer, trace) + for i, row in enumerate(ds): + assert row["relative_timestamp"] == i + assert row["prompt_tokens_count"] == (i + 1) * 10 + assert row["output_tokens_count"] == i + 1 + + @pytest.mark.smoke + def test_loads_sorted_rows_and_keeps_token_columns_aligned( + self, tmp_path: Path, deserializer + ): + n_rows = 10 + trace = _write_trace( + tmp_path, + _generate_trace( + n_rows, + [ + TraceColumnGenerator("timestamp", lambda i: n_rows - i), + TraceColumnGenerator("input_length", lambda i: n_rows - i), + TraceColumnGenerator("output_length", lambda i: (n_rows - i) * 10), + ], + ), + ) + ds = self._deserialize(deserializer, trace) + assert isinstance(ds, IterableDataset) + proc = _mock_processor() + for i, row in enumerate(ds): + assert row["prompt_tokens_count"] == i + 1 + assert row["output_tokens_count"] == (i + 1) * 10 + assert len(proc.encode(row["prompt"])) == row["prompt_tokens_count"] + + @pytest.mark.smoke + def test_emits_relative_timestamp_column_sorted_from_trace( + self, tmp_path: Path, deserializer + ): + n_rows = 5 + trace = _write_trace( + tmp_path, + _generate_trace( + n_rows, + [ + TraceColumnGenerator("timestamp", lambda i: i + 3), + TraceColumnGenerator("input_length", lambda i: i), + TraceColumnGenerator("output_length", lambda i: i), + ], + ), + ) + ds = self._deserialize(deserializer, trace) + for i, row in enumerate(ds): + assert row["relative_timestamp"] == i + + @pytest.mark.smoke + def test_rejects_invalid_path(self, deserializer): + with pytest.raises(ValidationError, match="not a valid path"): + self._deserialize(deserializer, 123) + with pytest.raises(DataNotSupportedError, match="file not found"): + self._deserialize(deserializer, "bad_path.jsonl") + with pytest.raises(DataNotSupportedError, match="not a file"): + self._deserialize(deserializer, Path.cwd()) + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("content", "kwargs", "match"), + [ + ("", {}, "empty"), + ( + '{"ts": 0, "input_length": 10, "output_length": 5}\n', + {}, + "timestamp", + ), + ( + '{"timestamp": 0, "input_length": 10}\n', + {}, + "output_length", + ), + ( + '{"timestamp": 0, "prompt_tokens": 10, "output_length": 5}\n', + { + "prompt_tokens_column": "prompt_tokens", + "output_tokens_column": "out", + }, + "out", + ), + ( + '{"timestamp": 0, "input_length": -1, "output_length": 5}\n', + {}, + "non-negative", + ), + ( + '{"timestamp": 0, "input_length": 10, "output_length": -1}\n', + {}, + "non-negative", + ), + ( + '{"timestamp": "bad", "input_length": 10, "output_length": 5}\n', + {}, + "scalar of type float", + ), + ( + '{"timestamp": 0, "input_length": "bad", "output_length": 5}\n', + {}, + "scalar of type int32", + ), + ( + '{"timestamp": 0, "input_length": 10, "output_length": null}\n', + {}, + "Missing column values", + ), + ], + ) + def test_trace_validation_raises( + self, tmp_path: Path, deserializer, content, kwargs, match + ): + trace = _write_trace(tmp_path, content) + with pytest.raises(DataNotSupportedError, match=match): + self._deserialize(deserializer, trace, **kwargs) + + @pytest.mark.sanity + def test_unsupported_file_suffix_raises(self, tmp_path: Path, deserializer): + trace = _write_trace( + tmp_path, + '{"timestamp": 0, "input_length": 10, "output_length": 5, ' + '"hash_ids": [0]}\n', + suffix=".txt", + ) + with pytest.raises(DataNotSupportedError, match=r"Unsupported.*\.txt"): + self._deserialize(deserializer, trace) diff --git a/tests/unit/data/deserializers/test_trace_minimal.py b/tests/unit/data/deserializers/test_trace_minimal.py new file mode 100644 index 000000000..fddda4f38 --- /dev/null +++ b/tests/unit/data/deserializers/test_trace_minimal.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +import dataclasses +import random +from collections.abc import Callable +from pathlib import Path +from typing import Any +from unittest.mock import Mock + +import pytest + +from guidellm.data.deserializers import DatasetDeserializerFactory +from guidellm.data.deserializers.trace_common import TraceDatasetDeserializer +from guidellm.data.deserializers.trace_minimal import MinimalTraceFormatArgs + + +def _mock_processor() -> Mock: + """Tokenizer where each whitespace-delimited word is one token.""" + proc = Mock() + proc.encode.side_effect = lambda text: list(range(len(text.split()))) + proc.decode.side_effect = lambda tokens, skip_special_tokens=False: " ".join( + f"tok{i}" for i, _ in enumerate(tokens) + ) + return proc + + +def _write_trace(tmp_path: Path, content: str, suffix: str = ".jsonl") -> Path: + path = tmp_path / f"trace{suffix}" + path.write_text(content) + return path + + +@dataclasses.dataclass +class TraceColumnGenerator: + name: str + # Function with row index as the one argument + data_generator: Callable[[int], Any] + + +def _generate_trace(num_rows: int, columns: list[TraceColumnGenerator]) -> str: + return "\n".join( + "{" + + ", ".join(f'"{col.name}": {col.data_generator(idx)}' for col in columns) + + "}" + for idx in range(num_rows) + ) + + +def _get_from_kwargs(keys, kwargs) -> dict: + return {k: v for k, v in kwargs.items() if k in keys} + + +class TestMinimalTraceFormat: + @pytest.mark.regression + def test_format_registered_with_deserializer(self, tmp_path: Path): + trace = _write_trace( + tmp_path, + '{"timestamp": 0.0, "input_length": 10, "output_length": 5}\n', + ) + DatasetDeserializerFactory.deserialize( + config=MinimalTraceFormatArgs(path=trace), + processor_factory=_mock_processor, + random_seed=42, + ) + + @pytest.fixture + def deserializer(self) -> TraceDatasetDeserializer: + return TraceDatasetDeserializer() + + def _deserialize(self, deserializer, data, **kwargs): + col_kwargs = _get_from_kwargs( + ( + "timestamp_column", + "prompt_tokens_column", + "output_tokens_column", + ), + kwargs, + ) + config = MinimalTraceFormatArgs(path=data, **col_kwargs) + return deserializer( + config=config, + processor_factory=_mock_processor, + random_seed=42, + ) + + @pytest.mark.smoke + def test_honors_custom_column_names(self, tmp_path: Path, deserializer): + trace = _write_trace( + tmp_path, + '{"ts": 3.0, "input_tokens": 4, "generated_tokens": 40}\n' + '{"ts": 1.0, "input_tokens": 2, "generated_tokens": 20}\n', + ) + + ds = self._deserialize( + deserializer, + trace, + timestamp_column="ts", + prompt_tokens_column="input_tokens", + output_tokens_column="generated_tokens", + ) + + expected_prompt_count = [2, 4] + expected_output_count = [20, 40] + for i, row in enumerate(ds): + assert row["prompt_tokens_count"] == expected_prompt_count[i] + assert row["output_tokens_count"] == expected_output_count[i] + + @pytest.mark.smoke + def test_generates_large_trace_prompts_from_reusable_base( + self, tmp_path: Path, deserializer + ): + random.seed(0) + n_rows = 25 + prompt_lengths = [random.randint(2000, 100000) for _ in range(n_rows)] + output_lengths = [random.randint(3, 800) for _ in range(n_rows)] + times = [0.0, 0.5, 1.0, 2.0] + timestamps = [times[int(i / n_rows * len(times))] for i in range(n_rows)] + trace = _write_trace( + tmp_path, + _generate_trace( + n_rows, + [ + TraceColumnGenerator("timestamp", lambda i: timestamps[i]), + TraceColumnGenerator("input_length", lambda i: prompt_lengths[i]), + TraceColumnGenerator("output_length", lambda i: output_lengths[i]), + ], + ), + ) + processor = _mock_processor() + config = MinimalTraceFormatArgs(path=trace) + ds = deserializer( + config=config, + processor_factory=lambda: processor, + random_seed=42, + ) + + assert processor.encode.call_count <= len(prompt_lengths) + 4 + for i, row in enumerate(ds): + in_cnt = row["prompt_tokens_count"] + assert in_cnt == prompt_lengths[i] + assert row["output_tokens_count"] == output_lengths[i] + + actual_prompt_length = len(processor.encode(row["prompt"])) + if actual_prompt_length != in_cnt: + pytest.fail(f"{actual_prompt_length} != {in_cnt}") diff --git a/tests/unit/data/deserializers/test_trace_mooncake.py b/tests/unit/data/deserializers/test_trace_mooncake.py index e516da89e..ce2c6998a 100644 --- a/tests/unit/data/deserializers/test_trace_mooncake.py +++ b/tests/unit/data/deserializers/test_trace_mooncake.py @@ -8,13 +8,10 @@ from unittest.mock import Mock import pytest -from datasets import IterableDataset -from pydantic import ValidationError -from guidellm.data.deserializers.trace_mooncake import ( - TraceMooncakeDataArgs, - TraceMooncakeDatasetDeserializer, -) +from guidellm.data.deserializers import DatasetDeserializerFactory +from guidellm.data.deserializers.trace_common import TraceDatasetDeserializer +from guidellm.data.deserializers.trace_mooncake import MooncakeTraceFormatArgs from guidellm.data.schemas import DataNotSupportedError @@ -86,13 +83,13 @@ def _all_distinct(items: list): @dataclasses.dataclass -class TraceColumn: +class TraceColumnGenerator: name: str # Function with row index as the one argument data_generator: Callable[[int], Any] -def _generate_trace(num_rows: int, columns: list[TraceColumn]) -> str: +def _generate_trace(num_rows: int, columns: list[TraceColumnGenerator]) -> str: return "\n".join( "{" + ", ".join(f'"{col.name}": {col.data_generator(idx)}' for col in columns) @@ -101,52 +98,46 @@ def _generate_trace(num_rows: int, columns: list[TraceColumn]) -> str: ) -class TestTraceMooncakeDatasetDeserializer: +def _get_from_kwargs(keys, kwargs) -> dict: + return {k: v for k, v in kwargs.items() if k in keys} + + +class TestMooncakeTraceFormat: + @pytest.mark.regression + def test_format_registered_with_deserializer(self, tmp_path: Path): + trace = _write_trace( + tmp_path, + '{"timestamp": 0.0, "input_length": 10, "output_length": 5, ' + '"hash_ids": [0]}\n', + ) + DatasetDeserializerFactory.deserialize( + config=MooncakeTraceFormatArgs(path=trace), + processor_factory=_ascending_processor, + random_seed=42, + ) + @pytest.fixture - def deserializer(self) -> TraceMooncakeDatasetDeserializer: - return TraceMooncakeDatasetDeserializer() + def deserializer(self) -> TraceDatasetDeserializer: + return TraceDatasetDeserializer() def _deserialize(self, deserializer, data, **kwargs): - field_names = ( - "timestamp_column", - "prompt_tokens_column", - "output_tokens_column", - "hash_ids_column", - "hash_id_block_size", + col_kwargs = _get_from_kwargs( + ( + "timestamp_column", + "prompt_tokens_column", + "output_tokens_column", + "hash_ids_column", + "hash_id_block_size", + ), + kwargs, ) - col_kwargs = {k: v for k, v in kwargs.items() if k in field_names} - config = TraceMooncakeDataArgs(path=data, **col_kwargs) + config = MooncakeTraceFormatArgs(path=data, **col_kwargs) return deserializer( config=config, processor_factory=_ascending_processor, random_seed=42, ) - @pytest.mark.smoke - def test_loads_sorted_rows_and_keeps_token_columns_aligned( - self, tmp_path: Path, deserializer - ): - n_rows = 10 - trace = _write_trace( - tmp_path, - _generate_trace( - n_rows, - [ - TraceColumn("timestamp", lambda i: n_rows - i), - TraceColumn("input_length", lambda i: n_rows - i), - TraceColumn("output_length", lambda i: (n_rows - i) * 10), - TraceColumn("hash_ids", lambda i: [n_rows - i]), - ], - ), - ) - ds = self._deserialize(deserializer, trace) - assert isinstance(ds, IterableDataset) - proc = _ascending_processor() - for i, row in enumerate(ds): - assert row["prompt_tokens_count"] == i + 1 - assert row["output_tokens_count"] == (i + 1) * 10 - assert len(proc.encode(row["prompt"])) == row["prompt_tokens_count"] - @pytest.mark.smoke def test_honors_custom_column_names(self, tmp_path: Path, deserializer): n_rows = 3 @@ -155,14 +146,14 @@ def test_honors_custom_column_names(self, tmp_path: Path, deserializer): _generate_trace( n_rows, [ - TraceColumn("ts", lambda i: i), - TraceColumn("input_tokens", lambda i: i + 1), - TraceColumn("generated_tokens", lambda i: (i + 1) * 10), - TraceColumn("ids", lambda i: [i]), + TraceColumnGenerator("ts", lambda i: i), + TraceColumnGenerator("input_tokens", lambda i: i + 1), + TraceColumnGenerator("generated_tokens", lambda i: (i + 1) * 10), + TraceColumnGenerator("ids", lambda i: [i]), ], ), ) - ds = self._deserialize( + self._deserialize( deserializer, trace, timestamp_column="ts", @@ -170,9 +161,6 @@ def test_honors_custom_column_names(self, tmp_path: Path, deserializer): output_tokens_column="generated_tokens", hash_ids_column="ids", ) - for i, row in enumerate(ds): - assert row["prompt_tokens_count"] == i + 1 - assert row["output_tokens_count"] == (i + 1) * 10 @pytest.mark.smoke def test_custom_hash_id_block_size(self, tmp_path: Path, deserializer): @@ -183,12 +171,12 @@ def test_custom_hash_id_block_size(self, tmp_path: Path, deserializer): _generate_trace( n_rows, [ - TraceColumn("timestamp", lambda i: i), - TraceColumn("input_length", lambda _: n_in), - TraceColumn("output_length", lambda i: i + 1), + TraceColumnGenerator("timestamp", lambda i: i), + TraceColumnGenerator("input_length", lambda _: n_in), + TraceColumnGenerator("output_length", lambda i: i + 1), # Would throw a DataNotSupportedError with default block size 512 # See row validation in trace_mooncake.py - TraceColumn("hash_ids", lambda _: [0, 1, 2, 3, 4]), + TraceColumnGenerator("hash_ids", lambda _: [0, 1, 2, 3, 4]), ], ), ) @@ -202,22 +190,22 @@ def test_generates_large_trace_prompts(self, tmp_path: Path, deserializer): output_lengths = [random.randint(3, 800) for _ in range(n_rows)] times = [0.0, 0.5, 1.0, 2.0] timestamps = [times[int(i / n_rows * len(times))] for i in range(n_rows)] - block_size = TraceMooncakeDataArgs(path=tmp_path).hash_id_block_size + block_size = MooncakeTraceFormatArgs(path=tmp_path).hash_id_block_size hash_ids = _make_valid_hash_ids(n_rows, prompt_lengths, block_size) trace = _write_trace( tmp_path, _generate_trace( n_rows, [ - TraceColumn("timestamp", lambda i: timestamps[i]), - TraceColumn("input_length", lambda i: prompt_lengths[i]), - TraceColumn("output_length", lambda i: output_lengths[i]), - TraceColumn("hash_ids", lambda i: hash_ids[i]), + TraceColumnGenerator("timestamp", lambda i: timestamps[i]), + TraceColumnGenerator("input_length", lambda i: prompt_lengths[i]), + TraceColumnGenerator("output_length", lambda i: output_lengths[i]), + TraceColumnGenerator("hash_ids", lambda i: hash_ids[i]), ], ), ) processor = _ascending_processor() - config = TraceMooncakeDataArgs(path=trace) + config = MooncakeTraceFormatArgs(path=trace) ds = deserializer( config=config, processor_factory=lambda: processor, @@ -233,60 +221,15 @@ def test_generates_large_trace_prompts(self, tmp_path: Path, deserializer): if actual_prompt_length != in_cnt: pytest.fail(f"{actual_prompt_length} != {in_cnt}") - @pytest.mark.smoke - def test_rejects_invalid_path(self, deserializer): - with pytest.raises(ValidationError, match="not a valid path"): - self._deserialize(deserializer, 123) - with pytest.raises(DataNotSupportedError, match="path to a trace file"): - self._deserialize(deserializer, "bad_path.jsonl") - @pytest.mark.sanity @pytest.mark.parametrize( ("content", "kwargs", "match"), [ - ("", {}, "empty"), - ( - '{"ts": 0, "input_length": 10, "output_length": 5, "hash_ids": [0]}\n', - {}, - "timestamp", - ), - ( - '{"timestamp": 0, "input_length": 10, "hash_ids": [0]}\n', - {}, - "output_length", - ), - ( - '{"timestamp": 0, "prompt_tokens": 10, "output_length": 5, ' - '"hash_ids": [0]}\n', - { - "prompt_tokens_column": "prompt_tokens", - "output_tokens_column": "out", - }, - "out", - ), - ( - '{"timestamp": "bad", "input_length": 10, "output_length": 5, ' - '"hash_ids": [0]}\n', - {}, - "scalar of type float", - ), - ( - '{"timestamp": 0, "input_length": "bad", "output_length": 5, ' - '"hash_ids": [0]}\n', - {}, - "scalar of type int32", - ), - ( - '{"timestamp": 0, "input_length": 10, "output_length": null, ' - '"hash_ids": [0]}\n', - {}, - "NoneType", - ), ( '{"timestamp": 0, "input_length": 10, "output_length": 5, ' - '"hash_ids": [0]}\nnot-json\n', + '"hash_ids": [-1]}\n', {}, - "generating the dataset", + "non-negative", ), ( '{"timestamp": 0, "input_length": 1024, "output_length": 5, ' @@ -303,17 +246,6 @@ def test_trace_validation_raises( with pytest.raises(DataNotSupportedError, match=match): self._deserialize(deserializer, trace, **kwargs) - @pytest.mark.sanity - def test_unsupported_file_suffix_raises(self, tmp_path: Path, deserializer): - trace = _write_trace( - tmp_path, - '{"timestamp": 0, "input_length": 10, "output_length": 5, ' - '"hash_ids": [0]}\n', - suffix=".json", - ) - with pytest.raises(DataNotSupportedError, match=r"Unsupported.*\.json"): - self._deserialize(deserializer, trace) - @pytest.mark.sanity def test_incompatible_encoding_raises(self, tmp_path: Path, deserializer): n_rows = 2 @@ -322,14 +254,14 @@ def test_incompatible_encoding_raises(self, tmp_path: Path, deserializer): _generate_trace( n_rows, [ - TraceColumn("timestamp", lambda i: i), - TraceColumn("input_length", lambda _: 1024), - TraceColumn("output_length", lambda _: 5), - TraceColumn("hash_ids", lambda i: [0, i + 1]), + TraceColumnGenerator("timestamp", lambda i: i), + TraceColumnGenerator("input_length", lambda _: 1024), + TraceColumnGenerator("output_length", lambda _: 5), + TraceColumnGenerator("hash_ids", lambda i: [0, i + 1]), ], ), ) - config = TraceMooncakeDataArgs(path=trace) + config = MooncakeTraceFormatArgs(path=trace) ds = deserializer( config=config, processor_factory=lambda: _ascending_processor(), @@ -348,14 +280,14 @@ def test_token_block_distinctness(self, tmp_path: Path, deserializer): _generate_trace( n_rows, [ - TraceColumn("timestamp", lambda i: i), - TraceColumn("input_length", lambda _: n_in), - TraceColumn("output_length", lambda _: 5), - TraceColumn("hash_ids", lambda i: [0, i + 1]), + TraceColumnGenerator("timestamp", lambda i: i), + TraceColumnGenerator("input_length", lambda _: n_in), + TraceColumnGenerator("output_length", lambda _: 5), + TraceColumnGenerator("hash_ids", lambda i: [0, i + 1]), ], ), ) - config = TraceMooncakeDataArgs(path=trace) + config = MooncakeTraceFormatArgs(path=trace) ds = deserializer( config=config, processor_factory=lambda: _compatible_processor(), diff --git a/tests/unit/data/deserializers/test_trace_synthetic.py b/tests/unit/data/deserializers/test_trace_synthetic.py deleted file mode 100644 index c624247b1..000000000 --- a/tests/unit/data/deserializers/test_trace_synthetic.py +++ /dev/null @@ -1,308 +0,0 @@ -from __future__ import annotations - -from pathlib import Path -from unittest.mock import Mock - -import pytest -from datasets import Dataset -from pydantic import ValidationError - -from guidellm.data.deserializers.trace_synthetic import ( - TraceSyntheticDataArgs, - TraceSyntheticDatasetDeserializer, -) -from guidellm.data.schemas import DataNotSupportedError - - -def _mock_processor() -> Mock: - """Tokenizer where each whitespace-delimited word is one token.""" - proc = Mock() - proc.encode.side_effect = lambda text: list(range(len(text.split()))) - proc.decode.side_effect = lambda tokens, skip_special_tokens=False: " ".join( - f"tok{i}" for i, _ in enumerate(tokens) - ) - return proc - - -def _write_trace(tmp_path: Path, content: str, suffix: str = ".jsonl") -> Path: - path = tmp_path / f"trace{suffix}" - path.write_text(content) - return path - - -class TestTraceSyntheticDatasetDeserializer: - @pytest.fixture - def deserializer(self) -> TraceSyntheticDatasetDeserializer: - return TraceSyntheticDatasetDeserializer() - - def _deserialize(self, deserializer, data, **kwargs): - col_kwargs = { - k: v - for k, v in kwargs.items() - if k in ("timestamp_column", "prompt_tokens_column", "output_tokens_column") - } - try: - config = TraceSyntheticDataArgs(path=data, **col_kwargs) - except ValidationError as e: - raise DataNotSupportedError( - f"Expected a path to a trace file, got {data!r}" - ) from e - return deserializer( - config=config, - processor_factory=_mock_processor, - random_seed=42, - ) - - @pytest.mark.smoke - def test_loads_sorted_rows_and_keeps_token_columns_aligned( - self, tmp_path: Path, deserializer - ): - trace = _write_trace( - tmp_path, - '{"timestamp": 5.0, "input_length": 3, "output_length": 30}\n' - '{"timestamp": 2.0, "input_length": 1, "output_length": 10}\n' - '{"timestamp": 2.0, "input_length": 2, "output_length": 20}\n' - '{"timestamp": 8.0, "input_length": 0, "output_length": 40}\n', - ) - - ds = self._deserialize(deserializer, trace) - - assert isinstance(ds, Dataset) - assert ds["prompt_tokens_count"] == [1, 2, 3, 0] - assert ds["output_tokens_count"] == [10, 20, 30, 40] - for prompt, token_count in zip( - ds["prompt"], ds["prompt_tokens_count"], strict=True - ): - assert len(_mock_processor().encode(prompt)) == token_count - - @pytest.mark.smoke - def test_emits_relative_timestamp_column_sorted_from_trace( - self, tmp_path: Path, deserializer - ): - """Each row gets offset seconds from the earliest sorted timestamp. - - ### WRITTEN BY AI ### - """ - trace = _write_trace( - tmp_path, - '{"timestamp": 5.0, "input_length": 1, "output_length": 10}\n' - '{"timestamp": 2.0, "input_length": 2, "output_length": 20}\n' - '{"timestamp": 8.0, "input_length": 3, "output_length": 30}\n' - '{"timestamp": 2.0, "input_length": 4, "output_length": 40}\n' - '{"timestamp": 5.0, "input_length": 5, "output_length": 50}\n', - ) - - ds = self._deserialize(deserializer, trace) - - assert ds["relative_timestamp"] == pytest.approx( - [0.0, 0.0, 3.0, 3.0, 6.0], abs=1e-9 - ) - - @pytest.mark.smoke - def test_honors_custom_column_names(self, tmp_path: Path, deserializer): - trace = _write_trace( - tmp_path, - '{"ts": 3.0, "input_tokens": 4, "generated_tokens": 40}\n' - '{"ts": 1.0, "input_tokens": 2, "generated_tokens": 20}\n', - ) - - ds = self._deserialize( - deserializer, - trace, - timestamp_column="ts", - prompt_tokens_column="input_tokens", - output_tokens_column="generated_tokens", - ) - - assert ds["prompt_tokens_count"] == [2, 4] - assert ds["output_tokens_count"] == [20, 40] - - @pytest.mark.smoke - def test_generates_large_trace_prompts_from_reusable_base( - self, tmp_path: Path, deserializer - ): - prompt_lengths = [ - 6755, - 7319, - 7234, - 2287, - 9013, - 6506, - 4824, - 3119, - 23090, - 3135, - 26874, - 10487, - 17448, - 6253, - 6725, - 13538, - 87162, - 6166, - 6320, - 2007, - 3174, - 3131, - 3159, - 6820, - 3154, - 9416, - 7460, - ] - output_lengths = [ - 500, - 490, - 794, - 316, - 3, - 3, - 173, - 20, - 453, - 19, - 458, - 402, - 610, - 3, - 32, - 71, - 402, - 24, - 548, - 354, - 19, - 23, - 20, - 26, - 21, - 145, - 3, - ] - timestamps = [ - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.5, - 0.5, - 0.5, - 0.5, - 0.5, - 0.5, - 0.5, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 2.0, - 2.0, - 2.0, - 2.0, - ] - trace = _write_trace( - tmp_path, - "\n".join( - ( - f'{{"timestamp": {timestamp}, ' - f'"input_length": {prompt_length}, ' - f'"output_length": {output_length}}}' - ) - for timestamp, prompt_length, output_length in zip( - timestamps, prompt_lengths, output_lengths, strict=True - ) - ), - ) - processor = _mock_processor() - - config = TraceSyntheticDataArgs(path=trace) - ds = deserializer( - config=config, - processor_factory=lambda: processor, - random_seed=42, - ) - - assert ds["prompt_tokens_count"] == prompt_lengths - assert ds["output_tokens_count"] == output_lengths - assert processor.encode.call_count <= len(prompt_lengths) + 4 - for prompt, token_count in zip( - ds["prompt"], ds["prompt_tokens_count"], strict=True - ): - assert len(_mock_processor().encode(prompt)) == token_count - - @pytest.mark.smoke - def test_rejects_invalid_data(self, deserializer): - with pytest.raises(DataNotSupportedError, match="path to a trace file"): - self._deserialize(deserializer, 123) - - @pytest.mark.sanity - @pytest.mark.parametrize( - ("content", "kwargs", "match"), - [ - ("", {}, "empty"), - ( - '{"ts": 0, "input_length": 10, "output_length": 5}\n', - {}, - "timestamp", - ), - ( - '{"timestamp": 0, "input_length": 10}\n', - {}, - "output_length", - ), - ( - '{"timestamp": 0, "prompt_tokens": 10, "output_length": 5}\n', - { - "prompt_tokens_column": "prompt_tokens", - "output_tokens_column": "out", - }, - "out", - ), - ( - '{"timestamp": "bad", "input_length": 10, "output_length": 5}\n', - {}, - "could not convert", - ), - ( - '{"timestamp": 0, "input_length": "bad", "output_length": 5}\n', - {}, - "invalid literal", - ), - ( - '{"timestamp": 0, "input_length": 10, "output_length": null}\n', - {}, - "NoneType", - ), - ( - '{"timestamp": 0, "input_length": 10, "output_length": 5}\nnot-json\n', - {}, - "generating the dataset", - ), - ], - ) - def test_trace_validation_raises( - self, tmp_path: Path, deserializer, content, kwargs, match - ): - trace = _write_trace(tmp_path, content) - - with pytest.raises(DataNotSupportedError, match=match): - self._deserialize(deserializer, trace, **kwargs) - - @pytest.mark.sanity - def test_unsupported_file_suffix_raises(self, tmp_path: Path, deserializer): - trace = _write_trace( - tmp_path, - '{"timestamp": 0, "input_length": 10, "output_length": 5}\n', - suffix=".json", - ) - - with pytest.raises(DataNotSupportedError, match=r"Unsupported.*\.json"): - self._deserialize(deserializer, trace) diff --git a/tests/unit/scheduler/test_trace_replay.py b/tests/unit/scheduler/test_trace_replay.py index e6872341e..884d0be96 100644 --- a/tests/unit/scheduler/test_trace_replay.py +++ b/tests/unit/scheduler/test_trace_replay.py @@ -2,95 +2,11 @@ import asyncio from multiprocessing import get_context -from pathlib import Path import pytest -from datasets.exceptions import DatasetGenerationError from guidellm.scheduler import SchedulingStrategy, TraceReplayStrategy from guidellm.schemas import RequestInfo, RequestSettings -from guidellm.utils.trace_io import load_relative_timestamps - - -def _write_trace(tmp_path: Path, content: str, suffix: str = ".jsonl") -> Path: - path = tmp_path / f"trace{suffix}" - path.write_text(content) - return path - - -class TestLoadRelativeTimestamps: - @pytest.mark.smoke - def test_loads_sorted_relative_timestamps_with_duplicates(self, tmp_path: Path): - trace = _write_trace( - tmp_path, - '{"timestamp": 5.0, "input_length": 10, "output_length": 10}\n' - '{"timestamp": 2.0, "input_length": 20, "output_length": 20}\n' - '{"timestamp": 2.0, "input_length": 30, "output_length": 30}\n' - '{"timestamp": 8.0, "input_length": 40, "output_length": 40}\n', - ) - - assert load_relative_timestamps(trace) == pytest.approx( - [0.0, 0.0, 3.0, 6.0], abs=1e-9 - ) - - @pytest.mark.smoke - def test_loads_custom_timestamp_column(self, tmp_path: Path): - trace = _write_trace( - tmp_path, - '{"ts": 10.0, "input_length": 10, "output_length": 10}\n' - '{"ts": 10.25, "input_length": 20, "output_length": 20}\n', - ) - - assert load_relative_timestamps(trace, timestamp_column="ts") == pytest.approx( - [0.0, 0.25], abs=1e-9 - ) - - @pytest.mark.smoke - @pytest.mark.parametrize( - ("suffix", "content", "error_type", "match"), - [ - (".jsonl", "", ValueError, "no valid rows"), - ( - ".json", - '[{"timestamp": 0, "input_length": 10, "output_length": 100}]', - ValueError, - r"Unsupported.*\.json", - ), - ( - ".csv", - "timestamp,input_length,output_length\n0,10,100\n", - ValueError, - r"Unsupported.*\.csv", - ), - ( - ".jsonl", - '{"ts": 0, "input_length": 10, "output_length": 100}\n', - KeyError, - "timestamp", - ), - ( - ".jsonl", - '{"timestamp": "bad", "input_length": 10, "output_length": 100}\n', - ValueError, - "could not convert", - ), - ( - ".jsonl", - '{"timestamp": 0, "input_length": 10, "output_length": 100}\n' - "not-json\n", - DatasetGenerationError, - "generating the dataset", - ), - ], - ) - def test_invalid_trace_inputs_raise( - self, tmp_path: Path, suffix, content, error_type, match - ): - trace = _write_trace(tmp_path, content, suffix=suffix) - - with pytest.raises(error_type, match=match): - load_relative_timestamps(trace) - TRACE_TIMESTAMPS = [0.0, 0.0, 0.0, 0.1, 0.1, 1.5, 2.0, 2.0, 3.5, 7.0]