Skip to content

Commit

Permalink
Add checkpoint/_src/checkpointers/BUILD, add tests and update depende…
Browse files Browse the repository at this point in the history
…ncies on checkpointers/BUILD to these targets.

PiperOrigin-RevId: 716510286
  • Loading branch information
liangyaning33 authored and Orbax Authors committed Jan 17, 2025
1 parent af82636 commit 6ed02f3
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 9 deletions.
61 changes: 61 additions & 0 deletions checkpoint/orbax/checkpoint/_src/checkpointers/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package(
default_applicable_licenses = ["//:package_license"],
default_visibility = ["//visibility:public"],
)

py_library(
name = "abstract_checkpointer",
srcs = ["abstract_checkpointer.py"],
)

py_library(
name = "checkpointer",
srcs = ["checkpointer.py"],
deps = [
":abstract_checkpointer",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/handlers:checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src/handlers:composite_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src/metadata:checkpoint",
"//checkpoint/orbax/checkpoint/_src/metadata:step_metadata_serialization",
"//checkpoint/orbax/checkpoint/_src/multihost",
"//checkpoint/orbax/checkpoint/_src/path:atomicity",
"//checkpoint/orbax/checkpoint/_src/path:atomicity_defaults",
"//checkpoint/orbax/checkpoint/_src/path:atomicity_types",
],
)

py_library(
name = "pytree_checkpointer",
srcs = ["pytree_checkpointer.py"],
deps = [
":checkpointer",
"//checkpoint/orbax/checkpoint/_src/handlers:pytree_checkpoint_handler",
],
)

py_library(
name = "standard_checkpointer",
srcs = ["standard_checkpointer.py"],
deps = [
":async_checkpointer",
"//checkpoint/orbax/checkpoint/_src/handlers:standard_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src/metadata:checkpoint",
"//checkpoint/orbax/checkpoint/_src/path:atomicity_types",
],
)

py_library(
name = "async_checkpointer",
srcs = ["async_checkpointer.py"],
deps = [
":checkpointer",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/handlers:async_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src/metadata:checkpoint",
"//checkpoint/orbax/checkpoint/_src/multihost",
"//checkpoint/orbax/checkpoint/_src/path:async_utils",
"//checkpoint/orbax/checkpoint/_src/path:atomicity",
"//checkpoint/orbax/checkpoint/_src/path:atomicity_types",
],
)
3 changes: 3 additions & 0 deletions checkpoint/orbax/checkpoint/_src/handlers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ py_library(
"//checkpoint/orbax/checkpoint/_src/serialization",
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
"//checkpoint/orbax/checkpoint/_src/tree:utils",
],
)

Expand All @@ -75,6 +76,7 @@ py_library(
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
"//checkpoint/orbax/checkpoint/_src/serialization:types",
"//checkpoint/orbax/checkpoint/_src/tree:utils",
],
)

Expand Down Expand Up @@ -132,6 +134,7 @@ py_library(
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/metadata:pytree_metadata_options",
"//checkpoint/orbax/checkpoint/_src/metadata:tree",
"//checkpoint/orbax/checkpoint/_src/tree:utils",
],
)

Expand Down
23 changes: 20 additions & 3 deletions checkpoint/orbax/checkpoint/_src/metadata/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@ py_library(
deps = [
":pytree_metadata_options",
":value_metadata_entry",
"//checkpoint/orbax/checkpoint/_src/tree:utils",
],
)

py_test(
name = "tree_rich_types_test",
srcs = ["tree_rich_types_test.py"],
deps = [":tree_rich_types"],
deps = [
":tree_rich_types",
"//checkpoint/orbax/checkpoint/_src/testing:test_tree_utils",
],
)

py_library(
Expand All @@ -30,6 +34,7 @@ py_library(
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:types",
"//checkpoint/orbax/checkpoint/_src/tree:utils",
],
)

Expand All @@ -40,13 +45,18 @@ py_test(
":tree",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
"//checkpoint/orbax/checkpoint/_src/serialization:types",
"//checkpoint/orbax/checkpoint/_src/testing:test_tree_utils",
"//checkpoint/orbax/checkpoint/_src/tree:utils",
],
)

py_library(
name = "value",
srcs = ["value.py"],
deps = [":sharding"],
deps = [
":sharding",
"//checkpoint/orbax/checkpoint/_src/arrays:types",
],
)

