forked from tms2744/SSI-Simulator
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathprocess.py
239 lines (188 loc) · 8.33 KB
/
process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
from scapy.all import *
import os
import pickle
import numpy as np
import json
import tqdm
# Server port used by tunnel apps
port_map = {
"ssh": 22,
"socat": 80,
}
# Transportation protocol used by tunnel apps
trans_map = {
"ssh": "TCP",
"socat": "TCP",
}
def get_streams(filepath, protocols):
# Read in PCAP file
pcap = rdpcap(filepath)
# Initialize dictionary to store SSH streams
ssh_streams = {}
# Loop through packets in PCAP file
for pkt in pcap:
for proto in protocols:
trans_proto = trans_map[proto]
port_num = port_map[proto]
if pkt.haslayer(trans_proto) \
and (pkt[trans_proto].dport == port_num or pkt[trans_proto].sport == port_num):
# avoid non-data packets
if trans_proto == "TCP" and not (pkt[trans_proto].flags & 0x08):
continue
# Get source and destination IP addresses
src = pkt['IP'].src
dst = pkt['IP'].dst
# Check if this is a new SSH stream
if (src, dst, proto) not in ssh_streams and (dst, src, proto) not in ssh_streams:
ssh_streams[(src, dst, proto)] = []
# Add packet to appropriate SSH stream list
if (src, dst, proto) in ssh_streams:
ssh_streams[(src, dst, proto)].append(pkt)
else:
ssh_streams[(dst, src, proto)].append(pkt)
# Print number of streams found
#print("Number of streams:", len(ssh_streams))
# Print number of packets in each stream
#for stream in ssh_streams:
# print("Number of packets in stream:", len(ssh_streams[stream]))
return ssh_streams
def get_args():
from argparse import ArgumentParser
parser = ArgumentParser(prog="SSIDDataProcessor",
description="Process raw dataset generated by the SSI data collector into pickle file.")
parser.add_argument('--root',
help="Path to root \'results\' directory. May use flag multiple times to parse multiple roots.",
required=True, type=str, action='append')
parser.add_argument('--out',
help="Path to file in which to store processed data.",
default="processed.pkl", type=str)
parser.add_argument('--min_pkts',
help="Filter out streams with packets lower than this threshold..",
default=10, type=int)
return parser.parse_args()
if __name__ == "__main__":
data = {} # metadata for SSH streams
IP_data = {} # extra IP information for each stream
proto_data = {}
args = get_args()
pckt_cutoff = args.min_pkts # skip streams with less than this number of packets
cur_sample_idx = 0
# loop through sample directories
for rt_dir in args.root:
dirs = os.listdir(rt_dir)
for dirname in tqdm.tqdm(dirs):
if not dirname.isnumeric():
continue
#sample_ID = dirname
sample_ID = cur_sample_idx
cur_sample_idx += 1
# key paths
dirpath = os.path.join(rt_dir, dirname)
infopath = os.path.join(dirpath, "tunnel.json")
pcaproot = os.path.join(dirpath, "tcpdump")
if not os.path.exists(infopath):
continue
with open(infopath, 'r') as fi:
tunnel_info = json.load(fi)
fnames = os.listdir(pcaproot)
data[sample_ID] = {}
IP_data[sample_ID] = {}
proto_data[sample_ID] = {}
for fname in fnames:
# get the host number for the pcap sample
host_ID = fname.replace('dev','').replace('.pcap', '')
if host_ID.isnumeric():
host_ID = int(host_ID)
else:
continue
protocols = []
if f'dev{host_ID}' in tunnel_info:
protocols.append(tunnel_info[f'dev{host_ID}'])
if f'dev{host_ID+1}' in tunnel_info:
protocols.append(tunnel_info[f'dev{host_ID+1}'])
# load and split the pcap into distinct ssh streams (defined by IP src/dst tuples)
pth = os.path.join(pcaproot, fname)
try:
streams = get_streams(pth, protocols)
if len(streams) < 1:
continue
except Exception as e:
print(e)
continue
data_t = []
IP_t = {}
proto_t = []
# process the stream scapy packets into metadata per stream
for src_ip,dst_ip,proto in streams:
stream = streams[(src_ip, dst_ip,proto)]
metadata = []
#init_time = float(stream[0].time)
for pkt in stream:
cur_time = float(pkt.time)# - init_time
pkt_dir = 1. if pkt['IP'].src == src_ip else -1.
pkt_size = len(pkt)
metadata += [(cur_time, pkt_size, pkt_dir)]
if len(metadata) < pckt_cutoff:
continue
metadata = np.array(metadata).T
#metadata[0,:] -= metadata[0,0] # adjust timestamp sequence to begin at zero
data_t.append(metadata)
proto_t.append(proto)
# store IP information in case it's needed
IP_t['src'] = src_ip
IP_t['dst'] = dst_ip
data[sample_ID][host_ID] = data_t
IP_data[sample_ID][host_ID] = IP_t
proto_data[sample_ID][host_ID] = proto_t
# filter out chain samples with odd stream counts per host
# (first & last hosts should have one stream, stepping stones should have two)
print("* Filter bad samples from processed data...")
for sample_ID in list(data.keys()):
if len(data[sample_ID]) <= 2: # no enough valid pcaps in sample
print(f"\t[{sample_ID}] Too few host captures.")
del data[sample_ID]
del IP_data[sample_ID]
del proto_data[sample_ID]
continue
host_IDs = set(data[sample_ID].keys())
# if first and last hosts do not have exactly one stream, something is odd with the sample
if (len(data[sample_ID][min(host_IDs)]) != 1) or (len(data[sample_ID][max(host_IDs)]) != 1):
print(f"\t[{sample_ID}] First/last hosts do not have exactly one stream.")
del data[sample_ID]
del IP_data[sample_ID]
del proto_data[sample_ID]
continue
host_IDs.remove(min(host_IDs))
host_IDs.remove(max(host_IDs))
# check if stepping stones all correctly have two streams
bad = False
for host_ID in host_IDs:
if len(data[sample_ID][host_ID]) != 2:
print(f"\t[{sample_ID}] Host {host_ID} does not have exactly two streams.")
del data[sample_ID]
del IP_data[sample_ID]
del proto_data[sample_ID]
bad = True
break
if bad:
continue
# align all streams in chain sample to the same relative start time
stream_min_time = 1e100
host_IDs = set(data[sample_ID].keys())
for host_ID in host_IDs:
streams = data[sample_ID][host_ID]
for stream in streams:
stream_min_time = min(stream[0][0], stream_min_time)
for host_ID in host_IDs:
streams = data[sample_ID][host_ID]
for stream in streams:
stream[0][:] -= stream_min_time
data[sample_ID][host_ID] = streams
print(f"Total sample count after filtering: {len(data)}")
# store dataset to file
with open(args.out, 'wb') as fi:
pickle.dump({
'data': data,
'IPs': IP_data,
'proto': proto_data
}, fi)