generated from databricks-industry-solutions/industry-solutions-blueprints
-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathokta_collector.py
144 lines (118 loc) · 4.71 KB
/
okta_collector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Databricks notebook source
# This notebook is designed to be run as a task within a multi-task job workflow.
# These time window input widgets enable the user to do back fill and re-processing within the multi-task job workflow
#dbutils.widgets.removeAll()
dbutils.widgets.text("okta_start_time", "", "start time (YYYY-mm-ddTHH:MM:SSZ): ")
start_time = dbutils.widgets.get("okta_start_time")
dbutils.widgets.text("okta_end_time", "", "end time (YYYY-mm-ddTHH:MM:SSZ): ")
end_time = dbutils.widgets.get("okta_end_time")
print(start_time + " to " + end_time)
# COMMAND ----------
import json
#Here we use Okta API token, in production we recommend storing this token in Databricks secret store
#https://docs.databricks.com/security/secrets/index.html
cfg = {
"base_url": "https://dev-74006068.okta.com/api/v1/logs",
"token": "CHANGEME",
"start_time": start_time,
"end_time": end_time,
"batch_size": 1000,
"target_db": "{{target_db_name}}",
"target_table": "okta_bronze",
"storage_path": "{{storage_path}}"
}
# we need to figure out where/when we execute this DDL. Ideally we don't want it in the collector task. This is needed to enable querying of the bronze table to figure out what is the latest event timestamp
sql_str = f"""
CREATE TABLE IF NOT EXISTS {cfg['target_db']}.{cfg['target_table']} (
ingest_ts TIMESTAMP,
event_ts TIMESTAMP,
event_date TIMESTAMP,
rid STRING,
raw STRING) USING DELTA PARTITIONED BY (event_date) LOCATION '{cfg['storage_path']}'
"""
print(sql_str)
spark.sql(sql_str)
# if task parameters (ie widgets) are empty, then we default to using the latest timestamp from the bronze table
if len(cfg["start_time"])==0 and len(cfg["end_time"])==0:
sql_str = f"""
select max(event_ts) as latest_event_ts
from {cfg['target_db']}.{cfg['target_table']}"""
df = spark.sql(sql_str)
latest_ts = df.first()["latest_event_ts"]
if latest_ts is None:
print("latest_ts is none - default to 7 days from now")
default_ts = datetime.today() - timedelta(days=7)
cfg["start_time"]=default_ts.strftime("%Y-%m-%dT%H:%M:%SZ")
else:
print("latest_ts from bronze table is " + latest_ts.isoformat())
cfg["start_time"]=latest_ts.strftime("%Y-%m-%dT%H:%M:%SZ")
print(json.dumps(cfg, indent=2))
# COMMAND ----------
import requests
import json
import re
import datetime
from pyspark.sql import Row
import pyspark.sql.functions as f
def poll_okta_logs(cfg, debug=False):
MINIMUM_COUNT=5 # Must be >= 2, see note below
headers = {'Authorization': 'SSWS ' + cfg["token"]}
query_params = {
"limit": str(cfg["batch_size"]),
"sortOrder": "ASCENDING",
"since": cfg["start_time"]
}
if cfg["end_time"]:
query_params["until"] = cfg["end_time"]
url = cfg["base_url"]
total_cnt = 0
while True:
# Request the next link in our sequence:
r = requests.get(url, headers=headers, params=query_params)
if not r.status_code == requests.codes.ok:
break
ingest_ts = datetime.datetime.now(datetime.timezone.utc)
# Break apart the records into individual rows
jsons = []
jsons.extend([json.dumps(x) for x in r.json()])
# Make sure we have something to add to the table
if len(jsons) == 0: break
# Load into a dataframe
df = (
sc.parallelize([Row(raw=x) for x in jsons]).toDF()
.selectExpr(f"'{ingest_ts.isoformat()}'::timestamp AS ingest_ts",
"date_trunc('DAY', raw:published::timestamp) AS event_date",
"raw:published::timestamp AS event_ts",
"uuid() AS rid",
"raw AS raw")
)
#print("%d %s" % (df.count(),url))
total_cnt += len(jsons)
if debug:
display(df)
else:
# Append to delta table
df.write\
.option("mergeSchema", "true")\
.format('delta') \
.mode('append') \
.partitionBy("event_date") \
.save(cfg["storage_path"])
#When we make an API call, we cause an event. So there is the potential to get
#into a self-perpetuating loop. Thus we look to ensure there is a certain minimum number
#of entries before we are willing loop again.
if len(jsons) < MINIMUM_COUNT: break
#print(r.headers["Link"])
# Look for the 'next' link; note there is also a 'self' link, so we need to get the right one
rgx = re.search(r"\<([^>]+)\>\; rel=\"next\"", str(r.headers['Link']), re.I)
if rgx:
# We got a next link match; set that as new URL and repeat
url = rgx.group(1)
continue
else:
# No next link, we are done
break
return total_cnt
cnt = poll_okta_logs(cfg)
print(f"Total records polled = {cnt}")
# COMMAND ----------