diff --git a/gensim/downloader.py b/gensim/downloader.py index f3672ead01..3b2cf34ffa 100644 --- a/gensim/downloader.py +++ b/gensim/downloader.py @@ -29,6 +29,7 @@ Also, this API available via CLI:: python -m gensim.downloader --info # same as api.info(dataname) + python -m gensim.downloader --info name # same as api.info(name_only=True) python -m gensim.downloader --download # same as api.load(dataname, return_path=True) """ @@ -154,7 +155,7 @@ def _calculate_md5_checksum(fname): return hash_md5.hexdigest() -def info(name=None, show_only_latest=True): +def info(name=None, show_only_latest=True, name_only=False): """Provide the information related to model/dataset. Parameters @@ -164,6 +165,8 @@ def info(name=None, show_only_latest=True): show_only_latest : bool, optional If storage contains different versions for one data/model, this flag allow to hide outdated versions. Affects only if `name` is None. + name_only : bool, optional + If True, will return only the names of available models and corpora. Returns ------- @@ -205,6 +208,9 @@ def info(name=None, show_only_latest=True): if not show_only_latest: return information + if name_only: + return {"corpora": list(information['corpora'].keys()), "models": list(information['models'])} + return { "corpora": {name: data for (name, data) in information['corpora'].items() if data.get("latest", True)}, "models": {name: data for (name, data) in information['models'].items() if data.get("latest", True)} @@ -444,5 +450,8 @@ def load(name, return_path=False): data_path = load(args.download[0], return_path=True) logger.info("Data has been installed and data path is %s", data_path) elif args.info is not None: - output = info() if (args.info == full_information) else info(name=args.info) - print(json.dumps(output, indent=4)) + if args.info == 'name': + print(json.dumps(info(name_only=True), indent=4)) + else: + output = info() if (args.info == full_information) else info(name=args.info) + print(json.dumps(output, indent=4)) diff --git a/gensim/test/test_api.py b/gensim/test/test_api.py index bf84800205..13245b2205 100644 --- a/gensim/test/test_api.py +++ b/gensim/test/test_api.py @@ -72,6 +72,9 @@ def test_info(self): self.assertEqual(sorted(data.keys()), sorted(['models', 'corpora'])) self.assertTrue(len(data['models'])) self.assertTrue(len(data['corpora'])) + name_only_data = api.info(name_only=True) + self.assertEqual(len(name_only_data.keys()), 2) + self.assertTrue({'models', 'corpora'} == set(name_only_data)) if __name__ == '__main__':