-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathextractor.py
47 lines (38 loc) · 1.62 KB
/
extractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
import numpy as np
from pigmento import pnt
from base_lego import BaseLego
from loader.symbols import Symbols
from utils.config_init import CommandInit
class Extractor(BaseLego):
def extract_user_embedding(self):
self.manager.get_train_loader(Symbols.test)
assert self.cacher.user.cached, 'fast eval not enabled'
user_embeddings = self.cacher.user.repr.detach().cpu().numpy()
store_path = os.path.join(self.exp.dir, 'user_embeddings.npy')
pnt(f'store user embeddings to {store_path}')
np.save(store_path, user_embeddings)
def extract_item_embedding(self):
self.cacher.item.cache(self.resampler.item_cache)
item_embeddings = self.cacher.item.repr.detach().cpu().numpy()
store_path = os.path.join(self.exp.dir, 'item_embeddings.npy')
pnt(f'store item embeddings to {store_path}')
np.save(store_path, item_embeddings)
def run(self):
if self.config.target.lower() is Symbols.user.name:
return self.extract_user_embedding()
elif self.config.target.lower() is Symbols.item.name:
return self.extract_item_embedding()
raise ValueError(f'unknown target: {self.config.target}, expect "user" or "item"')
if __name__ == '__main__':
configuration = CommandInit(
required_args=['data', 'model', 'target'],
default_args=dict(
exp='config/exp/default.yaml',
embed='config/embed/null.yaml',
hidden_size=256,
item_hidden_size='${hidden_size}$',
),
).parse()
extractor = Extractor(config=configuration)
extractor.run()