Skip to content
1 change: 1 addition & 0 deletions packages/reflex-base/news/6659.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Preserve runtime value types for unannotated `@rx.memo` component parameters.
53 changes: 46 additions & 7 deletions packages/reflex-base/src/reflex_base/components/memo.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,9 @@ class MemoComponentDefinition(MemoDefinition):

export_name: str
_component: _LazyBody[Component]
_runtime_param_values: dict[str, Any] = dataclasses.field(
default_factory=dict, repr=False, compare=False
)
# For passthrough wrappers built by the auto-memoize plugin: the
# ``Bare``-wrapped ``{children}`` placeholder used when rendering the memo
# body. The ``component`` keeps its ORIGINAL children so compile-time
Expand Down Expand Up @@ -756,16 +759,26 @@ def _rest_placeholder(name: str) -> RestProp:
return RestProp(_js_expr=name, _var_type=dict[str, Any])


def _var_placeholder(name: str, annotation: Any) -> Var:
def _var_placeholder(
name: str,
annotation: Any,
runtime_value: Any | None = None,
) -> Var:
"""Create a placeholder Var for a memo parameter.

Args:
name: The JavaScript identifier.
annotation: The parameter annotation.
runtime_value: Optional runtime value used to infer unannotated params.

Returns:
The placeholder Var.
"""
if _annotation_inner_type(annotation) is Any and runtime_value is not None:
runtime_type = (
runtime_value._var_type if isinstance(runtime_value, Var) else type(runtime_value)
)
return Var(_js_expr=name, _var_type=runtime_type).guess_type()
return Var(_js_expr=name, _var_type=_annotation_inner_type(annotation)).guess_type()


Expand Down Expand Up @@ -1033,12 +1046,14 @@ def finalize(
def _evaluate_memo_function(
fn: Callable[..., Any],
params: tuple[MemoParam, ...],
runtime_values: Mapping[str, Any] | None = None,
) -> Any:
"""Evaluate a memo function with placeholder vars.

Args:
fn: The function to evaluate.
params: The memo parameters.
runtime_values: Optional runtime values keyed by parameter name.

Returns:
The return value from the function.
Expand All @@ -1047,7 +1062,14 @@ def _evaluate_memo_function(
keyword_args = {}

for param in params:
placeholder = param.make_placeholder()
if param.kind is MemoParamKind.VALUE:
placeholder = _var_placeholder(
param.placeholder_name,
param.annotation,
runtime_values.get(param.name) if runtime_values is not None else None,
)
else:
placeholder = param.make_placeholder()
if param.parameter_kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
Expand Down Expand Up @@ -1299,21 +1321,26 @@ def _build_args_function(


def _evaluate_component_body(
fn: Callable[..., Any], params: tuple[MemoParam, ...]
fn: Callable[..., Any],
params: tuple[MemoParam, ...],
runtime_values: Mapping[str, Any] | None = None,
) -> Component:
"""Run a component memo's body and return its compiled component.

Args:
fn: The decorated function.
params: The analyzed memo parameters.
runtime_values: Optional runtime values keyed by parameter name.

Returns:
The wrapped component the body returned.

Raises:
TypeError: If the body does not return a component.
"""
body = _normalize_component_return(_evaluate_memo_function(fn, params))
body = _normalize_component_return(
_evaluate_memo_function(fn, params, runtime_values)
)
if body is None:
msg = (
f"Component-returning `@rx.memo` `{fn.__name__}` must return an "
Expand Down Expand Up @@ -1359,13 +1386,17 @@ def _create_component_definition(
TypeError: If the function does not return a component.
"""
params = _analyze_params(fn, for_component=True)
runtime_param_values: dict[str, Any] = {}
return MemoComponentDefinition(
fn=fn,
python_name=fn.__name__,
params=params,
source_module=source_module,
export_name=format.to_title_case(fn.__name__),
_component=_LazyBody.ready(_evaluate_component_body(fn, params)),
_component=_LazyBody(
lambda: _evaluate_component_body(fn, params, runtime_param_values)
),
_runtime_param_values=runtime_param_values,
)


Expand Down Expand Up @@ -1628,9 +1659,15 @@ def __call__(self, *children: Any, **props: Any) -> MemoComponent:

# Reading ``component`` materializes the deferred body, so ``type(...)``
# reflects the real wrapped class rather than the placeholder.
definition._runtime_param_values.clear()
definition._runtime_param_values.update(explicit_values)
try:
component_type = type(definition.component)
finally:
definition._runtime_param_values.clear()
return _get_memo_component_class(
definition.export_name,
type(definition.component),
component_type,
definition.source_module,
)._create(
children=list(children),
Expand Down Expand Up @@ -1925,16 +1962,18 @@ def memo(fn: Callable[..., Any]) -> _MemoComponentWrapper | _MemoFunctionWrapper
# where the name resolves to ``wrapper`` (already bound by first use).
definition: MemoComponentDefinition | MemoFunctionDefinition
if is_component:
runtime_param_values: dict[str, Any] = {}
definition = MemoComponentDefinition(
fn=fn,
python_name=fn.__name__,
params=params,
source_module=source_module,
export_name=format.to_title_case(fn.__name__),
_component=_LazyBody(
lambda: _evaluate_component_body(fn, params),
lambda: _evaluate_component_body(fn, params, runtime_param_values),
placeholder=Fragment.create(),
),
_runtime_param_values=runtime_param_values,
)
wrapper = _create_component_wrapper(definition)
else:
Expand Down
36 changes: 36 additions & 0 deletions tests/units/components/test_memo.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,42 @@ def soft_missing(value) -> rx.Component:
assert "`value`" in kwargs["reason"]


def test_memo_uses_first_call_value_type_for_missing_param_annotation():
"""Component memos should infer missing parameter types from the first call."""

@rx.memo
def user_card(user) -> rx.Component:
return rx.box(
rx.heading(user["name"]),
rx.text(user["email"]),
)
Comment thread
harsh21234i marked this conversation as resolved.

component = user_card(
user={"name": "Ada", "email": "ada@example.com"},
)

assert isinstance(component, MemoComponent)


def test_memo_uses_var_runtime_value_type_for_missing_param_annotation():
"""Component memos should infer missing parameter types from runtime Vars."""

@rx.memo
def user_card(user) -> rx.Component:
return rx.box(
rx.heading(user["name"]),
rx.text(user["email"]),
)

component = user_card(
user=Var(_js_expr="user", _var_type=dict),
)

assert isinstance(component, MemoComponent)
assert isinstance(component.user, Var)
assert component.user._var_type is dict


def test_memo_warns_on_missing_return_annotation():
"""A missing return annotation should default to ``rx.Component`` with a warning."""
with patch.object(console, "deprecate") as mock_deprecate:
Comment thread
harsh21234i marked this conversation as resolved.
Expand Down