From f4ea541d0fed45c8e72bd89a8bbb822898b5a398 Mon Sep 17 00:00:00 2001 From: weiwee Date: Mon, 19 Jun 2023 12:38:04 +0800 Subject: [PATCH] add magic slot for dtensor Signed-off-by: weiwee --- python/fate/arch/context/io/model/__init__.py | 14 ------ python/fate/arch/context/io/model/file.py | 40 ---------------- python/fate/arch/context/io/model/http.py | 46 ------------------- .../fate/arch/tensor/distributed/_tensor.py | 30 ++++++++++++ 4 files changed, 30 insertions(+), 100 deletions(-) delete mode 100644 python/fate/arch/context/io/model/__init__.py delete mode 100644 python/fate/arch/context/io/model/file.py delete mode 100644 python/fate/arch/context/io/model/http.py diff --git a/python/fate/arch/context/io/model/__init__.py b/python/fate/arch/context/io/model/__init__.py deleted file mode 100644 index ae946a49c4..0000000000 --- a/python/fate/arch/context/io/model/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/python/fate/arch/context/io/model/file.py b/python/fate/arch/context/io/model/file.py deleted file mode 100644 index 386f236372..0000000000 --- a/python/fate/arch/context/io/model/file.py +++ /dev/null @@ -1,40 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json - -from ....unify import URI - - -class FileModelWriter: - def __init__(self, ctx, name: str, uri: URI) -> None: - self.ctx = ctx - self.name = name - self.path = uri.path - - def write_model(self, model): - with open(self.path, "w") as f: - json.dump(model, f) - - -class FileModelReader: - def __init__(self, ctx, name: str, uri: URI, metadata: dict) -> None: - self.name = name - self.ctx = ctx - self.uri = uri - self.metadata = metadata - - def read_model(self): - with open(self.uri.path, "r") as fin: - return json.loads(fin.read()) diff --git a/python/fate/arch/context/io/model/http.py b/python/fate/arch/context/io/model/http.py deleted file mode 100644 index 903225bd0d..0000000000 --- a/python/fate/arch/context/io/model/http.py +++ /dev/null @@ -1,46 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -import requests - -from ....unify import URI - -logger = logging.getLogger(__name__) - - -class HTTPModelWriter: - def __init__(self, ctx, name: str, uri: URI, metadata) -> None: - self.name = name - self.ctx = ctx - self.uri = uri - self.entrypoint = f"{self.uri.schema}://{self.uri.authority}{self.uri.path}" - - def write_model(self, model): - logger.debug(self.entrypoint) - response = requests.post(url=self.entrypoint, json={"data": model}) - logger.debug(response.text) - - -class HTTPModelReader: - def __init__(self, ctx, name: str, uri: URI, metadata: dict) -> None: - self.name = name - self.ctx = ctx - self.uri = uri - self.entrypoint = f"{self.uri.schema}://{self.uri.authority}{self.uri.path}" - self.metadata = metadata - - def read_model(self): - return requests.get(url=self.entrypoint).json().get("data", {}) diff --git a/python/fate/arch/tensor/distributed/_tensor.py b/python/fate/arch/tensor/distributed/_tensor.py index baf394eff1..91c093787f 100644 --- a/python/fate/arch/tensor/distributed/_tensor.py +++ b/python/fate/arch/tensor/distributed/_tensor.py @@ -47,6 +47,36 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): def __init__(self, shardings: "Shardings") -> None: self.shardings = shardings + def __add__(self, other): + return torch.add(self, other) + + def __radd__(self, other): + return torch.add(other, self) + + def __sub__(self, other): + return torch.sub(self, other) + + def __rsub__(self, other): + return torch.rsub(self, other) + + def __mul__(self, other): + return torch.mul(self, other) + + def __rmul__(self, other): + return torch.mul(other, self) + + def __truediv__(self, other): + return torch.div(self, other) + + def __rtruediv__(self, other): + return torch.div(other, self) + + def __matmul__(self, other): + return torch.matmul(self, other) + + def __rmatmul__(self, other): + return torch.matmul(other, self) + @property def shape(self): return self.shardings.shape