Skip to content

Commit

Permalink
Add support for supertables and tags (#521)
Browse files Browse the repository at this point in the history
* Add support for supertables and tags

[ML-6367](https://iguazio.atlassian.net/browse/ML-6367)

* Remove redundant if and accidental print
  • Loading branch information
gtopper authored May 15, 2024
1 parent 9a1b9f8 commit 265575a
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 60 deletions.
89 changes: 56 additions & 33 deletions integration/test_tdengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
@pytest.fixture()
def tdengine():
db_name = "storey"
table_name = "test"
supertable_name = "test_supertable"

# Setup
if url.startswith("taosws"):
Expand All @@ -37,68 +37,88 @@ def tdengine():
)

try:
connection.execute(f"CREATE DATABASE {db_name};")
connection.execute(f"DROP DATABASE {db_name};")
except (ConnectError, QueryError) as err: # websocket connection raises QueryError
if "Database already exists" not in str(err):
if "Database not exist" not in str(err):
raise err

connection.execute(f"CREATE DATABASE {db_name};")

if not db_prefix:
connection.execute(f"USE {db_name}")

try:
connection.execute(f"DROP TABLE {db_prefix}{table_name};")
connection.execute(f"DROP STABLE {db_prefix}{supertable_name};")
except (ConnectError, QueryError) as err: # websocket connection raises QueryError
if "Table does not exist" not in str(err):
if "STable not exist" not in str(err):
raise err

connection.execute(f"CREATE TABLE {db_prefix}{table_name} (time TIMESTAMP, my_int INT, my_string NCHAR(10));")
connection.execute(
f"CREATE STABLE {db_prefix}{supertable_name} (time TIMESTAMP, my_string NCHAR(10)) TAGS (my_int INT);"
)

# Test runs
yield connection, url, user, password, db_name, table_name, db_prefix
yield connection, url, user, password, db_name, supertable_name, db_prefix

# Teardown
connection.execute(f"DROP TABLE {db_prefix}{table_name};")
connection.execute(f"DROP DATABASE {db_name};")
connection.close()


@pytest.mark.parametrize("dynamic_table", [None, "$key", "table"])
@pytest.mark.parametrize("table_col", [None, "$key", "table"])
@pytest.mark.skipif(not has_tdengine_credentials, reason="Missing TDEngine URL, user, and/or password")
def test_tdengine_target(tdengine, dynamic_table):
connection, url, user, password, db_name, table_name, db_prefix = tdengine
def test_tdengine_target(tdengine, table_col):
connection, url, user, password, db_name, supertable_name, db_prefix = tdengine
time_format = "%d/%m/%y %H:%M:%S UTC%z"

table_name = "test_table"

# Table is created automatically only when using a supertable
if not table_col:
connection.execute(f"CREATE TABLE {db_prefix}{table_name} (time TIMESTAMP, my_string NCHAR(10), my_int INT);")

controller = build_flow(
[
SyncEmitSource(),
TDEngineTarget(
url=url,
time_col="time",
columns=["my_string"] if table_col else ["my_string", "my_int"],
user=user,
password=password,
database=db_name,
table=None if dynamic_table else table_name,
dynamic_table=dynamic_table,
time_col="time",
columns=["my_int", "my_string"],
table=None if table_col else table_name,
table_col=table_col,
supertable=supertable_name if table_col else None,
tag_cols=["my_int"] if table_col else None,
time_format=time_format,
max_events=2,
max_events=10,
),
]
).run()

date_time_str = "18/09/19 01:55:1"
for i in range(9):
for i in range(5):
timestamp = f"{date_time_str}{i} UTC-0000"
event_body = {"time": timestamp, "my_int": i, "my_string": f"hello{i}"}
event_key = None
if dynamic_table == "$key":
event_key = table_name
elif dynamic_table:
event_body[dynamic_table] = table_name
subtable_name = f"{table_name}{i}"
if table_col == "$key":
event_key = subtable_name
elif table_col:
event_body[table_col] = subtable_name
controller.emit(event_body, event_key)

controller.terminate()
controller.await_termination()

result = connection.query(f"SELECT * FROM {db_prefix}{table_name};")
if table_col:
query_table = supertable_name
where_clause = " WHERE my_int > 0 AND my_int < 3"
else:
query_table = table_name
where_clause = ""
result = connection.query(f"SELECT * FROM {db_prefix}{query_table} {where_clause} ORDER BY my_int;")
result_list = []
for row in result:
row = list(row)
Expand All @@ -116,14 +136,17 @@ def test_tdengine_target(tdengine, dynamic_table):
t = t.astimezone(pytz.UTC).replace(tzinfo=None)
row[field_index] = t
result_list.append(row)
assert result_list == [
[datetime(2019, 9, 18, 1, 55, 10), 0, "hello0"],
[datetime(2019, 9, 18, 1, 55, 11), 1, "hello1"],
[datetime(2019, 9, 18, 1, 55, 12), 2, "hello2"],
[datetime(2019, 9, 18, 1, 55, 13), 3, "hello3"],
[datetime(2019, 9, 18, 1, 55, 14), 4, "hello4"],
[datetime(2019, 9, 18, 1, 55, 15), 5, "hello5"],
[datetime(2019, 9, 18, 1, 55, 16), 6, "hello6"],
[datetime(2019, 9, 18, 1, 55, 17), 7, "hello7"],
[datetime(2019, 9, 18, 1, 55, 18), 8, "hello8"],
]
if table_col:
expected_result = [
[datetime(2019, 9, 18, 1, 55, 11), "hello1", 1],
[datetime(2019, 9, 18, 1, 55, 12), "hello2", 2],
]
else:
expected_result = [
[datetime(2019, 9, 18, 1, 55, 10), "hello0", 0],
[datetime(2019, 9, 18, 1, 55, 11), "hello1", 1],
[datetime(2019, 9, 18, 1, 55, 12), "hello2", 2],
[datetime(2019, 9, 18, 1, 55, 13), "hello3", 3],
[datetime(2019, 9, 18, 1, 55, 14), "hello4", 4],
]
assert result_list == expected_result
95 changes: 68 additions & 27 deletions storey/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,9 @@ class TDEngineTarget(_Batching, _Writer):
"""Writes incoming events to a TDEngine table.
:param url: TDEngine Websocket or REST API URL.
:param time_col: Name of the time column.
:param columns: List of column names to be passed to the DataFrame constructor. Use = notation for renaming fields
(e.g. write_this=event_field). Use $ notation to refer to metadata ($key, event_time=$time).
:param user: Username with which to connect. This is ignored when url is a Websocket URL, which should already
contain the username.
:param password: Password with which to connect. This is ignored when url is a Websocket URL, which should already
Expand All @@ -786,12 +789,13 @@ class TDEngineTarget(_Batching, _Writer):
:param table: Name of the table in the database where events will be written. To set the table dynamically on a
per-event basis, use the $ prefix to indicate the field that should be used for the table name, or $$ prefix to
indicate the event attribute (e.g. key or path) that should be used.
:param dynamic_table: Alternative to the table parameter (exactly one of these must be set). The name of the field
:param table_col: Alternative to the table parameter (exactly one of these must be set). The name of the field
in the event body to use for the table, or the name of the event attribute preceded by a dollar sign (e.g.
$key or $path).
:param time_col: Name of the time column.
:param columns: List of column names to be passed to the DataFrame constructor. Use = notation for renaming fields
(e.g. write_this=event_field). Use $ notation to refer to metadata ($key, event_time=$time).
:param supertable: The supertable associated with the writes. Must be specified together with tag_cols or not at
all.
:param tag_cols: List of column names to be used as tags. Must be specified together with supertable or not at
all.
:param timeout: REST API timeout in seconds.
:param time_format: If time_col is a string column, and its format is not compatible with ISO-8601, use this
parameter to determine the expected format.
Expand All @@ -806,13 +810,15 @@ class TDEngineTarget(_Batching, _Writer):
def __init__(
self,
url: str,
user: Optional[str],
password: Optional[str],
database: Optional[str],
table: Optional[str],
dynamic_table: Optional[str],
time_col: str,
columns: List[str],
user: Optional[str] = None,
password: Optional[str] = None,
database: Optional[str] = None,
table: Optional[str] = None,
table_col: Optional[str] = None,
supertable: Optional[str] = None,
tag_cols: Union[str, List[str], None] = None,
timeout: Optional[int] = None,
time_format: Optional[str] = None,
**kwargs,
Expand All @@ -821,42 +827,60 @@ def __init__(
if parsed_url.scheme not in ("taosws", "http", "https"):
raise ValueError("URL must start with taosws://, http://, or https://")

if table and dynamic_table:
raise ValueError("Cannot set both table and dynamic_table")
if table and table_col:
raise ValueError("Cannot set both table and table_col")

if not table and not table_col:
raise ValueError("table or table_col must be set")

if not table and not dynamic_table:
raise ValueError("table or dynamic_table must be set")
if supertable and not tag_cols:
raise ValueError("supertable must be used in conjunction with tag_cols")

if tag_cols and not supertable:
raise ValueError("tag_cols must be used in conjunction with supertable")

kwargs["url"] = url
kwargs["user"] = user
kwargs["password"] = password
kwargs["database"] = database
kwargs["table"] = table
kwargs["time_col"] = time_col
kwargs["columns"] = columns
if user:
kwargs["user"] = user
if password:
kwargs["password"] = password
if database:
kwargs["database"] = database
if table:
kwargs["table"] = table
if table_col:
kwargs["table_col"] = table_col
if supertable:
kwargs["supertable"] = supertable
if tag_cols:
kwargs["tag_cols"] = tag_cols
if timeout:
kwargs["timeout"] = timeout
if time_format:
kwargs["time_format"] = time_format

self._table = table
self._supertable = supertable

if dynamic_table:
kwargs["key_field"] = dynamic_table
if table_col:
kwargs["key_field"] = table_col
if kwargs.get("drop_key_field") is None:
kwargs["drop_key_field"] = True

_Batching.__init__(self, **kwargs)
self._time_col = time_col
tag_cols = tag_cols or []
self._number_of_tags = len(tag_cols)
_Writer.__init__(
self,
[time_col] + columns,
tag_cols + [time_col] + columns,
infer_columns_from_data=False,
retain_dict=True,
time_field=time_col,
time_format=time_format,
)

self._url = url
self._user = user
self._password = password
Expand Down Expand Up @@ -892,6 +916,14 @@ def _init(self):
def _event_to_batch_entry(self, event):
return self._event_to_writer_entry(event)

@staticmethod
def _sanitize_value(value):
if isinstance(value, datetime.datetime):
value = round(value.timestamp() * 1000)
elif isinstance(value, str):
value = f"'{value}'"
return str(value)

async def _emit(self, batch, batch_key, batch_time, batch_events, last_event_time=None):
with StringIO() as b:
b.write("INSERT INTO ")
Expand All @@ -902,16 +934,25 @@ async def _emit(self, batch, batch_key, batch_time, batch_events, last_event_tim
b.write(self._table)
else: # table is dynamic
b.write(batch_key)
if self._supertable:
b.write(" USING ")
if not self._using_websocket:
b.write(self._database)
b.write(".")
b.write(self._supertable)
b.write(" TAGS (")
for column_index in range(self._number_of_tags):
value = batch[0].get(self._columns[column_index], "NULL")
b.write(self._sanitize_value(value))
if column_index < self._number_of_tags - 1:
b.write(",")
b.write(")")
b.write(" VALUES ")
for record in batch:
b.write("(")
for column_index in range(len(self._columns)):
for column_index in range(self._number_of_tags, len(self._columns)):
value = record.get(self._columns[column_index], "NULL")
if isinstance(value, datetime.datetime):
value = round(value.timestamp() * 1000)
elif isinstance(value, str):
value = f"'{value}'"
b.write(str(value))
b.write(self._sanitize_value(value))
if column_index < len(self._columns) - 1:
b.write(",")
b.write(") ")
Expand Down

0 comments on commit 265575a

Please sign in to comment.