diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 1f6b05e8631a8..51616cb0fdb44 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -45,7 +45,7 @@ class GraphCaptureContext: def _split_tensor_dict( - tensor_dict: Dict[Any, Union[torch.Tensor, Any]], + tensor_dict: Dict[str, Union[torch.Tensor, Any]], prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: """Split the tensor dictionary into two parts: 1. A list of (key, value) pairs. If the value is a tensor, it is replaced @@ -473,11 +473,11 @@ def recv_object(self, src: int) -> Any: def broadcast_tensor_dict( self, - tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, + tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, src: int = 0, group: Optional[ProcessGroup] = None, metadata_group: Optional[ProcessGroup] = None - ) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: """Broadcast the input tensor dictionary. NOTE: `src` is the local rank of the source rank. """ @@ -558,9 +558,9 @@ def broadcast_tensor_dict( def send_tensor_dict( self, - tensor_dict: Dict[Any, Union[torch.Tensor, Any]], + tensor_dict: Dict[str, Union[torch.Tensor, Any]], dst: Optional[int] = None - ) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: """Send the input tensor dictionary. NOTE: `dst` is the local rank of the source rank. """ @@ -599,7 +599,7 @@ def send_tensor_dict( def recv_tensor_dict( self, src: Optional[int] = None - ) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: """Recv the input tensor dictionary. NOTE: `src` is the local rank of the source rank. """ @@ -615,7 +615,7 @@ def recv_tensor_dict( assert src < self.world_size, f"Invalid src rank ({src})" recv_metadata_list = self.recv_object(src=src) - tensor_dict = {} + tensor_dict: Dict[str, Any] = {} for key, value in recv_metadata_list: if isinstance(value, TensorMetadata): tensor = torch.empty(value.size, @@ -623,7 +623,7 @@ def recv_tensor_dict( device=value.device) if tensor.numel() == 0: # Skip broadcasting empty tensors. - tensor_dict[key] = tensor + _update_nested_dict(tensor_dict, key, tensor) continue if tensor.is_cpu: # use metadata_group for CPU tensors @@ -633,9 +633,9 @@ def recv_tensor_dict( else: # use group for GPU tensors torch.distributed.recv(tensor, src=src, group=group) - tensor_dict[key] = tensor + _update_nested_dict(tensor_dict, key, tensor) else: - tensor_dict[key] = value + _update_nested_dict(tensor_dict, key, value) return tensor_dict def barrier(self):