Skip to content

Commit

Permalink
Project import generated by Copybara.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 427564090
  • Loading branch information
ModelSearch authored and hanna-maz committed Feb 9, 2022
1 parent 8c5eed4 commit d90bc39
Show file tree
Hide file tree
Showing 56 changed files with 2,568 additions and 1,393 deletions.
31 changes: 18 additions & 13 deletions model_search/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ exports_files(
)

py_library(
name = "blocks",
srcs = ["blocks.py"],
name = "block",
srcs = ["block.py"],
srcs_version = "PY3",
deps = [
":hparam",
Expand All @@ -44,36 +44,36 @@ py_library(
)

py_library(
name = "blocks_builder",
srcs = ["blocks_builder.py"],
name = "block_builder",
srcs = ["block_builder.py"],
srcs_version = "PY3",
deps = [
":blocks",
":block",
":registry",
"//model_search/hparams:hyperparameters",
],
)

py_test(
name = "blocks_builder_test",
srcs = ["blocks_builder_test.py"],
name = "block_builder_test",
srcs = ["block_builder_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":blocks_builder",
":block_builder",
"@absl_py//absl/testing:parameterized",
],
)

py_test(
name = "blocks_test",
name = "block_test",
size = "large",
srcs = ["blocks_test.py"],
srcs = ["block_test.py"],
python_version = "PY3",
shard_count = 15,
srcs_version = "PY3",
deps = [
":blocks",
":block",
"@absl_py//absl/testing:parameterized",
"//model_search:hparam",
"//model_search/architecture:architecture_utils",
Expand Down Expand Up @@ -222,7 +222,6 @@ py_library(
srcs_version = "PY3",
deps = [
":controller",
":ensembler",
":hparam",
":loss_fns",
":metric_fns",
Expand Down Expand Up @@ -285,9 +284,14 @@ py_library(
srcs = ["task_manager.py"],
srcs_version = "PY3",
deps = [
":blocks_builder",
":block_builder",
":ensembler",
":loss_fns",
"//model_search/architecture:architecture_utils",
"//model_search/architecture:tower",
"//model_search/generators:base_tower_generator",
"//model_search/generators:trial_utils",
"//model_search/meta:distillation",
],
)

Expand All @@ -303,6 +307,7 @@ py_test(
"//model_search/proto:all_proto_py_pb2",
"@absl_py//absl/testing:parameterized",
"//model_search/architecture:architecture_utils",
"//model_search/generators:trial_utils",
],
)

Expand Down
54 changes: 36 additions & 18 deletions model_search/architecture/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4,60 +4,78 @@ package(default_visibility = ["//visibility:public"])

licenses(["notice"])

exports_files(["LICENSE"])

py_library(
name = "architecture_utils",
srcs = ["architecture_utils.py"],
srcs_version = "PY3",
deps = [
"//model_search/proto:all_proto_py_pb2",
"//model_search:blocks",
"//model_search:blocks_builder",
"//model_search:block",
"//model_search:block_builder",
"//model_search:hparam",
"//model_search:utils",
],
)

py_test(
name = "architecture_utils_test",
srcs = ["architecture_utils_test.py"],
python_version = "PY3",
py_library(
name = "tower",
srcs = ["tower.py"],
srcs_version = "PY3",
deps = [
":architecture_utils",
"//model_search/proto:all_proto_py_pb2",
"@absl_py//absl/testing:parameterized",
"//model_search:blocks_builder",
"//model_search:block",
"//model_search:block_builder",
"//model_search:hparam",
"//model_search/metadata:trial",
"//model_search:utils",
],
)

py_test(
name = "graph_architecture_test",
srcs = ["graph_architecture_test.py"],
data = [
"//model_search/testdata",
],
name = "architecture_utils_test",
srcs = ["architecture_utils_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":architecture_utils",
":graph_architecture",
"//model_search/proto:all_proto_py_pb2",
"@absl_py//absl/testing:parameterized",
"//model_search:block_builder",
"//model_search:hparam",
"//model_search/metadata:trial",
],
)

# py_strict_test(
# name = "graph_architecture_test",
# srcs = ["graph_architecture_test.py"],
# data = [
# "//model_search/testdata",
# ],
# python_version = "PY3",
# srcs_version = "PY3",
# tags = [
# "notap",
# ],
# deps = [
# ":architecture_utils",
# ":graph_architecture",
#
# "@absl_py//absl/testing:parameterized",
# "//model_search/proto:phoenix_spec_py_pb2",
# "//third_party/py/numpy",
# "//third_party/py/tensorflow",
# ],
# )

py_library(
name = "graph_architecture",
srcs = ["graph_architecture.py"],
srcs_version = "PY3",
deps = [
":architecture_utils",
"//model_search/proto:all_proto_py_pb2",
"//model_search:blocks_builder",
"//model_search:block_builder",
"//model_search:utils",
],
)
Loading

0 comments on commit d90bc39

Please sign in to comment.