Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce truss train subcommand, API stubs #1422

Merged
merged 2 commits into from
Mar 5, 2025

Conversation

nnarayen
Copy link
Contributor

@nnarayen nnarayen commented Mar 4, 2025

🚀 What

This PR adds the foundation for new truss_train functionality:

  • truss train push CLI stub to create a training project + job
  • framework code to import + validate training job definitions
  • pydantic types (not finalized)

💻 How

🔬 Testing

  • New unit tests
  • Sample truss train push with the following config:
from truss_train import definitions

runtime_config = definitions.RuntimeConfig(
    start_commands=["/bin/bash ./my-entrypoint.sh"],
    environment_variables={
        "FOO_VAR": "FOO_VAL",
        "BAR_VAR": definitions.SecretReference(name="BAR_SECRET"),
    },
)

training_job = definitions.TrainingJob(
    compute=definitions.Compute(node_count=1, cpu_count=4),
    runtime_config=runtime_config,
)

first_project = definitions.TrainingProject(name="first-project", job=training_job)

@nnarayen nnarayen force-pushed the nikhil/introduce-truss-train branch 2 times, most recently from 536a1f7 to c81c767 Compare March 4, 2025 21:40
@@ -18,6 +18,7 @@ keywords = [
packages = [
{ include = "truss", from = "." },
{ include = "truss_chains", from = "./truss-chains" },
{ include = "truss_train", from = "./truss-train" },
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marius-baseten curious why this package structure was needed for chains? I think we can still use the import aliasing here without this, but I feel like there are other benefits I'm not aware of

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was not needed strictly, but since in the truss-subtree has all over the palce imports and chains had some different version requirements, so this extra dir was done as a salient way to keep up some isolation/structure between the to subtrees.

@@ -189,7 +190,7 @@ markers = [
addopts = "--ignore=smoketests"

[tool.ruff]
src = ["truss", "truss-chains"]
src = ["truss", "truss-chains", "truss-train"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows truss-train imports to be labeled as first-party for import sorting. Note for followup, I think we actually want to specify . instead of truss here, since right now truss is being categorized as third party across the board

@@ -27,6 +27,7 @@
import pydantic
from truss.base import truss_config
from truss.base.constants import PRODUCTION_ENVIRONMENT_NAME
from truss.base.custom_types import SafeModel, SafeModelNonSerializable
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be some shared paradigms between chains / train that aren't applicable to traditional truss. For now I'm opting to put into base, but open to suggestions on a better file structure

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code requires pydantiv v2, which was selectively required for chains only: https://github.com/basetenlabs/truss/blob/main/truss-chains/truss_chains/__init__.py#L3

Are we ready now to require v2 for the entire truss package? If so, we should remove those guards, update the project requirements etc...

Copy link
Contributor Author

@nnarayen nnarayen Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up reverting this for other reasons, but I think we can't upgrade to pydantic V2 until TaT ships, since we might break older trusses that are build with V1 unfortunately

name: str


class Compute(SafeModel):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need the extra layer like chains, and I'm not sure chains needs it either?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The layer was mainly for dealing with the conversion of gpu, but I believe this could be done with pydantic validators too.

One thing that is really important here: try to find a common denominator for the "new-style" APIs, so that we can re-use this for chains. In a way I'd even suggest putting these defs here into truss.base and then both train and chains can depend on it and use the same. It's a really bad user experience if these two products have "almost" the same APIs, but subtle differences, plus it's more upkeek and repo-clutter if we have so many definitions.

It was discussed in the API design/PRD doc to strive for this consolidation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy to consolidate the definitions for compute, but I think after our simplifications that's the only one that makes sense? Unfortunately I believe env variables / secrets are different enough to warrant new APIs.

Chains can register secrets to be made available to user code, but there is no similar entrypoint in training. I think the best UX would be to have a clean way to define both traditional env vars and ones derived from secrets (as proposed here), and then the user can consume via bash / python scripts as needed.

@nnarayen nnarayen force-pushed the nikhil/introduce-truss-train branch from c81c767 to 670b6ff Compare March 4, 2025 21:49
@@ -566,3 +567,27 @@ def get_all_secrets(self) -> Any:

secrets_info = resp.json()
return secrets_info

def upsert_training_project(self, training_project):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally think the rest API structure should mirror the truss SDK very closely. The server code is already going to have to transform things (i.e. get the actual instance type, worker planes, user/org, etc), but we can keep the truss integration simple (basically a model_dump() here instead of explicit transformation code)


return resp.json()

def create_training_job(self, project_id: str, job):
Copy link
Contributor Author

@nnarayen nnarayen Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here (and in the remote wrapper) I ran into a circular dep issue when trying to annotate TrainingProject / TrainingJob because truss.base ends up depending on the api. I'm sure we can disentangle this in a followup, but I don't think it needs to block for now

@nnarayen nnarayen force-pushed the nikhil/introduce-truss-train branch from 670b6ff to 5567716 Compare March 4, 2025 21:57
accelerator: Optional[truss_config.AcceleratorSpec] = None


class Runtime(SafeModel):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add client side validation code to these models in follow ups as well

from truss_train import definitions


@contextlib.contextmanager
Copy link
Contributor Author

@nnarayen nnarayen Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we just went over this in the offsite, but I actually think this code is different enough from chains to warrant the duplication, thoughts? Differences include:

  • no module modifications needed
  • targets aren't subclasses, but actual instances of TrainingProject
  • error messages

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this looks pretty good, and agree with the thought you put into the decision. Could you explain a little bit on what the context manager helps with here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice I don't think it's super necessary, was following the pattern from chains. The context manager allows us to defer execution to the block but ensure that we always perform cleanup code. Chains has more complicated import logic that tries to clean up modified modules, but we can probably get away without it for now

runtime: Runtime = Runtime()


class TrainingProject(SafeModel):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I intentionally left off the blob stuff for now, I think it'll be easier to add that in targeted followups. I imagine we'll eventually zip up the directory and pass it through to server to upload to an S3 bucket of our choosing

Copy link
Contributor

@rcano-baseten rcano-baseten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Interested to get Marius thoughts but I think this is pretty much in line with what I was thinking

from truss_train import definitions


@contextlib.contextmanager
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this looks pretty good, and agree with the thought you put into the decision. Could you explain a little bit on what the context manager helps with here?

truss/cli/cli.py Outdated


@train.command(name="push")
@click.argument("source", type=Path, required=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reason for source instead of config? I think we're closer to a config than any source code...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong reason, mainly (1) consistent w chains terminology (2) future proofing in case we extend to some way to have user written code provided. I'll switch to config for now since this is easy to change until we show to customers!

@nnarayen nnarayen merged commit a3a471e into main Mar 5, 2025
5 checks passed
@nnarayen nnarayen deleted the nikhil/introduce-truss-train branch March 5, 2025 15:14
@@ -18,6 +18,7 @@ keywords = [
packages = [
{ include = "truss", from = "." },
{ include = "truss_chains", from = "./truss-chains" },
{ include = "truss_train", from = "./truss-train" },
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was not needed strictly, but since in the truss-subtree has all over the palce imports and chains had some different version requirements, so this extra dir was done as a salient way to keep up some isolation/structure between the to subtrees.

@@ -27,6 +27,7 @@
import pydantic
from truss.base import truss_config
from truss.base.constants import PRODUCTION_ENVIRONMENT_NAME
from truss.base.custom_types import SafeModel, SafeModelNonSerializable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code requires pydantiv v2, which was selectively required for chains only: https://github.com/basetenlabs/truss/blob/main/truss-chains/truss_chains/__init__.py#L3

Are we ready now to require v2 for the entire truss package? If so, we should remove those guards, update the project requirements etc...

name: str


class Compute(SafeModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The layer was mainly for dealing with the conversion of gpu, but I believe this could be done with pydantic validators too.

One thing that is really important here: try to find a common denominator for the "new-style" APIs, so that we can re-use this for chains. In a way I'd even suggest putting these defs here into truss.base and then both train and chains can depend on it and use the same. It's a really bad user experience if these two products have "almost" the same APIs, but subtle differences, plus it's more upkeek and repo-clutter if we have so many definitions.

It was discussed in the API design/PRD doc to strive for this consolidation.



class TrainingJob(SafeModel):
image: Image
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not following the PRD / Design doc - why the deviation?

The design had a separation of defining an image as a semi-permanent resource and then referencing that by an ID (not nesting the definition) in the training job.

class SafeModelNonSerializable(pydantic.BaseModel):
"""Pydantic base model with reasonable config - allowing arbitrary types."""

model_config = pydantic.ConfigDict(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See other comment, this requires pydantic v2 in the entire truss package now.



@contextlib.contextmanager
def import_target(module_path: pathlib.Path) -> Iterator[definitions.TrainingProject]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I never really liked that we have to import stuff from a path dynamcially for chains and truss. Since training is completely fresh product, is there a way we can avoid those brittle imports and use dependency injection or something like that?

For truss the problem is that it essentially works like this:

class TrussServer:

  def run():
    import_from_path(user_module_path)

if __name__ == "__main__":
   TrussServer().run()

The better patter would be:


class UserStuff:
   ...
   
if __name__ == "__main__":
   TrussServer(UserStuff).run()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Synced offline - since this is purely on the CLI side for now, there's unfortunately no great way around this. Can revisit if this ever becomes runtime code!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants