-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #11 from elastic/partition_fix
Partition fix
- Loading branch information
Showing
64 changed files
with
2,053 additions
and
119,021 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,5 +5,6 @@ models | |
build | ||
dist | ||
*.egg-info | ||
*.png | ||
__pycache__ | ||
add_license.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
#!/usr/bin/env python | ||
|
||
""" | ||
This is an example that tracks chronological drift in the ember dataset. We train on the ember dataset on data before 2018-07, | ||
and then run everything through it. There's a massive increase in the total KL-div after the cutoff, so this does detect a | ||
shift in the dataset. | ||
""" | ||
|
||
import os | ||
import pandas as pd | ||
import ember | ||
import argparse | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from pygoko import CoverTree | ||
|
||
|
||
def main(): | ||
prog = "ember_drift_calc" | ||
descr = "Train an ember model from a directory with raw feature files" | ||
parser = argparse.ArgumentParser(prog=prog, description=descr) | ||
parser.add_argument("datadir", metavar="DATADIR", type=str, help="Directory with raw features") | ||
args = parser.parse_args() | ||
|
||
training_data, all_data, X_month = sort_ember_dataset(datadir = args.datadir, split_date = "2018-07") | ||
|
||
# Build the tree | ||
tree = CoverTree() | ||
tree.set_leaf_cutoff(50) | ||
tree.set_scale_base(1.5) | ||
tree.set_min_res_index(0) | ||
tree.fit(training_data) | ||
|
||
# Gather a baseline | ||
prior_weight = 1.0 | ||
observation_weight = 1.3 | ||
# 0 sets the window to be infinite, otherwise the "dataset" you're computing against is only the last `window_size` elements | ||
window_size = 5000 | ||
# We don't use this, our sequences are windowed so we only compute the KL Div on (at most) the last window_size elements | ||
sequence_len = 800000 | ||
# Actually computes the KL div this often. All other values are linearly interpolated between these sample points. | ||
# It's too slow to calculate each value and this is accurate enough. | ||
sample_rate = 10 | ||
# Gets the mean and variance over this number of simulated sequence. | ||
sequence_count = 50 | ||
|
||
''' | ||
We gather a baseline object. When you feed the entire dataset the covertree was created from to itself, | ||
you will get a non-zero KL-Div on any node that is non-trivial. This process will weight the node's posterior Dirichlet distribution, | ||
multiplying the internal weights by (prior_weight + observation_weight). This posterior distribution has a lower variance than the prior and | ||
the expected KL-divergence between the unknown distributions we're modeling is thus non-zero. | ||
This slowly builds up, but we expect a non-zero KL-div over the nodes as we feed in-distribution data in. This object estimates that, and | ||
allows us to normalize this natural variance away. | ||
''' | ||
baseline = tree.kl_div_dirichlet_baseline( | ||
prior_weight, | ||
observation_weight, | ||
window_size, | ||
sequence_count, | ||
sample_rate) | ||
goko_divs = {} | ||
|
||
""" | ||
This is the actual object that computes the KL Divergence statistics between the samples we feed in and the new samples. | ||
Internally, it is an evidence hashmap containing categorical distributions, and a queue of paths. | ||
The sample's path is computed, we then push it onto the queue and update the evidence by incrementing the correct buckets | ||
in the evidence hashmap. If the queue is full, we pop off the oldest path and decrement the correct paths in the queue. | ||
""" | ||
run_tracker = tree.kl_div_dirichlet( | ||
prior_weight, | ||
observation_weight, | ||
window_size) | ||
|
||
total_kl_div = [] | ||
|
||
for i,datum in enumerate(all_data): | ||
run_tracker.push(datum) | ||
if i % 500 == 0: | ||
goko_divs[i] = normalize(baseline,run_tracker.stats()) | ||
total_kl_div.append(goko_divs[i]['moment1_nz']) | ||
|
||
|
||
fig, ax = plt.subplots() | ||
ax.plot(list(range(0,len(all_data),500)),total_kl_div) | ||
ax.set_ylabel('KL Divergence') | ||
ax.set_xlabel('Sample Timestamp') | ||
tick_len = 0 | ||
cutoff_len = 0 | ||
tick_locations = [] | ||
dates = [d for d in X_month.keys()] | ||
for date in dates: | ||
if date == "2018-07": | ||
cutoff_len = tick_len | ||
tick_len += len(X_month[date]) | ||
tick_locations.append(tick_len) | ||
ax.set_xticks(tick_locations) | ||
ax.set_xticklabels(dates) | ||
ax.axvline(x=cutoff_len, linewidth=4, color='r') | ||
fig.tight_layout() | ||
fig.savefig("drift.png", bbox_inches='tight') | ||
plt.show() | ||
plt.close() | ||
|
||
def normalize(baseline,stats): | ||
""" | ||
Grabs the mean and variance from the baseline and normalizes the stats object passed in by subtracting | ||
the norm and dividing by the standard deviation. | ||
""" | ||
basesline_stats = baseline.stats(stats["sequence_len"]) | ||
normalized = {} | ||
for k in basesline_stats.keys(): | ||
n = (stats[k]-basesline_stats[k]["mean"]) | ||
if basesline_stats[k]["var"] > 0: | ||
n = n/np.sqrt(basesline_stats[k]["var"]) | ||
normalized[k] = n | ||
return normalized | ||
|
||
def sort_ember_dataset(datadir,split_date): | ||
""" | ||
Opens the dataset and creates a training dataset consisting of everything before the split date. | ||
Returns the training dataset and all data | ||
""" | ||
X, _ = ember.read_vectorized_features(datadir,"train") | ||
metadata = pd.read_csv(os.path.join(datadir, "train_metadata.csv"), index_col=0) | ||
dates = list(set(metadata['appeared'])) | ||
dates.sort() | ||
|
||
X_month = {k:X[metadata['appeared'] == k] for k in dates} | ||
|
||
training_dates = [d for d in dates if d < split_date] | ||
all_dates = [d for d in dates] | ||
|
||
training_data = np.concatenate([X_month[k] for k in training_dates]) | ||
training_data = np.ascontiguousarray(training_data) | ||
|
||
all_data = np.concatenate([X_month[k] for k in all_dates]) | ||
all_data = np.ascontiguousarray(all_data) | ||
|
||
return training_data, all_data, X_month | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
''' | ||
This produces 2 gaussians, one that is fixed at 0, and the other that moves slowly over. | ||
''' | ||
|
||
import pygoko | ||
import numpy as np | ||
import pandas as pd | ||
import plotly.graph_objects as go | ||
from collections import defaultdict | ||
import matplotlib.pyplot as plt | ||
|
||
def main(): | ||
# How many samples we grab from the fixed gaussian | ||
fixed_sample_count = 200000 | ||
# How many samples we grab from the moving gaussian, at each timestamp | ||
moving_sample_count = 1000 | ||
timestamps = np.linspace(0,2,50) | ||
|
||
# We treat multiply the weight vector of the Dirichlet prior by this before computing the KL-div | ||
prior_weight = 1.0 | ||
# We weight the evidence by this before we add it to the prior to get the posterior. | ||
observation_weight = 1.0 | ||
# How often we sample the KL div of the sequences | ||
sample_rate = 25 | ||
# How many artificial sequences do we average over | ||
sequence_count = 32 | ||
|
||
tree = build_covertree(sample_from_gaussian(0,fixed_sample_count)) | ||
|
||
# See the ember_ chronological _drift | ||
baseline = tree.kl_div_dirichlet_baseline( | ||
prior_weight, | ||
observation_weight, | ||
moving_sample_count, # We don't need to sample sequences longer than this | ||
sequence_count, | ||
sample_rate) | ||
|
||
|
||
|
||
tracking_stats = [] | ||
for t in timestamps: | ||
run_only_tracker = tree.kl_div_dirichlet( | ||
prior_weight, | ||
observation_weight, | ||
0) | ||
timestamps_stats = defaultdict(list) | ||
|
||
for i in range(10): | ||
moving_data = sample_from_gaussian(t,moving_sample_count) | ||
for x in moving_data: | ||
run_only_tracker.push(x) | ||
unpack_stats( | ||
timestamps_stats, | ||
run_only_tracker.stats(), | ||
baseline.stats(moving_sample_count)) | ||
tracking_stats.append(timestamps_stats) | ||
|
||
plot(timestamps,tracking_stats) | ||
|
||
def sample_from_gaussian(x_mean, count): | ||
""" | ||
Grabs count samples from a gaussian centered at [x_mean, 0, ... 0], with the identity matrix for the covariance. | ||
""" | ||
mean = np.zeros([100],dtype=np.float32) | ||
mean[0] = x_mean | ||
cov = np.diag(np.concatenate([np.ones([10],dtype=np.float32),0.001*np.ones([90],dtype=np.float32)])) | ||
return np.random.multivariate_normal(mean,cov,count).astype(np.float32) | ||
|
||
def build_covertree(data): | ||
""" | ||
Builds a covertree on the data | ||
""" | ||
tree = pygoko.CoverTree() | ||
tree.set_leaf_cutoff(100) | ||
tree.set_scale_base(1.5) | ||
tree.set_min_res_index(-30) | ||
tree.fit(data) | ||
return tree | ||
|
||
def unpack_stats(dataframe,stats, baseline): | ||
""" | ||
Normalizes the stats by the baseline | ||
""" | ||
for k in baseline.keys(): | ||
normalized = (stats[k]-baseline[k]["mean"]) | ||
if baseline[k]["var"] > 0: | ||
normalized/np.sqrt(baseline[k]["var"]) | ||
dataframe[k].append(normalized) | ||
|
||
def plot(timestamps,dataframes,statistic="moment1_nz"): | ||
cumulation = defaultdict(list) | ||
for dataframe in dataframes: | ||
for k,v in dataframe.items(): | ||
cumulation[k].append(v) | ||
|
||
cumulation = {k: np.stack(v) for k,v in cumulation.items()} | ||
cumulation_mean = {k: np.mean(v, axis=1) for k,v in cumulation.items()} | ||
fig, ax = plt.subplots() | ||
|
||
ax.plot(timestamps,cumulation_mean[statistic]) | ||
ax.set_ylabel('KL Divergence') | ||
ax.set_xlabel('Distance between mean of Multinomial in 100d') | ||
fig.tight_layout() | ||
fig.savefig("GaussianDrift.png", bbox_inches='tight') | ||
plt.show() | ||
plt.close() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#!/usr/bin/env python | ||
|
||
""" | ||
This example shows loading from a YAML file. You can specify all the parameters in the yaml file. | ||
Thiss | ||
""" | ||
|
||
import numpy as np | ||
from pygoko import CoverTree | ||
|
||
tree = CoverTree() | ||
tree.load_yaml_config("data/mnist_complex.yml") | ||
tree.fit() | ||
|
||
point = np.zeros([784], dtype=np.float32) | ||
|
||
""" | ||
This is a standard KNN, returning the 5 nearest nbrs. | ||
""" | ||
|
||
print(tree.knn(point,5)) | ||
|
||
""" | ||
This is the KNN that ignores singletons, the outliers attached to each node and the leftover indexes on the leaf | ||
are ignored. | ||
""" | ||
|
||
print(tree.routing_knn(point,5)) | ||
|
||
""" | ||
This returns the indexes of the nodes along the path. We can then ask for the label summaries of | ||
each node along the path. | ||
""" | ||
path = tree.path(point) | ||
print(path) | ||
|
||
print("Summary of the labels of points covered by the node at address") | ||
for dist, address in path: | ||
node = tree.node(address) | ||
label_summary = node.label_summary() | ||
print(f"Address {address}: Summary: {label_summary}") | ||
|
This file was deleted.
Oops, something went wrong.
Empty file.
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.