Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev 2.0.0 rc nn update #5332

Merged
merged 2 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fate_client
Submodule fate_client updated 0 files
2 changes: 2 additions & 0 deletions python/fate/ml/nn/hetero/hetero_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def prediction_step(
# (features, labels), this format is used in FATE-1.x
# now the model is in eval status
if isinstance(inputs, tuple) or isinstance(inputs, list):
inputs = self._prepare_inputs(inputs)
with torch.no_grad():
if len(inputs) == 2: # data & label
feats, labels = inputs
Expand Down Expand Up @@ -174,6 +175,7 @@ def prediction_step(
):
# (features, labels), this format is used in FATE-1.x
# now the model is in eval status
inputs = self._prepare_inputs(inputs)
if isinstance(inputs, torch.Tensor):
feats = inputs
elif isinstance(inputs, tuple) or isinstance(inputs, list):
Expand Down
15 changes: 14 additions & 1 deletion python/fate/ml/nn/model_zoo/agg_layer/agg_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self):
self._has_ctx = False
self._model = None
self.training = True
self.device = None

def forward(self, x):
raise NotImplementedError()
Expand All @@ -38,8 +39,13 @@ def predict(self, x):
def set_context(self, ctx: Context):
self._ctx = ctx
self._has_ctx = True

def has_context(self):
return self._has_ctx

def set_device(self, device):
self.device = device

@property
def ctx(self):
if self._ctx is None or self._has_ctx == False:
Expand All @@ -62,6 +68,7 @@ def eval(self: T) -> T:


class AggLayerGuest(_AggLayerBase):

def __init__(self, merge_type: Literal['sum', 'concat'] = 'sum', concat_dim = 1):
super(AggLayerGuest, self).__init__()
self._host_input_caches = None
Expand All @@ -77,6 +84,11 @@ def _forward(self, x_g: t.Tensor = None, x_h: List[t.Tensor] = None) -> t.Tensor

if x_g is None and x_h is None:
raise ValueError("guest input and host inputs cannot be both None")

if x_g is not None:
x_g = x_g.to(self.device)
if x_h is not None:
x_h = [h.to(self.device) for h in x_h]

can_cat = True
if x_g is None:
Expand Down Expand Up @@ -184,7 +196,7 @@ def forward(self, x: t.Tensor) -> None:
if self.training:
assert isinstance(x, t.Tensor), 'x should be a tensor'
if self._model is not None:
self._input_cache = t.from_numpy(x.detach().numpy()).requires_grad_(True)
self._input_cache = t.from_numpy(x.cpu().detach().numpy()).to(self.device).requires_grad_(True)
out_ = self._model(self._input_cache)
self._out_cache = out_
else:
Expand All @@ -197,6 +209,7 @@ def backward(self, error=None) -> t.Tensor:

error = self._get_error_from_guest()
if self._input_cache is not None and self._model is not None:
error = error.to(self.device)
loss = backward_loss(self._out_cache, error)
loss.backward()
error = self._input_cache.grad
Expand Down
17 changes: 17 additions & 0 deletions python/fate/ml/nn/model_zoo/hetero_nn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ def __init__(self):
self._top_model = None
self._agg_layer = None
self._ctx = None
self.device = None

def _auto_setup(self):
self._agg_layer = AggLayerGuest()
self._agg_layer.set_context(self._ctx)

def get_device(self, module):
return next(module.parameters()).device


class HeteroNNModelGuest(HeteroNNModelBase):
Expand Down Expand Up @@ -230,6 +238,10 @@ def forward(self, x = None):
if self._agg_layer is None:
self._auto_setup()

if self.device is None:
self.device = self.get_device(self._top_model)
self._agg_layer.set_device(self.device)

if self._bottom_model is None:
b_out = None
else:
Expand Down Expand Up @@ -338,6 +350,10 @@ def forward(self, x):
if self._agg_layer is None:
self._auto_setup()

if self.device is None:
self.device = self.get_device(self._bottom_model)
self._agg_layer.set_device(self.device)

