Skip to content

Commit

Permalink
Merge pull request #57 from flipkart-incubator/dev
Browse files Browse the repository at this point in the history
CLI improvement
  • Loading branch information
prajal authored Jul 27, 2018
2 parents 317b705 + 07021c1 commit 2597894
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 24 deletions.
Empty file added API/__init__.py
Empty file.
35 changes: 33 additions & 2 deletions API/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,41 @@
import hashlib
import time
import json
import threading
import logging

sys.path.append('../')

from flask import Flask, render_template
from flask import Response, make_response
from flask import request
from flask import Flask
from astra import scan_single_api
#from astra import scan_single_api
from flask import jsonify
from pymongo import MongoClient
from pymongo.errors import ServerSelectionTimeoutError
from utils.vulnerabilities import alerts
from jinja2 import utils


if os.getcwd().split('/')[-1] == 'API':
from astra import scan_single_api


app = Flask(__name__, template_folder='../Dashboard/templates', static_folder='../Dashboard/static')


class ServerThread(threading.Thread):

def __init__(self):
threading.Thread.__init__(self)

def run(self):
log = logging.getLogger('werkzeug')
log.setLevel(logging.ERROR)
app.run(host='0.0.0.0', port= 8094)


# Mongo DB connection
maxSevSelDelay = 1
try:
Expand Down Expand Up @@ -182,5 +201,17 @@ def return_alerts(scanid):
def view_dashboard(page):
return render_template('{}'.format(page))

app.run(host='0.0.0.0', port= 8094,debug=False)
def start_server():
app.run(host='0.0.0.0', port= 8094)

#if __name__ == "__main__":

def main():
if os.getcwd().split('/')[-1] == 'API':
start_server()
else:
thread = ServerThread()
thread.daemon = True
thread.start()

main()
53 changes: 32 additions & 21 deletions astra.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import utils.logger as logger
import utils.logs as logs
import urlparse

import hashlib
import webbrowser

from core.zapscan import *
from core.parsers import *
Expand All @@ -27,6 +28,11 @@
from multiprocessing import Process
from utils.db import Database_update


if os.getcwd().split('/')[-1] != 'API':
from API.api import main


dbupdate = Database_update()

def parse_collection(collection_name,collection_type):
Expand All @@ -36,6 +42,17 @@ def parse_collection(collection_name,collection_type):
print "[-]Failed to Parse collection"
sys.exit(1)

def scan_complete():
print "[+]Scan has been completed"
webbrowser.open("http://127.0.0.1:8094/reports.html#"+scanid)
while True:
pass

def generate_scanid():
global scanid
scanid = hashlib.md5(str(time.time())).hexdigest()
return scanid

def add_headers(headers):
# This function deals with adding custom header and auth value .
auth_type = get_value('config.property','login','auth_type')
Expand Down Expand Up @@ -63,15 +80,6 @@ def add_headers(headers):

return headers

def generate_report():
# Generating report once the scan is complete.
result = api_scan.generate_report()
if result is True:
print "%s[+]Report is generated successfully%s"% (api_logger.G, api_logger.W)
else:
print "%s[-]Failed to generate a report%s"% (api_logger.R, api_logger.W)


def read_scan_policy():
try:
scan_policy = get_value('scan.property','scan-policy','attack')
Expand Down Expand Up @@ -112,7 +120,7 @@ def modules_scan(url,method,headers,body,scanid=None):
status = zap_start()
if status is True:
api_scan.start_scan(url,method,headers,body,scanid)

# Custom modules scan
if attack['cors'] == 'Y' or attack['cors'] == 'y':
cors_main(url,method,headers,body,scanid)
Expand All @@ -139,7 +147,6 @@ def modules_scan(url,method,headers,body,scanid=None):
open_redirect_check(url,method,headers,body,scanid)
update_scan_status(scanid, "open-redirection")


def validate_data(url,method):
''' Validate HTTP request data and return boolean value'''
validate_url = urlparse.urlparse(url)
Expand All @@ -155,7 +162,7 @@ def scan_single_api(url, method, headers, body, api, scanid=None):
''' This function deals with scanning a single API. '''
if headers is None or headers == '':
headers = {'Content-Type' : 'application/json'}

try:
# Convert header and body in dict format
if type(headers) is not dict:
Expand All @@ -175,20 +182,23 @@ def scan_single_api(url, method, headers, body, api, scanid=None):
print "[-]Invalid Arguments"
return False

p = Process(target=modules_scan,args=(url,method,headers,body,scanid),name='module-scan')
p.start()
if api == "Y":
return True
p = Process(target=modules_scan,args=(url,method,headers,body,scanid),name='module-scan')
p.start()
if api == "Y":
return True
else:
modules_scan(url,method,headers,body,scanid)


def scan_core(collection_type,collection_name,url,headers,method,body,loginurl,loginheaders,logindata,login_require):
''' Scan API through different engines '''
scanid = ''
scanid = generate_scanid()
if collection_type and collection_name is not None:
parse_collection(collection_name,collection_type)
if login_require is True:
api_login.verify_login(parse_data.api_lst)
msg = True

for data in parse_data.api_lst:
try:
url = data['url']['raw']
Expand All @@ -197,14 +207,13 @@ def scan_core(collection_type,collection_name,url,headers,method,body,loginurl,l
headers,method,body = data['headers'],data['method'],''
if headers:
try:
headhers = add_headers(headers)
headers = add_headers(headers)
except:
pass

if data['body'] != '':
body = json.loads(base64.b64decode(data['body']))


modules_scan(url,method,headers,body,scanid)

else:
Expand Down Expand Up @@ -299,8 +308,10 @@ def main():
if collection_type and collection_name is not None:
scan_core(collection_type,collection_name,url,headers,method,body,loginurl,loginheaders,logindata,login_require)
else:
scan_single_api(url, method, headers, body, "False")
scanid = generate_scanid()
scan_single_api(url, method, headers, body, "F", scanid)

scan_complete()

if __name__ == '__main__':
api_login = APILogin()
Expand Down
8 changes: 7 additions & 1 deletion utils/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,10 @@
else:
path = 'logs/scan.log'

logging.basicConfig(filename=path, level=logging.INFO)

logger = logging.getLogger()
fh = logging.FileHandler(path)
logger.addHandler(fh)
logger.setLevel(logging.INFO)

#logging.basicConfig(filename=path, level=logging.INFO)

0 comments on commit 2597894

Please sign in to comment.