Skip to content

Commit

Permalink
[FEAT][FractoralNorm}
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed May 12, 2024
1 parent 2e6e0b6 commit b9abb28
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 13 deletions.
10 changes: 10 additions & 0 deletions fractoral_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from zeta.nn import FractoralNorm # Importing the FractoralNorm class from the zeta.nn module
import torch # Importing the torch module for tensor operations

# Norm
x = torch.randn(2, 3, 4) # Generating a random tensor of size (2, 3, 4)

# FractoralNorm
normed = FractoralNorm(4, 4)(x) # Applying the FractoralNorm operation to the tensor x

print(normed) # Printing the size of the resulting tensor, which should be torch.Size([2, 3, 4])
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zetascale"
version = "2.4.5"
version = "2.4.6"
description = "Rapidly Build, Optimize, and Deploy SOTA AI Models"
authors = ["Zeta Team <kye@apac.ai>"]
license = "MIT"
Expand All @@ -16,7 +16,7 @@ packages = [
]

[tool.poetry.dependencies]
python = "^3.9"
python = "^3.10"
torch = ">=2.1.1,<3.0"
pytest = "8.1.1"
torchfix = "*"
Expand Down
File renamed without changes.
8 changes: 5 additions & 3 deletions zeta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from zeta.optim import * # noqa: F403, E402
from zeta.quant import * # noqa: F403, E402
from zeta.rl import * # noqa: F403, E402

# from zeta.tokenizers import * # noqa: F403, E402
from zeta.training import * # noqa: F403, E402
from zeta.utils import * # noqa: F403, E402
from zeta.experimental import * # noqa: F403, E402

try:
from zeta.experimental import * # noqa: F403, E402
except ImportError:
pass
5 changes: 4 additions & 1 deletion zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@
from zeta.nn.modules.query_proposal import TextHawkQueryProposal
from zeta.nn.modules.pixel_shuffling import PixelShuffleDownscale
from zeta.nn.modules.kan import KAN

from zeta.nn.modules.layer_scale import LayerScale
from zeta.nn.modules.fractoral_norm import FractoralNorm
# from zeta.nn.modules.img_reshape import image_reshape
# from zeta.nn.modules.flatten_features import flatten_features
# from zeta.nn.modules.scaled_sinusoidal import ScaledSinuosidalEmbedding
Expand Down Expand Up @@ -426,4 +427,6 @@
"TextHawkQueryProposal",
"PixelShuffleDownscale",
"KAN",
"LayerScale",
"FractoralNorm",
]
10 changes: 5 additions & 5 deletions zeta/nn/modules/feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from zeta.nn.modules.glu import GLU
from zeta.nn.modules.swiglu import SwiGLU
from typing import Optional
from zeta.experimental.triton.triton_modules.linear_proj import LinearTriton
# from zeta.experimental.triton.triton_modules.linear_proj import LinearTriton


class ReluSquared(nn.Module):
Expand Down Expand Up @@ -95,10 +95,10 @@ def __init__(
project_in = GLU(
dim, inner_dim, activation, mult_bias=glu_mult_bias
)
elif triton_kernels_on is True:
project_in = nn.Sequential(
LinearTriton(dim, inner_dim, bias=no_bias), activation
)
# elif triton_kernels_on is True:
# project_in = nn.Sequential(
# LinearTriton(dim, inner_dim, bias=no_bias), activation
# )
else:
project_in = nn.Sequential(
nn.Linear(dim, inner_dim, bias=not no_bias), activation
Expand Down
4 changes: 2 additions & 2 deletions zeta/nn/modules/fractoral_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ class FractoralNorm(nn.Module):
depth (int): Number of times to apply LayerNorm.
"""

def __init__(self, num_features: int, depth: int):
def __init__(self, num_features: int, depth: int, *args, **kwargs):
super().__init__()

self.layers = nn.ModuleList(
[nn.LayerNorm(num_features) for _ in range(depth)]
[nn.LayerNorm(num_features, *args, **kwargs) for _ in range(depth)]
)

def forward(self, x: Tensor) -> Tensor:
Expand Down
32 changes: 32 additions & 0 deletions zeta/nn/modules/layer_scale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from torch.nn import Module
import torch
from torch import nn, Tensor

class LayerScale(Module):
"""
Applies layer scaling to the output of a given module.
Args:
fn (Module): The module to apply layer scaling to.
dim (int): The dimension along which to apply the scaling.
init_value (float, optional): The initial value for the scaling factor. Defaults to 0.
Attributes:
fn (Module): The module to apply layer scaling to.
gamma (Parameter): The scaling factor parameter.
"""

def __init__(self, fn: Module, dim, init_value=0.):
super().__init__()
self.fn = fn
self.gamma = nn.Parameter(torch.ones(dim) * init_value)

def forward(self, x, **kwargs):
out = self.fn(x, **kwargs)

if isinstance(out, Tensor):
return out * self.gamma

out, *rest = out
return out * self.gamma, *rest

0 comments on commit b9abb28

Please sign in to comment.