Skip to content

Commit

Permalink
Merge pull request #1 from ChristophWurst/feature/mlp-trainer
Browse files Browse the repository at this point in the history
Add MLP trainer based on DB data
  • Loading branch information
ChristophWurst authored Dec 13, 2018
2 parents 43297b2 + 43239ed commit 442aae7
Show file tree
Hide file tree
Showing 7 changed files with 598 additions and 1 deletion.
6 changes: 5 additions & 1 deletion appinfo/info.xml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
</repository>

<dependencies>
<php min-version="7.1" max-version="7.2"></php>
<php min-version="7.1" max-version="7.3"></php>
<nextcloud min-version="15" max-version="16"/>
</dependencies>

<commands>
<command>OCA\SuspiciousLogin\Command\TrainMLP</command>
</commands>
</info>
99 changes: 99 additions & 0 deletions lib/Command/TrainMLP.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
<?php

declare(strict_types=1);

/**
* @author Christoph Wurst <christoph@winzerhof-wurst.at>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
*/

namespace OCA\SuspiciousLogin\Command;

use OCA\SuspiciousLogin\Service\MLPTrainer;
use Symfony\Component\Console\Command\Command;
use Symfony\Component\Console\Input\InputArgument;
use Symfony\Component\Console\Input\InputInterface;
use Symfony\Component\Console\Input\InputOption;
use Symfony\Component\Console\Output\OutputInterface;

class TrainMLP extends Command {

/** @var MLPTrainer */
private $trainer;

public function __construct(MLPTrainer $trainer) {
parent::__construct("suspiciouslogin:train:mlp");
$this->trainer = $trainer;

$this->addOption(
'shuffled',
null,
InputOption::VALUE_OPTIONAL,
"ratio of shuffled negative samples",
1.0
);
$this->addOption(
'random',
null,
InputOption::VALUE_OPTIONAL,
"ratio of random negative samples",
1.0
);
$this->addOption(
'epochs',
'e',
InputOption::VALUE_OPTIONAL,
"number of epochs to train",
5000
);
$this->addOption(
'layers',
'l',
InputOption::VALUE_OPTIONAL,
"number of hidden layers",
6
);
$this->addOption(
'learn-rate',
null,
InputOption::VALUE_OPTIONAL,
"learning rate",
0.05
);
$this->addOption(
'validation-rate',
null,
InputOption::VALUE_OPTIONAL,
"relative size of the validation data set",
0.15
);
}

protected function execute(InputInterface $input, OutputInterface $output) {
$this->trainer->train(
$output,
(float) $input->getOption('shuffled'),
(float) $input->getOption('random'),
(int) $input->getOption('epochs'),
(int) $input->getOption('layers'),
(float) $input->getOption('learn-rate'),
(float) $input->getOption('validation-rate')
);
}

}
11 changes: 11 additions & 0 deletions lib/Db/LoginAddressMapper.php
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,15 @@ public function __construct(IDBConnection $db) {
parent::__construct($db, 'login_address');
}

public function findAll(int $minEntries = 1) {
$qb = $this->db->getQueryBuilder();

$query = $qb
->select('uid', 'ip')
->from($this->getTableName())
->groupBy('uid', 'ip');

return $this->findEntities($query);
}

}
118 changes: 118 additions & 0 deletions lib/Service/DataSet.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
<?php

declare(strict_types=1);

/**
* @copyright 2018 Christoph Wurst <christoph@winzerhof-wurst.at>
*
* @author 2018 Christoph Wurst <christoph@winzerhof-wurst.at>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

namespace OCA\SuspiciousLogin\Service;

use function array_key_exists;
use function array_map;
use function array_merge;
use ArrayAccess;
use Countable;
use OCA\SuspiciousLogin\Db\LoginAddress;
use function shuffle;

class DataSet implements ArrayAccess, Countable {

/** @var UidIPVector[] */
private $data;

public function __construct(array $data) {
$this->data = array_map(function (array $item) {
return new UidIPVector($item['uid'], $item['ip'], $item['label']);
}, $data);
}

/**
* @param LoginAddress[] $loginAddresses
*/
public static function fromLoginAddresses(array $loginAddresses): DataSet {
return new DataSet(array_map(function (LoginAddress $addr) {
return [
'uid' => $addr->getUid(),
'ip' => $addr->getIp(),
'label' => MLPTrainer::LABEL_POSITIVE,
];
}, $loginAddresses));
}

public function asTrainingData(): array {
return array_map(function (UidIPVector $vec) {
return $vec->asFeatureVector();
}, $this->data);
}

/**
* Whether a offset exists
*/
public function offsetExists($offset) {
return array_key_exists($offset, $this->data);
}

/**
* Offset to retrieve
*/
public function offsetGet($offset) {
return $this->data[$offset];
}

/**
* Offset to set
*/
public function offsetSet($offset, $value) {
$this->data[$offset] = $value;
}

/**
* Offset to unset
*/
public function offsetUnset($offset) {
unset($this->data[$offset]);
}

