-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce truss train
subcommand, API stubs
#1422
Conversation
536a1f7
to
c81c767
Compare
@@ -18,6 +18,7 @@ keywords = [ | |||
packages = [ | |||
{ include = "truss", from = "." }, | |||
{ include = "truss_chains", from = "./truss-chains" }, | |||
{ include = "truss_train", from = "./truss-train" }, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@marius-baseten curious why this package structure was needed for chains? I think we can still use the import aliasing here without this, but I feel like there are other benefits I'm not aware of
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was not needed strictly, but since in the truss-subtree has all over the palce imports and chains had some different version requirements, so this extra dir was done as a salient way to keep up some isolation/structure between the to subtrees.
@@ -189,7 +190,7 @@ markers = [ | |||
addopts = "--ignore=smoketests" | |||
|
|||
[tool.ruff] | |||
src = ["truss", "truss-chains"] | |||
src = ["truss", "truss-chains", "truss-train"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This allows truss-train
imports to be labeled as first-party for import sorting. Note for followup, I think we actually want to specify .
instead of truss
here, since right now truss
is being categorized as third party across the board
@@ -27,6 +27,7 @@ | |||
import pydantic | |||
from truss.base import truss_config | |||
from truss.base.constants import PRODUCTION_ENVIRONMENT_NAME | |||
from truss.base.custom_types import SafeModel, SafeModelNonSerializable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There might be some shared paradigms between chains / train that aren't applicable to traditional truss. For now I'm opting to put into base
, but open to suggestions on a better file structure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code requires pydantiv v2, which was selectively required for chains only: https://github.com/basetenlabs/truss/blob/main/truss-chains/truss_chains/__init__.py#L3
Are we ready now to require v2 for the entire truss package? If so, we should remove those guards, update the project requirements etc...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ended up reverting this for other reasons, but I think we can't upgrade to pydantic
V2 until TaT ships, since we might break older trusses that are build with V1 unfortunately
name: str | ||
|
||
|
||
class Compute(SafeModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need the extra layer like chains, and I'm not sure chains needs it either?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The layer was mainly for dealing with the conversion of gpu
, but I believe this could be done with pydantic validators too.
One thing that is really important here: try to find a common denominator for the "new-style" APIs, so that we can re-use this for chains. In a way I'd even suggest putting these defs here into truss.base
and then both train and chains can depend on it and use the same. It's a really bad user experience if these two products have "almost" the same APIs, but subtle differences, plus it's more upkeek and repo-clutter if we have so many definitions.
It was discussed in the API design/PRD doc to strive for this consolidation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy to consolidate the definitions for compute, but I think after our simplifications that's the only one that makes sense? Unfortunately I believe env variables / secrets are different enough to warrant new APIs.
Chains can register secrets to be made available to user code, but there is no similar entrypoint in training. I think the best UX would be to have a clean way to define both traditional env vars and ones derived from secrets (as proposed here), and then the user can consume via bash / python scripts as needed.
c81c767
to
670b6ff
Compare
@@ -566,3 +567,27 @@ def get_all_secrets(self) -> Any: | |||
|
|||
secrets_info = resp.json() | |||
return secrets_info | |||
|
|||
def upsert_training_project(self, training_project): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I personally think the rest API structure should mirror the truss SDK very closely. The server code is already going to have to transform things (i.e. get the actual instance type, worker planes, user/org, etc), but we can keep the truss integration simple (basically a model_dump() here instead of explicit transformation code)
|
||
return resp.json() | ||
|
||
def create_training_job(self, project_id: str, job): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here (and in the remote wrapper) I ran into a circular dep issue when trying to annotate TrainingProject
/ TrainingJob
because truss.base
ends up depending on the api. I'm sure we can disentangle this in a followup, but I don't think it needs to block for now
670b6ff
to
5567716
Compare
accelerator: Optional[truss_config.AcceleratorSpec] = None | ||
|
||
|
||
class Runtime(SafeModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can add client side validation code to these models in follow ups as well
from truss_train import definitions | ||
|
||
|
||
@contextlib.contextmanager |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know we just went over this in the offsite, but I actually think this code is different enough from chains to warrant the duplication, thoughts? Differences include:
- no module modifications needed
- targets aren't subclasses, but actual instances of TrainingProject
- error messages
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this looks pretty good, and agree with the thought you put into the decision. Could you explain a little bit on what the context manager helps with here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In practice I don't think it's super necessary, was following the pattern from chains. The context manager allows us to defer execution to the block but ensure that we always perform cleanup code. Chains has more complicated import logic that tries to clean up modified modules, but we can probably get away without it for now
runtime: Runtime = Runtime() | ||
|
||
|
||
class TrainingProject(SafeModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I intentionally left off the blob
stuff for now, I think it'll be easier to add that in targeted followups. I imagine we'll eventually zip up the directory and pass it through to server to upload to an S3 bucket of our choosing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Interested to get Marius thoughts but I think this is pretty much in line with what I was thinking
from truss_train import definitions | ||
|
||
|
||
@contextlib.contextmanager |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this looks pretty good, and agree with the thought you put into the decision. Could you explain a little bit on what the context manager helps with here?
truss/cli/cli.py
Outdated
|
||
|
||
@train.command(name="push") | ||
@click.argument("source", type=Path, required=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reason for source
instead of config
? I think we're closer to a config than any source code...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No strong reason, mainly (1) consistent w chains terminology (2) future proofing in case we extend to some way to have user written code provided. I'll switch to config
for now since this is easy to change until we show to customers!
@@ -18,6 +18,7 @@ keywords = [ | |||
packages = [ | |||
{ include = "truss", from = "." }, | |||
{ include = "truss_chains", from = "./truss-chains" }, | |||
{ include = "truss_train", from = "./truss-train" }, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was not needed strictly, but since in the truss-subtree has all over the palce imports and chains had some different version requirements, so this extra dir was done as a salient way to keep up some isolation/structure between the to subtrees.
@@ -27,6 +27,7 @@ | |||
import pydantic | |||
from truss.base import truss_config | |||
from truss.base.constants import PRODUCTION_ENVIRONMENT_NAME | |||
from truss.base.custom_types import SafeModel, SafeModelNonSerializable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code requires pydantiv v2, which was selectively required for chains only: https://github.com/basetenlabs/truss/blob/main/truss-chains/truss_chains/__init__.py#L3
Are we ready now to require v2 for the entire truss package? If so, we should remove those guards, update the project requirements etc...
name: str | ||
|
||
|
||
class Compute(SafeModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The layer was mainly for dealing with the conversion of gpu
, but I believe this could be done with pydantic validators too.
One thing that is really important here: try to find a common denominator for the "new-style" APIs, so that we can re-use this for chains. In a way I'd even suggest putting these defs here into truss.base
and then both train and chains can depend on it and use the same. It's a really bad user experience if these two products have "almost" the same APIs, but subtle differences, plus it's more upkeek and repo-clutter if we have so many definitions.
It was discussed in the API design/PRD doc to strive for this consolidation.
|
||
|
||
class TrainingJob(SafeModel): | ||
image: Image |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not following the PRD / Design doc - why the deviation?
The design had a separation of defining an image as a semi-permanent resource and then referencing that by an ID (not nesting the definition) in the training job.
class SafeModelNonSerializable(pydantic.BaseModel): | ||
"""Pydantic base model with reasonable config - allowing arbitrary types.""" | ||
|
||
model_config = pydantic.ConfigDict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See other comment, this requires pydantic v2 in the entire truss package now.
|
||
|
||
@contextlib.contextmanager | ||
def import_target(module_path: pathlib.Path) -> Iterator[definitions.TrainingProject]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I never really liked that we have to import stuff from a path dynamcially for chains and truss. Since training is completely fresh product, is there a way we can avoid those brittle imports and use dependency injection or something like that?
For truss the problem is that it essentially works like this:
class TrussServer:
def run():
import_from_path(user_module_path)
if __name__ == "__main__":
TrussServer().run()
The better patter would be:
class UserStuff:
...
if __name__ == "__main__":
TrussServer(UserStuff).run()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Synced offline - since this is purely on the CLI side for now, there's unfortunately no great way around this. Can revisit if this ever becomes runtime code!
🚀 What
This PR adds the foundation for new
truss_train
functionality:truss train push
CLI stub to create a training project + job💻 How
🔬 Testing
truss train push
with the following config: