diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml
index 5f3d61f3..afdbc1c9 100644
--- a/docs/mkdocs.yml
+++ b/docs/mkdocs.yml
@@ -131,233 +131,232 @@ nav:
- Home:
- Overview: "index.md"
- Contributing: "contributing.md"
- - Zeta:
- - Overview: "zeta/index.md"
- - zeta.nn:
- - zeta.nn.biases:
- - Xpos: "zeta/nn/biases/xpos.md"
- - RelativePositionBias: "zeta/nn/biases/relative_bias.md"
- - AlibiPositionalBias: "zeta/nn/biases/alibi.md"
- - DynamicPositionBias: "zeta/nn/biases/dynamic.md"
- - zeta.nn.embeddings:
- - MultiWay: "zeta/nn/embeddings/multiway.md"
- - RotaryEmbeddings: "zeta/nn/embeddings/rope.md"
- - TruncatedRotaryEmbedding: "zeta/nn/embeddings/truncated_rope.md"
- - PositionalEmbedding: "zeta/nn/embeddings/positional_embeddings.md"
- - XPOS: "zeta/nn/embeddings/xpos.md"
- - YarnEmbedding: "zeta/nn/embeddings/yarn.md"
- - VisionEmbedding: "zeta/nn/embeddings/vis_emb.md"
- - SinusoidalEmbeddings: "zeta/nn/embeddings/sinusoidal.md"
- - PatchEmbeddings: "zeta/nn/embeddings/patch_embeddings.md"
- - PositionInterpolationEmbeddings: "zeta/nn/embeddings/positional_interpolation.md"
- - zeta.nn.modules:
- - custom_mlp: "zeta/nn/modules/custom_mlp.md"
- - mbconv: "zeta/nn/modules/mbconv.md"
- - dynamicroutingblock: "zeta/nn/modules/dynamicroutingblock.md"
- - clippedgeluactivation: "zeta/nn/modules/clippedgeluactivation.md"
- - mambablock: "zeta/nn/modules/mambablock.md"
- - vittransformerblock: "zeta/nn/modules/vittransformerblock.md"
- - fuseddensegeludense: "zeta/nn/modules/fuseddensegeludense.md"
- - pscan: "zeta/nn/modules/pscan.md"
- - adaptive: "zeta/nn/modules/adaptive.md"
- - filmconditioning: "zeta/nn/modules/filmconditioning.md"
- - mmfusionffn: "zeta/nn/modules/mmfusionffn.md"
- - quickgeluactivation: "zeta/nn/modules/quickgeluactivation.md"
- - gatedresidualblock: "zeta/nn/modules/gatedresidualblock.md"
- - highwaylayer: "zeta/nn/modules/highwaylayer.md"
- - multimodalmambablock: "zeta/nn/modules/multimodalmambablock.md"
- - rms_norm: "zeta/nn/modules/rms_norm.md"
- - ssm: "zeta/nn/modules/ssm.md"
- - dualpathblock: "zeta/nn/modules/dualpathblock.md"
- - topngating: "zeta/nn/modules/topngating.md"
- - mmlayernorm: "zeta/nn/modules/mmlayernorm.md"
- - mm_adapter: "zeta/nn/modules/mm_adapter.md"
- - laplaceactivation: "zeta/nn/modules/laplaceactivation.md"
- - nfnstem: "zeta/nn/modules/nfnstem.md"
- - laser: "zeta/nn/modules/laser.md"
- - denseblock: "zeta/nn/modules/denseblock.md"
- - depthwiseconv2d: "zeta/nn/modules/depthwiseconv2d.md"
- - lora: "zeta/nn/modules/lora.md"
- - vlayernorm: "zeta/nn/modules/vlayernorm.md"
- - flexiconv: "zeta/nn/modules/flexiconv.md"
- - pulsar: "zeta/nn/modules/pulsar.md"
- - pool: "zeta/nn/modules/pool.md"
- - time_up_sample: "zeta/nn/modules/time_up_sample.md"
- - spatial_downsample: "zeta/nn/modules/spatial_downsample.md"
- - parallel: "zeta/nn/modules/parallel.md"
- - conv2dfeedforward: "zeta/nn/modules/conv2dfeedforward.md"
- - video_autoencoder: "zeta/nn/modules/video_autoencoder.md"
- - recursiveblock: "zeta/nn/modules/recursiveblock.md"
- - relusquaredactivation: "zeta/nn/modules/relusquaredactivation.md"
- - fastgeluactivation: "zeta/nn/modules/fastgeluactivation.md"
- - token_learner: "zeta/nn/modules/token_learner.md"
- - layernorm: "zeta/nn/modules/layernorm.md"
- - averagemodelmerger: "zeta/nn/modules/averagemodelmerger.md"
- - linearactivation: "zeta/nn/modules/linearactivation.md"
- - stochdepth: "zeta/nn/modules/stochdepth.md"
- - expert: "zeta/nn/modules/expert.md"
- - siglip: "zeta/nn/modules/siglip.md"
- - ether: "zeta/nn/modules/ether.md"
- - newgeluactivation: "zeta/nn/modules/newgeluactivation.md"
- - pytorchgelutanh: "zeta/nn/modules/pytorchgelutanh.md"
- - multiscaleblock: "zeta/nn/modules/multiscaleblock.md"
- - umambablock: "zeta/nn/modules/umambablock.md"
- - film: "zeta/nn/modules/film.md"
- - adaptive_conv: "zeta/nn/modules/adaptive_conv.md"
- - fused_dropout_layernorm: "zeta/nn/modules/fused_dropout_layernorm.md"
- - accurategeluactivation: "zeta/nn/modules/accurategeluactivation.md"
- - exo: "zeta/nn/modules/exo.md"
- - polymorphic_activation: "zeta/nn/modules/polymorphic_activation.md"
- - fusedprojsoftmax: "zeta/nn/modules/fusedprojsoftmax.md"
- - quantizedln: "zeta/nn/modules/quantizedln.md"
- - postnorm: "zeta/nn/modules/postnorm.md"
- - moerouter: "zeta/nn/modules/moerouter.md"
- - geluactivation: "zeta/nn/modules/geluactivation.md"
- - visionattention: "zeta/nn/modules/visionattention.md"
- - fused_gelu_dense: "zeta/nn/modules/fused_gelu_dense.md"
- - feedforward: "zeta/nn/modules/feedforward.md"
- - wsconv2d: "zeta/nn/modules/wsconv2d.md"
- - mlp: "zeta/nn/modules/mlp.md"
- - slerpmodelmerger: "zeta/nn/modules/slerpmodelmerger.md"
- - fuseddropoutlayernorm: "zeta/nn/modules/fuseddropoutlayernorm.md"
- - tripleskipblock: "zeta/nn/modules/tripleskipblock.md"
- - dm: "zeta/nn/modules/dm.md"
- - feedbackblock: "zeta/nn/modules/feedbackblock.md"
- - mixtureofexperts: "zeta/nn/modules/mixtureofexperts.md"
- - mamba: "zeta/nn/modules/mamba.md"
- - perceiverlayer: "zeta/nn/modules/perceiverlayer.md"
- - mishactivation: "zeta/nn/modules/mishactivation.md"
- - hebbian: "zeta/nn/modules/hebbian.md"
- - simple_feedback: "zeta/nn/modules/simple_feedback.md"
- - visual_expert: "zeta/nn/modules/visual_expert.md"
- - stochasticskipblock: "zeta/nn/modules/stochasticskipblock.md"
- - unet: "zeta/nn/modules/unet.md"
- - zeta.nn.attention:
- - FlashAttention: "zeta/nn/attention/flash_attention.md"
- - MultiQueryAttention: "zeta/nn/attention/multiquery.md"
- - MultiheadAttention: "zeta/nn/attention/multihead.md"
- - FlashAttentionTwo: "zeta/nn/attention/flash2.md"
- - BaseAttention: "zeta/nn/attention/base.md"
- - LocalAttention: "zeta/nn/attention/local.md"
- - LocalMHA: "zeta/nn/attention/localmha.md"
- - MixtureOfAttention: "zeta/nn/attention/mixture_of_attention.md"
- - MixtureOfAutoregressiveAttention: "zeta/nn/attention/mixture_of_attention_ar.md"
- - SparseAttention: "zeta/nn/attention/sparse_attn.md"
- - zeta.tokenizers:
- - Language:
- - LanguageTokenizerGPTX: "zeta/tokenizers/language_tokenizer.md"
- - SentencePieceTokenizer: "zeta/tokenizers/sentencepiece.md"
- - TokenMonster: "zeta/tokenizers/token_monster.md"
- - MultiModal:
- - MultiModalTokenizer: "zeta/tokenizers/multi_modal_tokenizer.md"
- - zeta.utils:
- - Misc:
- - cast_tuple: "zeta/utils/cast_tuple.md"
- - group_by_key_prefix: "zeta/utils/group_by_key_prefix.md"
- - eval_decorator: "zeta/utils/eval_decorator.md"
- - print_cuda_memory_usage: "zeta/utils/print_cuda_memory_usage.md"
- - once: "zeta/utils/once.md"
- - default: "zeta/utils/default.md"
- - gumbel_noise: "zeta/utils/gumbel_noise.md"
- - pad_at_dim: "zeta/utils/pad_at_dim.md"
- - init_zero_: "zeta/utils/init_zero_.md"
- - top_p: "zeta/utils/top_p.md"
- - cast_if_src_dtype: "zeta/utils/cast_if_src_dtype.md"
- - disable_warnings_and_logs: "zeta/utils/disable_warnings_and_logs.md"
- - save_load_wrapper: "zeta/utils/save_load_wrapper.md"
- - get_sinusoid_encoding_table: "zeta/utils/get_sinusoid_encoding_table.md"
- - main: "zeta/utils/main.md"
- - string_begins_with: "zeta/utils/string_begins_with.md"
- - gif_to_tensor: "zeta/utils/gif_to_tensor.md"
- - l2norm: "zeta/utils/l2norm.md"
- - save_load: "zeta/utils/save_load.md"
- - log: "zeta/utils/log.md"
- - module_device: "zeta/utils/module_device.md"
- - print_num_params: "zeta/utils/print_num_params.md"
- - top_a: "zeta/utils/top_a.md"
- - interpolate_pos_encoding_2d: "zeta/utils/interpolate_pos_encoding_2d.md"
- - exists: "zeta/utils/exists.md"
- - cosine_beta_schedule: "zeta/utils/cosine_beta_schedule.md"
- - track_cuda_memory: "zeta/utils/track_cuda_memory.md"
- - maybe: "zeta/utils/maybe.md"
- - save_memory_snapshot: "zeta/utils/save_memory_snapshot.md"
- - top_k: "zeta/utils/top_k.md"
- - print_main: "zeta/utils/print_main.md"
- - pick_and_pop: "zeta/utils/pick_and_pop.md"
- - track_cuda_memory_usage: "zeta/utils/track_cuda_memory_usage.md"
- - group_dict_by_key: "zeta/utils/group_dict_by_key.md"
- - video_tensor_to_gift: "zeta/utils/video_tensor_to_gift.md"
- - zeta.ops:
- - Misc:
- - img_compose_decompose: "zeta/ops/img_compose_decompose.md"
- - img_transpose_2daxis: "zeta/ops/img_transpose_2daxis.md"
- - img_transpose: "zeta/ops/img_transpose.md"
- - img_order_of_axes: "zeta/ops/img_order_of_axes.md"
- - mos: "zeta/ops/mos.md"
- - merge_small_dims: "zeta/ops/merge_small_dims.md"
- - multi_dim_cat: "zeta/ops/multi_dim_cat.md"
- - img_compose_bw: "zeta/ops/img_compose_bw.md"
- - squeeze_2d_new: "zeta/ops/squeeze_2d_new.md"
- - temp_softmax: "zeta/ops/temp_softmax.md"
- - gumbelmax: "zeta/ops/gumbelmax.md"
- - _matrix_inverse_root_newton: "zeta/ops/_matrix_inverse_root_newton.md"
- - compute_matrix_root_inverse_residuals: "zeta/ops/compute_matrix_root_inverse_residuals.md"
- - matrix_root_diagonal: "zeta/ops/matrix_root_diagonal.md"
- - sparse_softmax: "zeta/ops/sparse_softmax.md"
- - reshape_audio_to_text: "zeta/ops/reshape_audio_to_text.md"
- - local_softmax: "zeta/ops/local_softmax.md"
- - softmaxes: "zeta/ops/softmaxes.md"
- - _matrix_root_eigen: "zeta/ops/_matrix_root_eigen.md"
- - main: "zeta/ops/main.md"
- - norm_exp_softmax: "zeta/ops/norm_exp_softmax.md"
- - multi_dim_split: "zeta/ops/multi_dim_split.md"
- - img_width_to_height: "zeta/ops/img_width_to_height.md"
- - fast_softmax: "zeta/ops/fast_softmax.md"
- - standard_softmax: "zeta/ops/standard_softmax.md"
- - unitwise_norm: "zeta/ops/unitwise_norm.md"
- - reshape_video_to_text: "zeta/ops/reshape_video_to_text.md"
- - img_decompose: "zeta/ops/img_decompose.md"
- - unsqueeze_2d_new: "zeta/ops/unsqueeze_2d_new.md"
- - reshape_img_to_text: "zeta/ops/reshape_img_to_text.md"
- - channel_shuffle_new: "zeta/ops/channel_shuffle_new.md"
- - matrix_inverse_root: "zeta/ops/matrix_inverse_root.md"
- - sparsemax: "zeta/ops/sparsemax.md"
- - gram_matrix_new: "zeta/ops/gram_matrix_new.md"
- - logit_scaled_softmax: "zeta/ops/logit_scaled_softmax.md"
- - selu_softmax: "zeta/ops/selu_softmax.md"
- - reshape_text_to_img: "zeta/ops/reshape_text_to_img.md"
- - zeta.optim:
- - Optimizers:
- - StableAdamWUnfused: "zeta/optims/adamw.md"
- - GradientAscent: "zeta/optims/ga.md"
- - DecoupledLionW: "zeta/training/optimizers/decoupled_lion.md"
- - SophiaG: "zeta/training/optimizers/sophia.md"
- - zeta.training:
- - Training:
- - fsdp: "zeta/training/fsdp.md"
- - ParallelWrapper: "zeta/training/parallel_wrapper.md"
- - train: "zeta/training/train.md"
- - zeta.models:
- - Language and MultiModal:
- - vit: "zeta/models/vit.md"
- - gpt4multimodal: "zeta/models/gpt4multimodal.md"
- - maxvit: "zeta/models/maxvit.md"
- - llama2: "zeta/models/llama2.md"
- - gpt4: "zeta/models/gpt4.md"
- - andromeda: "zeta/models/andromeda.md"
- - basemodel: "zeta/models/basemodel.md"
- - palme: "zeta/models/palme.md"
- - megavit: "zeta/models/megavit.md"
- - navit: "zeta/models/navit.md"
- - zeta.structs:
- - Structures:
- - Decoder: "zeta/nn/architecture/decoder.md"
- - Transformer: "zeta/nn/architecture/transformer.md"
- - paralleltransformerblock: "paralleltransformerblock.md"
- - zeta.quant:
- - Quantization Algorithms:
- - QUIK: "zeta/quant/quik.md"
- - BitLinear: "zeta/quant/bitlinear.md"
- - niva: "zeta/quant/niva.md"
- - zeta.rl:
- - DPO: "zeta/rl/dpo.md"
\ No newline at end of file
+ - Overview: "zeta/index.md"
+ - zeta.nn:
+ - zeta.nn.biases:
+ - Xpos: "zeta/nn/biases/xpos.md"
+ - RelativePositionBias: "zeta/nn/biases/relative_bias.md"
+ - AlibiPositionalBias: "zeta/nn/biases/alibi.md"
+ - DynamicPositionBias: "zeta/nn/biases/dynamic.md"
+ - zeta.nn.embeddings:
+ - MultiWay: "zeta/nn/embeddings/multiway.md"
+ - RotaryEmbeddings: "zeta/nn/embeddings/rope.md"
+ - TruncatedRotaryEmbedding: "zeta/nn/embeddings/truncated_rope.md"
+ - PositionalEmbedding: "zeta/nn/embeddings/positional_embeddings.md"
+ - XPOS: "zeta/nn/embeddings/xpos.md"
+ - YarnEmbedding: "zeta/nn/embeddings/yarn.md"
+ - VisionEmbedding: "zeta/nn/embeddings/vis_emb.md"
+ - SinusoidalEmbeddings: "zeta/nn/embeddings/sinusoidal.md"
+ - PatchEmbeddings: "zeta/nn/embeddings/patch_embeddings.md"
+ - PositionInterpolationEmbeddings: "zeta/nn/embeddings/positional_interpolation.md"
+ - zeta.nn.modules:
+ - custom_mlp: "zeta/nn/modules/custom_mlp.md"
+ - mbconv: "zeta/nn/modules/mbconv.md"
+ - dynamicroutingblock: "zeta/nn/modules/dynamicroutingblock.md"
+ - clippedgeluactivation: "zeta/nn/modules/clippedgeluactivation.md"
+ - mambablock: "zeta/nn/modules/mambablock.md"
+ - vittransformerblock: "zeta/nn/modules/vittransformerblock.md"
+ - fuseddensegeludense: "zeta/nn/modules/fuseddensegeludense.md"
+ - pscan: "zeta/nn/modules/pscan.md"
+ - adaptive: "zeta/nn/modules/adaptive.md"
+ - filmconditioning: "zeta/nn/modules/filmconditioning.md"
+ - mmfusionffn: "zeta/nn/modules/mmfusionffn.md"
+ - quickgeluactivation: "zeta/nn/modules/quickgeluactivation.md"
+ - gatedresidualblock: "zeta/nn/modules/gatedresidualblock.md"
+ - highwaylayer: "zeta/nn/modules/highwaylayer.md"
+ - multimodalmambablock: "zeta/nn/modules/multimodalmambablock.md"
+ - rms_norm: "zeta/nn/modules/rms_norm.md"
+ - ssm: "zeta/nn/modules/ssm.md"
+ - dualpathblock: "zeta/nn/modules/dualpathblock.md"
+ - topngating: "zeta/nn/modules/topngating.md"
+ - mmlayernorm: "zeta/nn/modules/mmlayernorm.md"
+ - mm_adapter: "zeta/nn/modules/mm_adapter.md"
+ - laplaceactivation: "zeta/nn/modules/laplaceactivation.md"
+ - nfnstem: "zeta/nn/modules/nfnstem.md"
+ - laser: "zeta/nn/modules/laser.md"
+ - denseblock: "zeta/nn/modules/denseblock.md"
+ - depthwiseconv2d: "zeta/nn/modules/depthwiseconv2d.md"
+ - lora: "zeta/nn/modules/lora.md"
+ - vlayernorm: "zeta/nn/modules/vlayernorm.md"
+ - flexiconv: "zeta/nn/modules/flexiconv.md"
+ - pulsar: "zeta/nn/modules/pulsar.md"
+ - pool: "zeta/nn/modules/pool.md"
+ - time_up_sample: "zeta/nn/modules/time_up_sample.md"
+ - spatial_downsample: "zeta/nn/modules/spatial_downsample.md"
+ - parallel: "zeta/nn/modules/parallel.md"
+ - conv2dfeedforward: "zeta/nn/modules/conv2dfeedforward.md"
+ - video_autoencoder: "zeta/nn/modules/video_autoencoder.md"
+ - recursiveblock: "zeta/nn/modules/recursiveblock.md"
+ - relusquaredactivation: "zeta/nn/modules/relusquaredactivation.md"
+ - fastgeluactivation: "zeta/nn/modules/fastgeluactivation.md"
+ - token_learner: "zeta/nn/modules/token_learner.md"
+ - layernorm: "zeta/nn/modules/layernorm.md"
+ - averagemodelmerger: "zeta/nn/modules/averagemodelmerger.md"
+ - linearactivation: "zeta/nn/modules/linearactivation.md"
+ - stochdepth: "zeta/nn/modules/stochdepth.md"
+ - expert: "zeta/nn/modules/expert.md"
+ - siglip: "zeta/nn/modules/siglip.md"
+ - ether: "zeta/nn/modules/ether.md"
+ - newgeluactivation: "zeta/nn/modules/newgeluactivation.md"
+ - pytorchgelutanh: "zeta/nn/modules/pytorchgelutanh.md"
+ - multiscaleblock: "zeta/nn/modules/multiscaleblock.md"
+ - umambablock: "zeta/nn/modules/umambablock.md"
+ - film: "zeta/nn/modules/film.md"
+ - adaptive_conv: "zeta/nn/modules/adaptive_conv.md"
+ - fused_dropout_layernorm: "zeta/nn/modules/fused_dropout_layernorm.md"
+ - accurategeluactivation: "zeta/nn/modules/accurategeluactivation.md"
+ - exo: "zeta/nn/modules/exo.md"
+ - polymorphic_activation: "zeta/nn/modules/polymorphic_activation.md"
+ - fusedprojsoftmax: "zeta/nn/modules/fusedprojsoftmax.md"
+ - quantizedln: "zeta/nn/modules/quantizedln.md"
+ - postnorm: "zeta/nn/modules/postnorm.md"
+ - moerouter: "zeta/nn/modules/moerouter.md"
+ - geluactivation: "zeta/nn/modules/geluactivation.md"
+ - visionattention: "zeta/nn/modules/visionattention.md"
+ - fused_gelu_dense: "zeta/nn/modules/fused_gelu_dense.md"
+ - feedforward: "zeta/nn/modules/feedforward.md"
+ - wsconv2d: "zeta/nn/modules/wsconv2d.md"
+ - mlp: "zeta/nn/modules/mlp.md"
+ - slerpmodelmerger: "zeta/nn/modules/slerpmodelmerger.md"
+ - fuseddropoutlayernorm: "zeta/nn/modules/fuseddropoutlayernorm.md"
+ - tripleskipblock: "zeta/nn/modules/tripleskipblock.md"
+ - dm: "zeta/nn/modules/dm.md"
+ - feedbackblock: "zeta/nn/modules/feedbackblock.md"
+ - mixtureofexperts: "zeta/nn/modules/mixtureofexperts.md"
+ - mamba: "zeta/nn/modules/mamba.md"
+ - perceiverlayer: "zeta/nn/modules/perceiverlayer.md"
+ - mishactivation: "zeta/nn/modules/mishactivation.md"
+ - hebbian: "zeta/nn/modules/hebbian.md"
+ - simple_feedback: "zeta/nn/modules/simple_feedback.md"
+ - visual_expert: "zeta/nn/modules/visual_expert.md"
+ - stochasticskipblock: "zeta/nn/modules/stochasticskipblock.md"
+ - unet: "zeta/nn/modules/unet.md"
+ - zeta.nn.attention:
+ - FlashAttention: "zeta/nn/attention/flash_attention.md"
+ - MultiQueryAttention: "zeta/nn/attention/multiquery.md"
+ - MultiheadAttention: "zeta/nn/attention/multihead.md"
+ - FlashAttentionTwo: "zeta/nn/attention/flash2.md"
+ - BaseAttention: "zeta/nn/attention/base.md"
+ - LocalAttention: "zeta/nn/attention/local.md"
+ - LocalMHA: "zeta/nn/attention/localmha.md"
+ - MixtureOfAttention: "zeta/nn/attention/mixture_of_attention.md"
+ - MixtureOfAutoregressiveAttention: "zeta/nn/attention/mixture_of_attention_ar.md"
+ - SparseAttention: "zeta/nn/attention/sparse_attn.md"
+ - zeta.tokenizers:
+ - Language:
+ - LanguageTokenizerGPTX: "zeta/tokenizers/language_tokenizer.md"
+ - SentencePieceTokenizer: "zeta/tokenizers/sentencepiece.md"
+ - TokenMonster: "zeta/tokenizers/token_monster.md"
+ - MultiModal:
+ - MultiModalTokenizer: "zeta/tokenizers/multi_modal_tokenizer.md"
+ - zeta.utils:
+ - Misc:
+ - cast_tuple: "zeta/utils/cast_tuple.md"
+ - group_by_key_prefix: "zeta/utils/group_by_key_prefix.md"
+ - eval_decorator: "zeta/utils/eval_decorator.md"
+ - print_cuda_memory_usage: "zeta/utils/print_cuda_memory_usage.md"
+ - once: "zeta/utils/once.md"
+ - default: "zeta/utils/default.md"
+ - gumbel_noise: "zeta/utils/gumbel_noise.md"
+ - pad_at_dim: "zeta/utils/pad_at_dim.md"
+ - init_zero_: "zeta/utils/init_zero_.md"
+ - top_p: "zeta/utils/top_p.md"
+ - cast_if_src_dtype: "zeta/utils/cast_if_src_dtype.md"
+ - disable_warnings_and_logs: "zeta/utils/disable_warnings_and_logs.md"
+ - save_load_wrapper: "zeta/utils/save_load_wrapper.md"
+ - get_sinusoid_encoding_table: "zeta/utils/get_sinusoid_encoding_table.md"
+ - main: "zeta/utils/main.md"
+ - string_begins_with: "zeta/utils/string_begins_with.md"
+ - gif_to_tensor: "zeta/utils/gif_to_tensor.md"
+ - l2norm: "zeta/utils/l2norm.md"
+ - save_load: "zeta/utils/save_load.md"
+ - log: "zeta/utils/log.md"
+ - module_device: "zeta/utils/module_device.md"
+ - print_num_params: "zeta/utils/print_num_params.md"
+ - top_a: "zeta/utils/top_a.md"
+ - interpolate_pos_encoding_2d: "zeta/utils/interpolate_pos_encoding_2d.md"
+ - exists: "zeta/utils/exists.md"
+ - cosine_beta_schedule: "zeta/utils/cosine_beta_schedule.md"
+ - track_cuda_memory: "zeta/utils/track_cuda_memory.md"
+ - maybe: "zeta/utils/maybe.md"
+ - save_memory_snapshot: "zeta/utils/save_memory_snapshot.md"
+ - top_k: "zeta/utils/top_k.md"
+ - print_main: "zeta/utils/print_main.md"
+ - pick_and_pop: "zeta/utils/pick_and_pop.md"
+ - track_cuda_memory_usage: "zeta/utils/track_cuda_memory_usage.md"
+ - group_dict_by_key: "zeta/utils/group_dict_by_key.md"
+ - video_tensor_to_gift: "zeta/utils/video_tensor_to_gift.md"
+ - zeta.ops:
+ - Misc:
+ - img_compose_decompose: "zeta/ops/img_compose_decompose.md"
+ - img_transpose_2daxis: "zeta/ops/img_transpose_2daxis.md"
+ - img_transpose: "zeta/ops/img_transpose.md"
+ - img_order_of_axes: "zeta/ops/img_order_of_axes.md"
+ - mos: "zeta/ops/mos.md"
+ - merge_small_dims: "zeta/ops/merge_small_dims.md"
+ - multi_dim_cat: "zeta/ops/multi_dim_cat.md"
+ - img_compose_bw: "zeta/ops/img_compose_bw.md"
+ - squeeze_2d_new: "zeta/ops/squeeze_2d_new.md"
+ - temp_softmax: "zeta/ops/temp_softmax.md"
+ - gumbelmax: "zeta/ops/gumbelmax.md"
+ - _matrix_inverse_root_newton: "zeta/ops/_matrix_inverse_root_newton.md"
+ - compute_matrix_root_inverse_residuals: "zeta/ops/compute_matrix_root_inverse_residuals.md"
+ - matrix_root_diagonal: "zeta/ops/matrix_root_diagonal.md"
+ - sparse_softmax: "zeta/ops/sparse_softmax.md"
+ - reshape_audio_to_text: "zeta/ops/reshape_audio_to_text.md"
+ - local_softmax: "zeta/ops/local_softmax.md"
+ - softmaxes: "zeta/ops/softmaxes.md"
+ - _matrix_root_eigen: "zeta/ops/_matrix_root_eigen.md"
+ - main: "zeta/ops/main.md"
+ - norm_exp_softmax: "zeta/ops/norm_exp_softmax.md"
+ - multi_dim_split: "zeta/ops/multi_dim_split.md"
+ - img_width_to_height: "zeta/ops/img_width_to_height.md"
+ - fast_softmax: "zeta/ops/fast_softmax.md"
+ - standard_softmax: "zeta/ops/standard_softmax.md"
+ - unitwise_norm: "zeta/ops/unitwise_norm.md"
+ - reshape_video_to_text: "zeta/ops/reshape_video_to_text.md"
+ - img_decompose: "zeta/ops/img_decompose.md"
+ - unsqueeze_2d_new: "zeta/ops/unsqueeze_2d_new.md"
+ - reshape_img_to_text: "zeta/ops/reshape_img_to_text.md"
+ - channel_shuffle_new: "zeta/ops/channel_shuffle_new.md"
+ - matrix_inverse_root: "zeta/ops/matrix_inverse_root.md"
+ - sparsemax: "zeta/ops/sparsemax.md"
+ - gram_matrix_new: "zeta/ops/gram_matrix_new.md"
+ - logit_scaled_softmax: "zeta/ops/logit_scaled_softmax.md"
+ - selu_softmax: "zeta/ops/selu_softmax.md"
+ - reshape_text_to_img: "zeta/ops/reshape_text_to_img.md"
+ - zeta.optim:
+ - Optimizers:
+ - StableAdamWUnfused: "zeta/optims/adamw.md"
+ - GradientAscent: "zeta/optims/ga.md"
+ - DecoupledLionW: "zeta/training/optimizers/decoupled_lion.md"
+ - SophiaG: "zeta/training/optimizers/sophia.md"
+ - zeta.training:
+ - Training:
+ - fsdp: "zeta/training/fsdp.md"
+ - ParallelWrapper: "zeta/training/parallel_wrapper.md"
+ - train: "zeta/training/train.md"
+ - zeta.models:
+ - Language and MultiModal:
+ - vit: "zeta/models/vit.md"
+ - gpt4multimodal: "zeta/models/gpt4multimodal.md"
+ - maxvit: "zeta/models/maxvit.md"
+ - llama2: "zeta/models/llama2.md"
+ - gpt4: "zeta/models/gpt4.md"
+ - andromeda: "zeta/models/andromeda.md"
+ - basemodel: "zeta/models/basemodel.md"
+ - palme: "zeta/models/palme.md"
+ - megavit: "zeta/models/megavit.md"
+ - navit: "zeta/models/navit.md"
+ - zeta.structs:
+ - Structures:
+ - Decoder: "zeta/nn/architecture/decoder.md"
+ - Transformer: "zeta/nn/architecture/transformer.md"
+ - paralleltransformerblock: "paralleltransformerblock.md"
+ - zeta.quant:
+ - Quantization Algorithms:
+ - QUIK: "zeta/quant/quik.md"
+ - BitLinear: "zeta/quant/bitlinear.md"
+ - niva: "zeta/quant/niva.md"
+ - zeta.rl:
+ - DPO: "zeta/rl/dpo.md"
\ No newline at end of file
diff --git a/training/gan/gan.py b/training/gan/gan.py
new file mode 100644
index 00000000..69c88cfd
--- /dev/null
+++ b/training/gan/gan.py
@@ -0,0 +1,167 @@
+import logging
+import os
+import shutil
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torchaudio
+from datasets import load_dataset
+from torch.utils.data import DataLoader
+
+logging.basicConfig(
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
+)
+logger = logging.getLogger(__name__)
+
+
+class Generator(nn.Module):
+ def __init__(self, latent_dim: int, output_dim: int):
+ super().__init__()
+ self.model = nn.Sequential(
+ nn.Linear(latent_dim, 256),
+ nn.LeakyReLU(0.2),
+ nn.Linear(256, 512),
+ nn.LeakyReLU(0.2),
+ nn.Linear(512, 1024),
+ nn.LeakyReLU(0.2),
+ nn.Linear(1024, output_dim),
+ nn.Tanh(),
+ )
+
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
+ return self.model(z)
+
+
+class Discriminator(nn.Module):
+ def __init__(self, input_dim: int):
+ super().__init__()
+ self.model = nn.Sequential(
+ nn.Linear(input_dim, 1024),
+ nn.LeakyReLU(0.2),
+ nn.Linear(1024, 512),
+ nn.LeakyReLU(0.2),
+ nn.Linear(512, 256),
+ nn.LeakyReLU(0.2),
+ nn.Linear(256, 1),
+ nn.Sigmoid(),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.model(x)
+
+
+def train_gan(
+ generator: nn.Module,
+ discriminator: nn.Module,
+ dataloader: DataLoader,
+ num_epochs: int,
+ latent_dim: int,
+ device: torch.device,
+) -> Tuple[nn.Module, nn.Module]:
+ criterion = nn.BCELoss()
+ g_optimizer = optim.Adam(
+ generator.parameters(), lr=0.0002, betas=(0.5, 0.999)
+ )
+ d_optimizer = optim.Adam(
+ discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999)
+ )
+
+ for epoch in range(num_epochs):
+ for i, batch in enumerate(dataloader):
+ real_samples = batch["audio"].to(device)
+ batch_size = real_samples.size(0)
+
+ d_optimizer.zero_grad()
+ real_labels = torch.ones(batch_size, 1).to(device)
+ fake_labels = torch.zeros(batch_size, 1).to(device)
+
+ d_real_output = discriminator(real_samples)
+ d_real_loss = criterion(d_real_output, real_labels)
+
+ z = torch.randn(batch_size, latent_dim).to(device)
+ fake_samples = generator(z)
+ d_fake_output = discriminator(fake_samples.detach())
+ d_fake_loss = criterion(d_fake_output, fake_labels)
+
+ d_loss = d_real_loss + d_fake_loss
+ d_loss.backward()
+ d_optimizer.step()
+
+ g_optimizer.zero_grad()
+ z = torch.randn(batch_size, latent_dim).to(device)
+ fake_samples = generator(z)
+ d_output = discriminator(fake_samples)
+ g_loss = criterion(d_output, real_labels)
+ g_loss.backward()
+ g_optimizer.step()
+
+ logger.info(
+ f"Epoch [{epoch+1}/{num_epochs}] D_loss: {d_loss.item():.4f} G_loss: {g_loss.item():.4f}"
+ )
+
+ return generator, discriminator
+
+
+def main():
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ logger.info(f"Using device: {device}")
+
+ # Clear the dataset cache
+ cache_dir = os.path.expanduser("~/.cache/huggingface/datasets")
+ if os.path.exists(cache_dir):
+ shutil.rmtree(cache_dir)
+ logger.info("Cleared dataset cache")
+
+ # Try to load the dataset with retries
+ max_retries = 3
+ for attempt in range(max_retries):
+ try:
+ dataset = load_dataset(
+ "mozilla-foundation/common_voice_11_0",
+ "en",
+ split="train[:1000]",
+ trust_remote_code=True,
+ )
+ logger.info(f"Dataset loaded: {len(dataset)} samples")
+ break
+ except Exception as e:
+ if attempt < max_retries - 1:
+ logger.warning(
+ f"Dataset loading failed (attempt {attempt + 1}/{max_retries}). Retrying..."
+ )
+ else:
+ logger.error("Failed to load dataset after multiple attempts.")
+ raise e
+
+ def preprocess_audio(example):
+ audio = example["audio"]["array"]
+ resampled_audio = torchaudio.transforms.Resample(
+ example["audio"]["sampling_rate"], 16000
+ )(torch.tensor(audio))
+ return {"audio": resampled_audio.flatten()}
+
+ dataset = dataset.map(preprocess_audio, remove_columns=dataset.column_names)
+ dataset.set_format(type="torch", columns=["audio"])
+
+ dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
+
+ latent_dim = 100
+ output_dim = 16000
+
+ generator = Generator(latent_dim, output_dim).to(device)
+ discriminator = Discriminator(output_dim).to(device)
+
+ num_epochs = 50
+ generator, discriminator = train_gan(
+ generator, discriminator, dataloader, num_epochs, latent_dim, device
+ )
+
+ torch.save(generator.state_dict(), "generator.pth")
+ torch.save(discriminator.state_dict(), "discriminator.pth")
+ logger.info("Models saved successfully")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/training/gan/new_gan.py b/training/gan/new_gan.py
new file mode 100644
index 00000000..6bfe0f64
--- /dev/null
+++ b/training/gan/new_gan.py
@@ -0,0 +1,341 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import List, Tuple
+
+
+class ResidualBlock(nn.Module):
+ """Residual block with dilated convolutions."""
+
+ def __init__(self, channels: int, dilation: int = 1):
+ super().__init__()
+ self.conv1 = nn.Conv1d(
+ channels, channels, kernel_size=3, dilation=dilation, padding="same"
+ )
+ self.conv2 = nn.Conv1d(
+ channels, channels, kernel_size=3, dilation=dilation, padding="same"
+ )
+ self.norm1 = nn.GroupNorm(
+ num_groups=channels // 4, num_channels=channels
+ )
+ self.norm2 = nn.GroupNorm(
+ num_groups=channels // 4, num_channels=channels
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ residual = x
+ x = F.leaky_relu(self.norm1(self.conv1(x)), 0.1)
+ x = F.leaky_relu(self.norm2(self.conv2(x)), 0.1)
+ return x + residual
+
+
+class Generator(nn.Module):
+ """Dynamic generator for speech synthesis."""
+
+ def __init__(
+ self,
+ input_channels: int,
+ base_channels: int = 512,
+ upsample_rates: List[int] = [8, 8, 2, 2],
+ ):
+ super().__init__()
+ self.input_conv = nn.Conv1d(
+ input_channels, base_channels, kernel_size=7, padding=3
+ )
+ self.norm = nn.GroupNorm(
+ num_groups=base_channels // 4, num_channels=base_channels
+ )
+
+ self.upsample_layers = nn.ModuleList()
+ current_channels = base_channels
+ for rate in upsample_rates:
+ out_channels = current_channels // 2
+ self.upsample_layers.append(
+ nn.ConvTranspose1d(
+ current_channels,
+ out_channels,
+ kernel_size=rate * 2,
+ stride=rate,
+ padding=rate // 2,
+ )
+ )
+ current_channels = out_channels
+
+ self.resblocks = nn.ModuleList(
+ [ResidualBlock(current_channels, dilation=3**i) for i in range(3)]
+ )
+
+ self.output_conv = nn.Conv1d(
+ current_channels, 1, kernel_size=7, padding=3
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = F.leaky_relu(self.norm(self.input_conv(x)), 0.1)
+
+ for upsample in self.upsample_layers:
+ x = F.leaky_relu(upsample(x), 0.1)
+
+ for resblock in self.resblocks:
+ x = resblock(x)
+
+ return torch.tanh(self.output_conv(x))
+
+
+class PeriodDiscriminator(nn.Module):
+ """Period-based discriminator."""
+
+ def __init__(self, period: int, base_channels: int = 32):
+ super().__init__()
+ self.period = period
+ self.layers = nn.ModuleList(
+ [
+ nn.Conv2d(1, base_channels, (5, 1), (3, 1), padding=(2, 0)),
+ nn.Conv2d(
+ base_channels,
+ base_channels * 2,
+ (5, 1),
+ (3, 1),
+ padding=(2, 0),
+ ),
+ nn.Conv2d(
+ base_channels * 2,
+ base_channels * 4,
+ (5, 1),
+ (3, 1),
+ padding=(2, 0),
+ ),
+ nn.Conv2d(
+ base_channels * 4,
+ base_channels * 4,
+ (5, 1),
+ 1,
+ padding=(2, 0),
+ ),
+ nn.Conv2d(base_channels * 4, 1, (3, 1), 1, padding=(1, 0)),
+ ]
+ )
+
+ def forward(
+ self, x: torch.Tensor
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ batch_size, _, time = x.shape
+ if time % self.period != 0:
+ pad_size = self.period - (time % self.period)
+ x = F.pad(x, (0, pad_size))
+ x = x.view(batch_size, 1, -1, self.period)
+
+ features = []
+ for layer in self.layers[:-1]:
+ x = F.leaky_relu(layer(x), 0.1)
+ features.append(x)
+ x = self.layers[-1](x)
+ features.append(x)
+ return x.view(batch_size, -1, 1), features
+
+
+class MultiPeriodDiscriminator(nn.Module):
+ """Multi-period discriminator."""
+
+ def __init__(self, periods: List[int] = [2, 3, 5, 7, 11]):
+ super().__init__()
+ self.discriminators = nn.ModuleList(
+ [PeriodDiscriminator(period) for period in periods]
+ )
+
+ def forward(
+ self, x: torch.Tensor
+ ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]:
+ outputs, features = [], []
+ for disc in self.discriminators:
+ output, feature = disc(x)
+ outputs.append(output)
+ features.append(feature)
+ return outputs, features
+
+
+class ScaleDiscriminator(nn.Module):
+ """Scale discriminator with multiple resolutions."""
+
+ def __init__(self, base_channels: int = 16):
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [
+ nn.Conv1d(1, base_channels, 15, stride=1, padding=7),
+ nn.Conv1d(
+ base_channels,
+ base_channels * 2,
+ 41,
+ stride=4,
+ padding=20,
+ groups=base_channels,
+ ),
+ nn.Conv1d(
+ base_channels * 2,
+ base_channels * 4,
+ 41,
+ stride=4,
+ padding=20,
+ groups=base_channels * 2,
+ ),
+ nn.Conv1d(
+ base_channels * 4,
+ base_channels * 8,
+ 41,
+ stride=4,
+ padding=20,
+ groups=base_channels * 4,
+ ),
+ nn.Conv1d(
+ base_channels * 8,
+ base_channels * 16,
+ 41,
+ stride=4,
+ padding=20,
+ groups=base_channels * 8,
+ ),
+ nn.Conv1d(
+ base_channels * 16,
+ base_channels * 32,
+ 5,
+ stride=1,
+ padding=2,
+ ),
+ nn.Conv1d(base_channels * 32, 1, 3, stride=1, padding=1),
+ ]
+ )
+
+ def forward(
+ self, x: torch.Tensor
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ features = []
+ for layer in self.layers[:-1]:
+ x = F.leaky_relu(layer(x), 0.1)
+ features.append(x)
+ x = self.layers[-1](x)
+ features.append(x)
+ return x, features
+
+
+class MultiScaleDiscriminator(nn.Module):
+ """Multi-scale discriminator."""
+
+ def __init__(self, num_scales: int = 3):
+ super().__init__()
+ self.discriminators = nn.ModuleList(
+ [ScaleDiscriminator() for _ in range(num_scales)]
+ )
+ self.downsample = nn.AvgPool1d(
+ kernel_size=2, stride=2, padding=1, count_include_pad=False
+ )
+
+ def forward(
+ self, x: torch.Tensor
+ ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]:
+ outputs, features = [], []
+ for disc in self.discriminators:
+ output, feature = disc(x)
+ outputs.append(output)
+ features.append(feature)
+ x = self.downsample(x)
+ return outputs, features
+
+
+class SpeechSynthesisGAN(nn.Module):
+ """State-of-the-art GAN for speech synthesis."""
+
+ def __init__(self, input_channels: int):
+ super().__init__()
+ self.generator = Generator(input_channels)
+ self.mpd = MultiPeriodDiscriminator()
+ self.msd = MultiScaleDiscriminator()
+
+ def forward(
+ self, mel_spectrogram: torch.Tensor
+ ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
+ """
+ Forward pass of the SpeechSynthesisGAN.
+
+ Args:
+ mel_spectrogram (torch.Tensor): Input mel spectrogram of shape (batch_size, input_channels, time_steps)
+
+ Returns:
+ Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
+ Generated waveform, MPD outputs, and MSD outputs
+ """
+ waveform = self.generator(mel_spectrogram)
+ mpd_outputs, _ = self.mpd(waveform)
+ msd_outputs, _ = self.msd(waveform)
+ return waveform, mpd_outputs, msd_outputs
+
+
+def feature_loss(
+ real_features: List[List[torch.Tensor]],
+ fake_features: List[List[torch.Tensor]],
+) -> torch.Tensor:
+ """
+ Compute the feature matching loss between real and fake features.
+
+ Args:
+ real_features (List[List[torch.Tensor]]): Features from the discriminator for real audio
+ fake_features (List[List[torch.Tensor]]): Features from the discriminator for generated audio
+
+ Returns:
+ torch.Tensor: Feature matching loss
+ """
+ loss = 0
+ for real_feats, fake_feats in zip(real_features, fake_features):
+ for real_feat, fake_feat in zip(real_feats, fake_feats):
+ loss += F.l1_loss(fake_feat, real_feat)
+ return loss
+
+
+def discriminator_loss(
+ real_outputs: List[torch.Tensor], fake_outputs: List[torch.Tensor]
+) -> torch.Tensor:
+ """
+ Compute the discriminator loss.
+
+ Args:
+ real_outputs (List[torch.Tensor]): Discriminator outputs for real audio
+ fake_outputs (List[torch.Tensor]): Discriminator outputs for generated audio
+
+ Returns:
+ torch.Tensor: Discriminator loss
+ """
+ loss = 0
+ for real_output, fake_output in zip(real_outputs, fake_outputs):
+ r_loss = torch.mean((1 - real_output) ** 2)
+ f_loss = torch.mean(fake_output**2)
+ loss += r_loss + f_loss
+ return loss
+
+
+def generator_loss(fake_outputs: List[torch.Tensor]) -> torch.Tensor:
+ """
+ Compute the generator loss.
+
+ Args:
+ fake_outputs (List[torch.Tensor]): Discriminator outputs for generated audio
+
+ Returns:
+ torch.Tensor: Generator loss
+ """
+ loss = 0
+ for fake_output in fake_outputs:
+ loss += torch.mean((1 - fake_output) ** 2)
+ return loss
+
+
+# Example usage
+if __name__ == "__main__":
+ input_channels = 80
+ batch_size = 4
+ time_steps = 128
+
+ model = SpeechSynthesisGAN(input_channels)
+ mel_input = torch.randn(batch_size, input_channels, time_steps)
+
+ waveform, mpd_outputs, msd_outputs = model(mel_input)
+ print(f"Generated waveform shape: {waveform.shape}")
+ print(f"Number of MPD outputs: {len(mpd_outputs)}")
+ print(f"Number of MSD outputs: {len(msd_outputs)}")
diff --git a/training/vis/model.py b/training/vis/model.py
new file mode 100644
index 00000000..9d31f881
--- /dev/null
+++ b/training/vis/model.py
@@ -0,0 +1,110 @@
+import torch
+import torch.nn as nn
+import plotly.graph_objects as go
+import numpy as np
+
+
+# Helper function to convert Cartesian coordinates to spherical coordinates
+def spherical_coordinates(num_points, radius):
+ points = []
+ phi = np.pi * (3.0 - np.sqrt(5.0)) # golden angle in radians
+
+ for i in range(num_points):
+ y = 1 - (i / float(num_points - 1)) * 2 # y goes from 1 to -1
+ radius_y = np.sqrt(1 - y * y) # radius at y
+ theta = phi * i # golden angle increment
+
+ x = np.cos(theta) * radius_y * radius
+ z = np.sin(theta) * radius_y * radius
+ points.append([x, y * radius, z])
+
+ return np.array(points)
+
+
+# Dynamic extraction of model layers and parameters
+def extract_model_layers(model):
+ layers = []
+ params_per_layer = []
+
+ def recursive_layer_extraction(layer, layer_name=""):
+ if isinstance(layer, nn.Module):
+ num_params = sum(
+ p.numel() for p in layer.parameters() if p.requires_grad
+ )
+ if num_params > 0:
+ layers.append(
+ {
+ "name": layer_name,
+ "type": layer.__class__.__name__,
+ "num_params": num_params,
+ }
+ )
+ params_per_layer.append(num_params)
+ # Traverse through children layers
+ for name, child in layer.named_children():
+ layer_full_name = f"{layer_name}.{name}" if layer_name else name
+ recursive_layer_extraction(child, layer_full_name)
+
+ recursive_layer_extraction(model)
+ return layers, params_per_layer
+
+
+# Function to create a dynamic 3D visualization of the model
+def visualize_model_as_ball(model, max_points_per_layer=100):
+ layers, params_per_layer = extract_model_layers(model)
+
+ # Visualize each layer as a concentric sphere with points
+ fig = go.Figure()
+ max_radius = 10 # Maximum radius for the outer layer
+ radius_step = max_radius / len(layers)
+
+ # Create spheres and add points to each sphere
+ for i, layer in enumerate(layers):
+ radius = radius_step * (i + 1)
+ num_params = min(
+ params_per_layer[i], max_points_per_layer
+ ) # Limit points per layer for clarity
+ points = spherical_coordinates(num_params, radius)
+
+ fig.add_trace(
+ go.Scatter3d(
+ x=points[:, 0],
+ y=points[:, 1],
+ z=points[:, 2],
+ mode="markers",
+ marker=dict(
+ size=5,
+ color=np.linspace(0, 1, num_params),
+ colorscale="Viridis",
+ opacity=0.8,
+ ),
+ name=f'Layer {i + 1}: {layer["type"]}',
+ hovertext=[
+ f'{layer["name"]}, Param {j + 1}' for j in range(num_params)
+ ],
+ )
+ )
+
+ # Configure layout
+ fig.update_layout(
+ scene=dict(
+ xaxis=dict(title="X", showgrid=False, zeroline=False),
+ yaxis=dict(title="Y", showgrid=False, zeroline=False),
+ zaxis=dict(title="Z", showgrid=False, zeroline=False),
+ ),
+ title="Dynamic Model Visualization as Ball",
+ showlegend=True,
+ )
+
+ fig.show()
+
+
+# Example usage with a pretrained model
+if __name__ == "__main__":
+ # Load a pretrained model (e.g., ResNet18 from torchvision)
+ model = torch.hub.load(
+ "pytorch/vision:v0.10.0", "resnet18", pretrained=True
+ )
+
+ # Visualize the model
+ visualize_model_as_ball(model)
diff --git a/training/vis/t2.py b/training/vis/t2.py
new file mode 100644
index 00000000..62e23da9
--- /dev/null
+++ b/training/vis/t2.py
@@ -0,0 +1,175 @@
+import plotly.graph_objects as go
+import numpy as np
+from transformers import BertModel, BertTokenizer
+
+
+# Helper function to convert Cartesian coordinates to spherical coordinates
+def spherical_coordinates(num_points, radius):
+ points = []
+ phi = np.pi * (3.0 - np.sqrt(5.0)) # golden angle in radians
+
+ for i in range(num_points):
+ y = 1 - (i / float(num_points - 1)) * 2 # y goes from 1 to -1
+ radius_y = np.sqrt(1 - y * y) # radius at y
+ theta = phi * i # golden angle increment
+
+ x = np.cos(theta) * radius_y * radius
+ z = np.sin(theta) * radius_y * radius
+ points.append([x, y * radius, z])
+
+ return np.array(points)
+
+
+# Function to extract transformer layers and parameters (specific for BERT here)
+def extract_transformer_layers(model):
+ layers = []
+ params_per_layer = []
+
+ for name, param in model.named_parameters():
+ num_params = param.numel()
+ if num_params > 0:
+ layers.append(
+ {"name": name, "num_params": num_params, "shape": param.shape}
+ )
+ params_per_layer.append(num_params)
+
+ return layers, params_per_layer
+
+
+# Function to visualize the transformer model components as a 3D ball structure
+def visualize_transformer_as_ball(model, max_points_per_layer=100):
+ layers, params_per_layer = extract_transformer_layers(model)
+
+ # Visualize each layer as a concentric sphere with points
+ fig = go.Figure()
+ max_radius = 10 # Maximum radius for the outer layer
+ radius_step = max_radius / len(layers)
+
+ # Create spheres and add points to each sphere
+ for i, layer in enumerate(layers):
+ radius = radius_step * (i + 1)
+ num_params = min(
+ params_per_layer[i], max_points_per_layer
+ ) # Limit points per layer for clarity
+ points = spherical_coordinates(num_params, radius)
+
+ fig.add_trace(
+ go.Scatter3d(
+ x=points[:, 0],
+ y=points[:, 1],
+ z=points[:, 2],
+ mode="markers",
+ marker=dict(
+ size=5,
+ color=np.linspace(0, 1, num_params),
+ colorscale="Viridis",
+ opacity=0.8,
+ ),
+ name=f'Layer {i + 1}: {layer["name"]}',
+ hovertext=[
+ f'{layer["name"]}, Shape: {layer["shape"]}'
+ for j in range(num_params)
+ ],
+ )
+ )
+
+ # Configure layout
+ fig.update_layout(
+ scene=dict(
+ xaxis=dict(title="X", showgrid=False, zeroline=False),
+ yaxis=dict(title="Y", showgrid=False, zeroline=False),
+ zaxis=dict(title="Z", showgrid=False, zeroline=False),
+ ),
+ title="Transformer Model Visualization as Ball",
+ showlegend=True,
+ )
+
+ fig.show()
+
+
+# Visualizing Attention Weights
+def visualize_attention_weights(
+ attention_weights, sequence_length, attention_heads
+):
+ fig = go.Figure()
+
+ # For each attention head, draw the attention matrix as connections between tokens
+ for head in range(attention_heads):
+ attention = (
+ attention_weights[0, head].detach().numpy()
+ ) # Get attention weights for this head
+ points = spherical_coordinates(
+ sequence_length, radius=5 + head
+ ) # Slightly different radii for heads
+
+ # Add nodes (tokens)
+ fig.add_trace(
+ go.Scatter3d(
+ x=points[:, 0],
+ y=points[:, 1],
+ z=points[:, 2],
+ mode="markers",
+ marker=dict(
+ size=8,
+ color=np.linspace(0, 1, sequence_length),
+ colorscale="Plasma",
+ opacity=0.8,
+ ),
+ name=f"Attention Head {head + 1}",
+ hovertext=[f"Token {i + 1}" for i in range(sequence_length)],
+ )
+ )
+
+ # Add edges (attention weights between tokens)
+ for i in range(sequence_length):
+ for j in range(sequence_length):
+ if (
+ i != j and attention[i, j] > 0.1
+ ): # Only show significant attention weights
+ fig.add_trace(
+ go.Scatter3d(
+ x=[points[i, 0], points[j, 0]],
+ y=[points[i, 1], points[j, 1]],
+ z=[points[i, 2], points[j, 2]],
+ mode="lines",
+ line=dict(color="rgba(0, 0, 255, 0.2)", width=2),
+ hoverinfo="none",
+ showlegend=False,
+ )
+ )
+
+ fig.update_layout(
+ scene=dict(
+ xaxis=dict(title="X", showgrid=False, zeroline=False),
+ yaxis=dict(title="Y", showgrid=False, zeroline=False),
+ zaxis=dict(title="Z", showgrid=False, zeroline=False),
+ ),
+ title="Multi-Head Attention Weights",
+ showlegend=True,
+ )
+
+ fig.show()
+
+
+# Example usage with a pretrained BERT model
+if __name__ == "__main__":
+ # Load a pretrained transformer model (e.g., BERT) from Huggingface
+ model = BertModel.from_pretrained("bert-base-uncased")
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
+
+ # Example sentence
+ sentence = "Transformers are powerful models for NLP tasks."
+ inputs = tokenizer(sentence, return_tensors="pt")
+ outputs = model(**inputs, output_attentions=True)
+
+ # Visualize the model structure
+ visualize_transformer_as_ball(model)
+
+ # Visualize attention weights (BERT has 12 attention heads by default)
+ attention_weights = outputs.attentions[
+ -1
+ ] # Get the attention weights from the last layer
+ sequence_length = inputs["input_ids"].shape[1]
+ visualize_attention_weights(
+ attention_weights, sequence_length, attention_heads=12
+ )
diff --git a/training/vis/t3.py b/training/vis/t3.py
new file mode 100644
index 00000000..2774feed
--- /dev/null
+++ b/training/vis/t3.py
@@ -0,0 +1,103 @@
+import plotly.graph_objects as go
+from torchvision import models
+import numpy as np
+
+
+def visualize_model_3d(model):
+ def get_layer_info(module, depth=0, path=""):
+ layers = []
+ for name, child in module.named_children():
+ child_path = f"{path}/{name}".lstrip("/")
+ child_info = {
+ "name": child_path,
+ "type": child.__class__.__name__,
+ "params": sum(
+ p.numel() for p in child.parameters() if p.requires_grad
+ ),
+ "depth": depth,
+ }
+ layers.append(child_info)
+ layers.extend(get_layer_info(child, depth + 1, child_path))
+ return layers
+
+ layers = get_layer_info(model)
+
+ fig = go.Figure()
+
+ max_depth = max(layer["depth"] for layer in layers)
+ max_params = max(layer["params"] for layer in layers)
+
+ # Calculate positions on a sphere
+ phi = np.linspace(0, np.pi, len(layers))
+ theta = np.linspace(0, 2 * np.pi, len(layers))
+
+ # Create nodes
+ for i, layer in enumerate(layers):
+ r = (layer["depth"] + 1) / (max_depth + 1) # Radius based on depth
+ x = r * np.sin(phi[i]) * np.cos(theta[i])
+ y = r * np.sin(phi[i]) * np.sin(theta[i])
+ z = r * np.cos(phi[i])
+
+ size = (layer["params"] / max_params * 20 + 5) if max_params > 0 else 5
+
+ fig.add_trace(
+ go.Scatter3d(
+ x=[x],
+ y=[y],
+ z=[z],
+ mode="markers",
+ marker=dict(
+ size=size,
+ color=layer["depth"],
+ colorscale="Viridis",
+ opacity=0.8,
+ ),
+ text=f"{layer['name']}
{layer['type']}
Params: {layer['params']}",
+ hoverinfo="text",
+ )
+ )
+
+ # Create edges
+ for i in range(1, len(layers)):
+ prev_layer = layers[i - 1]
+ curr_layer = layers[i]
+
+ r_prev = (prev_layer["depth"] + 1) / (max_depth + 1)
+ r_curr = (curr_layer["depth"] + 1) / (max_depth + 1)
+
+ x_prev = r_prev * np.sin(phi[i - 1]) * np.cos(theta[i - 1])
+ y_prev = r_prev * np.sin(phi[i - 1]) * np.sin(theta[i - 1])
+ z_prev = r_prev * np.cos(phi[i - 1])
+
+ x_curr = r_curr * np.sin(phi[i]) * np.cos(theta[i])
+ y_curr = r_curr * np.sin(phi[i]) * np.sin(theta[i])
+ z_curr = r_curr * np.cos(phi[i])
+
+ fig.add_trace(
+ go.Scatter3d(
+ x=[x_prev, x_curr],
+ y=[y_prev, y_curr],
+ z=[z_prev, z_curr],
+ mode="lines",
+ line=dict(color="rgba(100,100,100,0.5)", width=2),
+ hoverinfo="none",
+ )
+ )
+
+ fig.update_layout(
+ title="Spherical 3D Model Architecture Visualization",
+ scene=dict(
+ xaxis_title="X", yaxis_title="Y", zaxis_title="Z", aspectmode="data"
+ ),
+ width=900,
+ height=700,
+ margin=dict(r=0, l=0, b=0, t=40),
+ )
+
+ return fig
+
+
+# Example usage
+model = models.vgg16(pretrained=True)
+fig = visualize_model_3d(model)
+fig.show()
diff --git a/training/vis/t4.py b/training/vis/t4.py
new file mode 100644
index 00000000..8fafe18f
--- /dev/null
+++ b/training/vis/t4.py
@@ -0,0 +1,232 @@
+import plotly.graph_objects as go
+import numpy as np
+
+
+def visualize_parameters(
+ model, dim="3d", color_by="mean", group_by=None, animate=False
+):
+ params_info = []
+
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ layer_name = name.split(".")[0] if group_by == "layer" else "All"
+ params_info.append(
+ {
+ "name": name,
+ "layer": layer_name,
+ "shape": param.shape,
+ "numel": param.numel(),
+ "mean": param.data.mean().item(),
+ "std": param.data.std().item(),
+ "grad_mean": (
+ param.grad.mean().item()
+ if param.grad is not None
+ else 0
+ ),
+ "grad_std": (
+ param.grad.std().item() if param.grad is not None else 0
+ ),
+ "type": param.__class__.__name__,
+ }
+ )
+
+ params_info.sort(key=lambda x: x["numel"], reverse=True)
+
+ if dim == "3d":
+ fig = create_3d_plot(params_info, color_by)
+ else:
+ fig = create_2d_plot(params_info, color_by)
+
+ if group_by == "layer":
+ fig = group_by_layer(fig, params_info)
+
+ add_model_summary(fig, model, params_info)
+
+ if animate:
+ fig = add_animation(fig, model)
+
+ return fig
+
+
+def create_3d_plot(params_info, color_by):
+ fig = go.Figure()
+
+ n_params = len(params_info)
+ grid_size = int(np.ceil(np.cbrt(n_params)))
+ x, y, z = np.mgrid[0:grid_size, 0:grid_size, 0:grid_size]
+ x, y, z = (
+ x.flatten()[:n_params],
+ y.flatten()[:n_params],
+ z.flatten()[:n_params],
+ )
+
+ max_numel = max(p["numel"] for p in params_info)
+
+ marker_shapes = {"Weight": "circle", "Bias": "square"}
+
+ for i, param in enumerate(params_info):
+ size = np.log1p(param["numel"]) / np.log1p(max_numel) * 20 + 5
+ color = param[color_by]
+
+ fig.add_trace(
+ go.Scatter3d(
+ x=[x[i]],
+ y=[y[i]],
+ z=[z[i]],
+ mode="markers",
+ marker=dict(
+ size=size,
+ color=color,
+ colorscale="Viridis",
+ opacity=0.8,
+ symbol=marker_shapes.get(param["type"], "circle"),
+ colorbar=dict(title=color_by.capitalize()),
+ ),
+ text=create_hover_text(param),
+ hoverinfo="text",
+ )
+ )
+
+ fig.update_layout(
+ title="3D Parameter Visualization",
+ scene=dict(
+ xaxis_title="X", yaxis_title="Y", zaxis_title="Z", aspectmode="cube"
+ ),
+ width=900,
+ height=700,
+ margin=dict(r=0, l=0, b=0, t=40),
+ )
+
+ return fig
+
+
+def create_2d_plot(params_info, color_by):
+ fig = go.Figure()
+
+ n_params = len(params_info)
+ grid_size = int(np.ceil(np.sqrt(n_params)))
+ x, y = np.mgrid[0:grid_size, 0:grid_size]
+ x, y = x.flatten()[:n_params], y.flatten()[:n_params]
+
+ max_numel = max(p["numel"] for p in params_info)
+
+ marker_shapes = {"Weight": "circle", "Bias": "square"}
+
+ for i, param in enumerate(params_info):
+ size = np.log1p(param["numel"]) / np.log1p(max_numel) * 20 + 5
+ color = param[color_by]
+
+ fig.add_trace(
+ go.Scatter(
+ x=[x[i]],
+ y=[y[i]],
+ mode="markers",
+ marker=dict(
+ size=size,
+ color=color,
+ colorscale="Viridis",
+ opacity=0.8,
+ symbol=marker_shapes.get(param["type"], "circle"),
+ colorbar=dict(title=color_by.capitalize()),
+ ),
+ text=create_hover_text(param),
+ hoverinfo="text",
+ )
+ )
+
+ fig.update_layout(
+ title="2D Parameter Visualization",
+ xaxis_title="X",
+ yaxis_title="Y",
+ width=900,
+ height=700,
+ margin=dict(r=0, l=0, b=0, t=40),
+ )
+
+ return fig
+
+
+def create_hover_text(param):
+ return (
+ f"Name: {param['name']}
"
+ f"Shape: {param['shape']}
"
+ f"Elements: {param['numel']}
"
+ f"Mean: {param['mean']:.4f}
"
+ f"Std: {param['std']:.4f}
"
+ f"Grad Mean: {param['grad_mean']:.4f}
"
+ f"Grad Std: {param['grad_std']:.4f}
"
+ f"Type: {param['type']}"
+ )
+
+
+def group_by_layer(fig, params_info):
+ layers = sorted(set(p["layer"] for p in params_info))
+ fig.update_layout(
+ updatemenus=[
+ {
+ "buttons": [
+ {
+ "label": layer,
+ "method": "update",
+ "args": [
+ {
+ "visible": [
+ p["layer"] == layer for p in params_info
+ ]
+ }
+ ],
+ }
+ for layer in layers
+ ]
+ + [
+ {
+ "label": "All",
+ "method": "update",
+ "args": [{"visible": [True] * len(params_info)}],
+ }
+ ],
+ "direction": "down",
+ "showactive": True,
+ }
+ ]
+ )
+ return fig
+
+
+def add_model_summary(fig, model, params_info):
+ total_params = sum(p["numel"] for p in params_info)
+ summary = f"Total parameters: {total_params:,}
"
+ summary += f"Layers: {len(set(p['layer'] for p in params_info))}
"
+ summary += f"Model type: {model.__class__.__name__}"
+
+ fig.add_annotation(
+ x=0.95,
+ y=0.95,
+ xref="paper",
+ yref="paper",
+ text=summary,
+ showarrow=False,
+ font=dict(size=12),
+ align="right",
+ bgcolor="rgba(255,255,255,0.8)",
+ bordercolor="black",
+ borderwidth=1,
+ )
+ return fig
+
+
+def add_animation(fig, model):
+ # This is a placeholder for animation functionality
+ # You would need to implement logic to update parameter values over time
+ return fig
+
+
+# Example usage
+if __name__ == "__main__":
+ from torchvision import models
+
+ model = models.vgg16(pretrained=True)
+ fig = visualize_parameters(
+ model, dim="3d", color_by="mean", group_by="layer"
+ )
+ fig.show()
diff --git a/training/vis/t5.py b/training/vis/t5.py
new file mode 100644
index 00000000..dbb2542a
--- /dev/null
+++ b/training/vis/t5.py
@@ -0,0 +1,318 @@
+import plotly.graph_objects as go
+import plotly.subplots as sp
+import numpy as np
+import plotly.io as pio
+
+
+def visualize_parameters(
+ model,
+ dim="3d",
+ color_by="mean",
+ group_by=None,
+ colorscale="Viridis",
+ compare_model=None,
+):
+ params_info = get_params_info(model)
+ if compare_model:
+ params_info_2 = get_params_info(compare_model)
+ return create_comparison_plot(
+ params_info, params_info_2, dim, color_by, colorscale
+ )
+
+ if dim == "3d":
+ fig = create_3d_plot(params_info, color_by, colorscale)
+ elif dim == "2d":
+ fig = create_2d_plot(params_info, color_by, colorscale)
+ else:
+ fig = create_heatmap(params_info, color_by)
+
+ if group_by == "layer":
+ fig = group_by_layer(fig, params_info)
+
+ add_model_summary(fig, model, params_info)
+ add_distribution_histogram(fig, params_info)
+
+ return fig
+
+
+def get_params_info(model):
+ params_info = []
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ layer_name = name.split(".")[0]
+ params_info.append(
+ {
+ "name": name,
+ "layer": layer_name,
+ "shape": param.shape,
+ "numel": param.numel(),
+ "mean": param.data.mean().item(),
+ "std": param.data.std().item(),
+ "grad_mean": (
+ param.grad.mean().item()
+ if param.grad is not None
+ else 0
+ ),
+ "grad_std": (
+ param.grad.std().item() if param.grad is not None else 0
+ ),
+ "type": param.__class__.__name__,
+ "data": param.data.flatten().tolist(),
+ }
+ )
+ params_info.sort(key=lambda x: x["numel"], reverse=True)
+ return params_info
+
+
+def create_3d_plot(params_info, color_by, colorscale):
+ fig = go.Figure()
+
+ n_params = len(params_info)
+ grid_size = int(np.ceil(np.cbrt(n_params)))
+ x, y, z = np.mgrid[0:grid_size, 0:grid_size, 0:grid_size]
+ x, y, z = (
+ x.flatten()[:n_params],
+ y.flatten()[:n_params],
+ z.flatten()[:n_params],
+ )
+
+ max_numel = max(p["numel"] for p in params_info)
+
+ marker_shapes = {"Weight": "circle", "Bias": "square"}
+
+ for i, param in enumerate(params_info):
+ size = np.log1p(param["numel"]) / np.log1p(max_numel) * 20 + 5
+ color = param[color_by]
+
+ fig.add_trace(
+ go.Scatter3d(
+ x=[x[i]],
+ y=[y[i]],
+ z=[z[i]],
+ mode="markers",
+ marker=dict(
+ size=size,
+ color=color,
+ colorscale=colorscale,
+ opacity=0.8,
+ symbol=marker_shapes.get(param["type"], "circle"),
+ colorbar=dict(title=color_by.capitalize()),
+ ),
+ text=create_hover_text(param),
+ hoverinfo="text",
+ )
+ )
+
+ fig.update_layout(
+ title="3D Parameter Visualization",
+ scene=dict(
+ xaxis_title="X", yaxis_title="Y", zaxis_title="Z", aspectmode="cube"
+ ),
+ width=900,
+ height=700,
+ margin=dict(r=0, l=0, b=0, t=40),
+ coloraxis=dict(colorscale=colorscale),
+ )
+
+ return fig
+
+
+def create_2d_plot(params_info, color_by, colorscale):
+ fig = go.Figure()
+
+ n_params = len(params_info)
+ grid_size = int(np.ceil(np.sqrt(n_params)))
+ x, y = np.mgrid[0:grid_size, 0:grid_size]
+ x, y = x.flatten()[:n_params], y.flatten()[:n_params]
+
+ max_numel = max(p["numel"] for p in params_info)
+
+ marker_shapes = {"Weight": "circle", "Bias": "square"}
+
+ for i, param in enumerate(params_info):
+ size = np.log1p(param["numel"]) / np.log1p(max_numel) * 20 + 5
+ color = param[color_by]
+
+ fig.add_trace(
+ go.Scatter(
+ x=[x[i]],
+ y=[y[i]],
+ mode="markers",
+ marker=dict(
+ size=size,
+ color=color,
+ colorscale=colorscale,
+ opacity=0.8,
+ symbol=marker_shapes.get(param["type"], "circle"),
+ colorbar=dict(title=color_by.capitalize()),
+ ),
+ text=create_hover_text(param),
+ hoverinfo="text",
+ )
+ )
+
+ fig.update_layout(
+ title="2D Parameter Visualization",
+ xaxis_title="X",
+ yaxis_title="Y",
+ width=900,
+ height=700,
+ margin=dict(r=0, l=0, b=0, t=40),
+ coloraxis=dict(colorscale=colorscale),
+ )
+
+ return fig
+
+
+def create_heatmap(params_info, color_by):
+ data = [param[color_by] for param in params_info]
+ names = [param["name"] for param in params_info]
+
+ fig = go.Figure(
+ data=go.Heatmap(
+ z=[data], y=["Parameters"], x=names, colorscale="Viridis"
+ )
+ )
+
+ fig.update_layout(
+ title="Parameter Heatmap",
+ xaxis_title="Parameter Name",
+ yaxis_title="",
+ width=1200,
+ height=400,
+ )
+
+ return fig
+
+
+def create_comparison_plot(
+ params_info_1, params_info_2, dim, color_by, colorscale
+):
+ fig = sp.make_subplots(
+ rows=1, cols=2, subplot_titles=("Model 1", "Model 2")
+ )
+
+ if dim == "3d":
+ fig1 = create_3d_plot(params_info_1, color_by, colorscale)
+ fig2 = create_3d_plot(params_info_2, color_by, colorscale)
+ elif dim == "2d":
+ fig1 = create_2d_plot(params_info_1, color_by, colorscale)
+ fig2 = create_2d_plot(params_info_2, color_by, colorscale)
+ else:
+ fig1 = create_heatmap(params_info_1, color_by)
+ fig2 = create_heatmap(params_info_2, color_by)
+
+ for trace in fig1.data:
+ fig.add_trace(trace, row=1, col=1)
+ for trace in fig2.data:
+ fig.add_trace(trace, row=1, col=2)
+
+ fig.update_layout(height=600, width=1200, title_text="Model Comparison")
+ return fig
+
+
+def add_distribution_histogram(fig, params_info):
+ data = [param["data"] for param in params_info]
+ flat_data = [item for sublist in data for item in sublist]
+
+ histogram = go.Figure(data=[go.Histogram(x=flat_data)])
+ histogram.update_layout(title_text="Parameter Distribution")
+
+ fig.add_trace(histogram.data[0])
+ return fig
+
+
+def add_gradient_flow(fig, model):
+ # Placeholder for gradient flow visualization
+ # This would require tracking gradients during backward pass
+ pass
+
+
+def create_hover_text(param):
+ return (
+ f"Name: {param['name']}
"
+ f"Shape: {param['shape']}
"
+ f"Elements: {param['numel']}
"
+ f"Mean: {param['mean']:.4f}
"
+ f"Std: {param['std']:.4f}
"
+ f"Grad Mean: {param['grad_mean']:.4f}
"
+ f"Grad Std: {param['grad_std']:.4f}
"
+ f"Type: {param['type']}"
+ )
+
+
+def group_by_layer(fig, params_info):
+ layers = sorted(set(p["layer"] for p in params_info))
+ fig.update_layout(
+ updatemenus=[
+ {
+ "buttons": [
+ {
+ "label": layer,
+ "method": "update",
+ "args": [
+ {
+ "visible": [
+ p["layer"] == layer for p in params_info
+ ]
+ }
+ ],
+ }
+ for layer in layers
+ ]
+ + [
+ {
+ "label": "All",
+ "method": "update",
+ "args": [{"visible": [True] * len(params_info)}],
+ }
+ ],
+ "direction": "down",
+ "showactive": True,
+ }
+ ]
+ )
+ return fig
+
+
+def add_model_summary(fig, model, params_info):
+ total_params = sum(p["numel"] for p in params_info)
+ summary = f"Total parameters: {total_params:,}
"
+ summary += f"Layers: {len(set(p['layer'] for p in params_info))}
"
+ summary += f"Model type: {model.__class__.__name__}"
+
+ fig.add_annotation(
+ x=0.95,
+ y=0.95,
+ xref="paper",
+ yref="paper",
+ text=summary,
+ showarrow=False,
+ font=dict(size=12),
+ align="right",
+ bgcolor="rgba(255,255,255,0.8)",
+ bordercolor="black",
+ borderwidth=1,
+ )
+ return fig
+
+
+def export_html(fig, filename="parameter_visualization.html"):
+ pio.write_html(fig, file=filename, auto_open=True)
+
+
+# Example usage
+if __name__ == "__main__":
+ from torchvision import models
+
+ model = models.resnet18(pretrained=True)
+ fig = visualize_parameters(
+ model, dim="3d", color_by="mean", group_by="layer", colorscale="Plasma"
+ )
+
+ # Uncomment to compare two models
+ # model2 = models.resnet34(pretrained=True)
+ # fig = visualize_parameters(model, dim='3d', color_by='mean', colorscale='Plasma', compare_model=model2)
+
+ fig.show()
+ export_html(fig)
diff --git a/training/vis/transformer.py b/training/vis/transformer.py
new file mode 100644
index 00000000..e1c5d32f
--- /dev/null
+++ b/training/vis/transformer.py
@@ -0,0 +1,148 @@
+import plotly.graph_objects as go
+import numpy as np
+from transformers import BertModel, BertTokenizer
+
+
+# Helper function to convert Cartesian coordinates to spherical coordinates
+def spherical_coordinates(num_points, radius):
+ points = []
+ phi = np.pi * (3.0 - np.sqrt(5.0)) # golden angle in radians
+
+ for i in range(num_points):
+ y = 1 - (i / float(num_points - 1)) * 2 # y goes from 1 to -1
+ radius_y = np.sqrt(1 - y * y) # radius at y
+ theta = phi * i # golden angle increment
+
+ x = np.cos(theta) * radius_y * radius
+ z = np.sin(theta) * radius_y * radius
+ points.append([x, y * radius, z])
+
+ return np.array(points)
+
+
+# Function to extract transformer layers and parameters (specific for BERT here)
+def extract_transformer_layers(model):
+ layers = []
+ params_per_layer = []
+
+ for name, param in model.named_parameters():
+ num_params = param.numel()
+ if num_params > 0:
+ layers.append(
+ {"name": name, "num_params": num_params, "shape": param.shape}
+ )
+ params_per_layer.append(num_params)
+
+ return layers, params_per_layer
+
+
+# Function to visualize the transformer model components as a 3D ball structure
+def visualize_transformer_as_ball(model, max_points_per_layer=100):
+ layers, params_per_layer = extract_transformer_layers(model)
+
+ # Visualize each layer as a concentric sphere with points
+ fig = go.Figure()
+ max_radius = 10 # Maximum radius for the outer layer
+ radius_step = max_radius / len(layers)
+
+ # Create spheres and add points to each sphere
+ for i, layer in enumerate(layers):
+ radius = radius_step * (i + 1)
+ num_params = min(
+ params_per_layer[i], max_points_per_layer
+ ) # Limit points per layer for clarity
+ points = spherical_coordinates(num_params, radius)
+
+ fig.add_trace(
+ go.Scatter3d(
+ x=points[:, 0],
+ y=points[:, 1],
+ z=points[:, 2],
+ mode="markers",
+ marker=dict(
+ size=5,
+ color=np.linspace(0, 1, num_params),
+ colorscale="Viridis",
+ opacity=0.8,
+ ),
+ name=f'Layer {i + 1}: {layer["name"]}',
+ hovertext=[
+ f'{layer["name"]}, Shape: {layer["shape"]}'
+ for j in range(num_params)
+ ],
+ )
+ )
+
+ # Configure layout
+ fig.update_layout(
+ scene=dict(
+ xaxis=dict(title="X", showgrid=False, zeroline=False),
+ yaxis=dict(title="Y", showgrid=False, zeroline=False),
+ zaxis=dict(title="Z", showgrid=False, zeroline=False),
+ ),
+ title="Transformer Model Visualization as Ball",
+ showlegend=True,
+ )
+
+ fig.show()
+
+
+# Visualizing Attention Weights
+def visualize_attention_weights(attention_heads, sequence_length):
+ fig = go.Figure()
+
+ # Create multiple scatter plots for each attention head
+ for head in range(attention_heads):
+ points = spherical_coordinates(
+ sequence_length, radius=5 + head
+ ) # Slightly different radii for heads
+
+ fig.add_trace(
+ go.Scatter3d(
+ x=points[:, 0],
+ y=points[:, 1],
+ z=points[:, 2],
+ mode="markers",
+ marker=dict(
+ size=6,
+ color=np.linspace(0, 1, sequence_length),
+ colorscale="Plasma",
+ opacity=0.8,
+ ),
+ name=f"Attention Head {head + 1}",
+ hovertext=[f"Token {i + 1}" for i in range(sequence_length)],
+ )
+ )
+
+ fig.update_layout(
+ scene=dict(
+ xaxis=dict(title="X", showgrid=False, zeroline=False),
+ yaxis=dict(title="Y", showgrid=False, zeroline=False),
+ zaxis=dict(title="Z", showgrid=False, zeroline=False),
+ ),
+ title="Multi-Head Attention Weights",
+ showlegend=True,
+ )
+
+ fig.show()
+
+
+# Example usage with a pretrained BERT model
+if __name__ == "__main__":
+ # Load a pretrained transformer model (e.g., BERT) from Huggingface
+ model = BertModel.from_pretrained("bert-base-uncased")
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
+
+ # Example sentence
+ sentence = "Transformers are powerful models for NLP tasks."
+ inputs = tokenizer(sentence, return_tensors="pt")
+ outputs = model(**inputs)
+
+ # Visualize the model structure
+ visualize_transformer_as_ball(model)
+
+ # Example: Visualize attention heads for a sequence length (BERT has 12 attention heads by default)
+ sequence_length = inputs["input_ids"].shape[1]
+ visualize_attention_weights(
+ attention_heads=12, sequence_length=sequence_length
+ )
diff --git a/training/yolo_alt/model.py b/training/yolo_alt/model.py
new file mode 100644
index 00000000..87a143e9
--- /dev/null
+++ b/training/yolo_alt/model.py
@@ -0,0 +1,172 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+
+
+class DepthwiseSeparableConv(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, stride):
+ super(DepthwiseSeparableConv, self).__init__()
+ self.depthwise = nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size,
+ stride,
+ padding=kernel_size // 2,
+ groups=in_channels,
+ bias=False,
+ )
+ self.pointwise = nn.Conv2d(
+ in_channels, out_channels, 1, 1, 0, bias=False
+ )
+ self.bn = nn.BatchNorm2d(out_channels)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ x = self.depthwise(x)
+ x = self.pointwise(x)
+ x = self.bn(x)
+ x = self.relu(x)
+ return x
+
+
+class SEBlock(nn.Module):
+ def __init__(self, channels, reduction=16):
+ super(SEBlock, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(
+ nn.Linear(channels, channels // reduction, bias=False),
+ nn.ReLU(inplace=True),
+ nn.Linear(channels // reduction, channels, bias=False),
+ nn.Sigmoid(),
+ )
+
+ def forward(self, x):
+ b, c, _, _ = x.size()
+ y = self.avg_pool(x).view(b, c)
+ y = self.fc(y).view(b, c, 1, 1)
+ return x * y.expand_as(x)
+
+
+class GhostModule(nn.Module):
+ def __init__(
+ self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True
+ ):
+ super(GhostModule, self).__init__()
+ self.oup = oup
+ init_channels = math.ceil(oup / ratio)
+ new_channels = init_channels * (ratio - 1)
+
+ self.primary_conv = nn.Sequential(
+ nn.Conv2d(
+ inp,
+ init_channels,
+ kernel_size,
+ stride,
+ kernel_size // 2,
+ bias=False,
+ ),
+ nn.BatchNorm2d(init_channels),
+ nn.ReLU(inplace=True) if relu else nn.Sequential(),
+ )
+
+ self.cheap_operation = nn.Sequential(
+ nn.Conv2d(
+ init_channels,
+ new_channels,
+ dw_size,
+ 1,
+ dw_size // 2,
+ groups=init_channels,
+ bias=False,
+ ),
+ nn.BatchNorm2d(new_channels),
+ nn.ReLU(inplace=True) if relu else nn.Sequential(),
+ )
+
+ def forward(self, x):
+ x1 = self.primary_conv(x)
+ x2 = self.cheap_operation(x1)
+ out = torch.cat([x1, x2], dim=1)
+ return out[:, : self.oup, :, :]
+
+
+class FastDetector(nn.Module):
+ def __init__(self, num_classes):
+ super(FastDetector, self).__init__()
+
+ # Lightweight backbone with Ghost modules
+ self.backbone = nn.Sequential(
+ GhostModule(3, 16, 3, stride=2),
+ GhostModule(16, 32, 3, stride=2),
+ GhostModule(32, 64, 3, stride=2),
+ GhostModule(64, 128, 3, stride=2),
+ )
+
+ # Feature Pyramid Network
+ self.fpn = nn.ModuleList(
+ [
+ DepthwiseSeparableConv(128, 128, 3, 1),
+ DepthwiseSeparableConv(128, 64, 3, 1),
+ DepthwiseSeparableConv(64, 32, 3, 1),
+ ]
+ )
+
+ # SE blocks for each FPN level
+ self.se_blocks = nn.ModuleList(
+ [
+ SEBlock(128),
+ SEBlock(128),
+ SEBlock(64),
+ SEBlock(32),
+ ]
+ )
+
+ # Detection heads
+ self.heads = nn.ModuleList(
+ [
+ self._make_head(128, num_classes),
+ self._make_head(128, num_classes),
+ self._make_head(64, num_classes),
+ self._make_head(32, num_classes),
+ ]
+ )
+
+ def _make_head(self, in_channels, num_classes):
+ return nn.Sequential(
+ DepthwiseSeparableConv(in_channels, in_channels, 3, 1),
+ nn.Conv2d(
+ in_channels, num_classes + 4, 1
+ ), # cls + x, y, w, h (anchor-free)
+ )
+
+ def forward(self, x):
+ features = []
+ for i, layer in enumerate(self.backbone):
+ x = layer(x)
+ features.append(x)
+
+ # FPN
+ for i in range(len(features) - 1, 0, -1):
+ features[i - 1] = features[i - 1] + F.interpolate(
+ self.fpn[i - 1](features[i]), size=features[i - 1].shape[2:]
+ )
+
+ # Apply SE blocks and get predictions
+ outputs = []
+ for feature, se_block, head in zip(
+ features, self.se_blocks, self.heads
+ ):
+ feature = se_block(feature)
+ outputs.append(head(feature).flatten(start_dim=2))
+
+ return outputs
+
+
+# Example usage
+num_classes = 80 # COCO dataset
+model = FastDetector(num_classes)
+input_tensor = torch.randn(1, 3, 416, 416)
+outputs = model(input_tensor)
+for i, output in enumerate(outputs):
+ print(f"Output {i} shape:", output.shape)