-
Notifications
You must be signed in to change notification settings - Fork 31
[utils] Add forward's input recorder #249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import unittest | ||
|
|
||
| import tico | ||
| import torch | ||
| from tico.utils.record_input import RecordingInput | ||
| from torch.export import export, save | ||
|
|
||
| from test.modules.op.add import SimpleAdd | ||
|
|
||
|
|
||
| class RecordInputTest(unittest.TestCase): | ||
| def test_args(self): | ||
| m = SimpleAdd() | ||
| inputs = m.get_example_inputs() | ||
| with RecordingInput(m) as rec: | ||
| m.eval() | ||
| m(*inputs) | ||
| captured_input = rec.captured_input | ||
|
|
||
| self.assertIsNotNone(captured_input) | ||
| self.assertEqual(captured_input, inputs) | ||
| tico.convert(m, captured_input) | ||
|
|
||
| def test_kwargs(self): | ||
| m = SimpleAdd() | ||
| inputs = m.get_example_inputs() | ||
| kwargs = {"x": inputs[0], "y": inputs[1]} | ||
| with RecordingInput(m) as rec: | ||
| m.eval() | ||
| m(**kwargs) | ||
| captured_input = rec.captured_input | ||
|
|
||
| self.assertIsNotNone(captured_input) | ||
| self.assertEqual(captured_input, inputs) | ||
| tico.convert(m, captured_input) | ||
|
|
||
| def test_args_kwargs(self): | ||
| m = SimpleAdd() | ||
| inputs = m.get_example_inputs() | ||
| args = (inputs[0],) | ||
| kwargs = {"y": inputs[1]} | ||
| with RecordingInput(m) as rec: | ||
| m.eval() | ||
| m(*args, **kwargs) | ||
| captured_input = rec.captured_input | ||
|
|
||
| self.assertIsNotNone(captured_input) | ||
| self.assertEqual(captured_input, inputs) | ||
| tico.convert(m, captured_input) | ||
|
|
||
| def test_input_to_remove(self): | ||
| m = SimpleAdd() | ||
| inputs = m.get_example_inputs() | ||
| with RecordingInput(m, input_to_remove=["x"]) as rec: | ||
| m.eval() | ||
| m(*inputs) | ||
| captured_input = rec.captured_input | ||
|
|
||
| self.assertIsNotNone(captured_input) | ||
| self.assertIsNone(captured_input[0]) # arg[0] = 'x' | ||
|
|
||
| def test_condition(self): | ||
| m = SimpleAdd() | ||
| inputs = m.get_example_inputs() | ||
| condition = lambda arg_dict: False | ||
| with RecordingInput(m, condition) as rec: | ||
| m.eval() | ||
| m(*inputs) | ||
| captured_input = rec.captured_input | ||
|
|
||
| self.assertEqual(captured_input, ()) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import copy | ||
|
|
||
| import inspect | ||
| from contextlib import contextmanager | ||
| from typing import Callable, List, Optional | ||
|
|
||
| import torch.nn as nn | ||
|
|
||
|
|
||
| class RecordingInput: | ||
| r"""Context-manager that records the input values of model::forward() | ||
|
|
||
| Recording input is useful for preparing example input for torch.export | ||
|
|
||
| Args: | ||
| condition: lambda to provide the condition whether to record or not | ||
|
|
||
| For examples, if you want to capture only args["past_key_values"] is not None, | ||
| conditon = lambda args_dict: args_dict["past_key_value"] is not None | ||
|
|
||
| input_to_remove: list of arg names to remove | ||
|
|
||
| Sometimes you would like to remove some arg values to make exported graph tidy or correct | ||
| For example, "past_key_values" may be not None, but just an empty cache. Then, | ||
| input_to_remove = [ "past_key_values" ]; makes the life easy | ||
|
|
||
| Example:: | ||
| >>> with RecordingInput(model, input_to_remove=input_to_remove) as rec: | ||
| ... outputs = model.generate( | ||
| ... **inputs, | ||
| ... ) | ||
| ... captured_input = rec.captured_input | ||
| >>> circle_model = tico.convert(model, captured_input) | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| module: nn.Module, | ||
| condition: Callable[[dict], bool] = lambda args_dict: True, | ||
| *, | ||
| input_to_remove: Optional[List[str]] = [], | ||
| ): | ||
| self.module = module | ||
| self.forward_org = module.forward | ||
| self.condition = condition | ||
| self.input_to_remove = input_to_remove | ||
| sig = inspect.signature(self.forward_org) | ||
| self.args_names = [ | ||
| name for name in sig.parameters.keys() if name not in ("self", "kwargs") | ||
| ] | ||
| self.captured_input = () | ||
|
|
||
| def __enter__(self): | ||
| def capture_and_forward(*args, **kwargs): | ||
| args_dict = dict(zip(self.args_names, args)) | ||
| args_dict.update(kwargs) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI, this works!
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, I am not sure it is perfectly equivalent to my original code. I have to refer to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've searched the benefits of your way. Your suggestion works better in several points: 1) when the method has default value, ... I will update as you suggested after checking more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @dayo09 I am investigating Your suggestion binds only the explicitly passed arguments ( len() = 7 ), not all arguments. while I wanted def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> CausalLMOutputWithPast:I have to call bound.apply_defaults() bound = self.sig.bind(*args, **kwargs)
bound.apply_defaults()
args_dict = dict(bound.arguments)Now, I will check whether it works better than my original code. Footnotes
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @glistening I don't have a strong preference - go with what works for you 😄 |
||
|
|
||
| def populate_args(args_dict, filter): | ||
|
glistening marked this conversation as resolved.
Outdated
|
||
| for key in filter: | ||
| args_dict.pop(key, None) | ||
| args_tuple = tuple( | ||
| args_dict.get(name, None) for name in self.args_names | ||
| ) | ||
| return copy.deepcopy(args_tuple) | ||
|
|
||
| if self.condition(args_dict) and self.captured_input == (): | ||
| self.captured_input = populate_args(args_dict, self.input_to_remove) | ||
|
|
||
| return self.forward_org(*args, **kwargs) | ||
|
|
||
| self.module.forward = capture_and_forward | ||
| return self | ||
|
|
||
| def __exit__(self, exc_type, exc_value, traceback): | ||
| self.module.forward = self.forward_org | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about initializing this with None?
Because, in some rare case, captured input could be void...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@glistening Could you check this..?