Skip to content

Commit

Permalink
Update base.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ZebinYang authored Feb 6, 2023
1 parent 706d2ca commit d115b51
Showing 1 changed file with 42 additions and 38 deletions.
80 changes: 42 additions & 38 deletions gaminet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,16 +665,16 @@ def initmaineffect(idx):
np.random.seed(self.random_state)
torch.manual_seed(self.random_state)
simu_xx = np.zeros((self.mlp_sample_size, self.n_features_))
simu_xx[:, idx] = np.random.uniform(self.min_value_[idx],
self.max_value_[idx], self.mlp_sample_size)
simu_xx[:, idx] = np.random.uniform(self.min_value_[idx].cpu().numpy(),
self.max_value_[idx].cpu().numpy(), self.mlp_sample_size)
if self.normalize:
simu_xx[:, idx] = ((simu_xx[:, idx] - self.mu_list_[idx].detach().cpu().numpy()) /
self.std_list_[idx].detach().cpu().numpy())
simu_yy = surrogate_estimator[idx](simu_xx)
self._fit_individual_subnet(simu_xx[:, [idx]], simu_yy, self.net_.main_effect_blocks.nsubnets,
self.nfeature_index_list_.index(idx), loss_fn=torch.nn.MSELoss(reduction="none"))

xgrid = np.linspace(self.min_value_[idx], self.max_value_[idx], 100)
xgrid = np.linspace(self.min_value_[idx].cpu().numpy(), self.max_value_[idx].cpu().numpy(), 100)
gam_input_grid = np.zeros((100, self.n_features_))
gam_input_grid[:, idx] = xgrid
gaminet_input_grid = torch.tensor(xgrid.reshape(-1, 1), dtype=torch.float32, device=self.device)
Expand All @@ -699,8 +699,8 @@ def initmaineffect(idx):
for idx in self.cfeature_index_list_:
i = self.cfeature_index_list_.index(idx)
simu_xx = np.zeros((self.num_classes_list_[i], self.n_features_))
simu_xx[:, idx] = np.linspace(self.min_value_[idx],
self.max_value_[idx], self.num_classes_list_[i])
simu_xx[:, idx] = np.linspace(self.min_value_[idx].cpu().numpy(),
self.max_value_[idx].cpu().numpy(), self.num_classes_list_[i])
simu_yy = surrogate_estimator[idx](simu_xx)
self.net_.main_effect_blocks.csubnets.class_bias[i].data = torch.tensor(simu_yy.reshape(-1, 1),
dtype=torch.float32, device=self.device)
Expand Down Expand Up @@ -751,36 +751,40 @@ def initinteraction(i, idx1, idx2):
simu_xx = np.zeros((self.mlp_sample_size, self.n_features_))
if idx1 in self.cfeature_index_list_:
num_classes = self.num_classes_list_[self.cfeature_index_list_.index(idx1)]
simu_xx[:, idx1] = np.random.randint(self.min_value_[idx1],
self.max_value_[idx1] + 1, self.mlp_sample_size)
simu_xx[:, idx1] = np.random.randint(self.min_value_[idx1].cpu().numpy(),
self.max_value_[idx1].cpu().numpy() + 1, self.mlp_sample_size)
x1 = torch.nn.functional.one_hot(torch.tensor(simu_xx[:, idx1]).to(torch.int64),
num_classes=num_classes).to(torch.float32).detach().cpu().numpy()
x1grid = np.linspace(self.min_value_[idx1], self.max_value_[idx1], num_classes).reshape(-1, 1)
x1grid = np.linspace(self.min_value_[idx1].cpu().numpy(),
self.max_value_[idx1].cpu().numpy(), num_classes).reshape(-1, 1)
else:
simu_xx[:, idx1] = np.random.uniform(self.min_value_[idx1],
self.max_value_[idx1], self.mlp_sample_size)
simu_xx[:, idx1] = np.random.uniform(self.min_value_[idx1].cpu().numpy(),
self.max_value_[idx1].cpu().numpy(), self.mlp_sample_size)
if self.normalize:
simu_xx[:, idx1] = ((simu_xx[:, idx1] -
self.mu_list_[idx1].detach().cpu().numpy()) /
self.std_list_[idx1].detach().cpu().numpy())
x1 = simu_xx[:, [idx1]]
x1grid = np.linspace(self.min_value_[idx1], self.max_value_[idx1], 20).reshape(-1, 1)
x1grid = np.linspace(self.min_value_[idx1].cpu().numpy(),
self.max_value_[idx1].cpu().numpy(), 20).reshape(-1, 1)
if idx2 in self.cfeature_index_list_:
num_classes = self.num_classes_list_[self.cfeature_index_list_.index(idx2)]
simu_xx[:, idx2] = np.random.randint(self.min_value_[idx2],
self.max_value_[idx2] + 1, self.mlp_sample_size)
simu_xx[:, idx2] = np.random.randint(self.min_value_[idx2].cpu().numpy(),
self.max_value_[idx2].cpu().numpy() + 1, self.mlp_sample_size)
x2 = torch.nn.functional.one_hot(torch.tensor(simu_xx[:, idx2]).to(torch.int64),
num_classes=num_classes).to(torch.float32).detach().cpu().numpy()
x2grid = np.linspace(self.min_value_[idx2], self.max_value_[idx2], num_classes).reshape(-1, 1)
x2grid = np.linspace(self.min_value_[idx2].cpu().numpy(),
self.max_value_[idx2].cpu().numpy(), num_classes).reshape(-1, 1)
else:
simu_xx[:, idx2] = np.random.uniform(self.min_value_[idx2],
self.max_value_[idx2], self.mlp_sample_size)
simu_xx[:, idx2] = np.random.uniform(self.min_value_[idx2].cpu().numpy(),
self.max_value_[idx2].cpu().numpy(), self.mlp_sample_size)
if self.normalize:
simu_xx[:, idx2] = ((simu_xx[:, idx2] -
self.mu_list_[idx2].detach().cpu().numpy()) /
self.std_list_[idx2].detach().cpu().numpy())
x2 = simu_xx[:, [idx2]]
x2grid = np.linspace(self.min_value_[idx2], self.max_value_[idx2], 20).reshape(-1, 1)
x2grid = np.linspace(self.min_value_[idx2].cpu().numpy(),
self.max_value_[idx2].cpu().numpy(), 20).reshape(-1, 1)

