-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #18 from zStupan/feature-cli
CLI
- Loading branch information
Showing
5 changed files
with
267 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import sys | ||
from niaarm import cli | ||
|
||
|
||
if __name__ == '__main__': | ||
sys.exit(cli.main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
import argparse | ||
from inspect import getmodule, getmembers, isfunction | ||
import os | ||
from pathlib import Path | ||
import platform | ||
import subprocess | ||
import sys | ||
import tempfile | ||
|
||
import numpy as np | ||
from niaarm import NiaARM, Dataset, Stats | ||
from niapy.task import OptimizationType, Task | ||
from niapy.util.factory import get_algorithm | ||
from niapy.util import distances, repair | ||
from niapy.algorithms.other import mts | ||
from niapy.algorithms.basic import de | ||
|
||
|
||
def get_parser(): | ||
parser = argparse.ArgumentParser(prog='niaarm', | ||
description='Perform ARM, output mined rules as csv, get mined rules\' statistics') | ||
parser.add_argument('-i', '--input-file', type=str, required=True, help='Input file containing a csv dataset') | ||
parser.add_argument('-o', '--output-file', type=str, help='Output file for mined rules') | ||
parser.add_argument('-a', '--algorithm', type=str, required=True, | ||
help='Algorithm to use (niapy class name, e.g. DifferentialEvolution)') | ||
parser.add_argument('-s', '--seed', type=int, help='Seed for the algorithm\'s random number generator') | ||
parser.add_argument('--max-evals', type=int, default=np.inf, help='Maximum number of fitness function evaluations') | ||
parser.add_argument('--max-iters', type=int, default=np.inf, help='Maximum number of iterations') | ||
parser.add_argument('--alpha', type=float, default=0.0, help='Alpha parameter. Default 0') | ||
parser.add_argument('--beta', type=float, default=0.0, help='Beta parameter. Default 0') | ||
parser.add_argument('--gamma', type=float, default=0.0, help='Gamma parameter. Default 0') | ||
parser.add_argument('--delta', type=float, default=0.0, help='Delta parameter. Default 0') | ||
parser.add_argument('--log', action='store_true', help='Enable logging of fitness improvements') | ||
parser.add_argument('--show-stats', action='store_true', help='Display stats about mined rules') | ||
|
||
return parser | ||
|
||
|
||
def text_editor(): | ||
return os.getenv('VISUAL') or os.getenv('EDITOR') or ('notepad' if platform.system() == 'Windows' else 'vi') | ||
|
||
|
||
def parameters_string(parameters): | ||
params_txt = '# You can edit the algorithm\'s parameter values here\n' \ | ||
'# Save and exit to continue\n' \ | ||
'# WARNING: Do not edit parameter names\n' | ||
for parameter, value in parameters.items(): | ||
if isinstance(value, tuple): | ||
if callable(value[0]): | ||
value = tuple(v.__name__ for v in value) | ||
else: | ||
value = tuple(str(v) for v in value) | ||
value = ', '.join(value) | ||
params_txt += f'{parameter} = {value.__name__ if callable(value) else value}\n' | ||
return params_txt | ||
|
||
|
||
def functions(algorithm): | ||
funcs = {} | ||
algorithm_funcs = dict(getmembers(getmodule(algorithm.__class__), isfunction)) | ||
repair_funcs = dict(getmembers(repair, isfunction)) | ||
distance_funcs = dict(getmembers(distances, isfunction)) | ||
de_funcs = dict(getmembers(de, isfunction)) | ||
mts_funcs = dict(getmembers(mts, isfunction)) | ||
funcs.update(algorithm_funcs) | ||
funcs.update(repair_funcs) | ||
funcs.update(distance_funcs) | ||
funcs.update(de_funcs) | ||
funcs.update(mts_funcs) | ||
return funcs | ||
|
||
|
||
def find_function(name, algorithm): | ||
return functions(algorithm)[name] | ||
|
||
|
||
def convert_string(string): | ||
try: | ||
value = float(string) | ||
if value.is_integer(): | ||
value = int(value) | ||
except ValueError: | ||
return string | ||
return value | ||
|
||
|
||
def parse_parameters(text, algorithm): | ||
lines: list[str] = text.strip().split('\n') | ||
lines = [line.strip() for line in lines if line.strip() and not line.strip().startswith('#')] | ||
parameters = {} | ||
for line in lines: | ||
key, value = line.split('=') | ||
key = key.strip() | ||
value = convert_string(value.strip()) | ||
if isinstance(value, str): | ||
if len(value.split(', ')) > 1: # tuple | ||
value = list(map(str.strip, value.split(', '))) | ||
value = tuple(map(convert_string, value)) | ||
value = tuple(find_function(v, algorithm) for v in value if type(v) == str) | ||
elif value.lower() == 'true' or value.lower() == 'false': # boolean | ||
value = value.lower() == 'true' | ||
else: # probably a function | ||
value = find_function(value, algorithm) | ||
parameters[key] = value | ||
return parameters | ||
|
||
|
||
def edit_parameters(parameters, algorithm): | ||
parameters.pop('individual_type', None) | ||
parameters.pop('initialization_function', None) | ||
fd, filename = tempfile.mkstemp() | ||
os.close(fd) | ||
|
||
new_parameters = None | ||
try: | ||
path = Path(filename) | ||
path.write_text(parameters_string(parameters)) | ||
command = f'{text_editor()} {filename}' | ||
subprocess.run(command, shell=True, check=True) | ||
params_txt = path.read_text() | ||
new_parameters = parse_parameters(params_txt, algorithm) | ||
finally: | ||
try: | ||
os.unlink(filename) | ||
except Exception as e: | ||
print('Error:', e, file=sys.stderr) | ||
return new_parameters | ||
|
||
|
||
def main(): | ||
parser = get_parser() | ||
args = parser.parse_args() | ||
|
||
if len(sys.argv) == 1: | ||
parser.print_help() | ||
if args.max_evals == np.inf and args.max_iters == np.inf: | ||
print('--max-evals and/or --max-iters missing', file=sys.stderr) | ||
return 1 | ||
|
||
try: | ||
dataset = Dataset(args.input_file) | ||
problem = NiaARM(dataset.dimension, dataset.features, dataset.transactions, args.alpha, args.beta, args.gamma, | ||
args.delta, args.log) | ||
task = Task(problem, max_iters=args.max_iters, max_evals=args.max_evals, | ||
optimization_type=OptimizationType.MAXIMIZATION) | ||
|
||
algorithm = get_algorithm(args.algorithm, seed=args.seed) | ||
params = algorithm.get_parameters() | ||
new_params = edit_parameters(params, algorithm.__class__) | ||
if new_params is None: | ||
print('Invalid parameters', file=sys.stderr) | ||
return 1 | ||
|
||
for param in new_params: | ||
if param not in params: | ||
print(f'Invalid parameter: {param}', file=sys.stderr) | ||
return 1 | ||
|
||
algorithm.set_parameters(**new_params) | ||
|
||
algorithm.run(task) | ||
|
||
if args.output_file: | ||
problem.sort_rules() | ||
problem.export_rules(args.output_file) | ||
|
||
if args.show_stats: | ||
stats = Stats(problem.rules) | ||
print('\nSTATS:') | ||
print(f'Total rules: {stats.total_rules}') | ||
print(f'Average fitness: {stats.avg_fitness}') | ||
print(f'Average support: {stats.avg_support}') | ||
print(f'Average confidence: {stats.avg_confidence}') | ||
print(f'Average coverage: {stats.avg_coverage}') | ||
print(f'Average shrinkage: {stats.avg_shrinkage}') | ||
print(f'Average length of antecedent: {stats.avg_ant_len}') | ||
print(f'Average length of consequent: {stats.avg_con_len}') | ||
|
||
except Exception as e: | ||
print('Error:', e, file=sys.stderr) | ||
return 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters