From 5aabaf1a3417a82643d70ab85aedb31a33363a82 Mon Sep 17 00:00:00 2001 From: Anna Sun <13106449+annasun28@users.noreply.github.com> Date: Wed, 11 Oct 2023 13:51:54 -0700 Subject: [PATCH] add fp16 and upstream states for tree pipeline (#80) --- simuleval/agents/pipeline.py | 12 +++++++++--- simuleval/options.py | 3 +++ simuleval/utils/agent.py | 4 ++-- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/simuleval/agents/pipeline.py b/simuleval/agents/pipeline.py index 47ad297a..5489729e 100644 --- a/simuleval/agents/pipeline.py +++ b/simuleval/agents/pipeline.py @@ -269,6 +269,7 @@ def push_impl( module: GenericAgent, segment: Segment, states: Optional[Dict[GenericAgent, AgentStates]], + upstream_states: Dict[int, AgentStates], ): # DFS over the tree children = self.module_dict[module] @@ -276,10 +277,12 @@ def push_impl( module.push(segment, states[module]) return [] - segment = module.pushpop(segment, states[module]) + segment = module.pushpop(segment, states[module], upstream_states) + assert len(upstream_states) not in upstream_states + upstream_states[len(upstream_states)] = states[module] for child in children: - self.push_impl(child, segment, states) + self.push_impl(child, segment, states, upstream_states) def pushpop( self, @@ -300,8 +303,11 @@ def push( states = {module: None for module in self.module_dict} else: assert len(states) == len(self.module_dict) + + if upstream_states is None: + upstream_states = {} - self.push_impl(self.source_module, segment, states) + self.push_impl(self.source_module, segment, states, upstream_states) def pop( self, states: Optional[Dict[GenericAgent, AgentStates]] = None diff --git a/simuleval/options.py b/simuleval/options.py index 2659785f..8eced511 100644 --- a/simuleval/options.py +++ b/simuleval/options.py @@ -178,6 +178,9 @@ def general_parser(): parser.add_argument( "--device", type=str, default="cpu", help="Device to run the model." ) + parser.add_argument( + "--fp16", action="store_true", default=False, help="Use fp16." + ) return parser diff --git a/simuleval/utils/agent.py b/simuleval/utils/agent.py index 19a7354b..81434fd1 100644 --- a/simuleval/utils/agent.py +++ b/simuleval/utils/agent.py @@ -141,8 +141,8 @@ def build_system_args( args = parser.parse_args(cli_argument_list(config_dict)) - logger.info(f"System will run on device: {args.device}.") - system.to(args.device) + logger.info(f"System will run on device: {args.device}. fp16: {args.fp16}") + system.to(args.device, fp16=args.fp16) args.source_type = system.source_type args.target_type = system.target_type