diff --git a/semilearn/algorithms/srpseudolabel/srpseudolabel.py b/semilearn/algorithms/srpseudolabel/srpseudolabel.py index 41f15b1..80cae7a 100644 --- a/semilearn/algorithms/srpseudolabel/srpseudolabel.py +++ b/semilearn/algorithms/srpseudolabel/srpseudolabel.py @@ -34,11 +34,11 @@ def __init__(self, args, net_builder, tb_log=None, logger=None, **kwargs): super().__init__(args, net_builder, tb_log, logger, **kwargs) self.init(p_cutoff=args.p_cutoff, unsup_warm_up=args.unsup_warm_up) self.it=0 - self.rewarder = Rewarder(128,384).cuda(self.gpu) - self.generator = Generator(384).cuda(self.gpu) - self.starttiming=20000 - self.rewarder_optimizer = torch.optim.Adam(self.rewarder.parameters(), lr=0.0005) - self.generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.0005) + self.rewarder = Rewarder(128,self.featinput).cuda(self.gpu) + self.generator = Generator(self.featinput).cuda(self.gpu) + self.starttiming=self.start + self.rewarder_optimizer = torch.optim.Adam(self.rewarder.parameters(), lr=self.srlr) + self.generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.srlr) self.criterion = torch.nn.MSELoss() self.semi_reward_infer = SemiReward_infer(self.rewarder, self.starttiming) @@ -109,4 +109,4 @@ def get_argument(): SSL_Argument('--p_cutoff', float, 0.95), SSL_Argument('--unsup_warm_up', float, 0.4, 'warm up ratio for unsupervised loss'), # SSL_Argument('--use_flex', str2bool, False), - ] \ No newline at end of file + ]