-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: allow multiple columns input for llm models #998
Conversation
bigframes/ml/llm.py
Outdated
@@ -244,7 +244,7 @@ def predict( | |||
|
|||
Args: | |||
X (bigframes.dataframe.DataFrame or bigframes.series.Series): | |||
Input DataFrame or Series, which contains only one column of prompts. | |||
Input DataFrame or Series, can contain one or more columns. If multiple columns in the DataFrame, it must contain a "prompt" column for prediction. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "If multiple columns are in the DataFrame, they must ..." and for other docs too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"it" refers to the DataFrame. Can add "are" in "If multiple columns are in the DataFrame"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
tests/system/small/ml/test_llm.py
Outdated
assert "text_embedding" in df.columns | ||
series = df["text_embedding"] | ||
value = series[0] | ||
assert len(value) == 768 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe we could coalesce line 323 - 325 into a single line?
assert len(df[..][0]) == 768
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, actually I'll rewrite the tests. Also some are already removed in a recent PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
# BQML identified the column by name | ||
col_label = cast(blocks.Label, X.columns[0]) | ||
X = X.rename(columns={col_label: "prompt"}) | ||
if len(X.columns) == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should make another check in the else clause - that the multi-column input does have a "prompt" column. Also add negative test for that scenario
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tswast had a suggestion that we shouldn't do much client side checks. I'm trying to follow: if the error message is meaningful to the user, then rely on server side checks. Otherwise we have to wrap server error messages or return client side error messages.
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly:
Fixes #<issue_number_goes_here> 🦕