Skip to content

Commit

Permalink
Update to mypy 0.990 (#801)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianeboyd authored Nov 16, 2022
1 parent f010ad8 commit 0b363b2
Show file tree
Hide file tree
Showing 5 changed files with 300 additions and 3 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pytest-cov>=2.7.0,<2.8.0
coverage>=5.0.0,<6.0.0
mock>=2.0.0,<3.0.0
flake8>=3.5.0,<3.6.0
mypy>=0.980,<0.990; python_version >= "3.7"
mypy>=0.990,<0.1000; python_version >= "3.7"
types-mock>=0.1.1
types-contextvars>=0.1.2; python_version < "3.7"
types-dataclasses>=0.1.3; python_version < "3.7"
Expand Down
3 changes: 3 additions & 0 deletions thinc/loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Tuple, Sequence, cast, TypeVar, Generic, Any, Union, Optional, List
from typing import Dict
from abc import abstractmethod

from .types import Floats2d, Ints1d
from .util import get_array_module, to_categorical
Expand Down Expand Up @@ -27,9 +28,11 @@ def __init__(self, **kwargs: Any) -> None:
def __call__(self, guesses: GuessT, truths: TruthT) -> Tuple[GradT, LossT]:
return self.get_grad(guesses, truths), self.get_loss(guesses, truths)

@abstractmethod
def get_grad(self, guesses: GuessT, truths: TruthT) -> GradT:
...

@abstractmethod
def get_loss(self, guesses: GuessT, truths: TruthT) -> LossT:
...

Expand Down
4 changes: 3 additions & 1 deletion thinc/tests/backends/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,9 @@ def torch_softmax_with_temperature(
Yt = torch.nn.functional.softmax(Xt_temp, dim=-1)
Yt.backward(dYt)

return cast(Floats2d, torch2xp(Yt)), cast(Floats2d, torch2xp(Xt.grad))
return cast(Floats2d, torch2xp(Yt)), cast(
Floats2d, torch2xp(cast(torch.Tensor, Xt.grad))
)


@pytest.mark.skipif(not has_torch, reason="needs PyTorch")
Expand Down
2 changes: 1 addition & 1 deletion thinc/tests/layers/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def torch_softmax_with_temperature(

return cast(
Floats2d, torch2xp(torch.nn.functional.softmax(XWbt_temp, dim=-1))
), cast(Floats2d, torch2xp(Xt.grad))
), cast(Floats2d, torch2xp(cast(torch.Tensor, Xt.grad)))


@pytest.mark.skipif(not has_torch, reason="needs PyTorch")
Expand Down
Loading

0 comments on commit 0b363b2

Please sign in to comment.