Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Different forward outputs in different model #18

Open
wuys13 opened this issue Jan 15, 2025 · 1 comment
Open

Different forward outputs in different model #18

wuys13 opened this issue Jan 15, 2025 · 1 comment

Comments

@wuys13
Copy link

wuys13 commented Jan 15, 2025

Thank you for sharing this outstanding work!
I notice in your code there are different outputs in different models (https://github.com/mahmoodlab/SurvPath/tree/3f73ddd6705ec67d643020c5bb04fb13f9f382cc/models):

  1. some are risks (e.g., in files of model_SurvPath.py and model_ABMIL.py):

     logits = self.to_logits(embedding)
    
     hazards = torch.sigmoid(logits)
     survival = torch.cumprod(1 - hazards, dim=1)
     risk = -torch.sum(survival, dim=1)
     
     return risk
    
  2. some are logits (e.g., in files of model_MLPWSI.py and model_TMIL.py):

     #---> get logits
     logits = self.to_logits(embedding)
    
     return logits
    
  3. even in some cases you comment out these codes (e.g., model_DeepMISL.py):

     logits  = self.classifier(h).unsqueeze(0) # logits needs to be a [1 x 4] vector 
     # Y_hat = torch.topk(logits, 1, dim = 1)[1]
     # hazards = torch.sigmoid(logits)
     # S = torch.cumprod(1 - hazards, dim=1)
     
     # return hazards, S, Y_hat, None, None
     return logits
    

I wonder this is a modified version for some tests, or you have included some data-processing codes elsewhere (however, I have checked Dataset and collate_fn part and do not find these processings). So is it true to directly use your bash scripts for the models using "logits" as output of forward ?

@guillaumejaume
Copy link
Contributor

Apologies for the late reply. Short answer is: it depends. What are you trying to achieve? If you want risk scores, you need to transform the logits "S", using the code you copy-pasted:

 logits  = self.classifier(h).unsqueeze(0) # logits needs to be a [1 x 4] vector 
 Y_hat = torch.topk(logits, 1, dim = 1)[1]
 hazards = torch.sigmoid(logits)
 S = torch.cumprod(1 - hazards, dim=1)

Hope this helps

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants