Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading pretrained models #38

Closed
VikaNa opened this issue Dec 17, 2020 · 7 comments
Closed

Loading pretrained models #38

VikaNa opened this issue Dec 17, 2020 · 7 comments
Assignees
Labels
bug Something isn't working

Comments

@VikaNa
Copy link

VikaNa commented Dec 17, 2020

  • Contextualized Topic Models version: 1.7.0
  • Python version: 3.7.3
  • Operating System: Windows

Description

I have trained a Zero Shot Cross Lingual topic model and saved it with the .save() method. Now, I have problems with loading this model.

What I Did

CTM().load(model_dir = 'path to the epoch_9.pth file', epoch=9)

This resulted in the following error:

TypeError: __init__() missing 2 required positional arguments: 'input_size' and 'bert_input_size'

Could you please tell me, what would be the right way to load a pretrained Zero Shot model ?
Thank you!

@vinid vinid self-assigned this Dec 17, 2020
@vinid
Copy link
Contributor

vinid commented Dec 17, 2020

Hello @VikaNa!

you just need to initialize a model first with the correct input size and then you can call the loss function.

ct = ZeroShotTM(input_size=2000, bert_input_size=512)
ct.load("path_to_folder", epoch=13)

One thing: this is mostly an experimental feature that we haven't been testing. I did some experiments and it seems to work (see the screenshot, I trained a ctm model, saved it, and then reloaded it and checked if the topics produced were the same):

Screenshot from 2020-12-17 10-56-13

I plan to work on adding some tests to this feature in the future. To save the model you can also use the standard pickle package in python: https://thepythonguru.com/pickling-objects-in-python/

Let me know if everything works or if you encounter other issues :)!

@VikaNa
Copy link
Author

VikaNa commented Dec 20, 2020

Hello @vinid , thank you for your quick respond!
I was able to load a pretrained Zero Shot model and it seems that topics are also the same.
However, I have a couple more questions regarding the methods by the ZeroShot model class.
So, for example, I was not able to use methods like .get_wordcloud() or .get_word_distribution_by_topic_id(). The following error occured:

AttributeError: 'ZeroShotTM' object has no attribute 'get_wordcloud'

or

AttributeError: 'ZeroShotTM' object has no attribute 'get_word_distribution_by_topic_id'

Could you help me with this issue?

@vinid
Copy link
Contributor

vinid commented Dec 20, 2020

Hello!

those two functionalities have been introduced in contextualized topic models 1.7.1. You can run

pip install -U contextualized-topic-models

to get the latest stable version.

Beware as I told you the saving/loading methods are still experimental. Thus, I am not sure - if you can check and get back to me it would be great - if you can reload the model you have trained with version 1.7.0. You might have to retrain the model with 1.7.1 to be able to load it.

Thanks :)

@rubmz
Copy link

rubmz commented May 13, 2021

After patching one issue in the code (because I don't have a cuda on my local PC) it worked fine!

ctm.py line 418
need to check if there's cuda on board, if not you need to pass map_location=torch.device('cpu') to the torch.load() function.

@vinid
Copy link
Contributor

vinid commented May 13, 2021

thanks a lot! I'll reopen this so I know I have to fix it

@vinid vinid reopened this May 13, 2021
@vinid vinid added the bug Something isn't working label May 13, 2021
@vinid vinid closed this as completed Sep 6, 2021
@amirmohammadkz
Copy link

I tested the saving/loading methods on a dataset.
image
Here are some results, using CTM for 2 topics: completely different thetas but 95% of classification labels were correct. So...
reproducibility is still a question here.

@vinid
Copy link
Contributor

vinid commented Feb 23, 2022

Refer to #106

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants