Skip to content

Commit

Permalink
Remove redundant code for osx
Browse files Browse the repository at this point in the history
Signed-off-by: dylan-fan <289765648@qq.com>
  • Loading branch information
dylan-fan committed Dec 18, 2023
1 parent 6a7fd19 commit 3677dc9
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 137 deletions.
77 changes: 23 additions & 54 deletions python/fate/arch/federation/backends/osx/_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from ._mq_channel import MQChannel

LOGGER = getLogger(__name__)
# default message max size in bytes = 1MB


class MQ(object):
Expand Down Expand Up @@ -51,13 +50,13 @@ def __str__(self) -> str:
class OSXFederation(MessageQueueBasedFederation):
@staticmethod
def from_conf(
federation_session_id: str,
computing_session,
party: PartyMeta,
parties: typing.List[PartyMeta],
host: str,
port: int,
max_message_size: typing.Optional[int] = None,
federation_session_id: str,
computing_session,
party: PartyMeta,
parties: typing.List[PartyMeta],
host: str,
port: int,
max_message_size: typing.Optional[int] = None,
):
mq = MQ(host, port)

Expand All @@ -71,13 +70,13 @@ def from_conf(
)

def __init__(
self,
federation_session_id,
computing_session,
party: PartyMeta,
parties: typing.List[PartyMeta],
max_message_size,
mq,
self,
federation_session_id,
computing_session,
party: PartyMeta,
parties: typing.List[PartyMeta],
max_message_size,
mq,
):
super().__init__(
session_id=federation_session_id,
Expand Down Expand Up @@ -123,7 +122,14 @@ def _maybe_create_topic_and_replication(self, party, topic_suffix):
return topic_pair

def _get_channel(
self, topic_pair: _TopicPair, src_party_id, src_role, dst_party_id, dst_role, mq: MQ, conf: dict = None
self,
topic_pair,
src_party_id,
src_role,
dst_party_id,
dst_role,
mq=None,
conf: dict = None
):
LOGGER.debug(
f"_get_channel, topic_pari={topic_pair}, src_party_id={src_party_id}, src_role={src_role}, dst_party_id={dst_party_id}, dst_role={dst_role}"
Expand All @@ -142,41 +148,6 @@ def _get_channel(

_topic_ip_map = {}

# @nretry
# def _query_receive_topic(self, channel_info):
# # LOGGER.debug(f"_query_receive_topic, channel_info={channel_info}")
# # topic = channel_info._receive_topic
# # if topic not in self._topic_ip_map:
# # LOGGER.info(f"query topic {topic} miss cache ")
# # response = channel_info.query()
# # if response.code == "0":
# # topic_info = osx_pb2.TopicInfo()
# # topic_info.ParseFromString(response.payload)
# # self._topic_ip_map[topic] = (topic_info.ip, topic_info.port)
# # LOGGER.info(f"query result {topic} {topic_info}")
# # else:
# # raise LookupError(f"{response}")
# # host, port = self._topic_ip_map[topic]
# #
# # new_channel_info = channel_info
# # if channel_info._host != host or channel_info._port != port:
# # LOGGER.info(
# # f"channel info missmatch, host: {channel_info._host} vs {host} and port: {channel_info._port} vs {port}"
# # )
# # new_channel_info = MQChannel(
# # host=host,
# # port=port,
# # namespace=channel_info._namespace,
# # send_topic=channel_info._send_topic,
# # receive_topic=channel_info._receive_topic,
# # src_party_id=channel_info._src_party_id,
# # src_role=channel_info._src_role,
# # dst_party_id=channel_info._dst_party_id,
# # dst_role=channel_info._dst_role,
# # )
# # return new_channel_info
# return channel_info;

def _get_consume_message(self, channel_info):
LOGGER.debug(f"_get_comsume_message, channel_info={channel_info}")
while True:
Expand All @@ -186,7 +157,7 @@ def _get_consume_message(self, channel_info):
raise LookupError(f"{response}")
message = osx_pb2.Message()
message.ParseFromString(response.payload)
# offset = response.metadata["MessageOffSet"]

head_str = str(message.head, encoding="utf-8")
# LOGGER.debug(f"head str {head_str}")
properties = json.loads(head_str)
Expand All @@ -196,5 +167,3 @@ def _get_consume_message(self, channel_info):

def _consume_ack(self, channel_info, id):
return
# LOGGER.debug(f"_comsume_ack, channel_info={channel_info}, id={id}")
# channel_info.ack(offset=id)
105 changes: 23 additions & 82 deletions python/fate/arch/federation/backends/osx/_mq_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,21 @@


import json
from logging import getLogger
import time
from enum import Enum
from logging import getLogger
from typing import Dict, List, Any
import time

import grpc
import numpy as np

from fate.arch.federation.backends.osx import osx_pb2
from fate.arch.federation.backends.osx.osx_pb2_grpc import PrivateTransferTransportStub
import numpy as np

LOGGER = getLogger(__name__)

class Metadata(Enum):


class Metadata(Enum):
PTP_VERSION = "x-ptp-version"
PTP_TECH_PROVIDER_CODE = "x-ptp-tech-provider-code"
PTP_TRACE_ID = "x-ptp-trace-id"
Expand Down Expand Up @@ -56,14 +57,16 @@ def append(self, attachments: List[Any], v: str):
if attachments is not None and '' != v:
attachments.append((self.key(), v))

def build_trace_id():

def build_trace_id():
timestamp = int(time.time())
timestamp_str = str(timestamp)
return timestamp_str+"_"+str(np.random.randint(10000))
return timestamp_str + "_" + str(np.random.randint(10000))


class MQChannel(object):
def __init__(
self, host, port, namespace, send_topic, receive_topic, src_party_id, src_role, dst_party_id, dst_role
self, host, port, namespace, send_topic, receive_topic, src_party_id, src_role, dst_party_id, dst_role
):
self._host = host
self._port = port
Expand All @@ -82,11 +85,6 @@ def __init__(
def __str__(self):
return f"<MQChannel namespace={self._namespace}, host={self._host},port={self._port}, src=({self._src_role}, {self._src_party_id}), dst=({self._dst_role}, {self._dst_party_id}), send_topic={self._send_topic}, receive_topic={self._receive_topic}>"






def prepare_metadata_consume(self):
metadata = []
# Metadata.PTP_TRACE_ID.append(metadata, )
Expand All @@ -98,95 +96,41 @@ def prepare_metadata_consume(self):
Metadata.PTP_FROM_NODE_ID.append(metadata, str(self._src_party_id))
# Metadata.PTP_TOPIC.append(metadata,str(self._receive_topic))
Metadata.PTP_TECH_PROVIDER_CODE.append(metadata, "FATE")
Metadata.PTP_TRACE_ID.append(metadata,build_trace_id())
Metadata.PTP_TRACE_ID.append(metadata, build_trace_id())
return metadata;

def prepare_metadata(self,):
def prepare_metadata(self, ):
metadata = []
Metadata.PTP_TRACE_ID.append(metadata,build_trace_id() )
Metadata.PTP_TRACE_ID.append(metadata, build_trace_id())
if not self._namespace is None:
Metadata.PTP_SESSION_ID.append(metadata,self._namespace)
Metadata.PTP_SESSION_ID.append(metadata, self._namespace)
if not self._dst_party_id is None:
Metadata.PTP_TARGET_NODE_ID.append(metadata,str(self._dst_party_id))
Metadata.PTP_TARGET_NODE_ID.append(metadata, str(self._dst_party_id))
if not self._src_party_id is None:
Metadata.PTP_FROM_NODE_ID.append(metadata, str(self._src_party_id))
# Metadata.PTP_TOPIC.append(metadata,str(self._receive_topic))
Metadata.PTP_TECH_PROVIDER_CODE.append(metadata,"FATE")
Metadata.PTP_TECH_PROVIDER_CODE.append(metadata, "FATE")
return metadata;


# @nretry
def consume(self):

self._get_or_create_channel()
# meta = dict(
# MessageTopic=self._receive_topic,
# TechProviderCode="FATE",
# SourceNodeID=self._src_party_id,
# TargetNodeID=self._dst_party_id,
# TargetComponentName=self._dst_role,
# SourceComponentName=self._src_role,
# TargetMethod="CONSUME_MSG",
# SessionID=self._namespace,
# MessageOffSet=str(offset),
# )
# inbound = osx_pb2.Inbound(metadata=meta)
# LOGGER.debug(f"consume, inbound={inbound}, mq={self}")
# result = self._stub.invoke(inbound)
inbound = osx_pb2.PopInbound(topic=self._receive_topic,timeout=36000000)
inbound = osx_pb2.PopInbound(topic=self._receive_topic, timeout=36000000)
metadata = self.prepare_metadata_consume();
result = self._stub.pop(request=inbound,metadata=metadata)
result = self._stub.pop(request=inbound, metadata=metadata)
# LOGGER.debug(f"consume, result={result.code}, mq={self}")
return result

# @nretry
# def query(self):
# LOGGER.debug(f"query, mq={self}")
# self._get_or_create_channel()
# meta = dict(
# MessageTopic=self._receive_topic,
# TechProviderCode="FATE",
# SourceNodeID=self._src_party_id,
# TargetNodeID=self._dst_party_id,
# TargetComponentName=self._dst_role,
# SourceComponentName=self._src_role,
# TargetMethod="QUERY_TOPIC",
# SessionID=self._namespace,
# )
# inbound = osx_pb2.Inbound(metadata=meta)
# LOGGER.debug(f"query, inbound={inbound}, mq={self}")
# result = self._stub.invoke(inbound)
# LOGGER.debug(f"query, result={result}, mq={self}")
# return result

# @nretry
def produce(self, body, properties):
# LOGGER.debug(f"produce body={body}, properties={properties}, mq={self}")
self._get_or_create_channel()
# meta = dict(
# MessageTopic=self._send_topic,
# TechProviderCode="FATE",
# SourceNodeID=self._src_party_id,
# TargetNodeID=self._dst_party_id,
# TargetComponentName=self._dst_role,
# SourceComponentName=self._src_role,
# TargetMethod="PRODUCE_MSG",
# SessionID=self._namespace,
# )
# msg = osx_pb2.Message(head=bytes(json.dumps(properties), encoding="utf-8"), body=body)
# inbound = osx_pb2.Inbound(metadata=meta, payload=msg.SerializeToString())
# LOGGER.debug(f"produce inbound={inbound}, mq={self}")
msg = osx_pb2.Message(head=bytes(json.dumps(properties), encoding="utf-8"), body=body)
inbound = osx_pb2.PushInbound(topic= self._send_topic,payload=msg.SerializeToString())
inbound = osx_pb2.PushInbound(topic=self._send_topic, payload=msg.SerializeToString())
metadata = self.prepare_metadata()

# result = self._stub.push(inbound)
result = self._stub.push(inbound,metadata=metadata)
# print(f"produce result {result}")
# LOGGER.debug(f"produce {self._receive_topic} index {self._index} result={result.code}, mq={self}")
# if result.code!="0":
# raise RuntimeError(f"produce msg error ,code : {result.code} msg : {result.message}")
# self._index+=1
result = self._stub.push(inbound, metadata=metadata)

return result

# @nretry
Expand All @@ -198,9 +142,8 @@ def cleanup(self):
self._get_or_create_channel()
inbound = osx_pb2.ReleaseInbound()
metadata = self.prepare_metadata()
# result = self._stub.push(inbound)
result = self._stub.release(inbound,metadata=metadata)

result = self._stub.release(inbound, metadata=metadata)

def cancel(self):
LOGGER.debug(f"cancel channel")
Expand Down Expand Up @@ -262,6 +205,4 @@ def _check_alive(self):
timestamp = int(time.time())
timestamp_str = str(timestamp)


print(build_trace_id())

Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def _get_channel(
src_role,
dst_party_id,
dst_role,
mq,
mq=None,
conf: dict = None,
):
LOGGER.debug(
Expand Down

0 comments on commit 3677dc9

Please sign in to comment.