Skip to content

Commit

Permalink
compare types with "is" so flake8 passes (#1958)
Browse files Browse the repository at this point in the history
Summary:
## Motivation

Flake8 is failing in the CI because it now requires comparing types with "is" rather than "==".

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: #1958

Test Plan: Units, flake8

Reviewed By: saitcakmak

Differential Revision: D47941295

Pulled By: esantorella

fbshipit-source-id: 8cdfd40274567c52ab6821024f71814da7b0eb97
  • Loading branch information
esantorella authored and facebook-github-bot committed Aug 1, 2023
1 parent a065992 commit f62cee3
Show file tree
Hide file tree
Showing 10 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion botorch/acquisition/multi_objective/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __init__(
self.add_module("objective", objective)
self.constraints = constraints
if constraints:
if type(eta) != Tensor:
if type(eta) is not Tensor:
eta = torch.full((len(constraints),), eta)
self.register_buffer("eta", eta)
self.X_pending = None
Expand Down
2 changes: 1 addition & 1 deletion botorch/acquisition/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def __init__(
"""
super().__init__(objective=objective)
self.constraints = constraints
if type(eta) != Tensor:
if type(eta) is not Tensor:
eta = torch.full((len(constraints),), eta)
self.register_buffer("eta", eta)
self.register_buffer("infeasible_cost", torch.as_tensor(infeasible_cost))
Expand Down
4 changes: 2 additions & 2 deletions botorch/models/fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@


def _handle_torch_linalg(exception: Exception) -> bool:
return type(exception) == torch.linalg.LinAlgError
return type(exception) is torch.linalg.LinAlgError


def _handle_valerr_in_dist_init(exception: Exception) -> bool:
if not type(exception) == ValueError:
if type(exception) is not ValueError:
return False
return "satisfy the constraint PositiveDefinite()" in str(exception)

Expand Down
4 changes: 2 additions & 2 deletions botorch/models/transforms/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def equals(self, other: InputTransform) -> bool:
"""
other_state_dict = other.state_dict()
return (
type(self) == type(other)
type(self) is type(other)
and (self.transform_on_train == other.transform_on_train)
and (self.transform_on_eval == other.transform_on_eval)
and (self.transform_on_fantasize == other.transform_on_fantasize)
Expand Down Expand Up @@ -1547,7 +1547,7 @@ def equals(self, other: InputTransform) -> bool:
A boolean indicating if the other transform is equivalent.
"""
return (
type(self) == type(other)
type(self) is type(other)
and (self.transform_on_train == other.transform_on_train)
and (self.transform_on_eval == other.transform_on_eval)
and (self.transform_on_fantasize == other.transform_on_fantasize)
Expand Down
2 changes: 1 addition & 1 deletion botorch/utils/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def compute_smoothed_feasibility_indicator(
Returns:
A `n_samples x b x q`-dim tensor of feasibility indicator values.
"""
if type(eta) != Tensor:
if type(eta) is not Tensor:
eta = torch.full((len(constraints),), eta)
if len(eta) != len(constraints):
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion botorch/utils/probability/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def concat(self, other: PivotedCholesky, dim: int = 0) -> PivotedCholesky:
for name in ("tril", "perm", "diag"):
a = getattr(self, name)
b = getattr(other, name)
if type(a) != type(b):
if type(a) is not type(b):
raise NotImplementedError(f"Types of field {name} do not match.")

if a is not None:
Expand Down
2 changes: 1 addition & 1 deletion botorch/utils/probability/mvnxpb.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def concat(self, other: MVNXPB, dim: int) -> MVNXPB:
if _self is None and _other is None:
continue

if type(_self) != type(_other):
if type(_self) is not type(_other):
raise TypeError(
f"Concatenation failed: `self.{key}` has type {type(_self)}, "
f"but `other.{key}` is of type {type(_self)}."
Expand Down
2 changes: 1 addition & 1 deletion test/models/test_fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ def potential_fn_rterr_foo(z):

# But once we register this specific error then it should
def catch_runtime_error(e):
return type(e) == RuntimeError and "foo" in str(e)
return type(e) is RuntimeError and "foo" in str(e)

register_exception_handler("foo_runtime", catch_runtime_error)
_, val = potential_grad(potential_fn_rterr_foo, z)
Expand Down
2 changes: 1 addition & 1 deletion test/optim/closures/test_model_closures.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_main(self):

def test_data_loader(self):
for mll in self.mlls.values():
if type(mll) != ExactMarginalLogLikelihood:
if type(mll) is not ExactMarginalLogLikelihood:
continue

dataset = TensorDataset(*mll.model.train_inputs, mll.model.train_targets)
Expand Down
4 changes: 2 additions & 2 deletions test/utils/test_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def _pow(a: int, b: int):
self.assertEqual(self.dispatcher[args], _pow)

retval = self.dispatcher(*args)
test_type = float if (type_a == float or type_b == float) else int
self.assertTrue(type(retval) == test_type)
test_type = float if (type_a is float or type_b is float) else int
self.assertIs(type(retval), test_type)
self.assertEqual(retval, test_type(8))

def test_notImplemented(self):
Expand Down

0 comments on commit f62cee3

Please sign in to comment.