public function count() {
return count($this->data);
}

/**
* @return string[]
*/
public function getLabels(): array {
return array_map(function (UidIPVector $vec) {
return $vec->getLabel();
}, $this->data);
}

public function merge(DataSet $other): DataSet {
$merged = array_merge($this->data, $other->data);
$new = new DataSet([]);
$new->data = $merged;
return $new;
}

public function shuffle() {
shuffle($this->data);
}

}
116 changes: 116 additions & 0 deletions lib/Service/MLPTrainer.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
<?php

declare(strict_types=1);

/**
* @author Christoph Wurst <christoph@winzerhof-wurst.at>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
*/

namespace OCA\SuspiciousLogin\Service;

use function array_slice;
use OCA\SuspiciousLogin\Db\LoginAddressMapper;
use OCP\AppFramework\Utility\ITimeFactory;
use Phpml\Classification\MLPClassifier;
use Phpml\Metric\ClassificationReport;
use function shuffle;
use Symfony\Component\Console\Output\OutputInterface;

class MLPTrainer {

const LABEL_POSITIVE = 'y';
const LABEL_NEGATIVE = 'n';

/** @var LoginAddressMapper */
private $loginAddressMapper;

/** @var NegativeSampleGenerator */
private $negativeSampleGenerator;

/** @var ITimeFactory */
private $timeFactory;

public function __construct(LoginAddressMapper $loginAddressMapper,
NegativeSampleGenerator $negativeSampleGenerator,
ITimeFactory $timeFactory) {
$this->loginAddressMapper = $loginAddressMapper;
$this->negativeSampleGenerator = $negativeSampleGenerator;
$this->timeFactory = $timeFactory;
}

public function train(OutputInterface $output,
float $shuffledNegativeRate,
float $randomNegativeRate,
int $epochs,
int $layers,
float $learningRate,
float $validationRate) {
$raw = $this->loginAddressMapper->findAll();
shuffle($raw);
$all = DataSet::fromLoginAddresses($raw);
$validationOffset = (int)min(count($all), max(0, count($raw) * (1 - $validationRate)));
$positives = DataSet::fromLoginAddresses(array_slice($raw, 0, $validationOffset));
$validationPositives = DataSet::fromLoginAddresses(array_slice($raw, $validationOffset));
$numValidation = count($validationPositives);
$numPositives = count($positives);
$numRandomNegatives = max((int)floor($numPositives * $randomNegativeRate), 1.0);
$numShuffledNegative = max((int)floor($numPositives * $shuffledNegativeRate), 1.0);
$randomNegatives = $this->negativeSampleGenerator->generateRandomFromPositiveSamples($positives, $numRandomNegatives);
$shuffledNegatives = $this->negativeSampleGenerator->generateRandomFromPositiveSamples($positives, $numRandomNegatives);

// Validation negatives are generated from all data (to have all UIDs), but shuffled
$all->shuffle();
$validationNegatives = $this->negativeSampleGenerator->generateRandomFromPositiveSamples($all, $numValidation);
$validationSamples = $validationPositives->merge($validationNegatives);

$total = $numPositives + $numRandomNegatives + $numShuffledNegative;
$totalValidation = $numValidation * 2;
$output->writeln("Got $total samples for training: $numPositives positive, $numRandomNegatives random negative and $numShuffledNegative shuffled negative");
$output->writeln("Got $numValidation positive and $numValidation negative samples for validation (rate: $validationRate)");
$output->writeln("Number of epochs: " . $epochs);
$output->writeln("Number of hidden layers: " . $layers);
$output->writeln("Learning rate: " . $learningRate);
$output->writeln("Vecor dimensions: " . UidIPVector::SIZE);

$allSamples = $positives->merge($randomNegatives)->merge($shuffledNegatives);
$allSamples->shuffle();

$output->writeln("Start training");
$start = $this->timeFactory->getDateTime();
$classifier = new MLPClassifier(UidIPVector::SIZE, [$layers], ['y', 'n'], $epochs, null, $learningRate);
$classifier->train(
$allSamples->asTrainingData(),
$allSamples->getLabels()
);
$finished = $this->timeFactory->getDateTime();
$elapsed = $finished->getTimestamp() - $start->getTimestamp();
$output->writeln("Training finished after " . $elapsed . "s");

$predicted = $classifier->predict($validationSamples->asTrainingData());
$result = new ClassificationReport($validationSamples->getLabels(), $predicted);
$output->writeln("Prescision(y): " . $result->getPrecision()['y']);
$output->writeln("Prescision(n): " . $result->getPrecision()['n']);
$output->writeln("Recall(y): " . $result->getRecall()['y']);
$output->writeln("Recall(n): " . $result->getRecall()['n']);
$output->writeln("Average(precision): " . $result->getAverage()['precision']);
$output->writeln("Average(recall): " . $result->getAverage()['recall']);
$output->writeln("Average(f1score): " . $result->getAverage()['f1score']);
}

}
Loading

0 comments on commit 442aae7

Please sign in to comment.