Skip to content
/ Echo Public

Whisper like ASR transformer model. Experimental.

Notifications You must be signed in to change notification settings

sine2pi/Echo

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

eval train legend2

Pilot run. A sanity check to see if the models could learn. Attention heads on Echo 8 Whisper 16. Note the runtime changes with the focused attention hybrid.

    #--#
    import base64, gzip, math, os, functools, warnings, numpy as np, torch, transformers, aiohttp, torch.nn.functional as F, evaluate, json, random
    from torch import Tensor, amp, optim, nn
    from torch.utils.checkpoint import checkpoint
    from torch.utils.tensorboard.writer import SummaryWriter
    from threading import Thread
    from typing import Dict, Optional, Tuple, Union, List, Any
    from transformers.modeling_utils import PreTrainedModel
    from dataclasses import dataclass
    from transformers import (Seq2SeqTrainer, Seq2SeqTrainingArguments, PretrainedConfig, TrainerCallback, WhisperProcessor, WhisperFeatureExtractor, WhisperTokenizerFast)
    from evaluate import module
    from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score
    from sklearn.model_selection import KFold, train_test_split
    from datasets import load_dataset, Dataset, concatenate_datasets, IterableDatasetDict, Audio, load_from_disk
    from torch.nn.functional import scaled_dot_product_attention
    
    transformers.utils.logging.set_verbosity_error()
    warnings.filterwarnings(action="ignore")
    warnings.warn = lambda *args, **kwargs: None
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    dtype = torch.float32       
    
    #--#
    
    class Linear(nn.Linear):
        def forward(self, x: Tensor) -> Tensor:# type: ignore
            return F.linear(x, self.weight.to(x.dtype),
                             None if self.bias is None else self.bias.to(x.dtype))
    
    class Conv1d(nn.Conv1d):
        def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:# type: ignore
            return super()._conv_forward(x, weight.to(x.dtype),
                                         None if bias is None else bias.to(x.dtype))
    
    class LayerNorm(nn.LayerNorm):
        def forward(self, x: Tensor) -> Tensor:  # type: ignore
            return super().forward(x.float()).type(x.dtype)      
    #--#
    
    class CombinedRotaryEmbedding(nn.Module):
        def __init__(self, base, dims: int, head: int, theta_scale_learnable: bool = True,
                     n_rots_scale_learnable: bool = True, r_matrix_learnable: bool = False, inv_freq_learnable: bool = True):
            super().__init__()
            self.dims = dims
            self.head = head
            self.base = base
    
            assert self.dims % self.head == 0, "dims must be divisible by head"
            self.h_dim = self.dims // self.head
            assert self.h_dim % 2 == 0, "Head dimension must be even for rotary embeddings"
            self.n_rots = ((dims // head) // 2)
    
            self.thetas = nn.Parameter(torch.zeros(self.n_rots))
            self.r_pairs = nn.Parameter(data=torch.rand(self.n_rots, 2) * self.h_dim)
    
            self.theta_scale = nn.Parameter(torch.ones(1), requires_grad=theta_scale_learnable)
            self.n_rots_scale = nn.Parameter(torch.ones(1), requires_grad=n_rots_scale_learnable)
    
            # --- R Matrix --- loss += embedding_layer.orthogonal_regularization_term()
            self.r_matrix = nn.Parameter(torch.eye(n=self.h_dim), requires_grad=r_matrix_learnable)
    
            inv_freq_data = 1.0 / (self.base ** (torch.arange(start=0, end=self.h_dim, step=2).float() / self.h_dim))
            self.inv_freq = nn.Parameter(inv_freq_data, requires_grad=inv_freq_learnable)
    
            self.orthogonal_reg_weight = 0.01
    
        def givens_r_matrix(self, dims, i, j, theta):
            G = torch.eye(dims).to(theta.device)
            G[i, i] = math.cos(theta)
            G[i, j] = -math.sin(theta)
            G[j, i] = math.sin(theta)
            G[j, j] = math.cos(theta)
            return G
    
        def update_base(self, new_base):
            if new_base is not None and new_base != self.base:
                self.base = new_base
                inv_freq = 1.0 / (self.base ** (torch.arange(start=0, end=self.h_dim, step=2).float() / self.h_dim))
                self.inv_freq.data.copy_(inv_freq)
                self.update_pairs()
                # print("Pairs updated")
    
        def reset_parameters(self):
            nn.init.orthogonal_(tensor=self.r_matrix)
            nn.init.zeros_(tensor=self.thetas)
    
        def orthogonal_regularization_term(self):
            loss = torch.tensor(0.0, device=self.r_matrix.device)
            if self.r_matrix.requires_grad:
                product = torch.matmul(self.r_matrix, self.r_matrix.t())
                identity = torch.eye(self.r_matrix.size(0)).to(self.r_matrix.device)
                loss = ((product - identity) ** 2).sum()
            return self.orthogonal_reg_weight * loss
    
        def update_pairs(self):
            pairs = []
            while len(pairs) < self.n_rots:
                i, j = random.randint(0, self.h_dim - 1), random.randint(0, self.h_dim - 1)
                if i != j and (i, j) not in pairs and (j, i) not in pairs:
                    pairs.append((i, j))
            self.r_pairs.data.copy_(torch.tensor(pairs, dtype=torch.float32))
    
        def forward(self, x, global_step=None):
            if x.dim() not in [3, 4]:
                raise ValueError(f"Expected input tensor to be 3D or 4D, but got {x.dim()}D")
    
            batch_size, seq_len, *rest = x.size()
    
            if x.dim() == 3:
                dims = rest[0]
                if dims != self.head * self.h_dim:
                    raise ValueError(
                        f"Expected dims ({dims}) to be compatible with head ({self.head}) * h_dim ({self.h_dim}={self.head * self.h_dim})")
            else:
                head, h_dim = rest
                if head != self.head or h_dim != self.h_dim:
                    raise ValueError(
                        f"For 4D input, expected head {self.head} and h_dim {self.h_dim}, but got head {head} and h_dim {h_dim}")
    
            x = x.view(batch_size, seq_len, self.head, self.h_dim)
            x = x.reshape(-1, self.h_dim)
            adjusted_n_rots = int(torch.round(self.n_rots_scale * self.n_rots))
    
            for k in range(adjusted_n_rots):
                i, j = self.r_pairs[k].long()
                theta = self.thetas[k] * self.theta_scale
                G = self.givens_r_matrix(dims=self.h_dim, i=i, j=j, theta=theta)
                x = torch.matmul(input=x, other=G)
    
            x = torch.matmul(input=x, other=self.r_matrix)
            x = x.view(batch_size, seq_len, self.head, self.h_dim)
    
            sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(end=seq_len, device=x.device),
                                         self.inv_freq.to(device=x.device))
            sin = sinusoid_inp.sin()[None, :, None, :]
            cos = sinusoid_inp.cos()[None, :, None, :]
    
            x1, x2 = x[..., ::2], x[..., 1::2]
            x = torch.cat(tensors=[x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
            x = x.view(batch_size, seq_len, self.dims)
            return x
    
    class LearnedSinusoidalEmbeddings(nn.Module):
        def __init__(self, n_ctx, dims, checkpoint=False):
            super().__init__()
            self.n_ctx = n_ctx
            self.dims = dims
            self.checkpoint = checkpoint
    
            position = torch.arange(0, n_ctx, dtype=torch.float).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, dims, 2).float() * -(math.log(10000.0) / dims))
            features = torch.zeros(n_ctx, dims)
            features[:, 0::2] = torch.sin(position * div_term)
            features[:, 1::2] = torch.cos(position * div_term)
            self.register_buffer('my_big_toe', features)
            self.positional_embeddings = nn.Parameter(self.my_big_toe.clone())
    
        def forward(self, positions):
            if self.checkpoint:
                position_embeddings = checkpoint(lambda x: self.positional_embeddings[x], positions)
            else:
                position_embeddings = self.positional_embeddings[positions]
            return F.normalize(position_embeddings, p=2, dim=-1) # type: ignore
    
    
    class CombinedPositionalEmbedding(nn.Module):
        def __init__(self, base, dims, head, n_ctx, theta_scale_learnable=True, n_rots_scale_learnable=True, 
                     r_matrix_learnable=False, inv_freq_learnable=True, checkpoint=False):
            super().__init__()
            self.rotary_embedding = CombinedRotaryEmbedding(base, dims, head, theta_scale_learnable, 
                                                            n_rots_scale_learnable, r_matrix_learnable, inv_freq_learnable)
            self.sinusoidal_embedding = LearnedSinusoidalEmbeddings(n_ctx, dims, checkpoint)
    
        def forward(self, x, positions, global_step=None):
            rotary_embed = self.rotary_embedding(x, global_step)
            sinusoidal_embed = self.sinusoidal_embedding(positions)
            
            combined_embedding = rotary_embed + sinusoidal_embed
            return combined_embedding      
            
    #--#
    
    class MultiheadAttention(nn.Module):
        use_sdpa = True
    
        def __init__(self, base, dims, head, max_dist):
            super().__init__()
            assert dims % head == 0, "dims must be divisible by head"
            self.head = head
            self.h_dim = dims // head
            assert self.h_dim % 2 == 0, "Head dimension must be even for rotary embeddings"
            self.max_dist = max_dist
    
            self.query = nn.Linear(dims, dims)
            self.key = nn.Linear(dims, dims, bias=False)
            self.value = nn.Linear(dims, dims)
            self.out = nn.Linear(dims, dims)
    
            # self.combined_rotary = CombinedRotaryEmbedding(base, dims, head)
    
        def forward(self, x, xa = None, mask = None, kv_cache = None):
    
            q = self.query(x)
    
            if kv_cache is None or xa is None or self.key not in kv_cache:
                k = self.key(x if xa is None else xa)
                v = self.value(x if xa is None else xa)
    
            else:
                k = kv_cache[self.key]
                v = kv_cache[self.value]
    
    
            # q = self.combined_rotary(q)
            # k = self.combined_rotary(k)
    
            wv, qk = self.qkv_attention(q, k, v, mask)
    
            out = self.out(wv)
            return out, qk
    
        def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor,
                          mask: Optional[Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
            n_batch, n_ctx, dims = q.shape
    
            scale = (dims // self.head) ** -0.25
            q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
            k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
            v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
    
            if MultiheadAttention.use_sdpa:
                a = scaled_dot_product_attention(q, k, v, is_causal=mask is not None and n_ctx > 1)
                out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
                qk = None
            else:
                L, S = q.size(-2), k.size(-2)
                scale_factor = 1 / math.sqrt(q.size(-1)) if scale is None else scale
                attn_bias = torch.zeros(L, S, dtype=q.dtype)
                w = q @ k.transpose(-2, -1) * scale_factor
                w += attn_bias.to(q.dtype).to(q.device)
                w = torch.softmax(w, dim=-1).to(q.dtype)
    
                qk = (q * scale) @ (k * scale).transpose(-1, -2)
    
                if mask is not None:
                    qk = qk + mask[:n_ctx, :n_ctx]
    
                qk = qk.float()
                out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
                qk = qk.detach()
    
            return out, qk
            
    #--#
    
    class AdaptiveSpanAttention(nn.Module):
        def __init__(self, base, dims, head, max_dist, sharpen_longer, win_size, max_span, temp_scale=0.01):  
            super().__init__()
    
            self.max_dist = max_dist
            self.win_size = win_size
            self.max_span = max_span
            self.temp_scale = temp_scale
            self.multihead_attn = MultiheadAttention(base, dims, head, max_dist)
            self.span_scale = nn.Parameter(torch.tensor(1.0))
            self.sharpen_longer = sharpen_longer  
    
    
        def forward(self, query, key, value, span_scale):
            span_len = int(self.max_span * span_scale.mean().item())
            span_len = min(span_len, query.shape[1], key.shape[1], value.shape[1])
            eff_span = min(span_len, self.max_dist)
    
            q_span = query[:, :eff_span, :]
            k_span = key[:, :eff_span, :]
            v_span = value[:, :eff_span, :]
    
            attn_out, attn_weights = self.multihead_attn(q_span, k_span, v_span)
    
            if self.sharpen_longer:
                temperature = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())  # Sharper for longer spans
            else:
                temperature = 0.5 + self.temp_scale * span_scale.mean().item()  # Sharper for shorter spans
    
            batch_size, _, dims = query.shape
            scale = (dims // self.multihead_attn.head) ** -0.25
    
            q = q_span.view(q_span.shape[0], q_span.shape[1], self.multihead_attn.head, -1).permute(0, 2, 1, 3)
            k = k_span.view(k_span.shape[0], k_span.shape[1], self.multihead_attn.head, -1).permute(0, 2, 1, 3)
            v = v_span.view(v_span.shape[0], v_span.shape[1], self.multihead_attn.head, -1).permute(0, 2, 1, 3)
    
            attn_scores = torch.matmul(q, k.transpose(-2, -1))
            attn_weights = torch.softmax((attn_scores / temperature) * scale, dim=-1)
            attn_out = torch.matmul(attn_weights, v)
            attn_out = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
            attn_out = attn_out.contiguous().view(batch_size, eff_span, dims)
    
            return attn_out, attn_weights
        
    class SpanPredictor(nn.Module):
        def __init__(self, dims):
            super().__init__()
            self.linear = nn.Linear(dims, 1)
    
        def forward(self, global_out):
            scale = torch.sigmoid(self.linear(global_out))
            return scale
    
    class HybridAttention(nn.Module):
        def __init__(self, base, dims, head, max_dist, sharpen_longer, win_size=32, max_span=32, slid_win=32):
            super().__init__()
            self.max_dist = max_dist
            self.win_size = win_size
            self.max_span = max_span
            self.slid_win = slid_win
    
            self.span_pred = SpanPredictor(dims)
            self.dist_local = max_dist
            self.dist_global = max_dist
            self.attn_local = AdaptiveSpanAttention(base, dims, head, max_dist, sharpen_longer, win_size, max_span)
            self.attn_global = MultiheadAttention(base=base, dims=dims, head=head, max_dist=self.dist_global)
            self.ln_local = LayerNorm(dims)
            self.ln_global = LayerNorm(dims)
            self.projection = Linear(2 * dims, dims)
    
        def forward(self, x, new_dist=None, new_base=None, xa=None, mask=None, kv_cache=None):
    
            local = self.ln_local(x)
            globe = self.ln_global(x)
    
            globe_out, _ = self.attn_global(globe, globe, globe)
    
            span_scale = self.span_pred(globe_out.mean(dim=1))
    
            win_size = max(1, int(self.slid_win * span_scale.mean().item()))
            span_len = max(1, int(self.max_span * span_scale.mean().item()))
    
            effective_max_dist = min(self.max_dist, local.size(1))
            local_max_dist = min(self.dist_local, span_len, win_size)
            globe_max_dist = effective_max_dist
    
            self.attn_local.max_dist = local_max_dist
            self.attn_global.max_dist = globe_max_dist
    
            local_out = self.slide_win(local, win_size, span_len, span_scale)
    
            combined = torch.cat([local_out, globe_out], dim=-1)  
            x = self.projection(combined)
    
            return x
    
        def slide_win(self, x, win_size, span_len, span_scale):
            batch_size, seq_len, dims = x.size()
            out = torch.zeros_like(x, device=x.device)  
    
            for i in range(0, seq_len, win_size):
                end = min(i + win_size, seq_len)
                query = x[:, i:end, :]
    
                start = max(0, i - span_len + win_size) 
                key = x[:, start:i + span_len, :]
                value = x[:, start:i + span_len, :]
                attn_out, _ = self.attn_local(query, key, value, span_scale)
                out[:, i:end, :] = attn_out 
            return out
    
    #--#
    
    class ResidualAttentionBlock(nn.Module):
        def __init__(self, base, dims, head, max_dist, win_size, max_span, hybrid, checkpoint, cross, sharpen_longer):
            super().__init__()
    
            if hybrid:
                # print("HybridDrive ON")
                self.attn = HybridAttention(base, dims, head, max_dist, sharpen_longer)
                self.attn_ln = LayerNorm(dims)
            else:
                self.attn = MultiheadAttention(base, dims, head, max_dist)
                self.attn_ln = LayerNorm(dims)
    
            n_mlp = dims * 4
            self.mlp = nn.Sequential(nn.Linear(dims, n_mlp), nn.GELU(), nn.Linear(n_mlp, dims))
            self.mlp_ln = LayerNorm(dims)
    
        def forward(self, x, mask=None, kv_cache=None):
            x = self._attn_forward(x, mask, kv_cache)
            x = self._mlp_forward(x)
            return x
    
        def _attn_forward(self, x, mask=None, kv_cache=None):
            residual = x
            x = self.attn_ln(x)
    
            if isinstance(self.attn, HybridAttention):
                attn_output = self.attn(x)  
    
                x = residual + attn_output
            else:
                attn_output, _ = self.attn(x, mask=mask, kv_cache=kv_cache)  
                x = residual + attn_output
            return x
    
        def _mlp_forward(self, x):
            residual = x
            x = self.mlp_ln(x)
            return residual + self.mlp(x)
            
    #--#
    
    class AudioEncoder(nn.Module):
        def __init__(self, base, mels, dims, head, n_layer, n_ctx, max_dist,
                     win_size, max_span, hybrid, checkpoint, cross, sharpen_longer):
            super().__init__()
            self.conv1 = Conv1d(mels, dims, kernel_size=3, padding=1)
            self.conv2 = Conv1d(dims, dims, kernel_size=3, stride=2, padding=1)
            self.positional_embedding = LearnedSinusoidalEmbeddings(n_ctx, dims)
            self.checkpoint = checkpoint
    
            self.combined_rotary = CombinedRotaryEmbedding(base, dims, head)
    
            self.blocks = nn.ModuleList([ResidualAttentionBlock(base, dims, head, max_dist, win_size, max_span, hybrid, checkpoint, cross, sharpen_longer) for _ in range(n_layer)])
    
            self.ln_post = LayerNorm(dims)
    
        def forward(self, x):
            if self.checkpoint:
                x = checkpoint(self._conv_forward, x)
            else:
                x = self._conv_forward(x)
    
            for block in self.blocks:
                if self.checkpoint:
                    x = checkpoint(block, x)
                else:
                    x = block(x)
            return self.ln_post(x)
    
        def _conv_forward(self, x):
            x = F.gelu(self.conv1(x))
            x = F.gelu(self.conv2(x))
            x = x.permute(0, 2, 1)
    
            p = self.positional_embedding(torch.arange(x.size(1), device=x.device)).unsqueeze(0)
    
            x = x + p
            x = self.combined_rotary(x)
    
            return x
    
    #--#
            
    class TextDecoder(nn.Module):
        def __init__(self, base, vocab, dims, head, n_layer, n_ctx, max_dist,
                     win_size, max_span, hybrid, checkpoint, cross, sharpen_longer):
            super().__init__()
            self.token_embedding = nn.Embedding(vocab, dims)
            self.positional_embedding = LearnedSinusoidalEmbeddings(n_ctx, dims)
            self.checkpoint = checkpoint
    
            self.combined_rotary = CombinedRotaryEmbedding(base, dims, head)

            self.blocks = nn.ModuleList([ResidualAttentionBlock(base, dims, head, max_dist, win_size, max_span, hybrid, checkpoint, cross, sharpen_longer) for _ in range(n_layer)])
    
            self.ln_post = LayerNorm(dims)
            self.ln = LayerNorm(dims)
    
            mask = torch.empty(n_ctx, n_ctx).fill_(value=-np.inf).triu_(diagonal=1)
            self.register_buffer("mask", mask, persistent=False)
            self.mask=mask
    
        def forward(self, x, xa, kv_cache=None):
            if self.checkpoint:
                x = checkpoint(self._embedding_forward, x, xa, kv_cache)
            else:
                x = self._embedding_forward(x, xa, kv_cache)
    
            for block in self.blocks:
                if self.checkpoint:
                    x = checkpoint(block, x, self.mask, kv_cache)
                else:
                    x = block(x, self.mask, kv_cache)
    
            x = self.ln(x)
    
            logits = (x @ torch.transpose(self.token_embedding.weight.to(dtype=x.dtype), 0, 1)).float()
            return logits
    
        def _embedding_forward(self, x, xa, kv_cache):
            offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
            positions = torch.arange(x.shape[1], device=x.device) + offset
            pos_emb = self.positional_embedding(positions).unsqueeze(0)
    
            x = self.token_embedding(x) + pos_emb
            x = x.to(xa.dtype)
    
            x = self.combined_rotary(x)
            return x
    
    #--#
    
    class EchoConfig(PretrainedConfig):
        model_type = "Echo"
        def __init__(
            self,
            checkpoint=False,
            cross=False,
            hybrid=True,
            sharpen_longer=True,
            a_ctx=1500,
            a_head=8,
            a_layer=8,
            a_dims=1024,
            mels=128,
            t_ctx=448,
            t_head=8,
            t_layer=4,
            t_dims=1024,
            win_size=64,
            max_span=64,
            max_dist=128,
            base=10000,
            
            pad_token_id=50257,
            unk_token_id=50257,
            vocab=51865,
            eos_token_id=50257,
            bos_token_id=50257,
            decoder_start_token_id=50258,
            **kwargs,
        ):
            
            super().__init__(**kwargs) 
            self.base = base
            self.bos_token_id = bos_token_id
            self.checkpoint = checkpoint
            self.cross = cross
            self.decoder_start_token_id = decoder_start_token_id
            self.eos_token_id = eos_token_id
            self.hybrid = hybrid
            self.max_dist = max_dist
            self.max_span = max_span
            self.a_ctx = a_ctx
            self.a_head = a_head
            self.a_layer = a_layer
            self.a_dims = a_dims
            self.mels = mels
            self.t_ctx = t_ctx
            self.t_head = t_head
            self.t_layer = t_layer
            self.t_dims = t_dims
            self.pad_token_id = pad_token_id
            self.unk_token_id = unk_token_id
            self.vocab = vocab
            self.win_size = win_size
            self.sharpen_longer=sharpen_longer
    
    
    class Echo(PreTrainedModel):
        config_class = EchoConfig
        
        def __init__(self, config: EchoConfig):
            super().__init__(config)
            self.config = config
                
            self.encoder = AudioEncoder(
                base=self.config.base,
                mels=self.config.mels,
                dims=self.config.a_dims, 
                head=self.config.a_head,
                n_layer=self.config.a_layer,
                n_ctx=self.config.a_ctx,
                max_dist=self.config.max_dist,
                win_size=self.config.win_size,  
                max_span=self.config.max_span,
                hybrid=self.config.hybrid,
                checkpoint=self.config.checkpoint,
                cross=self.config.cross,
                sharpen_longer=self.config.sharpen_longer,
            )
    
            self.decoder = TextDecoder(
                base=self.config.base,
                vocab=self.config.vocab,
                dims=self.config.t_dims, 
                head=self.config.t_head,
                n_layer=self.config.t_layer,
                n_ctx=self.config.t_ctx,
                max_dist=self.config.max_dist,
                win_size=self.config.win_size,  
                max_span=self.config.max_span,
                hybrid=self.config.hybrid,
                checkpoint=self.config.checkpoint,
                cross=self.config.cross,
                sharpen_longer=self.config.sharpen_longer,
            )
    
    
            all_heads = torch.zeros(self.config.t_layer, self.config.t_head, dtype=torch.bool) 
            all_heads[self.config.t_layer // 2:] = True
            self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
    
            self.base = self.config.base
            self.win_size = self.config.win_size
            self.adjust_counter = 0
            self.best_loss = float('inf')
            self.kv_cache = {}
    
        def update_window(self, new_window):
            self.win_size = new_window
            for module in self.modules(): 
                if isinstance(module, HybridAttention):
                    module.update_window(self.win_size)
    
        def adjust_window(self, loss, factor=1.00005):
            if self.adjust_counter % 10 == 0:
                if loss < self.best_loss:
                    new_window = self.win_size * factor
                else:
                    new_window = self.win_size / factor
                self.update_window(new_window)
                self.best_loss = loss
                self.adjust_counter += 1
                return new_window
            return self.win_size
    
        def adjust_base(self, loss, factor=1.0025):
                    if self.adjust_counter % 25 == 0:
                        if loss < self.best_loss:
                            new_base=self.base*factor
                        else:
                            new_base=self.base/factor
                        self.update_base(new_base)
                        self.base=new_base
                        self.best_loss=loss
                    self.adjust_counter += 1
                    return self.base
                
        def update_base(self, new_base):
            self.new_base=new_base
            for name, module in self.encoder.named_modules():
                if isinstance(module, (CombinedRotaryEmbedding)):
                    module.update_base(self.new_base)
    
        @staticmethod
        def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):
            shifted_input_ids = input_ids.new_zeros(input_ids.shape)
            shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() 
            shifted_input_ids[:, 0] = decoder_start_token_id
            shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
            return shifted_input_ids
    
        def forward(self, input_features, labels=None, dec_input_ids=None):
            if labels is not None:
                if dec_input_ids is None:
                    dec_input_ids = self.shift_tokens_right(
                        labels, self.config.pad_token_id, self.config.decoder_start_token_id
                    )
    
            encoded_features = self.encoder(input_features).to(self.device)  
            logits = self.decoder(dec_input_ids, encoded_features)
    
            loss = None
            if labels is not None:
                loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
                labels = labels.to(logits.device).long()
                loss = loss_fct(logits.view(-1, self.config.vocab), labels.view(-1))
    
                self.adjust_base(loss.item())
    
            return {"loss": loss, "logits": logits}
    
        def reset_parameters(self):
            for name, module in self.encoder.named_modules():
                if isinstance(module, CombinedRotaryEmbedding):
                    module.reset_parameters()
            self.encoder.apply(self._init_weights)
            
        def _initialize_weights(self, module):
                nn.init.normal_(self.decoder.token_embedding.weight, mean=0.0, std=0.02)
    
                for block in self.decoder.blocks:
                    for layer in block.children():
                        if isinstance(layer, nn.Linear):
                            nn.init.xavier_normal_(layer.weight)
                            if layer.bias is not None:
                                nn.init.zeros_(layer.bias)
    
                nn.init.constant_(self.decoder.ln.weight, 1)
                if self.decoder.ln.bias is not None:
                    nn.init.constant_(self.decoder.ln.bias, 0)
    
                nn.init.xavier_normal_(self.encoder.conv1.weight)
                if self.encoder.conv1.bias is not None:
                    nn.init.zeros_(self.encoder.conv1.bias)
    
                nn.init.kaiming_normal_(self.encoder.conv2.weight, mode='fan_out', nonlinearity='relu')
                if self.encoder.conv2.bias is not None:
                    nn.init.zeros_(self.encoder.conv2.bias)
    
                nn.init.constant_(self.encoder.ln_post.weight, 1)
                if self.encoder.ln_post.bias is not None:
                    nn.init.constant_(self.encoder.ln_post.bias, 0)
    
        def apply_initialization(self, module):
            self._initialize_weights(module)
    
    from datetime import datetime
    log_dir = os.path.join('./output/', datetime.now().strftime('%Y-%m-%d_%H'))
    os.makedirs(log_dir, exist_ok=True)
    
    name="/echo_test/"
    config = EchoConfig(
        checkpoint=False,
        cross=False,
        hybrid=True,
        sharpen_longer=True,
        audio_ctx=1500,
        audio_head=16,
        audio_layer=8,
        audio_dims=1024,
        mels=128,
        text_ctx=448,
        text_head=8,
        text_layer=8,
        text_dims=1024,
        win_size=64,
        max_span=64,
        max_dist=128,
        base=10000,
        pad_token_id=50257,
        unk_token_id=50257,
        vocab=51865,
        eos_token_id=50257,
        bos_token_id=50257,
        decoder_start_token_id=50258,
    
    )
    
    config.save_pretrained(log_dir+name)
    model = Echo(config).to(device)
    model.apply_initialization(module=module)
    
    #--#
    
    feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small", feature_size=128)
    tokenizer = WhisperTokenizerFast.from_pretrained("openai/whisper-small", language="en", task="transcribe")
    processor = WhisperProcessor.from_pretrained("openai/whisper-small")
    
    class GradientClippingCallback(TrainerCallback):
        def on_step_end(self, args, dims, control, **kwargs):
            torch.nn.utils.clip_grad_norm_(parameters=kwargs["model"].parameters(), max_norm=0.98)
    
    class MetricsCallback(TrainerCallback):
        def __init__(self, tb_writer, tokenizer, metric, log_every_n_steps=1):
            super().__init__()
            self.tb_writer = tb_writer
            self.tokenizer = tokenizer
            self.metric = metric
            self.log_every_n_steps = log_every_n_steps
            self.predictions = None
            self.label_ids = None
    
        def compute_wer(self, pred_str, label_str):
            wer = 100 * self.metric.compute(predictions=pred_str, references=label_str)
            return wer
    
        def on_evaluate(self, args, dims, control, model, metrics=None, **kwargs):
            if metrics is not None:
                self.eval_loss = metrics.get('eval_loss')
    
                if dims.global_step % self.log_every_n_steps == 0:
                    for key, value in metrics.items():
                        if key.startswith("eval_"):
                            self.tb_writer.add_scalar(key, value, dims.global_step)
    
            if self.predictions is not None and self.label_ids is not None:
                pred_str = self.tokenizer.batch_decode(self.predictions, skip_special_tokens=True)
                label_str = self.tokenizer.batch_decode(self.label_ids, skip_special_tokens=True)
    
                    
                if dims.global_step % self.log_every_n_steps == 0:
                    total_samples = len(pred_str)  
                    random_indices = random.sample(range(total_samples), 2)  
    
                    for sample_index in random_indices:
                        self.tb_writer.add_text(f"Prediction_{sample_index}", pred_str[sample_index], dims.global_step)
                        self.tb_writer.add_text(f"Label_{sample_index}", label_str[sample_index], dims.global_step)
                        print(f"Evaluation: - Step {dims.global_step} - Loss: {self.eval_loss:.2f}")
                        print(f"Prediction: {pred_str[sample_index]}")
                        print(f"Label: {label_str[sample_index]}")
                        print("-" * 10)
    
    
            self.predictions = None
            self.label_ids = None
    
    def create_compute_metrics(callback_instance):
        def compute_metrics(eval_pred):
            pred_logits = eval_pred.predictions
            label_ids = eval_pred.label_ids
    
            if isinstance(pred_logits, tuple):
                pred_ids = pred_logits[0]
            else:
                pred_ids = pred_logits
            if pred_ids.ndim == 3:
                pred_ids = np.argmax(pred_ids, axis=-1)
    
            label_ids[label_ids == -100] = callback_instance.tokenizer.pad_token_id
            callback_instance.predictions = pred_ids
            callback_instance.label_ids = label_ids
            pred_str = callback_instance.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
            label_str = callback_instance.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
            wer = 100 * callback_instance.metric.compute(predictions=pred_str, references=label_str)
            pred_flat = pred_ids.flatten()
            labels_flat = label_ids.flatten()
            mask = labels_flat != callback_instance.tokenizer.pad_token_id
            
            accuracy = accuracy_score(y_true=labels_flat[mask], y_pred=pred_flat[mask])
            precision = precision_score(y_true=labels_flat[mask], y_pred=pred_flat[mask], average='weighted', zero_division=0)
            recall = recall_score(y_true=labels_flat[mask], y_pred=pred_flat[mask], average='weighted', zero_division=0)
            f1 = f1_score(y_true=labels_flat[mask], y_pred=pred_flat[mask], average='weighted', zero_division=0)
            return {"wer": wer, "accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}
        return compute_metrics
    
    @dataclass
    class DataCollatorSpeechSeq2SeqWithPadding:
        processor: Any
        tokenizer: Any
        feature_extractor: Any
    
        def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
            input_features = [{"input_features": feature["input_features"]} for feature in features]
            batch = feature_extractor.pad(input_features, return_tensors="pt")
            label_features = [{"input_ids": feature["labels"]} for feature in features]
            labels_batch = tokenizer.pad(label_features, return_tensors="pt")
            labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
            if (labels[:, 0] == tokenizer.bos_token_id).all().cpu().item():
                labels = labels[:, 1:]
            batch["labels"] = labels
            return batch
    
    def get_length_of_dataset(dataset):
        length = 0
        for item in dataset:
            length += len(item["audio"]["array"]) / item["audio"]["sampling_rate"]
        return length / 3600  
    
    def prepare_dataset(batch):
        batch["input_features"] = feature_extractor(batch["audio"]["array"], sampling_rate=batch["audio"]["sampling_rate"]).input_features[0]
        batch["labels"] = tokenizer(batch["text"]).input_ids
        return batch
    
    train=load_dataset("fixie-ai/librispeech_asr", "clean", split="train.100", streaming=True, trust_remote_code=True).map(prepare_dataset).select_columns(["input_features", "labels"])
    
    test=load_dataset("fixie-ai/librispeech_asr", "clean", split="test", streaming=True, trust_remote_code=True).map(prepare_dataset).select_columns(["input_features", "labels"])
    
    data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor, tokenizer=tokenizer, feature_extractor=feature_extractor)
    
    metric = evaluate.load(path="wer")
    tb_writer = SummaryWriter(log_dir=log_dir)
    metrics_callback = MetricsCallback(tb_writer=tb_writer, tokenizer=tokenizer, metric=metric, log_every_n_steps=5)
    compute_metrics = create_compute_metrics(callback_instance=metrics_callback)
            
    #--#
    
    training_args = Seq2SeqTrainingArguments(
        output_dir=log_dir,
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=1,
        eval_accumulation_steps=1,
        tf32=True,
        bf16=True,
        eval_strategy="steps",
        max_steps=10000,
        save_steps=500,
        eval_steps=500,
        warmup_ratio = 0.1,
        logging_steps=1,
        logging_dir=log_dir + "/logs_hf",
        report_to=["tensorboard"],
        load_best_model_at_end=True,
        metric_for_best_model="wer",
        greater_is_better=False,
        push_to_hub=False,
        optim="adafactor",
        weight_decay=0.0025,
        disable_tqdm=False,
        save_total_limit=1,
        save_strategy="steps",
        remove_unused_columns=False,
        label_names=["labels"],
        gradient_checkpointing=False,
        eval_on_start=False,
    )
    
    trainer = Seq2SeqTrainer(
        args=training_args,
        model=model,
        train_dataset=train,
        eval_dataset=test,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        tokenizer=feature_extractor,
        callbacks=[metrics_callback]
    )
    
    #--#
    
    trainer.train(resume_from_checkpoint=False)
    eval_results = trainer.evaluate()

Releases

No releases published

Packages

No packages published