-
Notifications
You must be signed in to change notification settings - Fork 856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Option to run LabelModel on GPU #1466
Conversation
Codecov Report
@@ Coverage Diff @@
## master #1466 +/- ##
==========================================
+ Coverage 97.58% 97.59% +<.01%
==========================================
Files 55 55
Lines 2032 2034 +2
Branches 334 334
==========================================
+ Hits 1983 1985 +2
Misses 22 22
Partials 27 27
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved assuming there's a good answer for the clone()
question.
@@ -334,7 +336,7 @@ def get_conditional_probs(self) -> np.ndarray: | |||
np.ndarray | |||
An [m, k + 1, k] np.ndarray conditional probabilities table. | |||
""" | |||
return self._get_conditional_probs(self.mu.detach().numpy()) | |||
return self._get_conditional_probs(self.mu.clone().cpu().detach().numpy()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is clone() necessary here but not the other places where you're adding .cpu().detach()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch, probably added during debugging. removed
@@ -750,7 +752,7 @@ def _count_accurate_lfs(self, mu: np.ndarray) -> int: | |||
int | |||
Number of LFs better than random | |||
""" | |||
P = self.P.numpy() | |||
P = self.P.clone().cpu().detach().numpy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same Q.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
@@ -95,7 +95,9 @@ def test_generate_O(self): | |||
[1 / 4, 0, 0, 1 / 4, 0, 1 / 4], | |||
] | |||
) | |||
np.testing.assert_array_almost_equal(label_model.O.numpy(), true_O) | |||
np.testing.assert_array_almost_equal( | |||
label_model.O.cpu().detach().numpy(), true_O |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feels very strange to see `label_model.O', but I know that's not part of this PR's scope.
Description of proposed changes
Fixed some parameter definitions to allow LabelModel to run on GPU.
Related issue(s)
Fixes #1430
Test plan
device='cuda'
for LabelModel tests.Checklist
Need help on these? Just ask!
tox -e complex
and/ortox -e spark
if appropriate.