-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsplitter.py
35 lines (26 loc) · 1.06 KB
/
splitter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from oba import Obj
from base_lego import BaseLego
from model.operators.lm_operator import BaseLMOperator
from utils.config_init import CommandInit
class Splitter(BaseLego):
def run(self):
item_op = self.legommender.item_op
if not isinstance(item_op, BaseLMOperator):
raise ValueError('item encoder is not a LMOperator')
layers = map(int, self.config.layers.split('+'))
layers = list(map(lambda x: x if x >= 0 else x + self.legommender.item_op.num_hidden_layers, layers))
if not self.embed.embeddings:
raise ValueError('please specify pretrained embedding configurations when using LM layer split')
item_op.cache(layers)
if __name__ == '__main__':
configuration = CommandInit(
required_args=['data', 'model', 'embed', 'layers'],
default_args=dict(
exp='config/exp/default.yaml',
# unused but required arguments
hidden_size=256,
batch_size=64,
),
).parse()
splitter = Splitter(config=configuration)
splitter.run()