Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Zebin Yang committed Apr 18, 2022
1 parent 964550d commit c2de441
Show file tree
Hide file tree
Showing 13 changed files with 6,126 additions and 2,082 deletions.
1,080 changes: 666 additions & 414 deletions examples/FicoHeloc.ipynb

Large diffs are not rendered by default.

450 changes: 261 additions & 189 deletions examples/GAMINet-bike-share.ipynb

Large diffs are not rendered by default.

550 changes: 336 additions & 214 deletions examples/GAMINet-demo.ipynb

Large diffs are not rendered by default.

778 changes: 778 additions & 0 deletions examples/cocircle.ipynb

Large diffs are not rendered by default.

932 changes: 932 additions & 0 deletions examples/friedman.ipynb

Large diffs are not rendered by default.

Binary file removed examples/simu_dict.pickle
Binary file not shown.
Binary file removed examples/simu_model.pickle
Binary file not shown.
1,195 changes: 1,005 additions & 190 deletions examples/twiwan credit.ipynb

Large diffs are not rendered by default.

489 changes: 374 additions & 115 deletions gaminet/api.py

Large diffs are not rendered by default.

2,375 changes: 1,516 additions & 859 deletions gaminet/base.py

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion gaminet/interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sklearn.utils.validation import check_is_fitted
from sklearn.base import BaseEstimator, TransformerMixin
from contextlib import AbstractContextManager
import pkg_resources

try:
from pandas.api.types import is_numeric_dtype, is_string_dtype
Expand All @@ -20,6 +21,9 @@
log = logging.getLogger(__name__)


# All the codes in this file are from Interpretml by Microsoft.


