Skip to content

Commit

Permalink
Merge remote-tracking branch 'mlrun/development' into ML-1325
Browse files Browse the repository at this point in the history
  • Loading branch information
Gal Topper committed Jun 8, 2023
2 parents c5e5f0c + c2573e1 commit 4d1619d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
5 changes: 4 additions & 1 deletion storey/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,10 @@ async def _run_loop(self):
else:
line_id = self._get_uuid()
element = self._get_element(body, columns)
keys = keys[0] if len(keys) == 1 else (None if not keys else keys)
if len(keys) == 0:
keys = None
elif not isinstance(self._key_field, list):
keys = keys[0]
event = Event(element, keys, id=line_id)
await self._do_downstream(event)
return await self._do_downstream(_termination_obj)
Expand Down
10 changes: 6 additions & 4 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,8 @@ def test_indexed_dataframe_source():
assert termination_result == expected


def test_dataframe_source_with_metadata():
@pytest.mark.parametrize("key_field ,key1, key2", [("my_key", "key1", "key2"), (["my_key"], ["key1"], ["key2"])])
def test_dataframe_source_with_metadata(key_field, key1, key2):
t1 = datetime(2020, 2, 15)
t2 = datetime(2020, 2, 16)
df = pd.DataFrame(
Expand All @@ -698,7 +699,7 @@ def test_dataframe_source_with_metadata():
)
controller = build_flow(
[
DataframeSource(df, key_field="my_key", time_field="my_time", id_field="my_id"),
DataframeSource(df, key_field=key_field, time_field="my_time", id_field="my_id"),
Reduce([], append_and_return, full_event=True),
]
).run()
Expand All @@ -707,16 +708,17 @@ def test_dataframe_source_with_metadata():
expected = [
Event(
{"my_key": "key1", "my_time": t1, "my_id": "id1", "my_value": 1.1},
key="key1",
key=key1,
id="id1",
),
Event(
{"my_key": "key2", "my_time": t2, "my_id": "id2", "my_value": 2.2},
key="key2",
key=key2,
id="id2",
),
]
assert termination_result == expected
assert list(map(lambda event: event.key, termination_result)) == [key1, key2]


async def async_dataframe_source():
Expand Down

0 comments on commit 4d1619d

Please sign in to comment.