Skip to content

Commit

Permalink
Remove use of DataFrame merge from test_embedding_cat_export_import (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
oliverholworthy authored Jul 3, 2023
1 parent f9a2891 commit aad1112
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions tests/unit/workflow/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,33 +800,36 @@ def test_embedding_cat_export_import(tmpdir, cpu):
npy_path = str(tmpdir / "embeddings.npy")
emb_res.to_npy(npy_path)

embeddings = np.load(npy_path)
ids_and_embeddings = np.load(npy_path)
# second workflow that categorifies the embedding table data
df = make_df({"string_id": np.random.choice(string_ids, 30)})
graph2 = ["string_id"] >> cat_op
train_res = Workflow(graph2).transform(Dataset(df, cpu=(cpu is not None)))

ids = ids_and_embeddings[:, 0].astype(int)
embeddings = ids_and_embeddings[:, 1:]

data_loader = Loader(
train_res,
batch_size=1,
transforms=[
EmbeddingOperator(
embeddings[:, 1:],
id_lookup_table=embeddings[:, 0].astype(int),
embeddings,
id_lookup_table=ids,
lookup_key="string_id",
)
],
shuffle=False,
device=cpu,
)
origin_df = train_res.to_ddf().merge(emb_res.to_ddf(), on="string_id", how="left").compute()
embeddings_by_id = dict(zip(ids, embeddings))
for idx, batch in enumerate(data_loader):
batch
b_df = batch[0].to_df()
org_df = origin_df.iloc[idx]
if not cpu:
assert (b_df["string_id"].to_numpy() == org_df["string_id"].to_numpy()).all()
assert (b_df["embeddings"].list.leaves == org_df["embeddings"].list.leaves).all()
else:
assert (b_df["string_id"].values == org_df["string_id"]).all()
assert b_df["embeddings"].values[0] == org_df["embeddings"].tolist()
x, _ = batch
b_df = x.to_df()
org_df = make_df(
{
"string_id": x["string_id"].values,
"embeddings": [embeddings_by_id[_id] for _id in x["string_id"].values.tolist()],
}
)
assert_eq(b_df, org_df)

0 comments on commit aad1112

Please sign in to comment.