py_library(
Expand Down Expand Up @@ -107,6 +117,7 @@ py_library(
py_library(
name = "pytree_metadata_options",
srcs = ["pytree_metadata_options.py"],
deps = ["//checkpoint/orbax/checkpoint/_src/tree:utils"],
)

py_library(
Expand All @@ -115,14 +126,18 @@ py_library(
deps = [
":empty_values",
":pytree_metadata_options",
"//checkpoint/orbax/checkpoint/_src/arrays:types",
"//checkpoint/orbax/checkpoint/_src/serialization:types",
],
)

py_library(
name = "empty_values",
srcs = ["empty_values.py"],
deps = [":pytree_metadata_options"],
deps = [
":pytree_metadata_options",
"//checkpoint/orbax/checkpoint/_src/tree:utils",
],
)

py_test(
Expand All @@ -133,12 +148,14 @@ py_test(
deps = [
":empty_values",
":pytree_metadata_options",
"//checkpoint/orbax/checkpoint/_src/testing:test_tree_utils",
],
)

py_library(
name = "array_metadata",
srcs = ["array_metadata.py"],
deps = ["//checkpoint/orbax/checkpoint/_src/arrays:types"],
)

py_library(
Expand Down
3 changes: 3 additions & 0 deletions checkpoint/orbax/checkpoint/_src/path/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ py_test(
srcs = ["format_utils_test.py"],
deps = [
":format_utils",
"//checkpoint/orbax/checkpoint/_src/checkpointers:checkpointer",
"//checkpoint/orbax/checkpoint/_src/handlers:pytree_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src/handlers:standard_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src/metadata:checkpoint",
],
)
25 changes: 22 additions & 3 deletions checkpoint/orbax/checkpoint/_src/serialization/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@ py_library(
name = "tensorstore_utils",
srcs = ["tensorstore_utils.py"],
srcs_version = "PY3",
deps = ["//orbax/checkpoint/_src/metadata:array_metadata"],
deps = [
"//checkpoint/orbax/checkpoint/_src/arrays:subchunking",
"//checkpoint/orbax/checkpoint/_src/arrays:types",
"//orbax/checkpoint/_src/metadata:array_metadata",
],
)

py_library(
name = "types",
srcs = ["types.py"],
deps = [
":serialization",
"//checkpoint/orbax/checkpoint/_src/arrays:types",
"//checkpoint/orbax/checkpoint/_src/metadata:empty_values",
"//checkpoint/orbax/checkpoint/_src/metadata:pytree_metadata_options",
"//checkpoint/orbax/checkpoint/_src/metadata:value",
Expand All @@ -30,6 +35,8 @@ py_library(
":tensorstore_utils",
":types",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/arrays:subchunking",
"//checkpoint/orbax/checkpoint/_src/arrays:types",
"//checkpoint/orbax/checkpoint/_src/metadata:empty_values",
"//checkpoint/orbax/checkpoint/_src/metadata:sharding",
"//checkpoint/orbax/checkpoint/_src/metadata:value",
Expand All @@ -46,7 +53,11 @@ py_test(
srcs = ["tensorstore_utils_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":tensorstore_utils"],
deps = [
":tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/arrays:subchunking",
"//checkpoint/orbax/checkpoint/_src/arrays:types",
],
)

py_library(
Expand All @@ -55,14 +66,22 @@ py_library(
deps = [
":replica_slices",
":tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/arrays:fragments",
"//checkpoint/orbax/checkpoint/_src/arrays:numpy_utils",
"//checkpoint/orbax/checkpoint/_src/arrays:types",
"//checkpoint/orbax/checkpoint/_src/multihost",
],
)

py_library(
name = "replica_slices",
srcs = ["replica_slices.py"],
deps = ["//checkpoint/orbax/checkpoint/_src/multihost"],
deps = [
"//checkpoint/orbax/checkpoint/_src/arrays:fragments",
"//checkpoint/orbax/checkpoint/_src/arrays:numpy_utils",
"//checkpoint/orbax/checkpoint/_src/arrays:types",
"//checkpoint/orbax/checkpoint/_src/multihost",
],
)

py_test(
Expand Down
10 changes: 8 additions & 2 deletions checkpoint/orbax/checkpoint/_src/tree/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,17 @@ py_library(
py_library(
name = "utils",
srcs = ["utils.py"],
deps = [":types"],
deps = [
":types",
"//checkpoint/orbax/checkpoint/_src/arrays:abstract_arrays",
],
)

py_test(
name = "utils_test",
srcs = ["utils_test.py"],
deps = [":utils"],
deps = [
":utils",
"//checkpoint/orbax/checkpoint/_src/testing:test_tree_utils",
],
)
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ absl-py>=1.0,==1.*
numpy>=1.26.0
orbax-checkpoint>=0.9.0
etils[epath]
etils[epy]
simplejson
chex
optax
mock
nest_asyncio
tensorstore
humanize
flax
flax
typing_extensions

0 comments on commit 6ed02f3

Please sign in to comment.