Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow multiple columns input for llm models #998

Merged
merged 12 commits into from
Sep 25, 2024
Merged

Conversation

GarrettWu
Copy link
Contributor

Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly:

  • Make sure to open an issue as a bug/issue before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea
  • Ensure the tests and linter pass
  • Code coverage does not decrease (if any source code was changed)
  • Appropriate docs were updated (if necessary)

Fixes #<issue_number_goes_here> 🦕

@GarrettWu GarrettWu self-assigned this Sep 18, 2024
@product-auto-label product-auto-label bot added size: m Pull request size is medium. api: bigquery Issues related to the googleapis/python-bigquery-dataframes API. labels Sep 18, 2024
@GarrettWu GarrettWu requested review from shobsi and sycai September 20, 2024 18:25
@GarrettWu GarrettWu marked this pull request as ready for review September 20, 2024 18:25
@GarrettWu GarrettWu requested review from a team as code owners September 20, 2024 18:25
@@ -244,7 +244,7 @@ def predict(

Args:
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
Input DataFrame or Series, which contains only one column of prompts.
Input DataFrame or Series, can contain one or more columns. If multiple columns in the DataFrame, it must contain a "prompt" column for prediction.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "If multiple columns are in the DataFrame, they must ..." and for other docs too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"it" refers to the DataFrame. Can add "are" in "If multiple columns are in the DataFrame"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

assert "text_embedding" in df.columns
series = df["text_embedding"]
value = series[0]
assert len(value) == 768
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe we could coalesce line 323 - 325 into a single line?
assert len(df[..][0]) == 768

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, actually I'll rewrite the tests. Also some are already removed in a recent PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

tests/system/small/ml/test_llm.py Outdated Show resolved Hide resolved
# BQML identified the column by name
col_label = cast(blocks.Label, X.columns[0])
X = X.rename(columns={col_label: "prompt"})
if len(X.columns) == 1:
Copy link
Contributor

@shobsi shobsi Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should make another check in the else clause - that the multi-column input does have a "prompt" column. Also add negative test for that scenario

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tswast had a suggestion that we shouldn't do much client side checks. I'm trying to follow: if the error message is meaningful to the user, then rely on server side checks. Otherwise we have to wrap server error messages or return client side error messages.

@GarrettWu GarrettWu requested review from shobsi and sycai September 23, 2024 22:39
@GarrettWu GarrettWu enabled auto-merge (squash) September 24, 2024 18:58
@GarrettWu GarrettWu merged commit 2fe5e48 into main Sep 25, 2024
22 of 23 checks passed
@GarrettWu GarrettWu deleted the garrettwu-cols branch September 25, 2024 18:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
api: bigquery Issues related to the googleapis/python-bigquery-dataframes API. size: m Pull request size is medium.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants