From 43239ed89abfca8283bed9da0e18e9df46c9404c Mon Sep 17 00:00:00 2001 From: Christoph Wurst Date: Thu, 6 Dec 2018 13:15:54 +0100 Subject: [PATCH] Add MLP trainer based on DB data Signed-off-by: Christoph Wurst --- appinfo/info.xml | 6 +- lib/Command/TrainMLP.php | 99 +++++++++++++++++ lib/Db/LoginAddressMapper.php | 11 ++ lib/Service/DataSet.php | 118 ++++++++++++++++++++ lib/Service/MLPTrainer.php | 116 +++++++++++++++++++ lib/Service/NegativeSampleGenerator.php | 141 ++++++++++++++++++++++++ lib/Service/UidIPVector.php | 108 ++++++++++++++++++ 7 files changed, 598 insertions(+), 1 deletion(-) create mode 100644 lib/Command/TrainMLP.php create mode 100644 lib/Service/DataSet.php create mode 100644 lib/Service/MLPTrainer.php create mode 100644 lib/Service/NegativeSampleGenerator.php create mode 100644 lib/Service/UidIPVector.php diff --git a/appinfo/info.xml b/appinfo/info.xml index 5d107c39..0fc3fcbb 100644 --- a/appinfo/info.xml +++ b/appinfo/info.xml @@ -22,7 +22,11 @@ - + + + + OCA\SuspiciousLogin\Command\TrainMLP + diff --git a/lib/Command/TrainMLP.php b/lib/Command/TrainMLP.php new file mode 100644 index 00000000..a4f24dbb --- /dev/null +++ b/lib/Command/TrainMLP.php @@ -0,0 +1,99 @@ + + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + * + */ + +namespace OCA\SuspiciousLogin\Command; + +use OCA\SuspiciousLogin\Service\MLPTrainer; +use Symfony\Component\Console\Command\Command; +use Symfony\Component\Console\Input\InputArgument; +use Symfony\Component\Console\Input\InputInterface; +use Symfony\Component\Console\Input\InputOption; +use Symfony\Component\Console\Output\OutputInterface; + +class TrainMLP extends Command { + + /** @var MLPTrainer */ + private $trainer; + + public function __construct(MLPTrainer $trainer) { + parent::__construct("suspiciouslogin:train:mlp"); + $this->trainer = $trainer; + + $this->addOption( + 'shuffled', + null, + InputOption::VALUE_OPTIONAL, + "ratio of shuffled negative samples", + 1.0 + ); + $this->addOption( + 'random', + null, + InputOption::VALUE_OPTIONAL, + "ratio of random negative samples", + 1.0 + ); + $this->addOption( + 'epochs', + 'e', + InputOption::VALUE_OPTIONAL, + "number of epochs to train", + 5000 + ); + $this->addOption( + 'layers', + 'l', + InputOption::VALUE_OPTIONAL, + "number of hidden layers", + 6 + ); + $this->addOption( + 'learn-rate', + null, + InputOption::VALUE_OPTIONAL, + "learning rate", + 0.05 + ); + $this->addOption( + 'validation-rate', + null, + InputOption::VALUE_OPTIONAL, + "relative size of the validation data set", + 0.15 + ); + } + + protected function execute(InputInterface $input, OutputInterface $output) { + $this->trainer->train( + $output, + (float) $input->getOption('shuffled'), + (float) $input->getOption('random'), + (int) $input->getOption('epochs'), + (int) $input->getOption('layers'), + (float) $input->getOption('learn-rate'), + (float) $input->getOption('validation-rate') + ); + } + +} diff --git a/lib/Db/LoginAddressMapper.php b/lib/Db/LoginAddressMapper.php index 0f1cefab..1e62a95d 100644 --- a/lib/Db/LoginAddressMapper.php +++ b/lib/Db/LoginAddressMapper.php @@ -33,4 +33,15 @@ public function __construct(IDBConnection $db) { parent::__construct($db, 'login_address'); } + public function findAll(int $minEntries = 1) { + $qb = $this->db->getQueryBuilder(); + + $query = $qb + ->select('uid', 'ip') + ->from($this->getTableName()) + ->groupBy('uid', 'ip'); + + return $this->findEntities($query); + } + } diff --git a/lib/Service/DataSet.php b/lib/Service/DataSet.php new file mode 100644 index 00000000..08041fc6 --- /dev/null +++ b/lib/Service/DataSet.php @@ -0,0 +1,118 @@ + + * + * @author 2018 Christoph Wurst + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +namespace OCA\SuspiciousLogin\Service; + +use function array_key_exists; +use function array_map; +use function array_merge; +use ArrayAccess; +use Countable; +use OCA\SuspiciousLogin\Db\LoginAddress; +use function shuffle; + +class DataSet implements ArrayAccess, Countable { + + /** @var UidIPVector[] */ + private $data; + + public function __construct(array $data) { + $this->data = array_map(function (array $item) { + return new UidIPVector($item['uid'], $item['ip'], $item['label']); + }, $data); + } + + /** + * @param LoginAddress[] $loginAddresses + */ + public static function fromLoginAddresses(array $loginAddresses): DataSet { + return new DataSet(array_map(function (LoginAddress $addr) { + return [ + 'uid' => $addr->getUid(), + 'ip' => $addr->getIp(), + 'label' => MLPTrainer::LABEL_POSITIVE, + ]; + }, $loginAddresses)); + } + + public function asTrainingData(): array { + return array_map(function (UidIPVector $vec) { + return $vec->asFeatureVector(); + }, $this->data); + } + + /** + * Whether a offset exists + */ + public function offsetExists($offset) { + return array_key_exists($offset, $this->data); + } + + /** + * Offset to retrieve + */ + public function offsetGet($offset) { + return $this->data[$offset]; + } + + /** + * Offset to set + */ + public function offsetSet($offset, $value) { + $this->data[$offset] = $value; + } + + /** + * Offset to unset + */ + public function offsetUnset($offset) { + unset($this->data[$offset]); + } + + public function count() { + return count($this->data); + } + + /** + * @return string[] + */ + public function getLabels(): array { + return array_map(function (UidIPVector $vec) { + return $vec->getLabel(); + }, $this->data); + } + + public function merge(DataSet $other): DataSet { + $merged = array_merge($this->data, $other->data); + $new = new DataSet([]); + $new->data = $merged; + return $new; + } + + public function shuffle() { + shuffle($this->data); + } + +} diff --git a/lib/Service/MLPTrainer.php b/lib/Service/MLPTrainer.php new file mode 100644 index 00000000..bf64a494 --- /dev/null +++ b/lib/Service/MLPTrainer.php @@ -0,0 +1,116 @@ + + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + * + */ + +namespace OCA\SuspiciousLogin\Service; + +use function array_slice; +use OCA\SuspiciousLogin\Db\LoginAddressMapper; +use OCP\AppFramework\Utility\ITimeFactory; +use Phpml\Classification\MLPClassifier; +use Phpml\Metric\ClassificationReport; +use function shuffle; +use Symfony\Component\Console\Output\OutputInterface; + +class MLPTrainer { + + const LABEL_POSITIVE = 'y'; + const LABEL_NEGATIVE = 'n'; + + /** @var LoginAddressMapper */ + private $loginAddressMapper; + + /** @var NegativeSampleGenerator */ + private $negativeSampleGenerator; + + /** @var ITimeFactory */ + private $timeFactory; + + public function __construct(LoginAddressMapper $loginAddressMapper, + NegativeSampleGenerator $negativeSampleGenerator, + ITimeFactory $timeFactory) { + $this->loginAddressMapper = $loginAddressMapper; + $this->negativeSampleGenerator = $negativeSampleGenerator; + $this->timeFactory = $timeFactory; + } + + public function train(OutputInterface $output, + float $shuffledNegativeRate, + float $randomNegativeRate, + int $epochs, + int $layers, + float $learningRate, + float $validationRate) { + $raw = $this->loginAddressMapper->findAll(); + shuffle($raw); + $all = DataSet::fromLoginAddresses($raw); + $validationOffset = (int)min(count($all), max(0, count($raw) * (1 - $validationRate))); + $positives = DataSet::fromLoginAddresses(array_slice($raw, 0, $validationOffset)); + $validationPositives = DataSet::fromLoginAddresses(array_slice($raw, $validationOffset)); + $numValidation = count($validationPositives); + $numPositives = count($positives); + $numRandomNegatives = max((int)floor($numPositives * $randomNegativeRate), 1.0); + $numShuffledNegative = max((int)floor($numPositives * $shuffledNegativeRate), 1.0); + $randomNegatives = $this->negativeSampleGenerator->generateRandomFromPositiveSamples($positives, $numRandomNegatives); + $shuffledNegatives = $this->negativeSampleGenerator->generateRandomFromPositiveSamples($positives, $numRandomNegatives); + + // Validation negatives are generated from all data (to have all UIDs), but shuffled + $all->shuffle(); + $validationNegatives = $this->negativeSampleGenerator->generateRandomFromPositiveSamples($all, $numValidation); + $validationSamples = $validationPositives->merge($validationNegatives); + + $total = $numPositives + $numRandomNegatives + $numShuffledNegative; + $totalValidation = $numValidation * 2; + $output->writeln("Got $total samples for training: $numPositives positive, $numRandomNegatives random negative and $numShuffledNegative shuffled negative"); + $output->writeln("Got $numValidation positive and $numValidation negative samples for validation (rate: $validationRate)"); + $output->writeln("Number of epochs: " . $epochs); + $output->writeln("Number of hidden layers: " . $layers); + $output->writeln("Learning rate: " . $learningRate); + $output->writeln("Vecor dimensions: " . UidIPVector::SIZE); + + $allSamples = $positives->merge($randomNegatives)->merge($shuffledNegatives); + $allSamples->shuffle(); + + $output->writeln("Start training"); + $start = $this->timeFactory->getDateTime(); + $classifier = new MLPClassifier(UidIPVector::SIZE, [$layers], ['y', 'n'], $epochs, null, $learningRate); + $classifier->train( + $allSamples->asTrainingData(), + $allSamples->getLabels() + ); + $finished = $this->timeFactory->getDateTime(); + $elapsed = $finished->getTimestamp() - $start->getTimestamp(); + $output->writeln("Training finished after " . $elapsed . "s"); + + $predicted = $classifier->predict($validationSamples->asTrainingData()); + $result = new ClassificationReport($validationSamples->getLabels(), $predicted); + $output->writeln("Prescision(y): " . $result->getPrecision()['y']); + $output->writeln("Prescision(n): " . $result->getPrecision()['n']); + $output->writeln("Recall(y): " . $result->getRecall()['y']); + $output->writeln("Recall(n): " . $result->getRecall()['n']); + $output->writeln("Average(precision): " . $result->getAverage()['precision']); + $output->writeln("Average(recall): " . $result->getAverage()['recall']); + $output->writeln("Average(f1score): " . $result->getAverage()['f1score']); + } + +} diff --git a/lib/Service/NegativeSampleGenerator.php b/lib/Service/NegativeSampleGenerator.php new file mode 100644 index 00000000..b61dbfe2 --- /dev/null +++ b/lib/Service/NegativeSampleGenerator.php @@ -0,0 +1,141 @@ + + * + * @author 2018 Christoph Wurst + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +namespace OCA\SuspiciousLogin\Service; + +use function array_diff; +use function array_filter; +use function array_keys; +use function array_search; +use Exception; +use function random_int; + +class NegativeSampleGenerator { + + private function getUniqueIPsPerUser(DataSet $positives): array { + $ips = []; + + // First, let's map (uid,ip) to uid -> [ip] + $max = count($positives); + for ($i = 0; $i < $max; $i++) { + $positive = $positives[$i]; + if (!isset($ips[$positive->getUid()])) { + $ips[$positive->getUid()] = [ + $positive->getIp(), + ]; + } else { + $ips[$positive->getUid()][] = $positive->getIp(); + } + } + + $uniqueIps = []; + foreach ($ips as $uid => $userIps) { + $uniqueIps[$uid] = array_filter($userIps, function (string $ip) use ($ips, $uid) { + foreach ($ips as $other => $otherIps) { + if ($other === $uid) { + return false; + } + + // If the IP is not found for other users it's unique + return array_search($ip, $otherIps) === false; + } + }); + } + + return $uniqueIps; + } + + private function findRandomIp(string $uid, array $uniqueIps, int $maxRec = 10): string { + if ($maxRec === 0) { + throw new Exception("Could not generate negative sample off real data for $uid. Is there enough data for training?"); + } + + $rand = random_int(0, count($uniqueIps) - 1); + $randUid = array_keys($uniqueIps)[$rand]; + if ($randUid === $uid) { + return $this->findRandomIp($uid, $uniqueIps, $maxRec - 1); + } + $randIdx = random_int(0, count($uniqueIps[$randUid]) - 1); + return $uniqueIps[$randUid][$randIdx]; + } + + private function generateFromRealData(string $uid, array $uniqueIps): array { + return [ + 'uid' => $uid, + 'ip' => implode('.', [ + random_int(0, 255), + random_int(0, 255), + random_int(0, 255), + random_int(0, 255), + ]), + 'label' => MLPTrainer::LABEL_NEGATIVE, + ]; + } + + private function generateRandom(string $uid): array { + return [ + 'uid' => $uid, + 'ip' => implode('.', [ + random_int(0, 255), + random_int(0, 255), + random_int(0, 255), + random_int(0, 255), + ]), + 'label' => MLPTrainer::LABEL_NEGATIVE, + ]; + } + + /** + * @param DataSet $positives + * @param int $num + * + * @todo generate negative samples by mixing unrelated positive uids and positive ips + * + * @return DataSet + */ + public function generateRandomFromPositiveSamples(DataSet $positives, int $num): DataSet { + $max = count($positives); + + return new DataSet(array_map(function (int $id) use ($positives, $max) { + return $this->generateRandom($positives[$id % $max]->getUid()); + }, range(0, $num - 1))); + } + + /** + * @param DataSet $positives + * @param int $num + * + * @return DataSet + */ + public function generateShuffledFromPositiveSamples(DataSet $positives, int $num): DataSet { + $max = count($positives); + $uniqueIps = $this->getUniqueIPsPerUser($positives); + + return new DataSet(array_map(function (int $id) use ($uniqueIps, $positives, $max) { + return $this->generateFromRealData($positives[$id % $max]->getUid(), $uniqueIps); + }, range(0, $num - 1))); + } + +} diff --git a/lib/Service/UidIPVector.php b/lib/Service/UidIPVector.php new file mode 100644 index 00000000..8f57f26c --- /dev/null +++ b/lib/Service/UidIPVector.php @@ -0,0 +1,108 @@ + + * + * @author 2018 Christoph Wurst + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +namespace OCA\SuspiciousLogin\Service; + +use function array_merge; +use function str_pad; +use function substr; + +class UidIPVector { + + const SIZE = 16 + 32; + + /** @var string */ + private $uid; + + /** @var string */ + private $ip; + + /** @var string */ + private $label; + + public function __construct(string $uid, string $ip, string $label) { + $this->uid = $uid; + $this->ip = $ip; + $this->label = $label; + } + + private function numStringToBitArray(string $s, int $base, int $padding): array { + $bin = base_convert($s, $base, 2); + // make sure we get 00000000 to 11111111 + $padded = str_pad($bin, $padding, '0', STR_PAD_LEFT); + return str_split($padded); + } + + private function numStringsToBitArray(array $strings, $base, $padding): array { + $converted = array_reduce(array_map(function (string $s) use ($base, $padding) { + return $this->numStringToBitArray($s, $base, $padding); + }, $strings), 'array_merge', []); + return array_map(function ($x) { + return (float)$x; + }, $converted); + } + + private function uidAsFeatureVector(): array { + // TODO: just convert to binary and do substr of that + $splitHash = str_split( + substr( + md5($this->uid), + 0, + 4 + ) + ); + return $this->numStringsToBitArray($splitHash, 16, 4); + } + + /** + * Convert the decimal ip notation w.x.y.z to a binary (32bit) vector + * + * @return array + */ + private function ipAsFeatureVector(): array { + $splitIp = explode('.', $this->ip); + return $this->numStringsToBitArray($splitIp, 10, 8); + } + + public function asFeatureVector(): array { + return array_merge( + $this->uidAsFeatureVector(), + $this->ipAsFeatureVector() + ); + } + + public function getIp(): string { + return $this->ip; + } + + public function getUid(): string { + return $this->uid; + } + + public function getLabel(): string { + return $this->label; + } + +}