Skip to content

Commit

Permalink
Update srpseudolabel.py
Browse files Browse the repository at this point in the history
  • Loading branch information
WayneJin0918 authored Oct 5, 2023
1 parent 160a5c8 commit 9e2742e
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions semilearn/algorithms/srpseudolabel/srpseudolabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def __init__(self, args, net_builder, tb_log=None, logger=None, **kwargs):
super().__init__(args, net_builder, tb_log, logger, **kwargs)
self.init(p_cutoff=args.p_cutoff, unsup_warm_up=args.unsup_warm_up)
self.it=0
self.rewarder = Rewarder(128,384).cuda(self.gpu)
self.generator = Generator(384).cuda(self.gpu)
self.starttiming=20000
self.rewarder_optimizer = torch.optim.Adam(self.rewarder.parameters(), lr=0.0005)
self.generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.0005)
self.rewarder = Rewarder(128,self.featinput).cuda(self.gpu)
self.generator = Generator(self.featinput).cuda(self.gpu)
self.starttiming=self.start
self.rewarder_optimizer = torch.optim.Adam(self.rewarder.parameters(), lr=self.srlr)
self.generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.srlr)
self.criterion = torch.nn.MSELoss()

self.semi_reward_infer = SemiReward_infer(self.rewarder, self.starttiming)
Expand Down Expand Up @@ -109,4 +109,4 @@ def get_argument():
SSL_Argument('--p_cutoff', float, 0.95),
SSL_Argument('--unsup_warm_up', float, 0.4, 'warm up ratio for unsupervised loss'),
# SSL_Argument('--use_flex', str2bool, False),
]
]

0 comments on commit 9e2742e

Please sign in to comment.