xx = np.hstack([x1, x2])
xx = np.hstack([xx, np.zeros((xx.shape[0],
Expand Down Expand Up @@ -946,8 +950,8 @@ def _fit_main_effect(self):

if self.monotonicity_:
mono_loss_reg = self.reg_mono * self.net_.mono_loss
simu_inputs = np.random.uniform(self.min_value_,
self.max_value_, size=(self.mono_sample_size, len(self.max_value_)))
simu_inputs = np.random.uniform(self.min_value_.cpu().numpy(),
self.max_value_.cpu().numpy(), size=(self.mono_sample_size, len(self.max_value_)))
simu_inputs = torch.tensor(simu_inputs, dtype=torch.float32, device=self.device)
self.net_(simu_inputs,
main_effect=True, interaction=False,
Expand Down Expand Up @@ -1138,8 +1142,8 @@ def _fit_interaction(self):

if self.monotonicity_ > 0:
mono_loss_reg = self.reg_mono * self.net_.mono_loss
simu_inputs = np.random.uniform(self.min_value_,
self.max_value_, size=(self.mono_sample_size, len(self.max_value_)))
simu_inputs = np.random.uniform(self.min_value_.cpu().numpy(),
self.max_value_.cpu().numpy(), size=(self.mono_sample_size, len(self.max_value_)))
simu_inputs = torch.tensor(simu_inputs, dtype=torch.float32, device=self.device)
self.net_(simu_inputs,
main_effect=True, interaction=True,
Expand Down Expand Up @@ -1306,8 +1310,8 @@ def _fine_tune_all(self):

if self.monotonicity_:
mono_loss_reg = self.reg_mono * self.net_.mono_loss
simu_inputs = np.random.uniform(self.min_value_,
self.max_value_, size=(self.mono_sample_size, len(self.max_value_)))
simu_inputs = np.random.uniform(self.min_value_.cpu().numpy(),
self.max_value_.cpu().numpy(), size=(self.mono_sample_size, len(self.max_value_)))
simu_inputs = torch.tensor(simu_inputs, dtype=torch.float32, device=self.device)
self.net_(simu_inputs,
main_effect=True, interaction=True,
Expand Down Expand Up @@ -1441,7 +1445,7 @@ def get_main_effect_raw_output(self, x):

pred = []
self.net_.eval()
x = np.asarray(x).reshape(-1, self.n_features_)
x = x.reshape(-1, self.n_features_)
xx = x if torch.is_tensor(x) else torch.from_numpy(x.astype(np.float32)).to(self.device)
batch_size = int(np.minimum(self.batch_size_inference, x.shape[0]))
data_generator = FastTensorDataLoader(xx, batch_size=batch_size, shuffle=False)
Expand Down Expand Up @@ -1471,7 +1475,7 @@ def get_interaction_raw_output(self, x):

pred = []
self.net_.eval()
x = np.asarray(x).reshape(-1, self.n_features_)
x = x.reshape(-1, self.n_features_)
xx = x if torch.is_tensor(x) else torch.from_numpy(x.astype(np.float32)).to(self.device)
batch_size = int(np.minimum(self.batch_size_inference, x.shape[0]))
data_generator = FastTensorDataLoader(xx, batch_size=batch_size, shuffle=False)
Expand Down Expand Up @@ -1505,7 +1509,7 @@ def get_aggregate_output(self, x, main_effect=True, interaction=True):

pred = []
self.net_.eval()
x = np.asarray(x).reshape(-1, self.n_features_)
x = x.reshape(-1, self.n_features_)
xx = x if torch.is_tensor(x) else torch.from_numpy(x.astype(np.float32)).to(self.device)
batch_size = int(np.minimum(self.batch_size_inference, x.shape[0]))
data_generator = FastTensorDataLoader(xx, batch_size=batch_size, shuffle=False)
Expand Down Expand Up @@ -1594,8 +1598,8 @@ def certify_mono(self, n_samples=10000):
mono_status : boolean
True means monotonicity constraint is satisfied.
"""
x = np.random.uniform(self.min_value_,
self.max_value_, size=(n_samples, self.n_features_))
x = np.random.uniform(self.min_value_.cpu().numpy(),
self.max_value_.cpu().numpy(), size=(n_samples, self.n_features_))
mono_loss = self.get_mono_loss(x)
mono_status = mono_loss <= 0
return mono_status
Expand All @@ -1613,13 +1617,13 @@ def partial_derivatives(self, feature_idx, n_samples=10000):
by default 10000.
"""
np.random.seed(self.random_state)
inputs = np.random.uniform(self.min_value_,
self.max_value_, size=(n_samples, len(self.max_value_)))
inputs = np.random.uniform(self.min_value_.cpu().numpy(),
self.max_value_.cpu().numpy(), size=(n_samples, len(self.max_value_)))
inputs = torch.tensor(inputs, dtype=torch.float32, device=self.device)
outputs = self.net_(inputs)
grad = torch.autograd.grad(outputs=torch.sum(outputs),
inputs=inputs, create_graph=True)[0].detach().numpy()
plt.scatter(inputs.detach().numpy()[:, feature_idx], grad[:, feature_idx])
inputs=inputs, create_graph=True)[0].cpu().detach().numpy()
plt.scatter(inputs.cpu().detach().numpy()[:, feature_idx], grad[:, feature_idx])
plt.axhline(0, linestyle="--", linewidth=0.5, color="red")
plt.ylabel("First-order Derivatives")
plt.xlabel(self.feature_names_[feature_idx])
Expand Down Expand Up @@ -1695,8 +1699,8 @@ def global_explain(self, main_grid_size=100, interact_grid_size=20):
feature_name = self.feature_names_[idx]
if idx in self.nfeature_index_list_:
main_effect_inputs = np.zeros((main_grid_size, self.n_features_))
main_effect_inputs[:, idx] = np.linspace(self.min_value_[idx],
self.max_value_[idx], main_grid_size)
main_effect_inputs[:, idx] = np.linspace(self.min_value_[idx].cpu().numpy(),
self.max_value_[idx].cpu().numpy(), main_grid_size)
main_effect_inputs_original = main_effect_inputs[:, [idx]]
main_effect_outputs = (self.net_.main_effect_weights.cpu().detach().numpy()[idx] *
self.net_.main_effect_switcher.cpu().detach().numpy()[idx] *
Expand Down Expand Up @@ -1753,8 +1757,8 @@ def global_explain(self, main_grid_size=100, interact_grid_size=20):
interact_input_list.append(interact_input1)
axis_extent.extend([-0.5, len(interact_input1_original) - 0.5])
else:
interact_input1 = np.array(np.linspace(self.min_value_[idx1],
self.max_value_[idx1], interact_grid_size), dtype=np.float32)
interact_input1 = np.array(np.linspace(self.min_value_[idx1].cpu().numpy(),
self.max_value_[idx1].cpu().numpy(), interact_grid_size), dtype=np.float32)
interact_input1_original = interact_input1.reshape(-1, 1)
interact_input1_ticks = []
interact_input1_labels = []
Expand All @@ -1773,8 +1777,8 @@ def global_explain(self, main_grid_size=100, interact_grid_size=20):
interact_input_list.append(interact_input2)
axis_extent.extend([-0.5, len(interact_input2_original) - 0.5])
else:
interact_input2 = np.array(np.linspace(self.min_value_[idx2],
self.max_value_[idx2], interact_grid_size), dtype=np.float32)
interact_input2 = np.array(np.linspace(self.min_value_[idx2].cpu().numpy(),
self.max_value_[idx2].cpu().numpy(), interact_grid_size), dtype=np.float32)
interact_input2_original = interact_input2.reshape(-1, 1)
interact_input2_ticks = []
interact_input2_labels = []
Expand Down

0 comments on commit d115b51

Please sign in to comment.