Skip to content

Commit

Permalink
Merge branch 'release/0.1.15' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
ControlNet committed Oct 8, 2021
2 parents eb88b7b + 2c20b07 commit c9fea68
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 5 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ fn ~= 0.4.3
einops ~= 0.3.0
av ~= 8.0.3
matplotlib >= 3.3.4
clean-fid >= 0.1.12
clean-fid >= 0.1.12, < 0.1.14
pandas >= 1.3.1
ipython >= 7.26.0
pygments >= 2.9.0
39 changes: 39 additions & 0 deletions src/tensorneko/layer/aggregation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Sequence, Union

import torch
from torch import Tensor

from tensorneko import NekoModule
from tensorneko.util import F


class Aggregation(NekoModule):
"""
The torch module for aggregation.
Args:
mode (``str``, optional): The mode of aggregation. Default "mean".
dim (``int`` | ``Sequence[int]``, optional): The dimension chosen to apply aggregate function. Default None.
Examples::
x = torch.rand(1, 16, 32, 32)
global_avg_pooling = Aggregation("avg", dim=(1, 2, 3))
x_pooled = max_pooling(x)
"""
def __init__(self, mode: str = "mean", dim: Union[int, Sequence[int]] = None):
super().__init__()
if mode == "mean":
self.agg_func = F(torch.mean, dim=dim)
elif mode == "sum":
self.agg_func = F(torch.sum, dim=dim)
elif mode == "max":
self.agg_func = F(torch.max, dim=dim)
elif mode == "min":
self.agg_func = F(torch.min, dim=dim)
else:
raise ValueError("Wrong mode value. It should be in [mean, sum, max, min]")

def forward(self, x: Tensor) -> Tensor:
return self.agg_func(x)
12 changes: 9 additions & 3 deletions src/tensorneko/layer/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,21 @@ class Stack(NekoModule):
"""
The module version of torch.stack function family.
Args:
mode (``str``, optional): The mode of the pytorch stack type. Default original stack.
dim (``int``, optional): The dimension of stack apply to. Cannot use in non-default mode. Default 0.
Examples::
dstack = Stack("d")
x_stack = dstack([x1, x2])
"""

def __init__(self, mode: str = "", dim: int = 0):
super().__init__()
# other mode cannot specify the dim
assert not (mode == "" and dim != 0), "Other modes cannot specify the dim"
assert not (mode != "" and dim != 0), "Other modes cannot specify the dim"
if mode == "":
self.stack_func = F(torch.stack, dim=dim)
elif mode.lower() == "d":
Expand All @@ -32,8 +39,7 @@ def __init__(self, mode: str = "", dim: int = 0):
elif mode.lower() == "row":
self.stack_func = torch.row_stack
else:
raise ValueError("""Not a valid `mode` argument. It should be in ["", "d", "v", "h", "column", "row"].
""")
raise ValueError("""Not a valid `mode` argument. It should be in ["", "d", "v", "h", "column", "row"].""")

def forward(self, tensors: Union[List[Tensor], Tuple[Tensor, ...]]) -> Tensor:
return self.stack_func(tensors)
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.14
0.1.15

0 comments on commit c9fea68

Please sign in to comment.