Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions projects/fal/src/fal/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ class FalServerlessHost(Host):
"health_check_config",
"skip_retry_conditions",
"termination_grace_period_seconds",
"mounted_secrets",
}
)

Expand Down Expand Up @@ -645,6 +646,7 @@ def register(
termination_grace_period_seconds = options.host.get(
"termination_grace_period_seconds"
)
mounted_secrets = options.host.get("mounted_secrets")
machine_requirements = MachineRequirements(
machine_types=machine_type, # type: ignore
num_gpus=options.host.get("num_gpus"),
Expand Down Expand Up @@ -696,6 +698,7 @@ def register(
skip_retry_conditions=skip_retry_conditions,
environment_name=environment_name,
termination_grace_period_seconds=termination_grace_period_seconds,
mounted_secrets=mounted_secrets,
):
for log in partial_result.logs:
self._log_printer.print(log)
Expand Down
4 changes: 4 additions & 0 deletions projects/fal/src/fal/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ class App(BaseServable):
image: ClassVar[Optional[ContainerImage]] = None
local_file_path: ClassVar[Optional[str]] = None
skip_retry_conditions: ClassVar[Optional[list[RetryConditionLiteral]]] = None
mounted_secrets: ClassVar[Optional[list[str]]] = None
termination_grace_period_seconds: ClassVar[Optional[int]] = None

isolate_channel: async_grpc.Channel | None = None
Expand Down Expand Up @@ -667,6 +668,9 @@ def __init_subclass__(cls, **kwargs):
if cls.skip_retry_conditions is not None:
cls.host_kwargs["skip_retry_conditions"] = cls.skip_retry_conditions

if cls.mounted_secrets is not None:
cls.host_kwargs["mounted_secrets"] = cls.mounted_secrets

if cls.termination_grace_period_seconds is not None:
cls.host_kwargs["termination_grace_period_seconds"] = (
cls.termination_grace_period_seconds
Expand Down
4 changes: 4 additions & 0 deletions projects/fal/src/fal/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,7 @@ def register(
skip_retry_conditions: list[RetryConditionLiteral] | None = None,
environment_name: str | None = None,
termination_grace_period_seconds: int | None = None,
mounted_secrets: list[str] | None = None,
) -> Iterator[RegisterApplicationResult]:
wrapped_function = to_serialized_object(function, serialization_method)
if machine_requirements:
Expand Down Expand Up @@ -853,6 +854,7 @@ def register(
skip_retry_conditions=wrapped_skip_retry_conditions,
environment_name=environment_name,
termination_grace_period_seconds=termination_grace_period_seconds,
mounted_secrets=mounted_secrets or ["*"],
)
for partial_result in self.stub.RegisterApplication(request):
yield from_grpc(partial_result)
Expand All @@ -874,6 +876,7 @@ def update_application(
startup_timeout: int | None = None,
valid_regions: list[str] | None = None,
machine_types: list[str] | None = None,
mounted_secrets: list[str] | None = None,
*,
environment_name: str | None = None,
) -> AliasInfo:
Expand All @@ -892,6 +895,7 @@ def update_application(
startup_timeout=startup_timeout,
valid_regions=valid_regions,
machine_types=machine_types,
mounted_secrets=mounted_secrets,
environment_name=environment_name,
)
res: isolate_proto.UpdateApplicationResult = self.stub.UpdateApplication(
Expand Down
103 changes: 103 additions & 0 deletions projects/fal/tests/e2e/test_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1970,6 +1970,109 @@ def test_runner_machine_type(host: api.FalServerlessHost, test_sleep_app: str):
assert target_runner.machine_type == "XS"


class SecretsOutput(BaseModel):
has_mounted: bool
has_not_mounted: bool


class MountedSecretsApp(fal.App, keep_alive=300, max_concurrency=1):
machine_type = "XS"
mounted_secrets = ["MOUNTED_TEST_SECRET"]

@fal.endpoint("/")
def check_secrets(self) -> SecretsOutput:
return SecretsOutput(
has_mounted="MOUNTED_TEST_SECRET" in os.environ,
has_not_mounted="NOT_MOUNTED_SECRET" in os.environ,
)


class MountAllSecretsApp(fal.App, keep_alive=300, max_concurrency=1):
machine_type = "XS"

@fal.endpoint("/")
def check_secrets(self) -> SecretsOutput:
return SecretsOutput(
has_mounted="MOUNTED_TEST_SECRET" in os.environ,
has_not_mounted="NOT_MOUNTED_SECRET" in os.environ,
)


class MountNoSecretsApp(fal.App, keep_alive=300, max_concurrency=1):
machine_type = "XS"
mounted_secrets = []

@fal.endpoint("/")
def check_secrets(self) -> SecretsOutput:
return SecretsOutput(
has_mounted="MOUNTED_TEST_SECRET" in os.environ,
has_not_mounted="NOT_MOUNTED_SECRET" in os.environ,
)


@pytest.fixture(scope="module")
def _ensure_test_secrets(host: api.FalServerlessHost):
"""Create two secrets so we can test selective mounting."""
with host._connection as client:
client.set_secret("MOUNTED_TEST_SECRET", "secret-value-1")
client.set_secret("NOT_MOUNTED_SECRET", "secret-value-2")
yield
with host._connection as client:
client.delete_secret("MOUNTED_TEST_SECRET")
client.delete_secret("NOT_MOUNTED_SECRET")


@pytest.fixture(scope="module")
def test_mounted_secrets_app(
host: api.FalServerlessHost, user: User, _ensure_test_secrets
):
app = wrap_app(MountedSecretsApp)
with register_app(host, app, "mounted-secrets") as (app_alias, _):
yield f"{user.username}/{app_alias}"


@pytest.fixture(scope="module")
def test_mount_all_secrets_app(
host: api.FalServerlessHost, user: User, _ensure_test_secrets
):
app = wrap_app(MountAllSecretsApp)
with register_app(host, app, "mount-all-secrets") as (app_alias, _):
yield f"{user.username}/{app_alias}"


@pytest.fixture(scope="module")
def test_mount_no_secrets_app(
host: api.FalServerlessHost, user: User, _ensure_test_secrets
):
app = wrap_app(MountNoSecretsApp)
with register_app(host, app, "mount-no-secrets") as (app_alias, _):
yield f"{user.username}/{app_alias}"


@pytest.mark.flaky(max_runs=3)
def test_mounted_secrets_only_specified(test_mounted_secrets_app: str):
"""Only the secret listed in mounted_secrets should be available."""
result = apps.run(test_mounted_secrets_app, arguments={})
assert result["has_mounted"] is True
assert result["has_not_mounted"] is False


@pytest.mark.flaky(max_runs=3)
def test_mount_all_secrets_default(test_mount_all_secrets_app: str):
"""Without mounted_secrets set, all secrets should be available (default ["*"])."""
result = apps.run(test_mount_all_secrets_app, arguments={})
assert result["has_mounted"] is True
assert result["has_not_mounted"] is True


@pytest.mark.flaky(max_runs=3)
def test_mount_no_secrets(test_mount_no_secrets_app: str):
"""With mounted_secrets=[], no secrets should be available."""
result = apps.run(test_mount_no_secrets_app, arguments={})
assert result["has_mounted"] is False
assert result["has_not_mounted"] is False


class RequestContextOutput(BaseModel):
request_id_from_context: Optional[str]
endpoint_from_context: Optional[str]
Expand Down
4 changes: 4 additions & 0 deletions projects/isolate_proto/src/isolate_proto/controller.proto
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,8 @@ message RegisterApplicationRequest {
repeated RetryCondition skip_retry_conditions = 17;
// Grace period in seconds before forced termination of runners after a shutdown request.
optional int32 termination_grace_period_seconds = 18;
// Which secrets to mount. ["*"] means all, [] means none.
repeated string mounted_secrets = 19;
}

message RegisterApplicationResultType {
Expand All @@ -320,6 +322,8 @@ message UpdateApplicationRequest {
optional int32 concurrency_buffer_perc = 11;
optional int32 scaling_delay_seconds = 12;
optional string environment_name = 13;
// Which secrets to mount. ["*"] means all, [] means none.
repeated string mounted_secrets = 14;
}

message UpdateApplicationResult {
Expand Down
196 changes: 98 additions & 98 deletions projects/isolate_proto/src/isolate_proto/controller_pb2.py

Large diffs are not rendered by default.

14 changes: 12 additions & 2 deletions projects/isolate_proto/src/isolate_proto/controller_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,7 @@ class RegisterApplicationRequest(google.protobuf.message.Message):
HEALTH_CHECK_CONFIG_FIELD_NUMBER: builtins.int
SKIP_RETRY_CONDITIONS_FIELD_NUMBER: builtins.int
TERMINATION_GRACE_PERIOD_SECONDS_FIELD_NUMBER: builtins.int
MOUNTED_SECRETS_FIELD_NUMBER: builtins.int
@property
def environments(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[server_pb2.EnvironmentDefinition]:
"""Environment definitions."""
Expand Down Expand Up @@ -693,6 +694,9 @@ class RegisterApplicationRequest(google.protobuf.message.Message):
"""Skip retry on certain conditions"""
termination_grace_period_seconds: builtins.int
"""Grace period in seconds before forced termination of runners after a shutdown request."""
@property
def mounted_secrets(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
"""Which secrets to mount. ["*"] means all, [] means none."""
def __init__(
self,
*,
Expand All @@ -714,9 +718,10 @@ class RegisterApplicationRequest(google.protobuf.message.Message):
health_check_config: global___ApplicationHealthCheckConfig | None = ...,
skip_retry_conditions: collections.abc.Iterable[global___RetryCondition.ValueType] | None = ...,
termination_grace_period_seconds: builtins.int | None = ...,
mounted_secrets: collections.abc.Iterable[builtins.str] | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["_application_name", b"_application_name", "_auth_mode", b"_auth_mode", "_deployment_strategy", b"_deployment_strategy", "_environment_name", b"_environment_name", "_health_check_config", b"_health_check_config", "_health_check_path", b"_health_check_path", "_machine_requirements", b"_machine_requirements", "_max_concurrency", b"_max_concurrency", "_metadata", b"_metadata", "_private_logs", b"_private_logs", "_scale", b"_scale", "_setup_func", b"_setup_func", "_source_code", b"_source_code", "_termination_grace_period_seconds", b"_termination_grace_period_seconds", "application_name", b"application_name", "auth_mode", b"auth_mode", "deployment_strategy", b"deployment_strategy", "environment_name", b"environment_name", "function", b"function", "health_check_config", b"health_check_config", "health_check_path", b"health_check_path", "machine_requirements", b"machine_requirements", "max_concurrency", b"max_concurrency", "metadata", b"metadata", "private_logs", b"private_logs", "scale", b"scale", "setup_func", b"setup_func", "source_code", b"source_code", "termination_grace_period_seconds", b"termination_grace_period_seconds"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["_application_name", b"_application_name", "_auth_mode", b"_auth_mode", "_deployment_strategy", b"_deployment_strategy", "_environment_name", b"_environment_name", "_health_check_config", b"_health_check_config", "_health_check_path", b"_health_check_path", "_machine_requirements", b"_machine_requirements", "_max_concurrency", b"_max_concurrency", "_metadata", b"_metadata", "_private_logs", b"_private_logs", "_scale", b"_scale", "_setup_func", b"_setup_func", "_source_code", b"_source_code", "_termination_grace_period_seconds", b"_termination_grace_period_seconds", "application_name", b"application_name", "auth_mode", b"auth_mode", "deployment_strategy", b"deployment_strategy", "environment_name", b"environment_name", "environments", b"environments", "files", b"files", "function", b"function", "health_check_config", b"health_check_config", "health_check_path", b"health_check_path", "machine_requirements", b"machine_requirements", "max_concurrency", b"max_concurrency", "metadata", b"metadata", "private_logs", b"private_logs", "scale", b"scale", "setup_func", b"setup_func", "skip_retry_conditions", b"skip_retry_conditions", "source_code", b"source_code", "termination_grace_period_seconds", b"termination_grace_period_seconds"]) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["_application_name", b"_application_name", "_auth_mode", b"_auth_mode", "_deployment_strategy", b"_deployment_strategy", "_environment_name", b"_environment_name", "_health_check_config", b"_health_check_config", "_health_check_path", b"_health_check_path", "_machine_requirements", b"_machine_requirements", "_max_concurrency", b"_max_concurrency", "_metadata", b"_metadata", "_private_logs", b"_private_logs", "_scale", b"_scale", "_setup_func", b"_setup_func", "_source_code", b"_source_code", "_termination_grace_period_seconds", b"_termination_grace_period_seconds", "application_name", b"application_name", "auth_mode", b"auth_mode", "deployment_strategy", b"deployment_strategy", "environment_name", b"environment_name", "environments", b"environments", "files", b"files", "function", b"function", "health_check_config", b"health_check_config", "health_check_path", b"health_check_path", "machine_requirements", b"machine_requirements", "max_concurrency", b"max_concurrency", "metadata", b"metadata", "mounted_secrets", b"mounted_secrets", "private_logs", b"private_logs", "scale", b"scale", "setup_func", b"setup_func", "skip_retry_conditions", b"skip_retry_conditions", "source_code", b"source_code", "termination_grace_period_seconds", b"termination_grace_period_seconds"]) -> None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_application_name", b"_application_name"]) -> typing_extensions.Literal["application_name"] | None: ...
@typing.overload
Expand Down Expand Up @@ -809,6 +814,7 @@ class UpdateApplicationRequest(google.protobuf.message.Message):
CONCURRENCY_BUFFER_PERC_FIELD_NUMBER: builtins.int
SCALING_DELAY_SECONDS_FIELD_NUMBER: builtins.int
ENVIRONMENT_NAME_FIELD_NUMBER: builtins.int
MOUNTED_SECRETS_FIELD_NUMBER: builtins.int
application_name: builtins.str
keep_alive: builtins.int
max_multiplexing: builtins.int
Expand All @@ -824,6 +830,9 @@ class UpdateApplicationRequest(google.protobuf.message.Message):
concurrency_buffer_perc: builtins.int
scaling_delay_seconds: builtins.int
environment_name: builtins.str
@property
def mounted_secrets(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
"""Which secrets to mount. ["*"] means all, [] means none."""
def __init__(
self,
*,
Expand All @@ -840,9 +849,10 @@ class UpdateApplicationRequest(google.protobuf.message.Message):
concurrency_buffer_perc: builtins.int | None = ...,
scaling_delay_seconds: builtins.int | None = ...,
environment_name: builtins.str | None = ...,
mounted_secrets: collections.abc.Iterable[builtins.str] | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["_concurrency_buffer", b"_concurrency_buffer", "_concurrency_buffer_perc", b"_concurrency_buffer_perc", "_environment_name", b"_environment_name", "_keep_alive", b"_keep_alive", "_max_concurrency", b"_max_concurrency", "_max_multiplexing", b"_max_multiplexing", "_min_concurrency", b"_min_concurrency", "_request_timeout", b"_request_timeout", "_scaling_delay_seconds", b"_scaling_delay_seconds", "_startup_timeout", b"_startup_timeout", "concurrency_buffer", b"concurrency_buffer", "concurrency_buffer_perc", b"concurrency_buffer_perc", "environment_name", b"environment_name", "keep_alive", b"keep_alive", "max_concurrency", b"max_concurrency", "max_multiplexing", b"max_multiplexing", "min_concurrency", b"min_concurrency", "request_timeout", b"request_timeout", "scaling_delay_seconds", b"scaling_delay_seconds", "startup_timeout", b"startup_timeout"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["_concurrency_buffer", b"_concurrency_buffer", "_concurrency_buffer_perc", b"_concurrency_buffer_perc", "_environment_name", b"_environment_name", "_keep_alive", b"_keep_alive", "_max_concurrency", b"_max_concurrency", "_max_multiplexing", b"_max_multiplexing", "_min_concurrency", b"_min_concurrency", "_request_timeout", b"_request_timeout", "_scaling_delay_seconds", b"_scaling_delay_seconds", "_startup_timeout", b"_startup_timeout", "application_name", b"application_name", "concurrency_buffer", b"concurrency_buffer", "concurrency_buffer_perc", b"concurrency_buffer_perc", "environment_name", b"environment_name", "keep_alive", b"keep_alive", "machine_types", b"machine_types", "max_concurrency", b"max_concurrency", "max_multiplexing", b"max_multiplexing", "min_concurrency", b"min_concurrency", "request_timeout", b"request_timeout", "scaling_delay_seconds", b"scaling_delay_seconds", "startup_timeout", b"startup_timeout", "valid_regions", b"valid_regions"]) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["_concurrency_buffer", b"_concurrency_buffer", "_concurrency_buffer_perc", b"_concurrency_buffer_perc", "_environment_name", b"_environment_name", "_keep_alive", b"_keep_alive", "_max_concurrency", b"_max_concurrency", "_max_multiplexing", b"_max_multiplexing", "_min_concurrency", b"_min_concurrency", "_request_timeout", b"_request_timeout", "_scaling_delay_seconds", b"_scaling_delay_seconds", "_startup_timeout", b"_startup_timeout", "application_name", b"application_name", "concurrency_buffer", b"concurrency_buffer", "concurrency_buffer_perc", b"concurrency_buffer_perc", "environment_name", b"environment_name", "keep_alive", b"keep_alive", "machine_types", b"machine_types", "max_concurrency", b"max_concurrency", "max_multiplexing", b"max_multiplexing", "min_concurrency", b"min_concurrency", "mounted_secrets", b"mounted_secrets", "request_timeout", b"request_timeout", "scaling_delay_seconds", b"scaling_delay_seconds", "startup_timeout", b"startup_timeout", "valid_regions", b"valid_regions"]) -> None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_concurrency_buffer", b"_concurrency_buffer"]) -> typing_extensions.Literal["concurrency_buffer"] | None: ...
@typing.overload
Expand Down
Loading