Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions test/unit_test/utils_test/test_record_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import tico
import torch
from tico.utils.record_input import RecordingInput
from torch.export import export, save

from test.modules.op.add import SimpleAdd


class RecordInputTest(unittest.TestCase):
def test_args(self):
m = SimpleAdd()
inputs = m.get_example_inputs()
with RecordingInput(m) as rec:
m.eval()
m(*inputs)
captured_input = rec.captured_input

self.assertIsNotNone(captured_input)
self.assertEqual(captured_input, inputs)
tico.convert(m, captured_input)

def test_kwargs(self):
m = SimpleAdd()
inputs = m.get_example_inputs()
kwargs = {"x": inputs[0], "y": inputs[1]}
with RecordingInput(m) as rec:
m.eval()
m(**kwargs)
captured_input = rec.captured_input

self.assertIsNotNone(captured_input)
self.assertEqual(captured_input, inputs)
tico.convert(m, captured_input)

def test_args_kwargs(self):
m = SimpleAdd()
inputs = m.get_example_inputs()
args = (inputs[0],)
kwargs = {"y": inputs[1]}
with RecordingInput(m) as rec:
m.eval()
m(*args, **kwargs)
captured_input = rec.captured_input

self.assertIsNotNone(captured_input)
self.assertEqual(captured_input, inputs)
tico.convert(m, captured_input)

def test_input_to_remove(self):
m = SimpleAdd()
inputs = m.get_example_inputs()
with RecordingInput(m, input_to_remove=["x"]) as rec:
m.eval()
m(*inputs)
captured_input = rec.captured_input

self.assertIsNotNone(captured_input)
self.assertIsNone(captured_input[0]) # arg[0] = 'x'

def test_condition(self):
m = SimpleAdd()
inputs = m.get_example_inputs()
condition = lambda arg_dict: False
with RecordingInput(m, condition) as rec:
m.eval()
m(*inputs)
captured_input = rec.captured_input

self.assertEqual(captured_input, ())
89 changes: 89 additions & 0 deletions tico/utils/record_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy

import inspect
from contextlib import contextmanager
from typing import Callable, List, Optional

import torch.nn as nn


class RecordingInput:
r"""Context-manager that records the input values of model::forward()

Recording input is useful for preparing example input for torch.export

Args:
condition: lambda to provide the condition whether to record or not

For examples, if you want to capture only args["past_key_values"] is not None,
conditon = lambda args_dict: args_dict["past_key_value"] is not None

input_to_remove: list of arg names to remove

Sometimes you would like to remove some arg values to make exported graph tidy or correct
For example, "past_key_values" may be not None, but just an empty cache. Then,
input_to_remove = [ "past_key_values" ]; makes the life easy

Example::
>>> with RecordingInput(model, input_to_remove=input_to_remove) as rec:
... outputs = model.generate(
... **inputs,
... )
... captured_input = rec.captured_input
>>> circle_model = tico.convert(model, captured_input)
"""

def __init__(
self,
module: nn.Module,
condition: Callable[[dict], bool] = lambda args_dict: True,
*,
input_to_remove: Optional[List[str]] = [],
):
self.module = module
self.forward_org = module.forward
self.condition = condition
self.input_to_remove = input_to_remove
sig = inspect.signature(self.forward_org)
self.args_names = [
name for name in sig.parameters.keys() if name not in ("self", "kwargs")
]
self.captured_input = ()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.captured_input = ()
self.captured_input = None

How about initializing this with None?
Because, in some rare case, captured input could be void...

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glistening Could you check this..?


def __enter__(self):
def capture_and_forward(*args, **kwargs):
args_dict = dict(zip(self.args_names, args))
args_dict.update(kwargs)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, this works!

args_dict = dict(sig.bind(*args, **kwargs).arguments)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I am not sure it is perfectly equivalent to my original code. I have to refer to sig.bind and arguments and so on. Even it is equivalent, I found no benefit of using this API. Is there any advantage to replace 2 lines of straightforward code to your suggestion?

@glistening glistening Jul 29, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've searched the benefits of your way. Your suggestion works better in several points: 1) when the method has default value, ...

I will update as you suggested after checking more.

@glistening glistening Jul 29, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dayo09 I am investigating sig.bind stuffs.

Your suggestion binds only the explicitly passed arguments ( len() = 7 ), not all arguments.

{
 'input_ids': tensor([[    1, 21075,  7727,   550,   260, 12584, 31843,     2,     2,     2,  2]]), 
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, ..., 0, 0, 0, 0, 0]]), 
 'past_key_values': DynamicCache(), 
 'inputs_embeds': None, 
 'use_cache': True, 
 'return_dict': True, 
 'cache_position': tensor([ 0,  1,  2,  3,  4,  5,  6, ... 28, 29, 30, 31])
}

while I wanted arg_dicts has the whole positional arguments ( len() = 12 ) exactly in same order. 1

def forward(
    self,
    input_ids: Optional[torch.LongTensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[Cache] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    labels: Optional[torch.LongTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    cache_position: Optional[torch.LongTensor] = None,
    logits_to_keep: Union[int, torch.Tensor] = 0,
    **kwargs: Unpack[KwargsForCausalLM],
) -> CausalLMOutputWithPast:

I have to call bound.apply_defaults()

bound = self.sig.bind(*args, **kwargs)
bound.apply_defaults()
args_dict = dict(bound.arguments)

Now, I will check whether it works better than my original code.
I guess it will work better for non-None default arg, which did not happen in my target models. I will modify the target model and see what happens.

Footnotes

  1. If the whole arguments is not passed, torch.export complains that the number of arguments is different.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glistening I don't have a strong preference - go with what works for you 😄


def populate_args(args_dict, filter):
Comment thread
glistening marked this conversation as resolved.
Outdated
for key in filter:
args_dict.pop(key, None)
args_tuple = tuple(
args_dict.get(name, None) for name in self.args_names
)
return copy.deepcopy(args_tuple)

if self.condition(args_dict) and self.captured_input == ():
self.captured_input = populate_args(args_dict, self.input_to_remove)

return self.forward_org(*args, **kwargs)

self.module.forward = capture_and_forward
return self

def __exit__(self, exc_type, exc_value, traceback):
self.module.forward = self.forward_org