diff --git a/openapi/ga/individual/platform.openapi.yaml b/openapi/ga/individual/platform.openapi.yaml index 361a63db74..d6d9711c3a 100644 --- a/openapi/ga/individual/platform.openapi.yaml +++ b/openapi/ga/individual/platform.openapi.yaml @@ -3731,10 +3731,13 @@ paths: explode: true schema: $ref: '#/components/schemas/ExperimentFilter' - description: Filter experiments by name, experiment_group_id, dataset_name, + description: 'Filter experiments by name, experiment_group_id, dataset_name, dataset_version, created_by, created_at, or updated_at. Pass is_deleted=true to return only soft-deleted experiments; omit to see only live ones. Pass is_pinned=true (or false) to filter by pinned state; omit to return both. + Filter by a rollup metric with numeric range operators ($gte/$lte/$gt/$lt/$eq): + filter[run_count][$gte]=5, filter[cost_usd.mean][$lte]=0.5, filter[latency_ms.p95][$lte]=1000, + or filter[evaluators..mean][$gte]=0.8.' responses: '200': description: Successful Response @@ -3743,11 +3746,11 @@ paths: schema: $ref: '#/components/schemas/ExperimentResponsesPage' '400': - description: Unsupported sort field + description: Unsupported sort or filter field '413': description: Too many experiments selected to sort in one request '503': - description: Telemetry store unavailable for a metric-based sort + description: Telemetry store unavailable for a metric-based sort or filter '422': description: Validation Error content: @@ -10322,6 +10325,24 @@ components: only unpinned experiments. Omit to return both. title: Is Pinned type: boolean + run_count: + allOf: + - $ref: '#/components/schemas/NumberFilter' + description: Filter by run count, e.g. filter[run_count][$gte]=5. + cost_usd: + allOf: + - $ref: '#/components/schemas/MetricStatFilters' + description: Filter by a cost_usd rollup stat, e.g. filter[cost_usd.mean][$lte]=0.5. + latency_ms: + allOf: + - $ref: '#/components/schemas/MetricStatFilters' + description: Filter by a latency_ms rollup stat, e.g. filter[latency_ms.p95][$lte]=1000. + evaluators: + description: Filter by an evaluator rollup stat, e.g. filter[evaluators..mean][$gte]=0.8. + title: Evaluators + additionalProperties: + $ref: '#/components/schemas/MetricStatFilters' + type: object title: ExperimentFilter type: object ExperimentGroupFilter: @@ -12911,6 +12932,38 @@ components: - metadata title: MetadataAnnotationInput description: Structured key/value metadata attached to a span or session. + MetricStatFilters: + additionalProperties: false + description: 'Numeric range filters keyed by rollup aggregate stat. + + + Declaring each stat explicitly (rather than an open ``dict[str, NumberFilter]``) + makes the valid + + stats visible in the OpenAPI schema, e.g. ``filter[cost_usd.mean][$lte]=0.5``. + These stats must + + stay in sync with the runtime sort/filter grammar (``_METRIC_STATS`` in the + experiments + + endpoints); a unit test guards the parity.' + properties: + sum: + $ref: '#/components/schemas/NumberFilter' + mean: + $ref: '#/components/schemas/NumberFilter' + median: + $ref: '#/components/schemas/NumberFilter' + p90: + $ref: '#/components/schemas/NumberFilter' + p95: + $ref: '#/components/schemas/NumberFilter' + p99: + $ref: '#/components/schemas/NumberFilter' + count: + $ref: '#/components/schemas/NumberFilter' + title: MetricStatFilters + type: object MiddlewareCall: properties: name: @@ -14225,6 +14278,32 @@ components: - text title: NoteAnnotationInput description: Free-text note attached to a span or session. + NumberFilter: + additionalProperties: false + minProperties: 1 + properties: + $gte: + description: Filter for results greater than or equal to this value. + title: $Gte + type: number + $lte: + description: Filter for results less than or equal to this value. + title: $Lte + type: number + $gt: + description: Filter for results greater than this value. + title: $Gt + type: number + $lt: + description: Filter for results less than this value. + title: $Lt + type: number + $eq: + description: Filter for results equal to this value. + title: $Eq + type: number + title: NumberFilter + type: object NumericFilter: additionalProperties: false description: "Range filter for numeric annotation values.\n\nAt least one of\ diff --git a/openapi/ga/openapi.yaml b/openapi/ga/openapi.yaml index 361a63db74..d6d9711c3a 100644 --- a/openapi/ga/openapi.yaml +++ b/openapi/ga/openapi.yaml @@ -3731,10 +3731,13 @@ paths: explode: true schema: $ref: '#/components/schemas/ExperimentFilter' - description: Filter experiments by name, experiment_group_id, dataset_name, + description: 'Filter experiments by name, experiment_group_id, dataset_name, dataset_version, created_by, created_at, or updated_at. Pass is_deleted=true to return only soft-deleted experiments; omit to see only live ones. Pass is_pinned=true (or false) to filter by pinned state; omit to return both. + Filter by a rollup metric with numeric range operators ($gte/$lte/$gt/$lt/$eq): + filter[run_count][$gte]=5, filter[cost_usd.mean][$lte]=0.5, filter[latency_ms.p95][$lte]=1000, + or filter[evaluators..mean][$gte]=0.8.' responses: '200': description: Successful Response @@ -3743,11 +3746,11 @@ paths: schema: $ref: '#/components/schemas/ExperimentResponsesPage' '400': - description: Unsupported sort field + description: Unsupported sort or filter field '413': description: Too many experiments selected to sort in one request '503': - description: Telemetry store unavailable for a metric-based sort + description: Telemetry store unavailable for a metric-based sort or filter '422': description: Validation Error content: @@ -10322,6 +10325,24 @@ components: only unpinned experiments. Omit to return both. title: Is Pinned type: boolean + run_count: + allOf: + - $ref: '#/components/schemas/NumberFilter' + description: Filter by run count, e.g. filter[run_count][$gte]=5. + cost_usd: + allOf: + - $ref: '#/components/schemas/MetricStatFilters' + description: Filter by a cost_usd rollup stat, e.g. filter[cost_usd.mean][$lte]=0.5. + latency_ms: + allOf: + - $ref: '#/components/schemas/MetricStatFilters' + description: Filter by a latency_ms rollup stat, e.g. filter[latency_ms.p95][$lte]=1000. + evaluators: + description: Filter by an evaluator rollup stat, e.g. filter[evaluators..mean][$gte]=0.8. + title: Evaluators + additionalProperties: + $ref: '#/components/schemas/MetricStatFilters' + type: object title: ExperimentFilter type: object ExperimentGroupFilter: @@ -12911,6 +12932,38 @@ components: - metadata title: MetadataAnnotationInput description: Structured key/value metadata attached to a span or session. + MetricStatFilters: + additionalProperties: false + description: 'Numeric range filters keyed by rollup aggregate stat. + + + Declaring each stat explicitly (rather than an open ``dict[str, NumberFilter]``) + makes the valid + + stats visible in the OpenAPI schema, e.g. ``filter[cost_usd.mean][$lte]=0.5``. + These stats must + + stay in sync with the runtime sort/filter grammar (``_METRIC_STATS`` in the + experiments + + endpoints); a unit test guards the parity.' + properties: + sum: + $ref: '#/components/schemas/NumberFilter' + mean: + $ref: '#/components/schemas/NumberFilter' + median: + $ref: '#/components/schemas/NumberFilter' + p90: + $ref: '#/components/schemas/NumberFilter' + p95: + $ref: '#/components/schemas/NumberFilter' + p99: + $ref: '#/components/schemas/NumberFilter' + count: + $ref: '#/components/schemas/NumberFilter' + title: MetricStatFilters + type: object MiddlewareCall: properties: name: @@ -14225,6 +14278,32 @@ components: - text title: NoteAnnotationInput description: Free-text note attached to a span or session. + NumberFilter: + additionalProperties: false + minProperties: 1 + properties: + $gte: + description: Filter for results greater than or equal to this value. + title: $Gte + type: number + $lte: + description: Filter for results less than or equal to this value. + title: $Lte + type: number + $gt: + description: Filter for results greater than this value. + title: $Gt + type: number + $lt: + description: Filter for results less than this value. + title: $Lt + type: number + $eq: + description: Filter for results equal to this value. + title: $Eq + type: number + title: NumberFilter + type: object NumericFilter: additionalProperties: false description: "Range filter for numeric annotation values.\n\nAt least one of\ diff --git a/openapi/openapi.yaml b/openapi/openapi.yaml index 361a63db74..d6d9711c3a 100644 --- a/openapi/openapi.yaml +++ b/openapi/openapi.yaml @@ -3731,10 +3731,13 @@ paths: explode: true schema: $ref: '#/components/schemas/ExperimentFilter' - description: Filter experiments by name, experiment_group_id, dataset_name, + description: 'Filter experiments by name, experiment_group_id, dataset_name, dataset_version, created_by, created_at, or updated_at. Pass is_deleted=true to return only soft-deleted experiments; omit to see only live ones. Pass is_pinned=true (or false) to filter by pinned state; omit to return both. + Filter by a rollup metric with numeric range operators ($gte/$lte/$gt/$lt/$eq): + filter[run_count][$gte]=5, filter[cost_usd.mean][$lte]=0.5, filter[latency_ms.p95][$lte]=1000, + or filter[evaluators..mean][$gte]=0.8.' responses: '200': description: Successful Response @@ -3743,11 +3746,11 @@ paths: schema: $ref: '#/components/schemas/ExperimentResponsesPage' '400': - description: Unsupported sort field + description: Unsupported sort or filter field '413': description: Too many experiments selected to sort in one request '503': - description: Telemetry store unavailable for a metric-based sort + description: Telemetry store unavailable for a metric-based sort or filter '422': description: Validation Error content: @@ -10322,6 +10325,24 @@ components: only unpinned experiments. Omit to return both. title: Is Pinned type: boolean + run_count: + allOf: + - $ref: '#/components/schemas/NumberFilter' + description: Filter by run count, e.g. filter[run_count][$gte]=5. + cost_usd: + allOf: + - $ref: '#/components/schemas/MetricStatFilters' + description: Filter by a cost_usd rollup stat, e.g. filter[cost_usd.mean][$lte]=0.5. + latency_ms: + allOf: + - $ref: '#/components/schemas/MetricStatFilters' + description: Filter by a latency_ms rollup stat, e.g. filter[latency_ms.p95][$lte]=1000. + evaluators: + description: Filter by an evaluator rollup stat, e.g. filter[evaluators..mean][$gte]=0.8. + title: Evaluators + additionalProperties: + $ref: '#/components/schemas/MetricStatFilters' + type: object title: ExperimentFilter type: object ExperimentGroupFilter: @@ -12911,6 +12932,38 @@ components: - metadata title: MetadataAnnotationInput description: Structured key/value metadata attached to a span or session. + MetricStatFilters: + additionalProperties: false + description: 'Numeric range filters keyed by rollup aggregate stat. + + + Declaring each stat explicitly (rather than an open ``dict[str, NumberFilter]``) + makes the valid + + stats visible in the OpenAPI schema, e.g. ``filter[cost_usd.mean][$lte]=0.5``. + These stats must + + stay in sync with the runtime sort/filter grammar (``_METRIC_STATS`` in the + experiments + + endpoints); a unit test guards the parity.' + properties: + sum: + $ref: '#/components/schemas/NumberFilter' + mean: + $ref: '#/components/schemas/NumberFilter' + median: + $ref: '#/components/schemas/NumberFilter' + p90: + $ref: '#/components/schemas/NumberFilter' + p95: + $ref: '#/components/schemas/NumberFilter' + p99: + $ref: '#/components/schemas/NumberFilter' + count: + $ref: '#/components/schemas/NumberFilter' + title: MetricStatFilters + type: object MiddlewareCall: properties: name: @@ -14225,6 +14278,32 @@ components: - text title: NoteAnnotationInput description: Free-text note attached to a span or session. + NumberFilter: + additionalProperties: false + minProperties: 1 + properties: + $gte: + description: Filter for results greater than or equal to this value. + title: $Gte + type: number + $lte: + description: Filter for results less than or equal to this value. + title: $Lte + type: number + $gt: + description: Filter for results greater than this value. + title: $Gt + type: number + $lt: + description: Filter for results less than this value. + title: $Lt + type: number + $eq: + description: Filter for results equal to this value. + title: $Eq + type: number + title: NumberFilter + type: object NumericFilter: additionalProperties: false description: "Range filter for numeric annotation values.\n\nAt least one of\ diff --git a/packages/nmp_common/src/nmp/common/entities/values.py b/packages/nmp_common/src/nmp/common/entities/values.py index 71fe66def9..1948b9b797 100644 --- a/packages/nmp_common/src/nmp/common/entities/values.py +++ b/packages/nmp_common/src/nmp/common/entities/values.py @@ -269,3 +269,44 @@ class StringFilter(Filter): protected_namespaces=(), populate_by_name=True, # Accept both "eq" and "$eq" as input ) + + +class NumberFilter(Filter): + gte: Optional[float] = Field( + None, + alias="$gte", + serialization_alias="$gte", + description="Filter for results greater than or equal to this value.", + ) + lte: Optional[float] = Field( + None, + alias="$lte", + serialization_alias="$lte", + description="Filter for results less than or equal to this value.", + ) + gt: Optional[float] = Field( + None, + alias="$gt", + serialization_alias="$gt", + description="Filter for results greater than this value.", + ) + lt: Optional[float] = Field( + None, + alias="$lt", + serialization_alias="$lt", + description="Filter for results less than this value.", + ) + eq: Optional[float] = Field( + None, + alias="$eq", + serialization_alias="$eq", + description="Filter for results equal to this value.", + ) + + model_config = ConfigDict( + extra="forbid", + protected_namespaces=(), + populate_by_name=True, # Accept both "gte" and "$gte" as input + # Reject an empty predicate ({}) at the schema/contract level — it carries no comparison. + json_schema_extra={"minProperties": 1}, + ) diff --git a/sdk/python/nemo-platform/.nmpcontext/openapi.yaml b/sdk/python/nemo-platform/.nmpcontext/openapi.yaml index 361a63db74..d6d9711c3a 100644 --- a/sdk/python/nemo-platform/.nmpcontext/openapi.yaml +++ b/sdk/python/nemo-platform/.nmpcontext/openapi.yaml @@ -3731,10 +3731,13 @@ paths: explode: true schema: $ref: '#/components/schemas/ExperimentFilter' - description: Filter experiments by name, experiment_group_id, dataset_name, + description: 'Filter experiments by name, experiment_group_id, dataset_name, dataset_version, created_by, created_at, or updated_at. Pass is_deleted=true to return only soft-deleted experiments; omit to see only live ones. Pass is_pinned=true (or false) to filter by pinned state; omit to return both. + Filter by a rollup metric with numeric range operators ($gte/$lte/$gt/$lt/$eq): + filter[run_count][$gte]=5, filter[cost_usd.mean][$lte]=0.5, filter[latency_ms.p95][$lte]=1000, + or filter[evaluators..mean][$gte]=0.8.' responses: '200': description: Successful Response @@ -3743,11 +3746,11 @@ paths: schema: $ref: '#/components/schemas/ExperimentResponsesPage' '400': - description: Unsupported sort field + description: Unsupported sort or filter field '413': description: Too many experiments selected to sort in one request '503': - description: Telemetry store unavailable for a metric-based sort + description: Telemetry store unavailable for a metric-based sort or filter '422': description: Validation Error content: @@ -10322,6 +10325,24 @@ components: only unpinned experiments. Omit to return both. title: Is Pinned type: boolean + run_count: + allOf: + - $ref: '#/components/schemas/NumberFilter' + description: Filter by run count, e.g. filter[run_count][$gte]=5. + cost_usd: + allOf: + - $ref: '#/components/schemas/MetricStatFilters' + description: Filter by a cost_usd rollup stat, e.g. filter[cost_usd.mean][$lte]=0.5. + latency_ms: + allOf: + - $ref: '#/components/schemas/MetricStatFilters' + description: Filter by a latency_ms rollup stat, e.g. filter[latency_ms.p95][$lte]=1000. + evaluators: + description: Filter by an evaluator rollup stat, e.g. filter[evaluators..mean][$gte]=0.8. + title: Evaluators + additionalProperties: + $ref: '#/components/schemas/MetricStatFilters' + type: object title: ExperimentFilter type: object ExperimentGroupFilter: @@ -12911,6 +12932,38 @@ components: - metadata title: MetadataAnnotationInput description: Structured key/value metadata attached to a span or session. + MetricStatFilters: + additionalProperties: false + description: 'Numeric range filters keyed by rollup aggregate stat. + + + Declaring each stat explicitly (rather than an open ``dict[str, NumberFilter]``) + makes the valid + + stats visible in the OpenAPI schema, e.g. ``filter[cost_usd.mean][$lte]=0.5``. + These stats must + + stay in sync with the runtime sort/filter grammar (``_METRIC_STATS`` in the + experiments + + endpoints); a unit test guards the parity.' + properties: + sum: + $ref: '#/components/schemas/NumberFilter' + mean: + $ref: '#/components/schemas/NumberFilter' + median: + $ref: '#/components/schemas/NumberFilter' + p90: + $ref: '#/components/schemas/NumberFilter' + p95: + $ref: '#/components/schemas/NumberFilter' + p99: + $ref: '#/components/schemas/NumberFilter' + count: + $ref: '#/components/schemas/NumberFilter' + title: MetricStatFilters + type: object MiddlewareCall: properties: name: @@ -14225,6 +14278,32 @@ components: - text title: NoteAnnotationInput description: Free-text note attached to a span or session. + NumberFilter: + additionalProperties: false + minProperties: 1 + properties: + $gte: + description: Filter for results greater than or equal to this value. + title: $Gte + type: number + $lte: + description: Filter for results less than or equal to this value. + title: $Lte + type: number + $gt: + description: Filter for results greater than this value. + title: $Gt + type: number + $lt: + description: Filter for results less than this value. + title: $Lt + type: number + $eq: + description: Filter for results equal to this value. + title: $Eq + type: number + title: NumberFilter + type: object NumericFilter: additionalProperties: false description: "Range filter for numeric annotation values.\n\nAt least one of\ diff --git a/sdk/python/nemo-platform/.nmpcontext/stainless.yaml b/sdk/python/nemo-platform/.nmpcontext/stainless.yaml index 2f0a999c90..6ec8c209c3 100644 --- a/sdk/python/nemo-platform/.nmpcontext/stainless.yaml +++ b/sdk/python/nemo-platform/.nmpcontext/stainless.yaml @@ -913,6 +913,8 @@ resources: experiment_request: ExperimentRequest experiment_response: ExperimentResponse experiment_responses_page: ExperimentResponsesPage + metric_stat_filters: MetricStatFilters + number_filter: NumberFilter methods: create: post /apis/intake/v2/workspaces/{workspace}/experiments list: get /apis/intake/v2/workspaces/{workspace}/experiments diff --git a/sdk/python/nemo-platform/src/nemo_platform/resources/experiments/api.md b/sdk/python/nemo-platform/src/nemo_platform/resources/experiments/api.md index d6275f4e35..0c6e16973f 100644 --- a/sdk/python/nemo-platform/src/nemo_platform/resources/experiments/api.md +++ b/sdk/python/nemo-platform/src/nemo_platform/resources/experiments/api.md @@ -9,6 +9,8 @@ from nemo_platform.types.experiments import ( ExperimentRequest, ExperimentResponse, ExperimentResponsesPage, + MetricStatFilters, + NumberFilter, ) ``` diff --git a/sdk/python/nemo-platform/src/nemo_platform/resources/experiments/experiments.py b/sdk/python/nemo-platform/src/nemo_platform/resources/experiments/experiments.py index dff0352ec2..cc5b0b6818 100644 --- a/sdk/python/nemo-platform/src/nemo_platform/resources/experiments/experiments.py +++ b/sdk/python/nemo-platform/src/nemo_platform/resources/experiments/experiments.py @@ -300,7 +300,10 @@ def list( filter: Filter experiments by name, experiment_group_id, dataset_name, dataset_version, created_by, created_at, or updated_at. Pass is_deleted=true to return only soft-deleted experiments; omit to see only live ones. Pass is_pinned=true (or - false) to filter by pinned state; omit to return both. + false) to filter by pinned state; omit to return both. Filter by a rollup metric + with numeric range operators ($gte/$lte/$gt/$lt/$eq): filter[run_count][$gte]=5, + filter[cost_usd.mean][$lte]=0.5, filter[latency_ms.p95][$lte]=1000, or + filter[evaluators..mean][$gte]=0.8. page: Page number. @@ -716,7 +719,10 @@ def list( filter: Filter experiments by name, experiment_group_id, dataset_name, dataset_version, created_by, created_at, or updated_at. Pass is_deleted=true to return only soft-deleted experiments; omit to see only live ones. Pass is_pinned=true (or - false) to filter by pinned state; omit to return both. + false) to filter by pinned state; omit to return both. Filter by a rollup metric + with numeric range operators ($gte/$lte/$gt/$lt/$eq): filter[run_count][$gte]=5, + filter[cost_usd.mean][$lte]=0.5, filter[latency_ms.p95][$lte]=1000, or + filter[evaluators..mean][$gte]=0.8. page: Page number. diff --git a/sdk/python/nemo-platform/src/nemo_platform/types/experiments/__init__.py b/sdk/python/nemo-platform/src/nemo_platform/types/experiments/__init__.py index 1fae468515..b046d1fc20 100644 --- a/sdk/python/nemo-platform/src/nemo_platform/types/experiments/__init__.py +++ b/sdk/python/nemo-platform/src/nemo_platform/types/experiments/__init__.py @@ -19,12 +19,14 @@ from .evaluator_aggregate import EvaluatorAggregate as EvaluatorAggregate from .experiment_response import ExperimentResponse as ExperimentResponse +from .number_filter_param import NumberFilterParam as NumberFilterParam from .session_list_params import SessionListParams as SessionListParams from .experiment_list_params import ExperimentListParams as ExperimentListParams from .experiment_filter_param import ExperimentFilterParam as ExperimentFilterParam from .experiment_create_params import ExperimentCreateParams as ExperimentCreateParams from .experiment_update_params import ExperimentUpdateParams as ExperimentUpdateParams from .experiment_responses_page import ExperimentResponsesPage as ExperimentResponsesPage +from .metric_stat_filters_param import MetricStatFiltersParam as MetricStatFiltersParam from .experiment_session_response import ExperimentSessionResponse as ExperimentSessionResponse from .experiment_session_filter_param import ExperimentSessionFilterParam as ExperimentSessionFilterParam from .experiment_session_responses_page import ExperimentSessionResponsesPage as ExperimentSessionResponsesPage diff --git a/sdk/python/nemo-platform/src/nemo_platform/types/experiments/experiment_filter_param.py b/sdk/python/nemo-platform/src/nemo_platform/types/experiments/experiment_filter_param.py index 93e43c31a4..d6c5354181 100644 --- a/sdk/python/nemo-platform/src/nemo_platform/types/experiments/experiment_filter_param.py +++ b/sdk/python/nemo-platform/src/nemo_platform/types/experiments/experiment_filter_param.py @@ -17,8 +17,11 @@ from __future__ import annotations +from typing import Dict from typing_extensions import TypedDict +from .number_filter_param import NumberFilterParam +from .metric_stat_filters_param import MetricStatFiltersParam from ..shared_params.datetime_filter import DatetimeFilter __all__ = ["ExperimentFilterParam"] @@ -27,6 +30,16 @@ class ExperimentFilterParam(TypedDict, total=False): """Filter for listing Experiments.""" + cost_usd: MetricStatFiltersParam + """Numeric range filters keyed by rollup aggregate stat. + + Declaring each stat explicitly (rather than an open `dict[str, NumberFilter]`) + makes the valid stats visible in the OpenAPI schema, e.g. + `filter[cost_usd.mean][$lte]=0.5`. These stats must stay in sync with the + runtime sort/filter grammar (`_METRIC_STATS` in the experiments endpoints); a + unit test guards the parity. + """ + created_at: DatetimeFilter """ Filter experiments by creation timestamp; supports `$gte` and `$lte` for ranges. @@ -41,6 +54,12 @@ class ExperimentFilterParam(TypedDict, total=False): dataset_version: str """Filter experiments by dataset version.""" + evaluators: Dict[str, MetricStatFiltersParam] + """Filter by an evaluator rollup stat, e.g. + + filter[evaluators..mean][$gte]=0.8. + """ + experiment_group_id: str """Filter experiments by owning group id.""" @@ -56,9 +75,22 @@ class ExperimentFilterParam(TypedDict, total=False): When false, returns only unpinned experiments. Omit to return both. """ + latency_ms: MetricStatFiltersParam + """Numeric range filters keyed by rollup aggregate stat. + + Declaring each stat explicitly (rather than an open `dict[str, NumberFilter]`) + makes the valid stats visible in the OpenAPI schema, e.g. + `filter[cost_usd.mean][$lte]=0.5`. These stats must stay in sync with the + runtime sort/filter grammar (`_METRIC_STATS` in the experiments endpoints); a + unit test guards the parity. + """ + name: str """Filter experiments by name.""" + run_count: NumberFilterParam + """Filter by run count, e.g. filter[run_count][$gte]=5.""" + updated_at: DatetimeFilter """ Filter experiments by last-updated timestamp; supports `$gte` and `$lte` for diff --git a/sdk/python/nemo-platform/src/nemo_platform/types/experiments/experiment_list_params.py b/sdk/python/nemo-platform/src/nemo_platform/types/experiments/experiment_list_params.py index b46a82cf72..678b114380 100644 --- a/sdk/python/nemo-platform/src/nemo_platform/types/experiments/experiment_list_params.py +++ b/sdk/python/nemo-platform/src/nemo_platform/types/experiments/experiment_list_params.py @@ -32,7 +32,10 @@ class ExperimentListParams(TypedDict, total=False): Filter experiments by name, experiment_group_id, dataset_name, dataset_version, created_by, created_at, or updated_at. Pass is_deleted=true to return only soft-deleted experiments; omit to see only live ones. Pass is_pinned=true (or - false) to filter by pinned state; omit to return both. + false) to filter by pinned state; omit to return both. Filter by a rollup metric + with numeric range operators ($gte/$lte/$gt/$lt/$eq): filter[run_count][$gte]=5, + filter[cost_usd.mean][$lte]=0.5, filter[latency_ms.p95][$lte]=1000, or + filter[evaluators..mean][$gte]=0.8. """ page: int diff --git a/sdk/python/nemo-platform/src/nemo_platform/types/experiments/metric_stat_filters_param.py b/sdk/python/nemo-platform/src/nemo_platform/types/experiments/metric_stat_filters_param.py new file mode 100644 index 0000000000..45ee5ddc7e --- /dev/null +++ b/sdk/python/nemo-platform/src/nemo_platform/types/experiments/metric_stat_filters_param.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import TypedDict + +from .number_filter_param import NumberFilterParam + +__all__ = ["MetricStatFiltersParam"] + + +class MetricStatFiltersParam(TypedDict, total=False): + """Numeric range filters keyed by rollup aggregate stat. + + Declaring each stat explicitly (rather than an open ``dict[str, NumberFilter]``) makes the valid + stats visible in the OpenAPI schema, e.g. ``filter[cost_usd.mean][$lte]=0.5``. These stats must + stay in sync with the runtime sort/filter grammar (``_METRIC_STATS`` in the experiments + endpoints); a unit test guards the parity. + """ + + count: NumberFilterParam + + mean: NumberFilterParam + + median: NumberFilterParam + + p90: NumberFilterParam + + p95: NumberFilterParam + + p99: NumberFilterParam + + sum: NumberFilterParam diff --git a/sdk/python/nemo-platform/src/nemo_platform/types/experiments/number_filter_param.py b/sdk/python/nemo-platform/src/nemo_platform/types/experiments/number_filter_param.py new file mode 100644 index 0000000000..5ae6a6313c --- /dev/null +++ b/sdk/python/nemo-platform/src/nemo_platform/types/experiments/number_filter_param.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Annotated, TypedDict + +from ..._utils import PropertyInfo + +__all__ = ["NumberFilterParam"] + + +class NumberFilterParam(TypedDict, total=False): + eq: Annotated[float, PropertyInfo(alias="$eq")] + """Filter for results equal to this value.""" + + gt: Annotated[float, PropertyInfo(alias="$gt")] + """Filter for results greater than this value.""" + + gte: Annotated[float, PropertyInfo(alias="$gte")] + """Filter for results greater than or equal to this value.""" + + lt: Annotated[float, PropertyInfo(alias="$lt")] + """Filter for results less than this value.""" + + lte: Annotated[float, PropertyInfo(alias="$lte")] + """Filter for results less than or equal to this value.""" diff --git a/sdk/python/nemo-platform/src/nemo_platform/types/guardrail/log_adapter_config_param.py b/sdk/python/nemo-platform/src/nemo_platform/types/guardrail/log_adapter_config_param.py index ca984872d3..bca63291f6 100644 --- a/sdk/python/nemo-platform/src/nemo_platform/types/guardrail/log_adapter_config_param.py +++ b/sdk/python/nemo-platform/src/nemo_platform/types/guardrail/log_adapter_config_param.py @@ -22,6 +22,10 @@ __all__ = ["LogAdapterConfigParam"] -class LogAdapterConfigParam(TypedDict, total=False, extra_items=object): # type: ignore[call-arg] +class LogAdapterConfigParam( # type: ignore[call-arg] + TypedDict, + total=False, + extra_items=object, # pyright: ignore[reportGeneralTypeIssues] +): name: str """The name of the adapter.""" diff --git a/sdk/python/nemo-platform/src/nemo_platform/types/guardrail/model_parameters_param.py b/sdk/python/nemo-platform/src/nemo_platform/types/guardrail/model_parameters_param.py index 45b1a1668d..f9aa189268 100644 --- a/sdk/python/nemo-platform/src/nemo_platform/types/guardrail/model_parameters_param.py +++ b/sdk/python/nemo-platform/src/nemo_platform/types/guardrail/model_parameters_param.py @@ -23,7 +23,11 @@ __all__ = ["ModelParametersParam"] -class ModelParametersParam(TypedDict, total=False, extra_items=object): # type: ignore[call-arg] +class ModelParametersParam( # type: ignore[call-arg] + TypedDict, + total=False, + extra_items=object, # pyright: ignore[reportGeneralTypeIssues] +): """Parameters for configuring how to interact with a model in a guardrails config.""" base_url: str diff --git a/sdk/python/nemo-platform/src/nemo_platform/types/intake/ingest/captured_chat_completions_request_param.py b/sdk/python/nemo-platform/src/nemo_platform/types/intake/ingest/captured_chat_completions_request_param.py index b4f4b30335..165a966d21 100644 --- a/sdk/python/nemo-platform/src/nemo_platform/types/intake/ingest/captured_chat_completions_request_param.py +++ b/sdk/python/nemo-platform/src/nemo_platform/types/intake/ingest/captured_chat_completions_request_param.py @@ -25,7 +25,11 @@ __all__ = ["CapturedChatCompletionsRequestParam"] -class CapturedChatCompletionsRequestParam(TypedDict, total=False, extra_items=object): # type: ignore[call-arg] +class CapturedChatCompletionsRequestParam( # type: ignore[call-arg] + TypedDict, + total=False, + extra_items=object, # pyright: ignore[reportGeneralTypeIssues] +): """Flexible captured chat-completions request.""" messages: Required[Iterable[CapturedChatMessageParam]] diff --git a/sdk/python/nemo-platform/src/nemo_platform/types/intake/ingest/captured_chat_completions_response_param.py b/sdk/python/nemo-platform/src/nemo_platform/types/intake/ingest/captured_chat_completions_response_param.py index 4d61750240..c95d0d5931 100644 --- a/sdk/python/nemo-platform/src/nemo_platform/types/intake/ingest/captured_chat_completions_response_param.py +++ b/sdk/python/nemo-platform/src/nemo_platform/types/intake/ingest/captured_chat_completions_response_param.py @@ -23,7 +23,11 @@ __all__ = ["CapturedChatCompletionsResponseParam"] -class CapturedChatCompletionsResponseParam(TypedDict, total=False, extra_items=object): # type: ignore[call-arg] +class CapturedChatCompletionsResponseParam( # type: ignore[call-arg] + TypedDict, + total=False, + extra_items=object, # pyright: ignore[reportGeneralTypeIssues] +): """Flexible captured chat-completions response.""" choices: Iterable[Dict[str, object]] diff --git a/sdk/python/nemo-platform/src/nemo_platform/types/intake/ingest/captured_chat_message_param.py b/sdk/python/nemo-platform/src/nemo_platform/types/intake/ingest/captured_chat_message_param.py index 741ea03a77..798b7bfd24 100644 --- a/sdk/python/nemo-platform/src/nemo_platform/types/intake/ingest/captured_chat_message_param.py +++ b/sdk/python/nemo-platform/src/nemo_platform/types/intake/ingest/captured_chat_message_param.py @@ -24,7 +24,11 @@ __all__ = ["CapturedChatMessageParam"] -class CapturedChatMessageParam(TypedDict, total=False, extra_items=object): # type: ignore[call-arg] +class CapturedChatMessageParam( # type: ignore[call-arg] + TypedDict, + total=False, + extra_items=object, # pyright: ignore[reportGeneralTypeIssues] +): """ A flexible message model that requires a valid role field but allows provider-specific fields. """ diff --git a/sdk/python/nemo-platform/src/nemo_platform/types/shared_params/inference_params.py b/sdk/python/nemo-platform/src/nemo_platform/types/shared_params/inference_params.py index fe38ddffa9..2209f4c44f 100644 --- a/sdk/python/nemo-platform/src/nemo_platform/types/shared_params/inference_params.py +++ b/sdk/python/nemo-platform/src/nemo_platform/types/shared_params/inference_params.py @@ -24,7 +24,11 @@ __all__ = ["InferenceParams"] -class InferenceParams(TypedDict, total=False, extra_items=object): # type: ignore[call-arg] +class InferenceParams( # type: ignore[call-arg] + TypedDict, + total=False, + extra_items=object, # pyright: ignore[reportGeneralTypeIssues] +): """Parameters for model inference. Extra fields can be supplied for additional options applied to the inference request directly. Fields not supported by the model may cause inference errors during evaluation. diff --git a/sdk/python/nemo-platform/tests/api_resources/test_experiments.py b/sdk/python/nemo-platform/tests/api_resources/test_experiments.py index 4f4fb2e308..04392c117d 100644 --- a/sdk/python/nemo-platform/tests/api_resources/test_experiments.py +++ b/sdk/python/nemo-platform/tests/api_resources/test_experiments.py @@ -256,6 +256,57 @@ def test_method_list_with_all_params(self, client: NeMoPlatform) -> None: experiment = client.experiments.list( workspace="workspace", filter={ + "cost_usd": { + "count": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "mean": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "median": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "p90": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "p95": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "p99": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "sum": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + }, "created_at": { "gte": parse_datetime("2019-12-27T18:11:19.117Z"), "lte": parse_datetime("2019-12-27T18:11:19.117Z"), @@ -263,10 +314,121 @@ def test_method_list_with_all_params(self, client: NeMoPlatform) -> None: "created_by": "created_by", "dataset_name": "dataset_name", "dataset_version": "dataset_version", + "evaluators": { + "foo": { + "count": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "mean": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "median": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "p90": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "p95": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "p99": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "sum": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + } + }, "experiment_group_id": "experiment_group_id", "is_deleted": True, "is_pinned": True, + "latency_ms": { + "count": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "mean": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "median": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "p90": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "p95": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "p99": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "sum": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + }, "name": "name", + "run_count": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, "updated_at": { "gte": parse_datetime("2019-12-27T18:11:19.117Z"), "lte": parse_datetime("2019-12-27T18:11:19.117Z"), @@ -694,6 +856,57 @@ async def test_method_list_with_all_params(self, async_client: AsyncNeMoPlatform experiment = await async_client.experiments.list( workspace="workspace", filter={ + "cost_usd": { + "count": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "mean": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "median": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "p90": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "p95": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "p99": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "sum": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + }, "created_at": { "gte": parse_datetime("2019-12-27T18:11:19.117Z"), "lte": parse_datetime("2019-12-27T18:11:19.117Z"), @@ -701,10 +914,121 @@ async def test_method_list_with_all_params(self, async_client: AsyncNeMoPlatform "created_by": "created_by", "dataset_name": "dataset_name", "dataset_version": "dataset_version", + "evaluators": { + "foo": { + "count": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "mean": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "median": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "p90": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "p95": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "p99": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "sum": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + } + }, "experiment_group_id": "experiment_group_id", "is_deleted": True, "is_pinned": True, + "latency_ms": { + "count": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "mean": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "median": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "p90": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "p95": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "p99": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + "sum": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, + }, "name": "name", + "run_count": { + "eq": 0, + "gt": 0, + "gte": 0, + "lt": 0, + "lte": 0, + }, "updated_at": { "gte": parse_datetime("2019-12-27T18:11:19.117Z"), "lte": parse_datetime("2019-12-27T18:11:19.117Z"), diff --git a/sdk/stainless.yaml b/sdk/stainless.yaml index 2f0a999c90..6ec8c209c3 100644 --- a/sdk/stainless.yaml +++ b/sdk/stainless.yaml @@ -913,6 +913,8 @@ resources: experiment_request: ExperimentRequest experiment_response: ExperimentResponse experiment_responses_page: ExperimentResponsesPage + metric_stat_filters: MetricStatFilters + number_filter: NumberFilter methods: create: post /apis/intake/v2/workspaces/{workspace}/experiments list: get /apis/intake/v2/workspaces/{workspace}/experiments diff --git a/services/intake/src/nmp/intake/api/v2/experiments/endpoints.py b/services/intake/src/nmp/intake/api/v2/experiments/endpoints.py index f95204b1dc..9382f74932 100644 --- a/services/intake/src/nmp/intake/api/v2/experiments/endpoints.py +++ b/services/intake/src/nmp/intake/api/v2/experiments/endpoints.py @@ -15,11 +15,11 @@ import secrets import time from datetime import datetime, timezone -from typing import Annotated, Any, Literal, TypeVar +from typing import Annotated, Any, Literal, NamedTuple, TypeVar from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from nmp.common.api.common import Page, PaginationData -from nmp.common.api.filter import ComparisonOperation, FilterOperator, LogicalOperation +from nmp.common.api.filter import ComparisonOperation, FilterOperation, FilterOperator, LogicalOperation from nmp.common.api.parsed_filter import ParsedFilter, make_filter_dep from nmp.common.api.utils import generate_openapi_extra_params from nmp.common.entities.client import EntityClient, EntityConflictError, EntityNotFoundError @@ -336,9 +336,9 @@ async def create_experiment( response_model=Page[ExperimentResponse], tags=[EXPERIMENTS_TAG], responses={ - 400: {"description": "Unsupported sort field"}, + 400: {"description": "Unsupported sort or filter field"}, 413: {"description": "Too many experiments selected to sort in one request"}, - 503: {"description": "Telemetry store unavailable for a metric-based sort"}, + 503: {"description": "Telemetry store unavailable for a metric-based sort or filter"}, }, openapi_extra=generate_openapi_extra_params( filter_schema=ExperimentFilter, @@ -346,7 +346,10 @@ async def create_experiment( "Filter experiments by name, experiment_group_id, " "dataset_name, dataset_version, created_by, created_at, or updated_at. " "Pass is_deleted=true to return only soft-deleted experiments; omit to see only live ones. " - "Pass is_pinned=true (or false) to filter by pinned state; omit to return both." + "Pass is_pinned=true (or false) to filter by pinned state; omit to return both. " + "Filter by a rollup metric with numeric range operators ($gte/$lte/$gt/$lt/$eq): " + "filter[run_count][$gte]=5, filter[cost_usd.mean][$lte]=0.5, " + "filter[latency_ms.p95][$lte]=1000, or filter[evaluators..mean][$gte]=0.8." ), ), ) @@ -374,13 +377,18 @@ async def list_experiments( _validate_sort_field(sort_field) _apply_is_deleted_filter(parsed) _apply_is_pinned_filter(parsed) - # Compute-on-read: fetch the whole (entity-filtered) group, hydrate every rollup, then sort and - # paginate in memory so a single request can sort by a ClickHouse metric that lives outside the - # entity store. Bounded to hundreds of experiments per group (see _MAX_GROUP_EXPERIMENTS). + # Rollup-metric predicates live in ClickHouse, not the entity store, so they can't be pushed to + # Postgres. Split them out of the filter tree: only the entity predicates go to entity_client.list; + # the metric ones are applied in memory after hydration. parsed (the full user filter) is left + # intact so the response still echoes it. + entity_operation, metric_predicates = _extract_metric_predicates(parsed.operation) + # Compute-on-read: fetch the whole (entity-filtered) group, hydrate every rollup, then filter, sort, + # and paginate in memory so a single request can sort/filter by a ClickHouse metric that lives + # outside the entity store. Bounded to hundreds of experiments per group (see _MAX_GROUP_EXPERIMENTS). result = await entity_client.list( Experiment, workspace=workspace, - filter_operation=parsed.operation, + filter_operation=entity_operation, page=1, page_size=_MAX_GROUP_EXPERIMENTS, ) @@ -405,15 +413,18 @@ async def list_experiments( ), ) hydrated = await _hydrate_rollups(workspace=workspace, responses=responses, rollup_repository=rollup_repository) - # A metric-backed sort (anything other than an entity column) is meaningless without rollups: if - # hydration was skipped (ClickHouse disabled or down) every metric value would be unset and the - # result would silently collapse to name order. Reject the request instead of returning a - # misleading 200. Entity-column sorts still work and an empty group still hydrates fine. - if not hydrated and sort_field not in _ENTITY_SORT_FIELDS: + # A metric-backed sort or filter is meaningless without rollups: if hydration was skipped (ClickHouse + # disabled or down) every metric value would be unset, so a metric sort would silently collapse to + # name order and a metric filter would drop everything. Reject the request instead of returning a + # misleading 200. Entity-column sorts/filters still work and an empty group still hydrates fine. + metric_sort = sort_field not in _ENTITY_SORT_FIELDS + if not hydrated and (metric_sort or metric_predicates): raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail=f"Cannot sort experiments by '{sort_field}': the telemetry store is unavailable.", + detail="Cannot sort or filter experiments by a rollup metric: the telemetry store is unavailable.", ) + if metric_predicates: + responses = [r for r in responses if _matches_metric_predicates(r, metric_predicates)] ordered = _sort_experiments(responses, field=sort_field, descending=descending) start = (page - 1) * page_size page_items = ordered[start : start + page_size] @@ -888,21 +899,141 @@ def _apply_is_pinned_filter(parsed: ParsedFilter) -> None: parsed.and_with(null_clause) -def _validate_sort_field(field: str) -> None: - """Reject a sort field that isn't an entity column or a known rollup-metric path.""" - if field in _ENTITY_SORT_FIELDS or field == "run_count": - return +# Metric heads whose dotted sub-paths address a ClickHouse rollup (not an entity column). Declared as +# self-mapping namespaces on ExperimentFilter so paths survive filter validation untranslated. +_METRIC_NAMESPACES = frozenset({"cost_usd", "latency_ms", "evaluators"}) +_NUMERIC_FILTER_OPERATORS = frozenset( + {FilterOperator.GTE, FilterOperator.LTE, FilterOperator.GT, FilterOperator.LT, FilterOperator.EQ} +) + + +class _MetricPredicate(NamedTuple): + field: str + operator: FilterOperator + threshold: float + + +def _is_valid_metric_path(field: str) -> bool: + """True if `field` is a rollup-metric path: run_count, ., or evaluators...""" + if field == "run_count": + return True head, _, rest = field.partition(".") - if head in ("cost_usd", "latency_ms") and rest in _METRIC_STATS: - return + if head in ("cost_usd", "latency_ms"): + return rest in _METRIC_STATS if head == "evaluators": # Evaluator names can contain dots (e.g. "harbor.verifier"); the stat is the last segment. name, _, stat = rest.rpartition(".") - if name and stat in _METRIC_STATS: - return + return bool(name) and stat in _METRIC_STATS + return False + + +def _validate_sort_field(field: str) -> None: + """Reject a sort field that isn't an entity column or a known rollup-metric path.""" + if field in _ENTITY_SORT_FIELDS or _is_valid_metric_path(field): + return raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unsupported sort field: {field}") +def _is_metric_field(field: str) -> bool: + """True if `field` is *intended* as a rollup metric (by head), valid path or not. + + Looser than ``_is_valid_metric_path``: classifies e.g. ``cost_usd.bogus`` as a metric so it gets + extracted and rejected with a 400 rather than forwarded to the entity store. Entity fields (already + translated to ``data.*`` by the filter dep) never match. + """ + return field == "run_count" or field.split(".", 1)[0] in _METRIC_NAMESPACES + + +def _operation_references_metric(operation: FilterOperation | None) -> bool: + if isinstance(operation, ComparisonOperation): + return _is_metric_field(operation.field) + if isinstance(operation, LogicalOperation): + return any(_operation_references_metric(child) for child in operation.operations) + return False + + +def _validated_metric_predicate(operation: ComparisonOperation) -> _MetricPredicate: + field = operation.field + if not _is_valid_metric_path(field): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unsupported metric filter field: {field}") + if operation.operator not in _NUMERIC_FILTER_OPERATORS: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Metric filter '{field}' supports only numeric operators ($gte/$lte/$gt/$lt/$eq).", + ) + try: + threshold = float(operation.value) + except (TypeError, ValueError) as exc: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Metric filter '{field}' requires a numeric value, got {operation.value!r}.", + ) from exc + return _MetricPredicate(field=field, operator=operation.operator, threshold=threshold) + + +def _extract_metric_predicates( + operation: FilterOperation | None, +) -> tuple[FilterOperation | None, list[_MetricPredicate]]: + """Split rollup-metric comparisons out of the filter tree. + + Returns ``(entity_operation, metric_predicates)``: the entity operation is forwarded to the entity + store, the metric predicates are applied in memory after hydration. Metric filters must be AND-ed + (at any nesting depth) with entity filters; a metric field under OR/NOT raises 400, since we can't + evaluate half a boolean tree in SQL and half in the application layer. Nested ANDs are flattened by + recursion, so a metric comparison inside a sub-AND is accepted. + """ + if operation is None: + return None, [] + if isinstance(operation, ComparisonOperation): + if _is_metric_field(operation.field): + return None, [_validated_metric_predicate(operation)] + return operation, [] + if isinstance(operation, LogicalOperation): + if operation.operator != FilterOperator.AND: + if _operation_references_metric(operation): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Metric filters can only be combined with AND, not OR/NOT.", + ) + return operation, [] + entity_ops: list[FilterOperation] = [] + metric_predicates: list[_MetricPredicate] = [] + for child in operation.operations: + # Recurse so metric comparisons nested inside sub-ANDs are extracted too (and OR/NOT + # children that reference a metric still raise inside this call). + child_entity, child_metrics = _extract_metric_predicates(child) + if child_entity is not None: + entity_ops.append(child_entity) + metric_predicates.extend(child_metrics) + if not entity_ops: + return None, metric_predicates + if len(entity_ops) == 1: + return entity_ops[0], metric_predicates + return LogicalOperation(operator=FilterOperator.AND, operations=entity_ops), metric_predicates + return operation, [] + + +def _matches_metric_predicates(response: ExperimentResponse, predicates: list[_MetricPredicate]) -> bool: + """True if the response satisfies every metric predicate. A missing metric never matches.""" + for predicate in predicates: + value = _experiment_sort_value(response, predicate.field) + if value is None or not _compare_metric(value, predicate.operator, predicate.threshold): + return False + return True + + +def _compare_metric(value: float, operator: FilterOperator, threshold: float) -> bool: + if operator == FilterOperator.GTE: + return value >= threshold + if operator == FilterOperator.LTE: + return value <= threshold + if operator == FilterOperator.GT: + return value > threshold + if operator == FilterOperator.LT: + return value < threshold + return value == threshold # EQ + + def _experiment_sort_value(response: ExperimentResponse, field: str) -> Any: """Value for `field` on a hydrated response, or None when the metric is absent (sorts last).""" if field in _ENTITY_SORT_FIELDS: diff --git a/services/intake/src/nmp/intake/api/v2/experiments/schemas.py b/services/intake/src/nmp/intake/api/v2/experiments/schemas.py index 0adff4c3b7..2eabe0dc03 100644 --- a/services/intake/src/nmp/intake/api/v2/experiments/schemas.py +++ b/services/intake/src/nmp/intake/api/v2/experiments/schemas.py @@ -10,9 +10,9 @@ from __future__ import annotations from datetime import datetime -from typing import Any +from typing import Annotated, Any -from nmp.common.entities.values import DatetimeFilter, Filter +from nmp.common.entities.values import DatetimeFilter, Filter, NumberFilter, map_entity_field from nmp.intake.entities.experiments import Experiment, ExperimentGroup from nmp.intake.spans.domain import SpanStatus from nmp.intake.spans.experiment_session_repository import ExperimentSessionRow @@ -152,6 +152,26 @@ def from_entity(cls, entity: Experiment) -> ExperimentResponse: ) +class MetricStatFilters(BaseModel): + """Numeric range filters keyed by rollup aggregate stat. + + Declaring each stat explicitly (rather than an open ``dict[str, NumberFilter]``) makes the valid + stats visible in the OpenAPI schema, e.g. ``filter[cost_usd.mean][$lte]=0.5``. These stats must + stay in sync with the runtime sort/filter grammar (``_METRIC_STATS`` in the experiments + endpoints); a unit test guards the parity. + """ + + model_config = ConfigDict(extra="forbid") + + sum: NumberFilter | None = None + mean: NumberFilter | None = None + median: NumberFilter | None = None + p90: NumberFilter | None = None + p95: NumberFilter | None = None + p99: NumberFilter | None = None + count: NumberFilter | None = None + + class ExperimentGroupFilter(Filter): """Filter for listing ExperimentGroups.""" @@ -189,6 +209,23 @@ class ExperimentFilter(Filter): "Omit to return both." ), ) + # Rollup-metric filters. These live in ClickHouse, not the entity store, so they're declared as + # self-mapping namespaces (the path is left untranslated) and applied in the application layer + # after rollup hydration rather than forwarded to Postgres. Stat sub-paths mirror the sort grammar: + # filter[cost_usd.mean][gte]=0.8, filter[evaluators..mean][lte]=0.5, filter[run_count][gte]=5. + run_count: Annotated[NumberFilter | None, map_entity_field("run_count")] = Field( + default=None, description="Filter by run count, e.g. filter[run_count][$gte]=5." + ) + cost_usd: Annotated[MetricStatFilters | None, map_entity_field("cost_usd", namespace=True)] = Field( + default=None, description="Filter by a cost_usd rollup stat, e.g. filter[cost_usd.mean][$lte]=0.5." + ) + latency_ms: Annotated[MetricStatFilters | None, map_entity_field("latency_ms", namespace=True)] = Field( + default=None, description="Filter by a latency_ms rollup stat, e.g. filter[latency_ms.p95][$lte]=1000." + ) + evaluators: Annotated[dict[str, MetricStatFilters] | None, map_entity_field("evaluators", namespace=True)] = Field( + default=None, + description="Filter by an evaluator rollup stat, e.g. filter[evaluators..mean][$gte]=0.8.", + ) class ExperimentSessionFilter(Filter): diff --git a/services/intake/tests/integration/spans/test_experiment_metric_sort.py b/services/intake/tests/integration/spans/test_experiment_metric_sort.py index f387fbb192..a08f5627ce 100644 --- a/services/intake/tests/integration/spans/test_experiment_metric_sort.py +++ b/services/intake/tests/integration/spans/test_experiment_metric_sort.py @@ -82,6 +82,39 @@ def test_list_sorts_by_cost_metric_missing_last(client: TestClient) -> None: assert names == [pricey, mid, cheap, norun] +def test_list_filters_by_cost_metric(client: TestClient) -> None: + # Same shape as the sort test, but filtering. Combine an entity filter (group) with two metric + # filters on different fields: cost_usd.mean <= 0.50 excludes pricey; run_count >= 1 excludes the + # never-ingested experiment (whose cost rollup is also missing). Sort by cost so the order is + # deterministic by value rather than by creation time. + suffix = uuid.uuid4().hex + group_id = _ensure_group(client, name=f"metric-filter-group-{suffix}") + started_at = datetime.now(timezone.utc).replace(microsecond=0) + cheap, pricey, mid = f"exp-cheap-{suffix}", f"exp-pricey-{suffix}", f"exp-mid-{suffix}" + for index, (name, cost) in enumerate([(cheap, 0.10), (pricey, 0.90), (mid, 0.50)]): + _create_experiment(client, group_id, name) + response = client.post( + ATIF_INGEST, + json=_atif_body(started_at=started_at, experiment_id=name, cost_usd=cost, offset_seconds=index * 10), + ) + assert response.status_code == 201, response.text + _create_experiment(client, group_id, f"exp-norun-{suffix}") # no ingest -> excluded by both predicates + + listed = client.get( + EXPERIMENTS, + params={ + "filter[experiment_group_id]": group_id, + "filter[cost_usd.mean][lte]": "0.50", + "filter[run_count][gte]": "1", + "sort": "cost_usd.mean", + "page_size": 50, + }, + ) + assert listed.status_code == 200, listed.text + names = [row["name"] for row in listed.json()["data"]] + assert names == [cheap, mid] + + def test_list_rejects_unknown_sort_field(client: TestClient) -> None: response = client.get(EXPERIMENTS, params={"sort": "bogus.field"}) assert response.status_code == 400, response.text diff --git a/services/intake/tests/test_experiment_metric_filter.py b/services/intake/tests/test_experiment_metric_filter.py new file mode 100644 index 0000000000..de56047df4 --- /dev/null +++ b/services/intake/tests/test_experiment_metric_filter.py @@ -0,0 +1,196 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Metric filtering on the experiments list (Option A app-merge). + +Two layers: pure helpers (split/validate/match) and endpoint wiring. The shared ``client`` fixture +overrides the rollup repository to ``None`` (ClickHouse unavailable), so a metric filter that passes +field validation must surface as 503 rather than silently dropping every row. +""" + +import pytest +from fastapi import HTTPException +from fastapi.testclient import TestClient +from nmp.common.api.filter import ComparisonOperation, FilterOperator, LogicalOperation +from nmp.intake.api.v2.experiments.endpoints import ( + _METRIC_STATS, + _extract_metric_predicates, + _is_metric_field, + _is_valid_metric_path, + _matches_metric_predicates, + _operation_references_metric, +) +from nmp.intake.api.v2.experiments.schemas import EvaluatorAggregate, ExperimentResponse, MetricStatFilters + +EXPERIMENTS = "/apis/intake/v2/workspaces/default/experiments" +GROUPS = "/apis/intake/v2/workspaces/default/experiment-groups" + + +def _exp(name: str, *, run_count: int = 0, cost_mean: float | None = None) -> ExperimentResponse: + return ExperimentResponse( + id=name, + name=name, + workspace="default", + experiment_group_id="grp", + dataset_name="ds", + run_count=run_count, + cost_usd=EvaluatorAggregate(mean=cost_mean) if cost_mean is not None else None, + ) + + +def _cmp(field: str, op: FilterOperator, value: object) -> ComparisonOperation: + return ComparisonOperation(operator=op, field=field, value=value) + + +# ----------------------------- pure helpers ----------------------------- + + +def test_is_metric_field_classifies_by_head() -> None: + assert _is_metric_field("run_count") + assert _is_metric_field("cost_usd.mean") + assert _is_metric_field("cost_usd.bogus") # intentionally loose: extracted, then rejected + assert _is_metric_field("evaluators.harbor.verifier.mean") + assert not _is_metric_field("data.name") + assert not _is_metric_field("name") + + +def test_metric_stat_filters_match_runtime_stats() -> None: + # The stats enumerated in the OpenAPI-visible schema must mirror the runtime grammar, or the spec + # would advertise stats the server rejects (or omit ones it accepts). + assert set(MetricStatFilters.model_fields) == set(_METRIC_STATS) + + +def test_is_valid_metric_path() -> None: + assert _is_valid_metric_path("run_count") + assert _is_valid_metric_path("cost_usd.p95") + assert _is_valid_metric_path("evaluators.harbor.verifier.mean") + assert not _is_valid_metric_path("cost_usd.bogus") + assert not _is_valid_metric_path("cost_usd") # missing stat + assert not _is_valid_metric_path("evaluators.reward") # missing stat + + +def test_extract_splits_metric_from_entity_predicates() -> None: + tree = LogicalOperation( + operator=FilterOperator.AND, + operations=[ + _cmp("data.name", FilterOperator.EQ, "foo"), + _cmp("cost_usd.mean", FilterOperator.LTE, "0.5"), + _cmp("run_count", FilterOperator.GTE, "3"), + ], + ) + entity_op, predicates = _extract_metric_predicates(tree) + # Only the entity predicate is forwarded to the store. + assert isinstance(entity_op, ComparisonOperation) + assert entity_op.field == "data.name" + assert {p.field for p in predicates} == {"cost_usd.mean", "run_count"} + assert all(isinstance(p.threshold, float) for p in predicates) + + +def test_extract_single_metric_comparison() -> None: + entity_op, predicates = _extract_metric_predicates(_cmp("cost_usd.mean", FilterOperator.GT, "0.1")) + assert entity_op is None + assert predicates[0].field == "cost_usd.mean" + + +def test_extract_rejects_bad_stat() -> None: + with pytest.raises(HTTPException) as exc: + _extract_metric_predicates(_cmp("cost_usd.bogus", FilterOperator.GTE, "1")) + assert exc.value.status_code == 400 + + +def test_extract_rejects_non_numeric_operator() -> None: + with pytest.raises(HTTPException) as exc: + _extract_metric_predicates(_cmp("cost_usd.mean", FilterOperator.LIKE, "x")) + assert exc.value.status_code == 400 + + +def test_extract_rejects_non_numeric_value() -> None: + with pytest.raises(HTTPException) as exc: + _extract_metric_predicates(_cmp("cost_usd.mean", FilterOperator.GTE, "not-a-number")) + assert exc.value.status_code == 400 + + +def test_extract_flattens_nested_and() -> None: + # A metric comparison nested inside a sub-AND is still AND-combined, so it must be accepted. + tree = LogicalOperation( + operator=FilterOperator.AND, + operations=[ + _cmp("data.name", FilterOperator.EQ, "foo"), + LogicalOperation( + operator=FilterOperator.AND, + operations=[_cmp("cost_usd.mean", FilterOperator.LTE, "0.5")], + ), + ], + ) + entity_op, predicates = _extract_metric_predicates(tree) + assert [p.field for p in predicates] == ["cost_usd.mean"] + # The entity predicate survives; the metric one is stripped out for in-app evaluation. + assert entity_op is not None + assert not _operation_references_metric(entity_op) + + +def test_extract_rejects_metric_under_or() -> None: + tree = LogicalOperation( + operator=FilterOperator.OR, + operations=[ + _cmp("cost_usd.mean", FilterOperator.GTE, "0.5"), + _cmp("data.name", FilterOperator.EQ, "foo"), + ], + ) + with pytest.raises(HTTPException) as exc: + _extract_metric_predicates(tree) + assert exc.value.status_code == 400 + + +def test_matches_predicates_excludes_missing_metric() -> None: + cheap = _exp("cheap", cost_mean=0.2) + pricey = _exp("pricey", cost_mean=0.9) + norun = _exp("norun") # no cost rollup + _, predicates = _extract_metric_predicates(_cmp("cost_usd.mean", FilterOperator.LTE, "0.5")) + assert _matches_metric_predicates(cheap, predicates) + assert not _matches_metric_predicates(pricey, predicates) + assert not _matches_metric_predicates(norun, predicates) # missing metric never matches + + +# ----------------------------- endpoint wiring ----------------------------- + + +def _make_experiment(client: TestClient, name: str = "exp-1", group: str = "grp-1") -> None: + group_resp = client.post(GROUPS, json={"name": group}) + assert group_resp.status_code == 201, group_resp.text + exp_resp = client.post( + EXPERIMENTS, + json={"name": name, "experiment_group_id": group_resp.json()["id"], "dataset_name": "ds"}, + ) + assert exp_resp.status_code == 201, exp_resp.text + + +def test_metric_filter_passes_validation_and_503s_without_rollups(client: TestClient) -> None: + # If the namespace declaration works, these paths get past field validation and reach the + # metric-filter path, which 503s because the (mocked) rollup repository is None. Needs a non-empty + # result set: an empty group has nothing to hydrate and correctly returns 200 empty. + _make_experiment(client) + for param in ( + {"filter[cost_usd.mean][gte]": "0.5"}, + {"filter[latency_ms.p95][lte]": "1000"}, + {"filter[evaluators.harbor.verifier.mean][gte]": "0.8"}, + {"filter[run_count][gte]": "5"}, + ): + response = client.get(EXPERIMENTS, params=param) + assert response.status_code == 503, (param, response.text) + + +def test_metric_filter_bad_stat_returns_400(client: TestClient) -> None: + response = client.get(EXPERIMENTS, params={"filter[cost_usd.bogus][gte]": "0.5"}) + assert response.status_code == 400, response.text + + +def test_metric_filter_non_numeric_value_returns_400(client: TestClient) -> None: + response = client.get(EXPERIMENTS, params={"filter[cost_usd.mean][gte]": "abc"}) + assert response.status_code == 400, response.text + + +def test_metric_filter_under_or_returns_400(client: TestClient) -> None: + json_filter = '{"$or": [{"cost_usd.mean": {"$gte": 0.5}}, {"name": {"$eq": "x"}}]}' + response = client.get(EXPERIMENTS, params={"filter": json_filter}) + assert response.status_code == 400, response.text