def autogen_schema(X, ordinal_max_items=2, feature_names=None, feature_types=None):
""" Generates data schema for a given dataset as JSON representable.
Args:
Expand Down Expand Up @@ -918,7 +922,8 @@ def _get_ebm_lib_path(debug=False):
bitsize = struct.calcsize("P") * 8
is_64_bit = bitsize == 64

script_path = os.path.dirname(os.path.abspath(__file__))
# script_path = os.path.dirname(os.path.abspath(__file__))
script_path = pkg_resources.resource_filename('piml','models/ebm_module/ebm')
package_path = script_path # os.path.join(script_path, "..", "..")

debug_str = "" # "_debug" if debug else ""
Expand Down
101 changes: 63 additions & 38 deletions gaminet/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self, n_subnets, subnet_arch, n_input_nodes, activation_func, devic
self.activation_func = activation_func
self.n_hidden_layers = len(subnet_arch)

all_biases = []
all_biases = []
all_weights = []
n_hidden_nodes_prev = n_input_nodes
for i, n_hidden_nodes in enumerate(subnet_arch + [1]):
Expand Down Expand Up @@ -44,18 +44,20 @@ def individual_forward(self, inputs, idx):

xs = inputs
for i in range(self.n_hidden_layers):
xs = self.activation_func(torch.matmul(xs, self.all_weights[i][idx]) + self.all_biases[i][idx])
xs = self.activation_func(torch.matmul(xs, self.all_weights[i][idx]) +
self.all_biases[i][idx])
outputs = torch.matmul(xs, self.all_weights[-1][idx]) + self.all_biases[-1][idx]
return outputs

def forward(self, inputs):

xs = inputs
for i in range(self.n_hidden_layers):
xs = self.activation_func(torch.matmul(xs, self.all_weights[i])
+ torch.reshape(self.all_biases[i], [self.n_subnets, 1, -1]))
xs = self.activation_func(torch.matmul(xs, self.all_weights[i]) +
torch.reshape(self.all_biases[i], [self.n_subnets, 1, -1]))

outputs = torch.matmul(xs, self.all_weights[-1]) + torch.reshape(self.all_biases[-1], [self.n_subnets, 1, -1])
outputs = (torch.matmul(xs, self.all_weights[-1]) +
torch.reshape(self.all_biases[-1], [self.n_subnets, 1, -1]))
outputs = torch.squeeze(torch.transpose(outputs, 0, 1), dim=2)
return outputs

Expand Down Expand Up @@ -84,7 +86,7 @@ def forward(self, inputs, sample_weight=None, training=False):
output = []
for i in range(len(self.num_classes_list)):
dummy = torch.nn.functional.one_hot(inputs[:, i].to(torch.int64),
num_classes=self.num_classes_list[i]).to(torch.float)
num_classes=self.num_classes_list[i]).to(torch.float)
output.append(torch.matmul(dummy, self.class_bias[i]) + self.global_bias[i])
output = torch.squeeze(torch.hstack(output))
return output
Expand All @@ -104,21 +106,23 @@ def __init__(self, nfeature_index_list, cfeature_index_list, num_classes_list,
self.activation_func = activation_func

if len(self.nfeature_index_list) > 0:
self.nsubnets = TensorLayer(len(nfeature_index_list), subnet_arch, 1, activation_func, device)
self.nsubnets = TensorLayer(len(nfeature_index_list), subnet_arch,
1, activation_func, device)
if len(self.cfeature_index_list) > 0:
self.csubnets = UnivariateOneHotEncodingLayer(num_classes_list, device)

def forward(self, inputs):

output = torch.zeros(size=(inputs.shape[0], inputs.shape[1]), dtype=torch.float)
if len(self.nfeature_index_list) > 0:
ntensor_inputs = torch.unsqueeze(torch.transpose(inputs[:, self.nfeature_index_list], 0, 1), 2)
ntensor_inputs = torch.unsqueeze(torch.transpose(inputs[:,
self.nfeature_index_list], 0, 1), 2)
output[:, self.nfeature_index_list] = self.nsubnets(ntensor_inputs)
if len(self.cfeature_index_list) > 0:
ctensor_inputs = inputs[:, self.cfeature_index_list]
output[:, self.cfeature_index_list] = self.csubnets(ctensor_inputs)
return output


class pyInteractionNet(torch.nn.Module):

Expand All @@ -139,17 +143,21 @@ def __init__(self, interaction_list, nfeature_index_list, cfeature_index_list, n
self.n_inputs2 = []
for i in range(self.n_interactions):
if self.interaction_list[i][0] in self.cfeature_index_list:
self.n_inputs1.append(self.num_classes_list[self.cfeature_index_list.index(self.interaction_list[i][0])])
self.n_inputs1.append(self.num_classes_list[
self.cfeature_index_list.index(self.interaction_list[i][0])])
else:
self.n_inputs1.append(1)

if self.interaction_list[i][1] in self.cfeature_index_list:
self.n_inputs2.append(self.num_classes_list[self.cfeature_index_list.index(self.interaction_list[i][1])])
self.n_inputs2.append(self.num_classes_list[
self.cfeature_index_list.index(self.interaction_list[i][1])])
else:
self.n_inputs2.append(1)

self.max_n_inputs = max([self.n_inputs1[i] + self.n_inputs2[i] for i in range(self.n_interactions)])
self.subnets = TensorLayer(self.n_interactions, subnet_arch, self.max_n_inputs, activation_func, device)
self.max_n_inputs = max([self.n_inputs1[i] + self.n_inputs2[i]
for i in range(self.n_interactions)])
self.subnets = TensorLayer(self.n_interactions, subnet_arch, self.max_n_inputs,
activation_func, device)

def preprocessing(self, inputs):

Expand All @@ -172,8 +180,9 @@ def preprocessing(self, inputs):
interact_input_list.append(inputs[:, [idx2]])

if (self.n_inputs1[i] + self.n_inputs2[i]) < self.max_n_inputs:
interact_input_list.append(torch.zeros(size=(inputs.shape[0], self.max_n_inputs - (self.n_inputs1[i] + self.n_inputs2[i])),
dtype=torch.float, requires_grad=True, device=self.device))
interact_input_list.append(torch.zeros(size=(inputs.shape[0],
self.max_n_inputs - (self.n_inputs1[i] + self.n_inputs2[i])),
dtype=torch.float, requires_grad=True, device=self.device))
preprocessed_inputs.append(torch.hstack(interact_input_list))
preprocessed_inputs = torch.hstack(preprocessed_inputs)
return preprocessed_inputs
Expand All @@ -190,7 +199,7 @@ class pyGAMINet(torch.nn.Module):

def __init__(self, nfeature_index_list, cfeature_index_list, num_classes_list,
subnet_size_main_effect, subnet_size_interaction, activation_func,
heredity, mono_increasing_list, mono_decreasing_list,
heredity, mono_increasing_list, mono_decreasing_list,
boundary_clip, min_value, max_value, mu_list, std_list, device):

super(pyGAMINet, self).__init__()
Expand All @@ -201,7 +210,7 @@ def __init__(self, nfeature_index_list, cfeature_index_list, num_classes_list,
self.num_classes_list = num_classes_list
self.subnet_size_main_effect = subnet_size_main_effect
self.subnet_size_interaction = subnet_size_interaction
self.activation_func= activation_func
self.activation_func = activation_func
self.heredity = heredity
self.mono_increasing_list = mono_increasing_list
self.mono_decreasing_list = mono_decreasing_list
Expand All @@ -211,7 +220,7 @@ def __init__(self, nfeature_index_list, cfeature_index_list, num_classes_list,
self.max_value = max_value
self.mu_list = mu_list
self.std_list = std_list

self.device = device
self.interaction_status = False
self.main_effect_blocks = pyGAMNet(nfeature_index_list=nfeature_index_list,
Expand Down Expand Up @@ -244,27 +253,39 @@ def init_interaction_blocks(self, interaction_list):
activation_func=self.activation_func,
device=self.device)
self.interaction_weights = torch.nn.Parameter(torch.empty(size=(self.n_interactions, 1),
dtype=torch.float, requires_grad=True, device=self.device))
dtype=torch.float, requires_grad=True, device=self.device))
self.interaction_switcher = torch.nn.Parameter(torch.empty(size=(self.n_interactions, 1),
dtype=torch.float, requires_grad=False, device=self.device))
dtype=torch.float, requires_grad=False, device=self.device))
torch.nn.init.ones_(self.interaction_switcher)
torch.nn.init.ones_(self.interaction_weights)

def get_mono_loss(self, inputs, outputs=None, monotonicity=False):
def get_mono_loss(self, inputs, outputs=None, monotonicity=False, sample_weight=None):

mono_loss = torch.tensor(0.0, requires_grad=True)
if not monotonicity:
return mono_loss

grad = torch.autograd.grad(outputs=torch.sum(outputs),
inputs=inputs, create_graph=True)[0]
if len(self.mono_increasing_list) > 0:
mono_loss = mono_loss + torch.mean(torch.nn.ReLU()(-grad[:, self.mono_increasing_list]))
if len(self.mono_decreasing_list) > 0:
mono_loss = mono_loss + torch.mean(torch.nn.ReLU()(grad[:, self.mono_decreasing_list]))

if sample_weight is not None:
if len(self.mono_increasing_list) > 0:
mono_loss = mono_loss + torch.mean(torch.nn.ReLU()(
-grad[:, self.mono_increasing_list]) * sample_weight.reshape(-1, 1))
if len(self.mono_decreasing_list) > 0:
mono_loss = mono_loss + torch.mean(torch.nn.ReLU()(
grad[:, self.mono_decreasing_list]) * sample_weight.reshape(-1, 1))
else:
if len(self.mono_increasing_list) > 0:
mono_loss = mono_loss + torch.mean(torch.nn.ReLU()(
-grad[:, self.mono_increasing_list]))
if len(self.mono_decreasing_list) > 0:
mono_loss = mono_loss + torch.mean(torch.nn.ReLU()(
grad[:, self.mono_decreasing_list]))
return mono_loss

def get_clarity_loss(self, main_effect_outputs=None, interaction_outputs=None, sample_weight=None, clarity=False):
def get_clarity_loss(self, main_effect_outputs=None, interaction_outputs=None,
sample_weight=None, clarity=False):

clarity_loss = torch.tensor(0.0, requires_grad=True)
if main_effect_outputs is None:
Expand All @@ -276,32 +297,35 @@ def get_clarity_loss(self, main_effect_outputs=None, interaction_outputs=None, s

for i, (k1, k2) in enumerate(self.interaction_blocks.interaction_list):
if sample_weight is not None:
clarity_loss = clarity_loss + torch.abs((main_effect_outputs[:, k1]
* interaction_outputs[:, i] * sample_weight.ravel()).mean())
clarity_loss = clarity_loss + torch.abs((main_effect_outputs[:, k2]
* interaction_outputs[:, i] * sample_weight.ravel()).mean())
clarity_loss = clarity_loss + torch.abs((main_effect_outputs[:, k1] *
interaction_outputs[:, i] * sample_weight.ravel()).mean())
clarity_loss = clarity_loss + torch.abs((main_effect_outputs[:, k2] *
interaction_outputs[:, i] * sample_weight.ravel()).mean())
else:
clarity_loss = clarity_loss + torch.abs((main_effect_outputs[:, k1]
clarity_loss = clarity_loss + torch.abs((main_effect_outputs[:, k1]
* interaction_outputs[:, i]).mean())
clarity_loss = clarity_loss + torch.abs((main_effect_outputs[:, k2]
clarity_loss = clarity_loss + torch.abs((main_effect_outputs[:, k2]
* interaction_outputs[:, i]).mean())
return clarity_loss

def forward_main_effect(self, inputs):

inputs = torch.max(torch.min(inputs, self.max_value), self.min_value) if self.boundary_clip else inputs
inputs = (inputs - self.mu_list) / self.std_list
outputs = self.main_effect_blocks(inputs)
main_effect_weights = self.main_effect_switcher * self.main_effect_weights
outputs = self.main_effect_blocks(inputs) * main_effect_weights.ravel()
return outputs

def forward_interaction(self, inputs):

inputs = torch.max(torch.min(inputs, self.max_value), self.min_value) if self.boundary_clip else inputs
inputs = (inputs - self.mu_list) / self.std_list
outputs = self.interaction_blocks(inputs)
interaction_weights = self.interaction_switcher * self.interaction_weights
outputs = self.interaction_blocks(inputs) * interaction_weights.ravel()
return outputs

def forward(self, inputs, sample_weight=None, main_effect=True, interaction=True, clarity=False, monotonicity=False):
def forward(self, inputs, sample_weight=None, main_effect=True, interaction=True,
clarity=False, monotonicity=False):

main_effect_outputs = None
interaction_outputs = None
Expand All @@ -318,6 +342,7 @@ def forward(self, inputs, sample_weight=None, main_effect=True, interaction=True
interaction_outputs = self.interaction_blocks(inputs) * interaction_weights.ravel()
outputs = outputs + interaction_outputs.sum(1, keepdim=True)

self.mono_loss = self.get_mono_loss(inputs, outputs, monotonicity)
self.clarity_loss = self.get_clarity_loss(main_effect_outputs, interaction_outputs, sample_weight, clarity)
self.mono_loss = self.get_mono_loss(inputs, outputs, monotonicity, sample_weight)
self.clarity_loss = self.get_clarity_loss(main_effect_outputs, interaction_outputs,
sample_weight, clarity)
return outputs
Loading

0 comments on commit c2de441

Please sign in to comment.