diff --git a/suzieq/restServer/query.py b/suzieq/restServer/query.py index b1eaf2692d..bae2743daf 100755 --- a/suzieq/restServer/query.py +++ b/suzieq/restServer/query.py @@ -116,7 +116,7 @@ def get_log_config_level(cfg): return log_config, loglevel -def rest_main(config_file: str, no_https: bool) -> None: +def rest_main(*args) -> None: """The main function for the REST server Args: @@ -124,7 +124,24 @@ def rest_main(config_file: str, no_https: bool) -> None: no_https (bool): If true, disable https """ - config_file = sq_get_config_file(config_file) + if not args: + args = sys.argv + + parser = argparse.ArgumentParser(args) + parser.add_argument( + "-c", + "--config", + type=str, help="alternate config file", + default=None + ) + parser.add_argument( + "--no-https", + help="Turn off HTTPS", + default=False, action='store_true', + ) + userargs = parser.parse_args() + + config_file = sq_get_config_file(userargs.config) app = app_init(config_file) cfg = load_sq_config(config_file=config_file) try: @@ -135,7 +152,7 @@ def rest_main(config_file: str, no_https: bool) -> None: logcfg, loglevel = get_log_config_level(cfg) - no_https = cfg.get('rest', {}).get('no-https', False) or no_https + no_https = cfg.get('rest', {}).get('no-https', False) or userargs.no_https srvr_addr = cfg.get('rest', {}).get('address', '127.0.0.1') srvr_port = cfg.get('rest', {}).get('port', 8000) diff --git a/suzieq/restServer/sq_rest_server.py b/suzieq/restServer/sq_rest_server.py index 6ebdf8a3b4..28142a3499 100755 --- a/suzieq/restServer/sq_rest_server.py +++ b/suzieq/restServer/sq_rest_server.py @@ -1,23 +1,8 @@ #!/usr/bin/env python3 import sys -import argparse - from suzieq.restServer.query import rest_main if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "-c", - "--config", - type=str, help="alternate config file", - default=None - ) - parser.add_argument( - "--no-https", - help="Turn off HTTPS", - default=False, action='store_true', - ) - userargs = parser.parse_args() - rest_main(userargs.config, userargs.no_https) + rest_main(sys.argv) diff --git a/tests/integration/test_rest.py b/tests/integration/test_rest.py index afd5ac938b..1cc51c1ea1 100644 --- a/tests/integration/test_rest.py +++ b/tests/integration/test_rest.py @@ -292,29 +292,27 @@ def app_initialize(): # For some reason, putting the no_https in a for loop didn't work either @pytest.mark.rest def test_rest_server(): - from multiprocessing import Process + import subprocess from time import sleep import requests cfgfile = create_dummy_config_file( datadir='./tests/data/multidc/parquet-out') - server = Process(target=rest_main, args=(cfgfile, True)) - server.start() - assert (server.is_alive()) - sleep(1) - assert (server.is_alive()) + server = subprocess.Popen( + f'./suzieq/restServer/sq_rest_server.py -c {cfgfile} --no-https'.split()) + sleep(5) + assert(server.pid) assert(requests.get('http://localhost:8000/api/docs')) - server.terminate() - server.join() - - server = Process(target=rest_main, args=(cfgfile, False)) - server.start() - assert (server.is_alive()) - sleep(1) - assert (server.is_alive()) - assert (requests.get("https://localhost:8000/api/docs", verify=False)) - server.terminate() - server.join() + server.kill() + sleep(5) + + server = subprocess.Popen( + f'./suzieq/restServer/sq_rest_server.py -c {cfgfile} '.split()) + sleep(5) + assert(server.pid) + assert(requests.get('https://localhost:8000/api/docs', verify=False)) + server.kill() + sleep(5) os.remove(cfgfile)