diff --git a/py/kubeflow/testing/util.py b/py/kubeflow/testing/util.py index 7568f782139..026b1a1f1b3 100755 --- a/py/kubeflow/testing/util.py +++ b/py/kubeflow/testing/util.py @@ -93,6 +93,25 @@ def run(command, def run_and_output(*args, **argv): return run(*args, **argv) +def combine_repos(list_of_repos): + """Builds a dictionary of repo owner/names to commit hashes. + + Args: + list_of_repos: A list of repos to checkout, each one in the format of + "owner/name@commit". Later values override earlier ones. + Returns: + repos: A dictionary of repository names to commit hashes. + """ + + # Convert list_of_repos to a dictionary where key is "repo_owner/repo_name" + # and value is the commit hash. By convention, values that appear later in + # the list would override earlier ones. + repos = {} + for r in list_of_repos: + parts = r.split('@') + repos[parts[0]] = parts[1] + + return repos def clone_repo(dest, repo_owner=MASTER_REPO_OWNER, diff --git a/py/kubeflow/tests/util_test.py b/py/kubeflow/tests/util_test.py index 457677f0a11..b1d0f983a26 100644 --- a/py/kubeflow/tests/util_test.py +++ b/py/kubeflow/tests/util_test.py @@ -36,5 +36,35 @@ def testSplitGcsUri(self): self.assertEqual("some-bucket", bucket) self.assertEqual("", path) + def testCombineReposDefault(self): + repos = util.combine_repos([]) + expected_repos = {} + self.assertDictEqual(repos, expected_repos) + + def testCombineReposOverrides(self): + repos = util.combine_repos(["kubeflow/kubeflow@HEAD", + "kubeflow/tf-operator@HEAD", + "kubeflow/kubeflow@12345", + "kubeflow/tf-operator@23456"]) + expected_repos = { + "kubeflow/kubeflow": "12345", + "kubeflow/tf-operator": "23456" + } + self.assertDictEqual(repos, expected_repos) + + def testCombineReposExtras(self): + repos = util.combine_repos(["kubeflow/kubeflow@HEAD", + "kubeflow/tf-operator@HEAD", + "kubeflow/kfctl@12345", + "kubeflow/katib@23456"]) + expected_repos = { + "kubeflow/kubeflow": "HEAD", + "kubeflow/tf-operator": "HEAD", + "kubeflow/kfctl": "12345", + "kubeflow/katib": "23456" + } + self.assertDictEqual(repos, expected_repos) + + if __name__ == "__main__": unittest.main()