Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to run LabelModel on GPU #1466

Merged
merged 4 commits into from
Sep 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions snorkel/labeling/model/label_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,9 @@ def _generate_O(self, L: np.ndarray, higher_order: bool = False) -> None:
"""
L_aug = self._get_augmented_label_matrix(L, higher_order=higher_order)
self.d = L_aug.shape[1]
self.O = torch.from_numpy(L_aug.T @ L_aug / self.n).float()
self.O = (
torch.from_numpy(L_aug.T @ L_aug / self.n).float().to(self.config.device)
)

def _init_params(self) -> None:
r"""Initialize the learned params.
Expand Down Expand Up @@ -269,7 +271,7 @@ def _init_params(self) -> None:

# Get the per-value labeling propensities
# Note that self.O must have been computed already!
lps = torch.diag(self.O).numpy()
lps = torch.diag(self.O).cpu().detach().numpy()

# TODO: Update for higher-order cliques!
self.mu_init = torch.zeros(self.d, self.cardinality)
Expand Down Expand Up @@ -335,7 +337,7 @@ def get_conditional_probs(self) -> np.ndarray:
np.ndarray
An [m, k + 1, k] np.ndarray conditional probabilities table.
"""
return self._get_conditional_probs(self.mu.detach().numpy())
return self._get_conditional_probs(self.mu.cpu().detach().numpy())

def get_weights(self) -> np.ndarray:
"""Return the vector of learned LF weights for combining LFs.
Expand All @@ -356,7 +358,7 @@ def get_weights(self) -> np.ndarray:
accs = np.zeros(self.m)
cprobs = self.get_conditional_probs()
for i in range(self.m):
accs[i] = np.diag(cprobs[i, 1:, :] @ self.P.numpy()).sum()
accs[i] = np.diag(cprobs[i, 1:, :] @ self.P.cpu().detach().numpy()).sum()
return np.clip(accs / self.coverage, 1e-6, 1.0)

def predict_proba(self, L: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -385,7 +387,7 @@ def predict_proba(self, L: np.ndarray) -> np.ndarray:
L_shift = L + 1 # convert to {0, 1, ..., k}
self._set_constants(L_shift)
L_aug = self._get_augmented_label_matrix(L_shift)
mu = self.mu.detach().numpy()
mu = self.mu.cpu().detach().numpy()
jtm = np.ones(L_aug.shape[1])

# Note: We omit abstains, effectively assuming uniform distribution here
Expand Down Expand Up @@ -516,7 +518,7 @@ def _loss_l2(self, l2: float = 0) -> torch.Tensor:
D = l2 * torch.eye(self.d)
else:
D = torch.diag(torch.from_numpy(l2)).type(torch.float32)

D = D.to(self.config.device)
# Note that mu is a matrix and this is the *Frobenius norm*
return torch.norm(D @ (self.mu - self.mu_init)) ** 2

Expand Down Expand Up @@ -569,7 +571,7 @@ def _set_class_balance(
raise ValueError(
f"Class balance prior is 0 for class(es) {np.where(self.p)[0]}."
)
self.P = torch.diag(torch.from_numpy(self.p)).float()
self.P = torch.diag(torch.from_numpy(self.p)).float().to(self.config.device)

def _set_constants(self, L: np.ndarray) -> None:
self.n, self.m = L.shape
Expand Down Expand Up @@ -756,7 +758,7 @@ def _count_accurate_lfs(self, mu: np.ndarray) -> int:
int
Number of LFs better than random
"""
P = self.P.numpy()
P = self.P.cpu().detach().numpy()
cprobs = self._get_conditional_probs(mu)
count = 0
for i in range(self.m):
Expand Down Expand Up @@ -787,8 +789,8 @@ def _break_col_permutation_symmetry(self) -> None:
assumption that we could use, and in practice this may require further
iteration here.
"""
mu = self.mu.detach().numpy()
P = self.P.numpy()
mu = self.mu.cpu().detach().numpy()
P = self.P.cpu().detach().numpy()
d, k = mu.shape

# Iterate through the possible perumation matrices and track heuristic scores
Expand All @@ -805,7 +807,11 @@ def _break_col_permutation_symmetry(self) -> None:
scores.append(-1)

# Set mu according to highest-scoring permutation
self.mu.data = torch.Tensor(mu @ Zs[np.argmax(scores)]) # type: ignore
self.mu = nn.Parameter(
torch.Tensor(mu @ Zs[np.argmax(scores)]).to( # type: ignore
self.config.device
)
)

def fit(
self,
Expand Down Expand Up @@ -878,6 +884,7 @@ def fit(
self.train()

# Move model to GPU
self.mu_init = self.mu_init.to(self.config.device)
if self.config.verbose and self.config.device != "cpu": # pragma: no cover
logging.info("Using GPU...")
self.to(self.config.device)
Expand Down
17 changes: 11 additions & 6 deletions test/labeling/model/test_label_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def test_generate_O(self):
[1 / 4, 0, 0, 1 / 4, 0, 1 / 4],
]
)
np.testing.assert_array_almost_equal(label_model.O.numpy(), true_O)
np.testing.assert_array_almost_equal(
label_model.O.cpu().detach().numpy(), true_O
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels very strange to see `label_model.O', but I know that's not part of this PR's scope.

)

label_model = self._set_up_model(L)
label_model._generate_O(L + 1, higher_order=False)
Expand All @@ -109,12 +111,16 @@ def test_generate_O(self):
[1 / 4, 0, 0, 1 / 4, 0, 1 / 4],
]
)
np.testing.assert_array_almost_equal(label_model.O.numpy(), true_O)
np.testing.assert_array_almost_equal(
label_model.O.cpu().detach().numpy(), true_O
)

# Higher order returns same matrix (num source = num cliques)
# Need to test c_tree form
label_model._generate_O(L + 1, higher_order=True)
np.testing.assert_array_almost_equal(label_model.O.numpy(), true_O)
np.testing.assert_array_almost_equal(
label_model.O.cpu().detach().numpy(), true_O
)

def test_augmented_L_construction(self):
# 5 LFs
Expand Down Expand Up @@ -288,9 +294,8 @@ def test_score(self):

def test_loss(self):
L = np.array([[0, -1, 0], [0, 1, -1]])
label_model = self._set_up_model(L)
label_model._get_augmented_label_matrix(L + 1, higher_order=True)

label_model = LabelModel(cardinality=2, verbose=False)
label_model.fit(L, n_epochs=1)
label_model.mu = nn.Parameter(label_model.mu_init.clone() + 0.05)

# l2_loss = l2*M*K*||mu - mu_init||_2 = 3*2*(0.05^2) = 0.03
Expand Down