b_out = self._bottom_model(x)
# bottom layer
self._bottom_fw = b_out
Expand All @@ -356,6 +372,7 @@ def backward(self):
self._clear_state()
else:
error = self._agg_layer.backward()
error = error.to(self.device)
loss = backward_loss(self._bottom_fw, error)
loss.backward()
self._clear_state()
Expand Down
44 changes: 28 additions & 16 deletions python/fate/ml/nn/test/test_fedpass_lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def create_ctx(local, context_name):
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# init fate context
computing = CSession()
computing = CSession(data_dir='./session_dir')
return Context(
computing=computing, federation=StandaloneFederation(computing, context_name, local, [guest, host])
)
Expand Down Expand Up @@ -189,22 +189,30 @@ def set_seed(seed):
train=True,
download=True,
transform=torchvision.transforms.ToTensor())

test_data = torchvision.datasets.MNIST(root='./mnist',
train=False,
download=True,
transform=torchvision.transforms.ToTensor())

digit_indices = [[] for _ in range(10)]
for idx, (_, label) in enumerate(train_data):
digit_indices[label].append(idx)
# digit_indices = [[] for _ in range(10)]
# for idx, (_, label) in enumerate(train_data):
# digit_indices[label].append(idx)

selected_train_indices = []
for indices in digit_indices:
selected_train_indices.extend(torch.randperm(len(indices))[:300].tolist())
# selected_train_indices = []
# for indices in digit_indices:
# selected_train_indices.extend(torch.randperm(len(indices))[:10000].tolist())

selected_val_indices = []
for indices in digit_indices:
remaining_indices = [idx for idx in indices if idx not in selected_train_indices]
selected_val_indices.extend(torch.randperm(len(remaining_indices))[:100].tolist())
# selected_val_indices = []
# for indices in digit_indices:
# remaining_indices = [idx for idx in indices if idx not in selected_train_indices]
# selected_val_indices.extend(torch.randperm(len(remaining_indices))[:1000].tolist())

subset_train_data = torch.utils.data.Subset(train_data, selected_train_indices)
subset_val_data = torch.utils.data.Subset(train_data, selected_val_indices)
# subset_train_data = torch.utils.data.Subset(train_data, selected_train_indices)
# subset_val_data = torch.utils.data.Subset(train_data, selected_val_indices)

subset_train_data = train_data
subset_val_data = test_data

epochs = 10

Expand All @@ -220,9 +228,11 @@ def __len__(self):
def __getitem__(self, item):
return [self.ds[item][1]]

arg = TrainingArguments(num_train_epochs=20, per_device_train_batch_size=16, disable_tqdm=False,
arg = TrainingArguments(num_train_epochs=10, per_device_train_batch_size=512, disable_tqdm=False,
per_gpu_eval_batch_size=512,
eval_steps=1,
evaluation_strategy='epoch'
evaluation_strategy='epoch',
no_cuda=False
)

if party == 'guest':
Expand Down Expand Up @@ -264,6 +274,8 @@ def __getitem__(self, item):
ctx = create_ctx(host, get_current_datetime_str())

bottom_model = LeNetBottom()
passport_mode = 'multi'
print('passport mode is {}'.format(passport_mode))
model = HeteroNNModelHost(
bottom_model=bottom_model,
agglayer_arg=FedPassArgument(
Expand All @@ -272,7 +284,7 @@ def __getitem__(self, item):
out_channels_or_features=16,
kernel_size=(5, 5),
stride=(1, 1),
passport_mode='multi',
passport_mode=passport_mode,
activation='relu',
num_passport=64
)
Expand Down
1 change: 1 addition & 0 deletions python/fate/ml/nn/trainer/trainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,7 @@ def prediction_step(
else:
# (features, labels), this format is used in FATE-1.x
# now the model is in eval status
inputs = self._prepare_inputs(inputs)
if isinstance(inputs, tuple) or isinstance(inputs, list) and len(inputs) == 2:
with torch.no_grad():
feats, labels = inputs
Expand Down