Skip to content

Commit

Permalink
Merge pull request #127 from sissa-data-science/improvements_ii
Browse files Browse the repository at this point in the history
information imbalance with automatic subsampling
  • Loading branch information
imacocco authored Mar 27, 2024
2 parents 517b9df + 6ed5aa7 commit 24c4c60
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 52 deletions.
8 changes: 4 additions & 4 deletions dadapy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ def __init__(
njobs (int): number of cores to be used
"""
self.X = coordinates
self.maxk = maxk
self.maxk = maxk # remove from here
self.verb = verbose
self.n_jobs = n_jobs
self.dims = None
self.N = None
self.metric = "euclidean"
self.period = period
self.metric = "euclidean" # remove from here
self.period = period # remove from here
self.rng = np.random.default_rng(rng_seed)

if self.X is not None:
Expand All @@ -83,7 +83,7 @@ def __init__(
self.dims = coordinates.shape[1]
self.distances = None
self.dist_indices = None
if self.maxk is None:
if self.maxk is None: # remove from here
self.maxk = min(100, self.N - 1)

if distances is not None:
Expand Down
113 changes: 89 additions & 24 deletions dadapy/metric_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,40 +73,91 @@ def __init__(
n_jobs=n_jobs,
)

def return_information_imbalace(self, coordinates, k=1):
def return_information_imbalace(
self, coordinates, k=1, subset_size=2000, repeats=None, avg=True
):
"""Return the imbalance with another dataset X.
Args:
coordinates (np.ndarray(float)): the coordinates of the othe dataset (N , dimension of embedding space)
k (int): order of nearest neighbour considered for the calculation of the imbalance, default is 1
coordinates (np.ndarray(float)): the coordinates of the othe dataset (N , dimension of embedding space).
k (int): order of nearest neighbour considered for the calculation of the imbalance, default is 1,
subset_size (int): size of the subsets on which the information imbalance is computed.
repeats (int): the number of repetitions for the information imbalance calculation.
Returns:
(float, float): the information imbalance from distance i to distance j and vice versa
(np.array, np.array): the information imbalances their standard error
"""
distances = None
dist_indices = None
assert self.X is not None, "information imbalance requires coordinate matrix."
assert (
self.X.shape[0] == coordinates.shape[0]
), "the two datasets must have the same number of samples"

assert any(
var is not None for var in [self.X, self.distances, self.dist_indices]
), "MetricComparisons should be initialized with a dataset."
if repeats is None:
repeats = self.N // subset_size

assert any(
var is not None for var in [coordinates, distances, dist_indices]
), "The overlap with data requires a second dataset. \
Provide at least one of coordinates, distances, dist_indices."
if self.N <= subset_size:
warnings.warn(
"Subset size greater than the dataset size. \
Computing information imbalance once on the entire dataset.",
stacklevel=2,
)
repeats = 1
subset_size = self.N
elif repeats > self.N // subset_size:
warnings.warn(
"repeats * subset_size > dataset size. \
setting repeats = dataset_size // subset_size.",
stacklevel=2,
)
repeats = self.N // subset_size

# subsets is a list of arrays. Each array contained the indices of points belonging to
# the subsets
subsets = [np.arange(self.N)]
if repeats > 1:
# shuffling the integers from 0 to self.N -1
indices = self.rng.choice(self.N, self.N, replace=False)
# splitting the indices array into 'repeats'
subsets = np.array_split(indices, repeats)
if len(subsets[-1]) != len(subsets[-2]):
# all groups should have the same size
subsets = subsets[:-1]
repeats -= 1

imb_ij = np.zeros(repeats)
imb_ji = np.zeros(repeats)
for i, idx in enumerate(subsets):
x_base = self.X[idx]
x_other = coordinates[idx]

dist_indices_base, _ = self._get_nn_indices(
x_base, None, None, subset_size - 1, force_computation=True
)
dist_indices_other, _ = self._get_nn_indices(
x_other, None, None, subset_size - 1, force_computation=True
)

dist_indices_base, _ = self._get_nn_indices(
self.X, self.distances, self.dist_indices, self.maxk
)
assert dist_indices_base.shape[0] == dist_indices_other.shape[0]

dist_indices_other, _ = self._get_nn_indices(
coordinates, distances, dist_indices, self.maxk
)
imb_ij[i] = _return_imbalance(
dist_indices_base, dist_indices_other, self.rng, k=k
)

assert dist_indices_base.shape[0] == dist_indices_other.shape[0]
imb_ji[i] = _return_imbalance(
dist_indices_other, dist_indices_base, self.rng, k=k
)

imb_ij = _return_imbalance(dist_indices_base, dist_indices_other, self.rng, k=k)
imb_ji = _return_imbalance(dist_indices_other, dist_indices_base, self.rng, k=k)
if avg:
if repeats == 1:
return np.array([imb_ij[0], 0]), np.array([imb_ji[0], 0])
mean_ij, err_ij = (
np.mean(imb_ij),
np.std(imb_ij, ddof=1) / repeats**0.5,
)
mean_ji, err_ji = (
np.mean(imb_ji),
np.std(imb_ji, ddof=1) / repeats**0.5,
)
return np.array([mean_ij, err_ij]), np.array([mean_ji, err_ji])

return imb_ij, imb_ji

Expand Down Expand Up @@ -506,7 +557,21 @@ def return_inf_imb_target_all_dplets(self, target_ranks, d, k=1):

return np.array(coord_list), np.array(imbalances)

def _get_nn_indices(self, coordinates, distances, dist_indices, k, coords=None):
def _get_nn_indices(
self,
coordinates,
distances,
dist_indices,
k,
coords=None,
force_computation=False,
):
if force_computation:
_, dist_indices = compute_nn_distances(
coordinates, k, self.metric, self.period
)
return dist_indices, k

if coords is not None:
assert (
coordinates is not None
Expand Down
60 changes: 37 additions & 23 deletions examples/notebook_on_differentiable_imbalance.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,15 @@
}
],
"source": [
"n_epochs = 50 # number of training epochs\n",
"n_epochs = 50 # number of training epochs\n",
"\n",
"f = FeatureWeighting(coordinates=X)\n",
"f_target = FeatureWeighting(coordinates=X_target)\n",
"final_weights = f.return_weights_optimize_dii(target_data=f_target, \n",
" initial_weights=None, # (default) automatic weights as inverse std.dev. of features\n",
" n_epochs=n_epochs)"
"final_weights = f.return_weights_optimize_dii(\n",
" target_data=f_target,\n",
" initial_weights=None, # (default) automatic weights as inverse std.dev. of features\n",
" n_epochs=n_epochs,\n",
")"
]
},
{
Expand Down Expand Up @@ -267,11 +269,11 @@
"dii_per_epoch = f.history[\"dii_per_epoch\"]\n",
"weights_per_epoch = f.history[\"weights_per_epoch\"]\n",
"\n",
"fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (10, 4))\n",
"ax1.plot(np.arange(n_epochs+1), dii_per_epoch)\n",
"ax2.plot(np.arange(n_epochs+1), weights_per_epoch[:,0], label=\"$w_1$\")\n",
"ax2.plot(np.arange(n_epochs+1), weights_per_epoch[:,1], label=\"$w_2$\")\n",
"ax2.plot(np.arange(n_epochs+1), weights_per_epoch[:,2], label=\"$w_3$\")\n",
"fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))\n",
"ax1.plot(np.arange(n_epochs + 1), dii_per_epoch)\n",
"ax2.plot(np.arange(n_epochs + 1), weights_per_epoch[:, 0], label=\"$w_1$\")\n",
"ax2.plot(np.arange(n_epochs + 1), weights_per_epoch[:, 1], label=\"$w_2$\")\n",
"ax2.plot(np.arange(n_epochs + 1), weights_per_epoch[:, 2], label=\"$w_3$\")\n",
"\n",
"ax1.set(ylabel=\"Differentiable Information Imbalance\", xlabel=\"Epoch number\")\n",
"ax2.set(ylabel=\"Sample weights\", xlabel=\"Epoch number\")\n",
Expand Down Expand Up @@ -305,10 +307,22 @@
}
],
"source": [
"weights_names = (\"$w_1$\", \"$w_2$\", \"$w_3$\", \"$w_4$\", \"$w_5$\",\n",
" \"$w_6$\", \"$w_7$\", \"$w_8$\", \"$w_9$\", \"$w_{10}$\")\n",
"weights_grouped = {'Ground-truth weights': weights / max(weights),\n",
" 'Learnt weights (DII)': final_weights / max(final_weights)}\n",
"weights_names = (\n",
" \"$w_1$\",\n",
" \"$w_2$\",\n",
" \"$w_3$\",\n",
" \"$w_4$\",\n",
" \"$w_5$\",\n",
" \"$w_6$\",\n",
" \"$w_7$\",\n",
" \"$w_8$\",\n",
" \"$w_9$\",\n",
" \"$w_{10}$\",\n",
")\n",
"weights_grouped = {\n",
" \"Ground-truth weights\": weights / max(weights),\n",
" \"Learnt weights (DII)\": final_weights / max(final_weights),\n",
"}\n",
"\n",
"x = np.arange(len(weights_names)) # label locations\n",
"width = 0.2 # bar widths\n",
Expand Down Expand Up @@ -516,16 +530,16 @@
}
],
"source": [
"n_epochs = 80 # number of training epochs\n",
"n_epochs = 80 # number of training epochs\n",
"\n",
"f = FeatureWeighting(coordinates=X_monomials, verbose=True)\n",
"f_target = FeatureWeighting(coordinates=X_monomials_target)\n",
"\n",
"final_imbs, final_weights = f.return_backward_greedy_dii_elimination(\n",
" target_data=f_target,\n",
" initial_weights=None, # set automatically (default)\n",
" initial_weights=None, # set automatically (default)\n",
" n_epochs=n_epochs,\n",
" learning_rate=None, # set automatically (default)\n",
" learning_rate=None, # set automatically (default)\n",
")"
]
},
Expand Down Expand Up @@ -594,7 +608,7 @@
"weights_names = []\n",
"for monomial_index in union_set:\n",
" monomial = monomials_list[monomial_index]\n",
" weights_names.append( f\"$w({''.join((coords_names[list(monomial)]))})$\" )\n",
" weights_names.append(f\"$w({''.join((coords_names[list(monomial)]))})$\")\n",
"\n",
"x = np.arange(len(weights_names)) # label locations\n",
"width = 0.2 # bar widths\n",
Expand Down Expand Up @@ -751,7 +765,7 @@
}
],
"source": [
"n_epochs = 80 # number of training epochs\n",
"n_epochs = 80 # number of training epochs\n",
"\n",
"f = FeatureWeighting(coordinates=X_monomials, verbose=True)\n",
"f_target = FeatureWeighting(coordinates=X_monomials_target)\n",
Expand All @@ -763,11 +777,11 @@
" weights_opt_per_nfeatures,\n",
") = f.return_lasso_optimization_dii_search(\n",
" target_data=f_target,\n",
" initial_weights=None, # (default) set automatically\n",
" initial_weights=None, # (default) set automatically\n",
" n_epochs=n_epochs,\n",
" learning_rate=None, # (default) set automatically\n",
" refine=False, # only 10 values of the L1 strength are tested\n",
" plotlasso=True # automatically show DII vs number of non-zero features\n",
" learning_rate=None, # (default) set automatically\n",
" refine=False, # only 10 values of the L1 strength are tested\n",
" plotlasso=True, # automatically show DII vs number of non-zero features\n",
")"
]
},
Expand Down Expand Up @@ -810,7 +824,7 @@
"weights_names = []\n",
"for monomial_index in union_set:\n",
" monomial = monomials_list[monomial_index]\n",
" weights_names.append( f\"$w({''.join((coords_names[list(monomial)]))})$\" )\n",
" weights_names.append(f\"$w({''.join((coords_names[list(monomial)]))})$\")\n",
"\n",
"x = np.arange(len(weights_names)) # label locations\n",
"width = 0.2 # bar widths\n",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ license = {file = "LICENSE"}
readme = "README.md"
requires-python = ">=3.7"

dependencies = ["numpy", "scipy", "scikit-learn", "matplotlib"]
dependencies = ["numpy", "scipy", "scikit-learn", "matplotlib", "seaborn"]

[project.urls]
homepage = "https://github.com/sissa-data-science/DADApy"
Expand Down

0 comments on commit 24c4c60

Please sign in to comment.