diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 5f3d61f3..afdbc1c9 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -131,233 +131,232 @@ nav: - Home: - Overview: "index.md" - Contributing: "contributing.md" - - Zeta: - - Overview: "zeta/index.md" - - zeta.nn: - - zeta.nn.biases: - - Xpos: "zeta/nn/biases/xpos.md" - - RelativePositionBias: "zeta/nn/biases/relative_bias.md" - - AlibiPositionalBias: "zeta/nn/biases/alibi.md" - - DynamicPositionBias: "zeta/nn/biases/dynamic.md" - - zeta.nn.embeddings: - - MultiWay: "zeta/nn/embeddings/multiway.md" - - RotaryEmbeddings: "zeta/nn/embeddings/rope.md" - - TruncatedRotaryEmbedding: "zeta/nn/embeddings/truncated_rope.md" - - PositionalEmbedding: "zeta/nn/embeddings/positional_embeddings.md" - - XPOS: "zeta/nn/embeddings/xpos.md" - - YarnEmbedding: "zeta/nn/embeddings/yarn.md" - - VisionEmbedding: "zeta/nn/embeddings/vis_emb.md" - - SinusoidalEmbeddings: "zeta/nn/embeddings/sinusoidal.md" - - PatchEmbeddings: "zeta/nn/embeddings/patch_embeddings.md" - - PositionInterpolationEmbeddings: "zeta/nn/embeddings/positional_interpolation.md" - - zeta.nn.modules: - - custom_mlp: "zeta/nn/modules/custom_mlp.md" - - mbconv: "zeta/nn/modules/mbconv.md" - - dynamicroutingblock: "zeta/nn/modules/dynamicroutingblock.md" - - clippedgeluactivation: "zeta/nn/modules/clippedgeluactivation.md" - - mambablock: "zeta/nn/modules/mambablock.md" - - vittransformerblock: "zeta/nn/modules/vittransformerblock.md" - - fuseddensegeludense: "zeta/nn/modules/fuseddensegeludense.md" - - pscan: "zeta/nn/modules/pscan.md" - - adaptive: "zeta/nn/modules/adaptive.md" - - filmconditioning: "zeta/nn/modules/filmconditioning.md" - - mmfusionffn: "zeta/nn/modules/mmfusionffn.md" - - quickgeluactivation: "zeta/nn/modules/quickgeluactivation.md" - - gatedresidualblock: "zeta/nn/modules/gatedresidualblock.md" - - highwaylayer: "zeta/nn/modules/highwaylayer.md" - - multimodalmambablock: "zeta/nn/modules/multimodalmambablock.md" - - rms_norm: "zeta/nn/modules/rms_norm.md" - - ssm: "zeta/nn/modules/ssm.md" - - dualpathblock: "zeta/nn/modules/dualpathblock.md" - - topngating: "zeta/nn/modules/topngating.md" - - mmlayernorm: "zeta/nn/modules/mmlayernorm.md" - - mm_adapter: "zeta/nn/modules/mm_adapter.md" - - laplaceactivation: "zeta/nn/modules/laplaceactivation.md" - - nfnstem: "zeta/nn/modules/nfnstem.md" - - laser: "zeta/nn/modules/laser.md" - - denseblock: "zeta/nn/modules/denseblock.md" - - depthwiseconv2d: "zeta/nn/modules/depthwiseconv2d.md" - - lora: "zeta/nn/modules/lora.md" - - vlayernorm: "zeta/nn/modules/vlayernorm.md" - - flexiconv: "zeta/nn/modules/flexiconv.md" - - pulsar: "zeta/nn/modules/pulsar.md" - - pool: "zeta/nn/modules/pool.md" - - time_up_sample: "zeta/nn/modules/time_up_sample.md" - - spatial_downsample: "zeta/nn/modules/spatial_downsample.md" - - parallel: "zeta/nn/modules/parallel.md" - - conv2dfeedforward: "zeta/nn/modules/conv2dfeedforward.md" - - video_autoencoder: "zeta/nn/modules/video_autoencoder.md" - - recursiveblock: "zeta/nn/modules/recursiveblock.md" - - relusquaredactivation: "zeta/nn/modules/relusquaredactivation.md" - - fastgeluactivation: "zeta/nn/modules/fastgeluactivation.md" - - token_learner: "zeta/nn/modules/token_learner.md" - - layernorm: "zeta/nn/modules/layernorm.md" - - averagemodelmerger: "zeta/nn/modules/averagemodelmerger.md" - - linearactivation: "zeta/nn/modules/linearactivation.md" - - stochdepth: "zeta/nn/modules/stochdepth.md" - - expert: "zeta/nn/modules/expert.md" - - siglip: "zeta/nn/modules/siglip.md" - - ether: "zeta/nn/modules/ether.md" - - newgeluactivation: "zeta/nn/modules/newgeluactivation.md" - - pytorchgelutanh: "zeta/nn/modules/pytorchgelutanh.md" - - multiscaleblock: "zeta/nn/modules/multiscaleblock.md" - - umambablock: "zeta/nn/modules/umambablock.md" - - film: "zeta/nn/modules/film.md" - - adaptive_conv: "zeta/nn/modules/adaptive_conv.md" - - fused_dropout_layernorm: "zeta/nn/modules/fused_dropout_layernorm.md" - - accurategeluactivation: "zeta/nn/modules/accurategeluactivation.md" - - exo: "zeta/nn/modules/exo.md" - - polymorphic_activation: "zeta/nn/modules/polymorphic_activation.md" - - fusedprojsoftmax: "zeta/nn/modules/fusedprojsoftmax.md" - - quantizedln: "zeta/nn/modules/quantizedln.md" - - postnorm: "zeta/nn/modules/postnorm.md" - - moerouter: "zeta/nn/modules/moerouter.md" - - geluactivation: "zeta/nn/modules/geluactivation.md" - - visionattention: "zeta/nn/modules/visionattention.md" - - fused_gelu_dense: "zeta/nn/modules/fused_gelu_dense.md" - - feedforward: "zeta/nn/modules/feedforward.md" - - wsconv2d: "zeta/nn/modules/wsconv2d.md" - - mlp: "zeta/nn/modules/mlp.md" - - slerpmodelmerger: "zeta/nn/modules/slerpmodelmerger.md" - - fuseddropoutlayernorm: "zeta/nn/modules/fuseddropoutlayernorm.md" - - tripleskipblock: "zeta/nn/modules/tripleskipblock.md" - - dm: "zeta/nn/modules/dm.md" - - feedbackblock: "zeta/nn/modules/feedbackblock.md" - - mixtureofexperts: "zeta/nn/modules/mixtureofexperts.md" - - mamba: "zeta/nn/modules/mamba.md" - - perceiverlayer: "zeta/nn/modules/perceiverlayer.md" - - mishactivation: "zeta/nn/modules/mishactivation.md" - - hebbian: "zeta/nn/modules/hebbian.md" - - simple_feedback: "zeta/nn/modules/simple_feedback.md" - - visual_expert: "zeta/nn/modules/visual_expert.md" - - stochasticskipblock: "zeta/nn/modules/stochasticskipblock.md" - - unet: "zeta/nn/modules/unet.md" - - zeta.nn.attention: - - FlashAttention: "zeta/nn/attention/flash_attention.md" - - MultiQueryAttention: "zeta/nn/attention/multiquery.md" - - MultiheadAttention: "zeta/nn/attention/multihead.md" - - FlashAttentionTwo: "zeta/nn/attention/flash2.md" - - BaseAttention: "zeta/nn/attention/base.md" - - LocalAttention: "zeta/nn/attention/local.md" - - LocalMHA: "zeta/nn/attention/localmha.md" - - MixtureOfAttention: "zeta/nn/attention/mixture_of_attention.md" - - MixtureOfAutoregressiveAttention: "zeta/nn/attention/mixture_of_attention_ar.md" - - SparseAttention: "zeta/nn/attention/sparse_attn.md" - - zeta.tokenizers: - - Language: - - LanguageTokenizerGPTX: "zeta/tokenizers/language_tokenizer.md" - - SentencePieceTokenizer: "zeta/tokenizers/sentencepiece.md" - - TokenMonster: "zeta/tokenizers/token_monster.md" - - MultiModal: - - MultiModalTokenizer: "zeta/tokenizers/multi_modal_tokenizer.md" - - zeta.utils: - - Misc: - - cast_tuple: "zeta/utils/cast_tuple.md" - - group_by_key_prefix: "zeta/utils/group_by_key_prefix.md" - - eval_decorator: "zeta/utils/eval_decorator.md" - - print_cuda_memory_usage: "zeta/utils/print_cuda_memory_usage.md" - - once: "zeta/utils/once.md" - - default: "zeta/utils/default.md" - - gumbel_noise: "zeta/utils/gumbel_noise.md" - - pad_at_dim: "zeta/utils/pad_at_dim.md" - - init_zero_: "zeta/utils/init_zero_.md" - - top_p: "zeta/utils/top_p.md" - - cast_if_src_dtype: "zeta/utils/cast_if_src_dtype.md" - - disable_warnings_and_logs: "zeta/utils/disable_warnings_and_logs.md" - - save_load_wrapper: "zeta/utils/save_load_wrapper.md" - - get_sinusoid_encoding_table: "zeta/utils/get_sinusoid_encoding_table.md" - - main: "zeta/utils/main.md" - - string_begins_with: "zeta/utils/string_begins_with.md" - - gif_to_tensor: "zeta/utils/gif_to_tensor.md" - - l2norm: "zeta/utils/l2norm.md" - - save_load: "zeta/utils/save_load.md" - - log: "zeta/utils/log.md" - - module_device: "zeta/utils/module_device.md" - - print_num_params: "zeta/utils/print_num_params.md" - - top_a: "zeta/utils/top_a.md" - - interpolate_pos_encoding_2d: "zeta/utils/interpolate_pos_encoding_2d.md" - - exists: "zeta/utils/exists.md" - - cosine_beta_schedule: "zeta/utils/cosine_beta_schedule.md" - - track_cuda_memory: "zeta/utils/track_cuda_memory.md" - - maybe: "zeta/utils/maybe.md" - - save_memory_snapshot: "zeta/utils/save_memory_snapshot.md" - - top_k: "zeta/utils/top_k.md" - - print_main: "zeta/utils/print_main.md" - - pick_and_pop: "zeta/utils/pick_and_pop.md" - - track_cuda_memory_usage: "zeta/utils/track_cuda_memory_usage.md" - - group_dict_by_key: "zeta/utils/group_dict_by_key.md" - - video_tensor_to_gift: "zeta/utils/video_tensor_to_gift.md" - - zeta.ops: - - Misc: - - img_compose_decompose: "zeta/ops/img_compose_decompose.md" - - img_transpose_2daxis: "zeta/ops/img_transpose_2daxis.md" - - img_transpose: "zeta/ops/img_transpose.md" - - img_order_of_axes: "zeta/ops/img_order_of_axes.md" - - mos: "zeta/ops/mos.md" - - merge_small_dims: "zeta/ops/merge_small_dims.md" - - multi_dim_cat: "zeta/ops/multi_dim_cat.md" - - img_compose_bw: "zeta/ops/img_compose_bw.md" - - squeeze_2d_new: "zeta/ops/squeeze_2d_new.md" - - temp_softmax: "zeta/ops/temp_softmax.md" - - gumbelmax: "zeta/ops/gumbelmax.md" - - _matrix_inverse_root_newton: "zeta/ops/_matrix_inverse_root_newton.md" - - compute_matrix_root_inverse_residuals: "zeta/ops/compute_matrix_root_inverse_residuals.md" - - matrix_root_diagonal: "zeta/ops/matrix_root_diagonal.md" - - sparse_softmax: "zeta/ops/sparse_softmax.md" - - reshape_audio_to_text: "zeta/ops/reshape_audio_to_text.md" - - local_softmax: "zeta/ops/local_softmax.md" - - softmaxes: "zeta/ops/softmaxes.md" - - _matrix_root_eigen: "zeta/ops/_matrix_root_eigen.md" - - main: "zeta/ops/main.md" - - norm_exp_softmax: "zeta/ops/norm_exp_softmax.md" - - multi_dim_split: "zeta/ops/multi_dim_split.md" - - img_width_to_height: "zeta/ops/img_width_to_height.md" - - fast_softmax: "zeta/ops/fast_softmax.md" - - standard_softmax: "zeta/ops/standard_softmax.md" - - unitwise_norm: "zeta/ops/unitwise_norm.md" - - reshape_video_to_text: "zeta/ops/reshape_video_to_text.md" - - img_decompose: "zeta/ops/img_decompose.md" - - unsqueeze_2d_new: "zeta/ops/unsqueeze_2d_new.md" - - reshape_img_to_text: "zeta/ops/reshape_img_to_text.md" - - channel_shuffle_new: "zeta/ops/channel_shuffle_new.md" - - matrix_inverse_root: "zeta/ops/matrix_inverse_root.md" - - sparsemax: "zeta/ops/sparsemax.md" - - gram_matrix_new: "zeta/ops/gram_matrix_new.md" - - logit_scaled_softmax: "zeta/ops/logit_scaled_softmax.md" - - selu_softmax: "zeta/ops/selu_softmax.md" - - reshape_text_to_img: "zeta/ops/reshape_text_to_img.md" - - zeta.optim: - - Optimizers: - - StableAdamWUnfused: "zeta/optims/adamw.md" - - GradientAscent: "zeta/optims/ga.md" - - DecoupledLionW: "zeta/training/optimizers/decoupled_lion.md" - - SophiaG: "zeta/training/optimizers/sophia.md" - - zeta.training: - - Training: - - fsdp: "zeta/training/fsdp.md" - - ParallelWrapper: "zeta/training/parallel_wrapper.md" - - train: "zeta/training/train.md" - - zeta.models: - - Language and MultiModal: - - vit: "zeta/models/vit.md" - - gpt4multimodal: "zeta/models/gpt4multimodal.md" - - maxvit: "zeta/models/maxvit.md" - - llama2: "zeta/models/llama2.md" - - gpt4: "zeta/models/gpt4.md" - - andromeda: "zeta/models/andromeda.md" - - basemodel: "zeta/models/basemodel.md" - - palme: "zeta/models/palme.md" - - megavit: "zeta/models/megavit.md" - - navit: "zeta/models/navit.md" - - zeta.structs: - - Structures: - - Decoder: "zeta/nn/architecture/decoder.md" - - Transformer: "zeta/nn/architecture/transformer.md" - - paralleltransformerblock: "paralleltransformerblock.md" - - zeta.quant: - - Quantization Algorithms: - - QUIK: "zeta/quant/quik.md" - - BitLinear: "zeta/quant/bitlinear.md" - - niva: "zeta/quant/niva.md" - - zeta.rl: - - DPO: "zeta/rl/dpo.md" \ No newline at end of file + - Overview: "zeta/index.md" + - zeta.nn: + - zeta.nn.biases: + - Xpos: "zeta/nn/biases/xpos.md" + - RelativePositionBias: "zeta/nn/biases/relative_bias.md" + - AlibiPositionalBias: "zeta/nn/biases/alibi.md" + - DynamicPositionBias: "zeta/nn/biases/dynamic.md" + - zeta.nn.embeddings: + - MultiWay: "zeta/nn/embeddings/multiway.md" + - RotaryEmbeddings: "zeta/nn/embeddings/rope.md" + - TruncatedRotaryEmbedding: "zeta/nn/embeddings/truncated_rope.md" + - PositionalEmbedding: "zeta/nn/embeddings/positional_embeddings.md" + - XPOS: "zeta/nn/embeddings/xpos.md" + - YarnEmbedding: "zeta/nn/embeddings/yarn.md" + - VisionEmbedding: "zeta/nn/embeddings/vis_emb.md" + - SinusoidalEmbeddings: "zeta/nn/embeddings/sinusoidal.md" + - PatchEmbeddings: "zeta/nn/embeddings/patch_embeddings.md" + - PositionInterpolationEmbeddings: "zeta/nn/embeddings/positional_interpolation.md" + - zeta.nn.modules: + - custom_mlp: "zeta/nn/modules/custom_mlp.md" + - mbconv: "zeta/nn/modules/mbconv.md" + - dynamicroutingblock: "zeta/nn/modules/dynamicroutingblock.md" + - clippedgeluactivation: "zeta/nn/modules/clippedgeluactivation.md" + - mambablock: "zeta/nn/modules/mambablock.md" + - vittransformerblock: "zeta/nn/modules/vittransformerblock.md" + - fuseddensegeludense: "zeta/nn/modules/fuseddensegeludense.md" + - pscan: "zeta/nn/modules/pscan.md" + - adaptive: "zeta/nn/modules/adaptive.md" + - filmconditioning: "zeta/nn/modules/filmconditioning.md" + - mmfusionffn: "zeta/nn/modules/mmfusionffn.md" + - quickgeluactivation: "zeta/nn/modules/quickgeluactivation.md" + - gatedresidualblock: "zeta/nn/modules/gatedresidualblock.md" + - highwaylayer: "zeta/nn/modules/highwaylayer.md" + - multimodalmambablock: "zeta/nn/modules/multimodalmambablock.md" + - rms_norm: "zeta/nn/modules/rms_norm.md" + - ssm: "zeta/nn/modules/ssm.md" + - dualpathblock: "zeta/nn/modules/dualpathblock.md" + - topngating: "zeta/nn/modules/topngating.md" + - mmlayernorm: "zeta/nn/modules/mmlayernorm.md" + - mm_adapter: "zeta/nn/modules/mm_adapter.md" + - laplaceactivation: "zeta/nn/modules/laplaceactivation.md" + - nfnstem: "zeta/nn/modules/nfnstem.md" + - laser: "zeta/nn/modules/laser.md" + - denseblock: "zeta/nn/modules/denseblock.md" + - depthwiseconv2d: "zeta/nn/modules/depthwiseconv2d.md" + - lora: "zeta/nn/modules/lora.md" + - vlayernorm: "zeta/nn/modules/vlayernorm.md" + - flexiconv: "zeta/nn/modules/flexiconv.md" + - pulsar: "zeta/nn/modules/pulsar.md" + - pool: "zeta/nn/modules/pool.md" + - time_up_sample: "zeta/nn/modules/time_up_sample.md" + - spatial_downsample: "zeta/nn/modules/spatial_downsample.md" + - parallel: "zeta/nn/modules/parallel.md" + - conv2dfeedforward: "zeta/nn/modules/conv2dfeedforward.md" + - video_autoencoder: "zeta/nn/modules/video_autoencoder.md" + - recursiveblock: "zeta/nn/modules/recursiveblock.md" + - relusquaredactivation: "zeta/nn/modules/relusquaredactivation.md" + - fastgeluactivation: "zeta/nn/modules/fastgeluactivation.md" + - token_learner: "zeta/nn/modules/token_learner.md" + - layernorm: "zeta/nn/modules/layernorm.md" + - averagemodelmerger: "zeta/nn/modules/averagemodelmerger.md" + - linearactivation: "zeta/nn/modules/linearactivation.md" + - stochdepth: "zeta/nn/modules/stochdepth.md" + - expert: "zeta/nn/modules/expert.md" + - siglip: "zeta/nn/modules/siglip.md" + - ether: "zeta/nn/modules/ether.md" + - newgeluactivation: "zeta/nn/modules/newgeluactivation.md" + - pytorchgelutanh: "zeta/nn/modules/pytorchgelutanh.md" + - multiscaleblock: "zeta/nn/modules/multiscaleblock.md" + - umambablock: "zeta/nn/modules/umambablock.md" + - film: "zeta/nn/modules/film.md" + - adaptive_conv: "zeta/nn/modules/adaptive_conv.md" + - fused_dropout_layernorm: "zeta/nn/modules/fused_dropout_layernorm.md" + - accurategeluactivation: "zeta/nn/modules/accurategeluactivation.md" + - exo: "zeta/nn/modules/exo.md" + - polymorphic_activation: "zeta/nn/modules/polymorphic_activation.md" + - fusedprojsoftmax: "zeta/nn/modules/fusedprojsoftmax.md" + - quantizedln: "zeta/nn/modules/quantizedln.md" + - postnorm: "zeta/nn/modules/postnorm.md" + - moerouter: "zeta/nn/modules/moerouter.md" + - geluactivation: "zeta/nn/modules/geluactivation.md" + - visionattention: "zeta/nn/modules/visionattention.md" + - fused_gelu_dense: "zeta/nn/modules/fused_gelu_dense.md" + - feedforward: "zeta/nn/modules/feedforward.md" + - wsconv2d: "zeta/nn/modules/wsconv2d.md" + - mlp: "zeta/nn/modules/mlp.md" + - slerpmodelmerger: "zeta/nn/modules/slerpmodelmerger.md" + - fuseddropoutlayernorm: "zeta/nn/modules/fuseddropoutlayernorm.md" + - tripleskipblock: "zeta/nn/modules/tripleskipblock.md" + - dm: "zeta/nn/modules/dm.md" + - feedbackblock: "zeta/nn/modules/feedbackblock.md" + - mixtureofexperts: "zeta/nn/modules/mixtureofexperts.md" + - mamba: "zeta/nn/modules/mamba.md" + - perceiverlayer: "zeta/nn/modules/perceiverlayer.md" + - mishactivation: "zeta/nn/modules/mishactivation.md" + - hebbian: "zeta/nn/modules/hebbian.md" + - simple_feedback: "zeta/nn/modules/simple_feedback.md" + - visual_expert: "zeta/nn/modules/visual_expert.md" + - stochasticskipblock: "zeta/nn/modules/stochasticskipblock.md" + - unet: "zeta/nn/modules/unet.md" + - zeta.nn.attention: + - FlashAttention: "zeta/nn/attention/flash_attention.md" + - MultiQueryAttention: "zeta/nn/attention/multiquery.md" + - MultiheadAttention: "zeta/nn/attention/multihead.md" + - FlashAttentionTwo: "zeta/nn/attention/flash2.md" + - BaseAttention: "zeta/nn/attention/base.md" + - LocalAttention: "zeta/nn/attention/local.md" + - LocalMHA: "zeta/nn/attention/localmha.md" + - MixtureOfAttention: "zeta/nn/attention/mixture_of_attention.md" + - MixtureOfAutoregressiveAttention: "zeta/nn/attention/mixture_of_attention_ar.md" + - SparseAttention: "zeta/nn/attention/sparse_attn.md" + - zeta.tokenizers: + - Language: + - LanguageTokenizerGPTX: "zeta/tokenizers/language_tokenizer.md" + - SentencePieceTokenizer: "zeta/tokenizers/sentencepiece.md" + - TokenMonster: "zeta/tokenizers/token_monster.md" + - MultiModal: + - MultiModalTokenizer: "zeta/tokenizers/multi_modal_tokenizer.md" + - zeta.utils: + - Misc: + - cast_tuple: "zeta/utils/cast_tuple.md" + - group_by_key_prefix: "zeta/utils/group_by_key_prefix.md" + - eval_decorator: "zeta/utils/eval_decorator.md" + - print_cuda_memory_usage: "zeta/utils/print_cuda_memory_usage.md" + - once: "zeta/utils/once.md" + - default: "zeta/utils/default.md" + - gumbel_noise: "zeta/utils/gumbel_noise.md" + - pad_at_dim: "zeta/utils/pad_at_dim.md" + - init_zero_: "zeta/utils/init_zero_.md" + - top_p: "zeta/utils/top_p.md" + - cast_if_src_dtype: "zeta/utils/cast_if_src_dtype.md" + - disable_warnings_and_logs: "zeta/utils/disable_warnings_and_logs.md" + - save_load_wrapper: "zeta/utils/save_load_wrapper.md" + - get_sinusoid_encoding_table: "zeta/utils/get_sinusoid_encoding_table.md" + - main: "zeta/utils/main.md" + - string_begins_with: "zeta/utils/string_begins_with.md" + - gif_to_tensor: "zeta/utils/gif_to_tensor.md" + - l2norm: "zeta/utils/l2norm.md" + - save_load: "zeta/utils/save_load.md" + - log: "zeta/utils/log.md" + - module_device: "zeta/utils/module_device.md" + - print_num_params: "zeta/utils/print_num_params.md" + - top_a: "zeta/utils/top_a.md" + - interpolate_pos_encoding_2d: "zeta/utils/interpolate_pos_encoding_2d.md" + - exists: "zeta/utils/exists.md" + - cosine_beta_schedule: "zeta/utils/cosine_beta_schedule.md" + - track_cuda_memory: "zeta/utils/track_cuda_memory.md" + - maybe: "zeta/utils/maybe.md" + - save_memory_snapshot: "zeta/utils/save_memory_snapshot.md" + - top_k: "zeta/utils/top_k.md" + - print_main: "zeta/utils/print_main.md" + - pick_and_pop: "zeta/utils/pick_and_pop.md" + - track_cuda_memory_usage: "zeta/utils/track_cuda_memory_usage.md" + - group_dict_by_key: "zeta/utils/group_dict_by_key.md" + - video_tensor_to_gift: "zeta/utils/video_tensor_to_gift.md" + - zeta.ops: + - Misc: + - img_compose_decompose: "zeta/ops/img_compose_decompose.md" + - img_transpose_2daxis: "zeta/ops/img_transpose_2daxis.md" + - img_transpose: "zeta/ops/img_transpose.md" + - img_order_of_axes: "zeta/ops/img_order_of_axes.md" + - mos: "zeta/ops/mos.md" + - merge_small_dims: "zeta/ops/merge_small_dims.md" + - multi_dim_cat: "zeta/ops/multi_dim_cat.md" + - img_compose_bw: "zeta/ops/img_compose_bw.md" + - squeeze_2d_new: "zeta/ops/squeeze_2d_new.md" + - temp_softmax: "zeta/ops/temp_softmax.md" + - gumbelmax: "zeta/ops/gumbelmax.md" + - _matrix_inverse_root_newton: "zeta/ops/_matrix_inverse_root_newton.md" + - compute_matrix_root_inverse_residuals: "zeta/ops/compute_matrix_root_inverse_residuals.md" + - matrix_root_diagonal: "zeta/ops/matrix_root_diagonal.md" + - sparse_softmax: "zeta/ops/sparse_softmax.md" + - reshape_audio_to_text: "zeta/ops/reshape_audio_to_text.md" + - local_softmax: "zeta/ops/local_softmax.md" + - softmaxes: "zeta/ops/softmaxes.md" + - _matrix_root_eigen: "zeta/ops/_matrix_root_eigen.md" + - main: "zeta/ops/main.md" + - norm_exp_softmax: "zeta/ops/norm_exp_softmax.md" + - multi_dim_split: "zeta/ops/multi_dim_split.md" + - img_width_to_height: "zeta/ops/img_width_to_height.md" + - fast_softmax: "zeta/ops/fast_softmax.md" + - standard_softmax: "zeta/ops/standard_softmax.md" + - unitwise_norm: "zeta/ops/unitwise_norm.md" + - reshape_video_to_text: "zeta/ops/reshape_video_to_text.md" + - img_decompose: "zeta/ops/img_decompose.md" + - unsqueeze_2d_new: "zeta/ops/unsqueeze_2d_new.md" + - reshape_img_to_text: "zeta/ops/reshape_img_to_text.md" + - channel_shuffle_new: "zeta/ops/channel_shuffle_new.md" + - matrix_inverse_root: "zeta/ops/matrix_inverse_root.md" + - sparsemax: "zeta/ops/sparsemax.md" + - gram_matrix_new: "zeta/ops/gram_matrix_new.md" + - logit_scaled_softmax: "zeta/ops/logit_scaled_softmax.md" + - selu_softmax: "zeta/ops/selu_softmax.md" + - reshape_text_to_img: "zeta/ops/reshape_text_to_img.md" + - zeta.optim: + - Optimizers: + - StableAdamWUnfused: "zeta/optims/adamw.md" + - GradientAscent: "zeta/optims/ga.md" + - DecoupledLionW: "zeta/training/optimizers/decoupled_lion.md" + - SophiaG: "zeta/training/optimizers/sophia.md" + - zeta.training: + - Training: + - fsdp: "zeta/training/fsdp.md" + - ParallelWrapper: "zeta/training/parallel_wrapper.md" + - train: "zeta/training/train.md" + - zeta.models: + - Language and MultiModal: + - vit: "zeta/models/vit.md" + - gpt4multimodal: "zeta/models/gpt4multimodal.md" + - maxvit: "zeta/models/maxvit.md" + - llama2: "zeta/models/llama2.md" + - gpt4: "zeta/models/gpt4.md" + - andromeda: "zeta/models/andromeda.md" + - basemodel: "zeta/models/basemodel.md" + - palme: "zeta/models/palme.md" + - megavit: "zeta/models/megavit.md" + - navit: "zeta/models/navit.md" + - zeta.structs: + - Structures: + - Decoder: "zeta/nn/architecture/decoder.md" + - Transformer: "zeta/nn/architecture/transformer.md" + - paralleltransformerblock: "paralleltransformerblock.md" + - zeta.quant: + - Quantization Algorithms: + - QUIK: "zeta/quant/quik.md" + - BitLinear: "zeta/quant/bitlinear.md" + - niva: "zeta/quant/niva.md" + - zeta.rl: + - DPO: "zeta/rl/dpo.md" \ No newline at end of file diff --git a/training/gan/gan.py b/training/gan/gan.py new file mode 100644 index 00000000..69c88cfd --- /dev/null +++ b/training/gan/gan.py @@ -0,0 +1,167 @@ +import logging +import os +import shutil +from typing import Tuple + +import torch +import torch.nn as nn +import torch.optim as optim +import torchaudio +from datasets import load_dataset +from torch.utils.data import DataLoader + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +class Generator(nn.Module): + def __init__(self, latent_dim: int, output_dim: int): + super().__init__() + self.model = nn.Sequential( + nn.Linear(latent_dim, 256), + nn.LeakyReLU(0.2), + nn.Linear(256, 512), + nn.LeakyReLU(0.2), + nn.Linear(512, 1024), + nn.LeakyReLU(0.2), + nn.Linear(1024, output_dim), + nn.Tanh(), + ) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + return self.model(z) + + +class Discriminator(nn.Module): + def __init__(self, input_dim: int): + super().__init__() + self.model = nn.Sequential( + nn.Linear(input_dim, 1024), + nn.LeakyReLU(0.2), + nn.Linear(1024, 512), + nn.LeakyReLU(0.2), + nn.Linear(512, 256), + nn.LeakyReLU(0.2), + nn.Linear(256, 1), + nn.Sigmoid(), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.model(x) + + +def train_gan( + generator: nn.Module, + discriminator: nn.Module, + dataloader: DataLoader, + num_epochs: int, + latent_dim: int, + device: torch.device, +) -> Tuple[nn.Module, nn.Module]: + criterion = nn.BCELoss() + g_optimizer = optim.Adam( + generator.parameters(), lr=0.0002, betas=(0.5, 0.999) + ) + d_optimizer = optim.Adam( + discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999) + ) + + for epoch in range(num_epochs): + for i, batch in enumerate(dataloader): + real_samples = batch["audio"].to(device) + batch_size = real_samples.size(0) + + d_optimizer.zero_grad() + real_labels = torch.ones(batch_size, 1).to(device) + fake_labels = torch.zeros(batch_size, 1).to(device) + + d_real_output = discriminator(real_samples) + d_real_loss = criterion(d_real_output, real_labels) + + z = torch.randn(batch_size, latent_dim).to(device) + fake_samples = generator(z) + d_fake_output = discriminator(fake_samples.detach()) + d_fake_loss = criterion(d_fake_output, fake_labels) + + d_loss = d_real_loss + d_fake_loss + d_loss.backward() + d_optimizer.step() + + g_optimizer.zero_grad() + z = torch.randn(batch_size, latent_dim).to(device) + fake_samples = generator(z) + d_output = discriminator(fake_samples) + g_loss = criterion(d_output, real_labels) + g_loss.backward() + g_optimizer.step() + + logger.info( + f"Epoch [{epoch+1}/{num_epochs}] D_loss: {d_loss.item():.4f} G_loss: {g_loss.item():.4f}" + ) + + return generator, discriminator + + +def main(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logger.info(f"Using device: {device}") + + # Clear the dataset cache + cache_dir = os.path.expanduser("~/.cache/huggingface/datasets") + if os.path.exists(cache_dir): + shutil.rmtree(cache_dir) + logger.info("Cleared dataset cache") + + # Try to load the dataset with retries + max_retries = 3 + for attempt in range(max_retries): + try: + dataset = load_dataset( + "mozilla-foundation/common_voice_11_0", + "en", + split="train[:1000]", + trust_remote_code=True, + ) + logger.info(f"Dataset loaded: {len(dataset)} samples") + break + except Exception as e: + if attempt < max_retries - 1: + logger.warning( + f"Dataset loading failed (attempt {attempt + 1}/{max_retries}). Retrying..." + ) + else: + logger.error("Failed to load dataset after multiple attempts.") + raise e + + def preprocess_audio(example): + audio = example["audio"]["array"] + resampled_audio = torchaudio.transforms.Resample( + example["audio"]["sampling_rate"], 16000 + )(torch.tensor(audio)) + return {"audio": resampled_audio.flatten()} + + dataset = dataset.map(preprocess_audio, remove_columns=dataset.column_names) + dataset.set_format(type="torch", columns=["audio"]) + + dataloader = DataLoader(dataset, batch_size=64, shuffle=True) + + latent_dim = 100 + output_dim = 16000 + + generator = Generator(latent_dim, output_dim).to(device) + discriminator = Discriminator(output_dim).to(device) + + num_epochs = 50 + generator, discriminator = train_gan( + generator, discriminator, dataloader, num_epochs, latent_dim, device + ) + + torch.save(generator.state_dict(), "generator.pth") + torch.save(discriminator.state_dict(), "discriminator.pth") + logger.info("Models saved successfully") + + +if __name__ == "__main__": + main() diff --git a/training/gan/new_gan.py b/training/gan/new_gan.py new file mode 100644 index 00000000..6bfe0f64 --- /dev/null +++ b/training/gan/new_gan.py @@ -0,0 +1,341 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import List, Tuple + + +class ResidualBlock(nn.Module): + """Residual block with dilated convolutions.""" + + def __init__(self, channels: int, dilation: int = 1): + super().__init__() + self.conv1 = nn.Conv1d( + channels, channels, kernel_size=3, dilation=dilation, padding="same" + ) + self.conv2 = nn.Conv1d( + channels, channels, kernel_size=3, dilation=dilation, padding="same" + ) + self.norm1 = nn.GroupNorm( + num_groups=channels // 4, num_channels=channels + ) + self.norm2 = nn.GroupNorm( + num_groups=channels // 4, num_channels=channels + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = F.leaky_relu(self.norm1(self.conv1(x)), 0.1) + x = F.leaky_relu(self.norm2(self.conv2(x)), 0.1) + return x + residual + + +class Generator(nn.Module): + """Dynamic generator for speech synthesis.""" + + def __init__( + self, + input_channels: int, + base_channels: int = 512, + upsample_rates: List[int] = [8, 8, 2, 2], + ): + super().__init__() + self.input_conv = nn.Conv1d( + input_channels, base_channels, kernel_size=7, padding=3 + ) + self.norm = nn.GroupNorm( + num_groups=base_channels // 4, num_channels=base_channels + ) + + self.upsample_layers = nn.ModuleList() + current_channels = base_channels + for rate in upsample_rates: + out_channels = current_channels // 2 + self.upsample_layers.append( + nn.ConvTranspose1d( + current_channels, + out_channels, + kernel_size=rate * 2, + stride=rate, + padding=rate // 2, + ) + ) + current_channels = out_channels + + self.resblocks = nn.ModuleList( + [ResidualBlock(current_channels, dilation=3**i) for i in range(3)] + ) + + self.output_conv = nn.Conv1d( + current_channels, 1, kernel_size=7, padding=3 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.leaky_relu(self.norm(self.input_conv(x)), 0.1) + + for upsample in self.upsample_layers: + x = F.leaky_relu(upsample(x), 0.1) + + for resblock in self.resblocks: + x = resblock(x) + + return torch.tanh(self.output_conv(x)) + + +class PeriodDiscriminator(nn.Module): + """Period-based discriminator.""" + + def __init__(self, period: int, base_channels: int = 32): + super().__init__() + self.period = period + self.layers = nn.ModuleList( + [ + nn.Conv2d(1, base_channels, (5, 1), (3, 1), padding=(2, 0)), + nn.Conv2d( + base_channels, + base_channels * 2, + (5, 1), + (3, 1), + padding=(2, 0), + ), + nn.Conv2d( + base_channels * 2, + base_channels * 4, + (5, 1), + (3, 1), + padding=(2, 0), + ), + nn.Conv2d( + base_channels * 4, + base_channels * 4, + (5, 1), + 1, + padding=(2, 0), + ), + nn.Conv2d(base_channels * 4, 1, (3, 1), 1, padding=(1, 0)), + ] + ) + + def forward( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + batch_size, _, time = x.shape + if time % self.period != 0: + pad_size = self.period - (time % self.period) + x = F.pad(x, (0, pad_size)) + x = x.view(batch_size, 1, -1, self.period) + + features = [] + for layer in self.layers[:-1]: + x = F.leaky_relu(layer(x), 0.1) + features.append(x) + x = self.layers[-1](x) + features.append(x) + return x.view(batch_size, -1, 1), features + + +class MultiPeriodDiscriminator(nn.Module): + """Multi-period discriminator.""" + + def __init__(self, periods: List[int] = [2, 3, 5, 7, 11]): + super().__init__() + self.discriminators = nn.ModuleList( + [PeriodDiscriminator(period) for period in periods] + ) + + def forward( + self, x: torch.Tensor + ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: + outputs, features = [], [] + for disc in self.discriminators: + output, feature = disc(x) + outputs.append(output) + features.append(feature) + return outputs, features + + +class ScaleDiscriminator(nn.Module): + """Scale discriminator with multiple resolutions.""" + + def __init__(self, base_channels: int = 16): + super().__init__() + self.layers = nn.ModuleList( + [ + nn.Conv1d(1, base_channels, 15, stride=1, padding=7), + nn.Conv1d( + base_channels, + base_channels * 2, + 41, + stride=4, + padding=20, + groups=base_channels, + ), + nn.Conv1d( + base_channels * 2, + base_channels * 4, + 41, + stride=4, + padding=20, + groups=base_channels * 2, + ), + nn.Conv1d( + base_channels * 4, + base_channels * 8, + 41, + stride=4, + padding=20, + groups=base_channels * 4, + ), + nn.Conv1d( + base_channels * 8, + base_channels * 16, + 41, + stride=4, + padding=20, + groups=base_channels * 8, + ), + nn.Conv1d( + base_channels * 16, + base_channels * 32, + 5, + stride=1, + padding=2, + ), + nn.Conv1d(base_channels * 32, 1, 3, stride=1, padding=1), + ] + ) + + def forward( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + features = [] + for layer in self.layers[:-1]: + x = F.leaky_relu(layer(x), 0.1) + features.append(x) + x = self.layers[-1](x) + features.append(x) + return x, features + + +class MultiScaleDiscriminator(nn.Module): + """Multi-scale discriminator.""" + + def __init__(self, num_scales: int = 3): + super().__init__() + self.discriminators = nn.ModuleList( + [ScaleDiscriminator() for _ in range(num_scales)] + ) + self.downsample = nn.AvgPool1d( + kernel_size=2, stride=2, padding=1, count_include_pad=False + ) + + def forward( + self, x: torch.Tensor + ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: + outputs, features = [], [] + for disc in self.discriminators: + output, feature = disc(x) + outputs.append(output) + features.append(feature) + x = self.downsample(x) + return outputs, features + + +class SpeechSynthesisGAN(nn.Module): + """State-of-the-art GAN for speech synthesis.""" + + def __init__(self, input_channels: int): + super().__init__() + self.generator = Generator(input_channels) + self.mpd = MultiPeriodDiscriminator() + self.msd = MultiScaleDiscriminator() + + def forward( + self, mel_spectrogram: torch.Tensor + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + """ + Forward pass of the SpeechSynthesisGAN. + + Args: + mel_spectrogram (torch.Tensor): Input mel spectrogram of shape (batch_size, input_channels, time_steps) + + Returns: + Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + Generated waveform, MPD outputs, and MSD outputs + """ + waveform = self.generator(mel_spectrogram) + mpd_outputs, _ = self.mpd(waveform) + msd_outputs, _ = self.msd(waveform) + return waveform, mpd_outputs, msd_outputs + + +def feature_loss( + real_features: List[List[torch.Tensor]], + fake_features: List[List[torch.Tensor]], +) -> torch.Tensor: + """ + Compute the feature matching loss between real and fake features. + + Args: + real_features (List[List[torch.Tensor]]): Features from the discriminator for real audio + fake_features (List[List[torch.Tensor]]): Features from the discriminator for generated audio + + Returns: + torch.Tensor: Feature matching loss + """ + loss = 0 + for real_feats, fake_feats in zip(real_features, fake_features): + for real_feat, fake_feat in zip(real_feats, fake_feats): + loss += F.l1_loss(fake_feat, real_feat) + return loss + + +def discriminator_loss( + real_outputs: List[torch.Tensor], fake_outputs: List[torch.Tensor] +) -> torch.Tensor: + """ + Compute the discriminator loss. + + Args: + real_outputs (List[torch.Tensor]): Discriminator outputs for real audio + fake_outputs (List[torch.Tensor]): Discriminator outputs for generated audio + + Returns: + torch.Tensor: Discriminator loss + """ + loss = 0 + for real_output, fake_output in zip(real_outputs, fake_outputs): + r_loss = torch.mean((1 - real_output) ** 2) + f_loss = torch.mean(fake_output**2) + loss += r_loss + f_loss + return loss + + +def generator_loss(fake_outputs: List[torch.Tensor]) -> torch.Tensor: + """ + Compute the generator loss. + + Args: + fake_outputs (List[torch.Tensor]): Discriminator outputs for generated audio + + Returns: + torch.Tensor: Generator loss + """ + loss = 0 + for fake_output in fake_outputs: + loss += torch.mean((1 - fake_output) ** 2) + return loss + + +# Example usage +if __name__ == "__main__": + input_channels = 80 + batch_size = 4 + time_steps = 128 + + model = SpeechSynthesisGAN(input_channels) + mel_input = torch.randn(batch_size, input_channels, time_steps) + + waveform, mpd_outputs, msd_outputs = model(mel_input) + print(f"Generated waveform shape: {waveform.shape}") + print(f"Number of MPD outputs: {len(mpd_outputs)}") + print(f"Number of MSD outputs: {len(msd_outputs)}") diff --git a/training/vis/model.py b/training/vis/model.py new file mode 100644 index 00000000..9d31f881 --- /dev/null +++ b/training/vis/model.py @@ -0,0 +1,110 @@ +import torch +import torch.nn as nn +import plotly.graph_objects as go +import numpy as np + + +# Helper function to convert Cartesian coordinates to spherical coordinates +def spherical_coordinates(num_points, radius): + points = [] + phi = np.pi * (3.0 - np.sqrt(5.0)) # golden angle in radians + + for i in range(num_points): + y = 1 - (i / float(num_points - 1)) * 2 # y goes from 1 to -1 + radius_y = np.sqrt(1 - y * y) # radius at y + theta = phi * i # golden angle increment + + x = np.cos(theta) * radius_y * radius + z = np.sin(theta) * radius_y * radius + points.append([x, y * radius, z]) + + return np.array(points) + + +# Dynamic extraction of model layers and parameters +def extract_model_layers(model): + layers = [] + params_per_layer = [] + + def recursive_layer_extraction(layer, layer_name=""): + if isinstance(layer, nn.Module): + num_params = sum( + p.numel() for p in layer.parameters() if p.requires_grad + ) + if num_params > 0: + layers.append( + { + "name": layer_name, + "type": layer.__class__.__name__, + "num_params": num_params, + } + ) + params_per_layer.append(num_params) + # Traverse through children layers + for name, child in layer.named_children(): + layer_full_name = f"{layer_name}.{name}" if layer_name else name + recursive_layer_extraction(child, layer_full_name) + + recursive_layer_extraction(model) + return layers, params_per_layer + + +# Function to create a dynamic 3D visualization of the model +def visualize_model_as_ball(model, max_points_per_layer=100): + layers, params_per_layer = extract_model_layers(model) + + # Visualize each layer as a concentric sphere with points + fig = go.Figure() + max_radius = 10 # Maximum radius for the outer layer + radius_step = max_radius / len(layers) + + # Create spheres and add points to each sphere + for i, layer in enumerate(layers): + radius = radius_step * (i + 1) + num_params = min( + params_per_layer[i], max_points_per_layer + ) # Limit points per layer for clarity + points = spherical_coordinates(num_params, radius) + + fig.add_trace( + go.Scatter3d( + x=points[:, 0], + y=points[:, 1], + z=points[:, 2], + mode="markers", + marker=dict( + size=5, + color=np.linspace(0, 1, num_params), + colorscale="Viridis", + opacity=0.8, + ), + name=f'Layer {i + 1}: {layer["type"]}', + hovertext=[ + f'{layer["name"]}, Param {j + 1}' for j in range(num_params) + ], + ) + ) + + # Configure layout + fig.update_layout( + scene=dict( + xaxis=dict(title="X", showgrid=False, zeroline=False), + yaxis=dict(title="Y", showgrid=False, zeroline=False), + zaxis=dict(title="Z", showgrid=False, zeroline=False), + ), + title="Dynamic Model Visualization as Ball", + showlegend=True, + ) + + fig.show() + + +# Example usage with a pretrained model +if __name__ == "__main__": + # Load a pretrained model (e.g., ResNet18 from torchvision) + model = torch.hub.load( + "pytorch/vision:v0.10.0", "resnet18", pretrained=True + ) + + # Visualize the model + visualize_model_as_ball(model) diff --git a/training/vis/t2.py b/training/vis/t2.py new file mode 100644 index 00000000..62e23da9 --- /dev/null +++ b/training/vis/t2.py @@ -0,0 +1,175 @@ +import plotly.graph_objects as go +import numpy as np +from transformers import BertModel, BertTokenizer + + +# Helper function to convert Cartesian coordinates to spherical coordinates +def spherical_coordinates(num_points, radius): + points = [] + phi = np.pi * (3.0 - np.sqrt(5.0)) # golden angle in radians + + for i in range(num_points): + y = 1 - (i / float(num_points - 1)) * 2 # y goes from 1 to -1 + radius_y = np.sqrt(1 - y * y) # radius at y + theta = phi * i # golden angle increment + + x = np.cos(theta) * radius_y * radius + z = np.sin(theta) * radius_y * radius + points.append([x, y * radius, z]) + + return np.array(points) + + +# Function to extract transformer layers and parameters (specific for BERT here) +def extract_transformer_layers(model): + layers = [] + params_per_layer = [] + + for name, param in model.named_parameters(): + num_params = param.numel() + if num_params > 0: + layers.append( + {"name": name, "num_params": num_params, "shape": param.shape} + ) + params_per_layer.append(num_params) + + return layers, params_per_layer + + +# Function to visualize the transformer model components as a 3D ball structure +def visualize_transformer_as_ball(model, max_points_per_layer=100): + layers, params_per_layer = extract_transformer_layers(model) + + # Visualize each layer as a concentric sphere with points + fig = go.Figure() + max_radius = 10 # Maximum radius for the outer layer + radius_step = max_radius / len(layers) + + # Create spheres and add points to each sphere + for i, layer in enumerate(layers): + radius = radius_step * (i + 1) + num_params = min( + params_per_layer[i], max_points_per_layer + ) # Limit points per layer for clarity + points = spherical_coordinates(num_params, radius) + + fig.add_trace( + go.Scatter3d( + x=points[:, 0], + y=points[:, 1], + z=points[:, 2], + mode="markers", + marker=dict( + size=5, + color=np.linspace(0, 1, num_params), + colorscale="Viridis", + opacity=0.8, + ), + name=f'Layer {i + 1}: {layer["name"]}', + hovertext=[ + f'{layer["name"]}, Shape: {layer["shape"]}' + for j in range(num_params) + ], + ) + ) + + # Configure layout + fig.update_layout( + scene=dict( + xaxis=dict(title="X", showgrid=False, zeroline=False), + yaxis=dict(title="Y", showgrid=False, zeroline=False), + zaxis=dict(title="Z", showgrid=False, zeroline=False), + ), + title="Transformer Model Visualization as Ball", + showlegend=True, + ) + + fig.show() + + +# Visualizing Attention Weights +def visualize_attention_weights( + attention_weights, sequence_length, attention_heads +): + fig = go.Figure() + + # For each attention head, draw the attention matrix as connections between tokens + for head in range(attention_heads): + attention = ( + attention_weights[0, head].detach().numpy() + ) # Get attention weights for this head + points = spherical_coordinates( + sequence_length, radius=5 + head + ) # Slightly different radii for heads + + # Add nodes (tokens) + fig.add_trace( + go.Scatter3d( + x=points[:, 0], + y=points[:, 1], + z=points[:, 2], + mode="markers", + marker=dict( + size=8, + color=np.linspace(0, 1, sequence_length), + colorscale="Plasma", + opacity=0.8, + ), + name=f"Attention Head {head + 1}", + hovertext=[f"Token {i + 1}" for i in range(sequence_length)], + ) + ) + + # Add edges (attention weights between tokens) + for i in range(sequence_length): + for j in range(sequence_length): + if ( + i != j and attention[i, j] > 0.1 + ): # Only show significant attention weights + fig.add_trace( + go.Scatter3d( + x=[points[i, 0], points[j, 0]], + y=[points[i, 1], points[j, 1]], + z=[points[i, 2], points[j, 2]], + mode="lines", + line=dict(color="rgba(0, 0, 255, 0.2)", width=2), + hoverinfo="none", + showlegend=False, + ) + ) + + fig.update_layout( + scene=dict( + xaxis=dict(title="X", showgrid=False, zeroline=False), + yaxis=dict(title="Y", showgrid=False, zeroline=False), + zaxis=dict(title="Z", showgrid=False, zeroline=False), + ), + title="Multi-Head Attention Weights", + showlegend=True, + ) + + fig.show() + + +# Example usage with a pretrained BERT model +if __name__ == "__main__": + # Load a pretrained transformer model (e.g., BERT) from Huggingface + model = BertModel.from_pretrained("bert-base-uncased") + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + + # Example sentence + sentence = "Transformers are powerful models for NLP tasks." + inputs = tokenizer(sentence, return_tensors="pt") + outputs = model(**inputs, output_attentions=True) + + # Visualize the model structure + visualize_transformer_as_ball(model) + + # Visualize attention weights (BERT has 12 attention heads by default) + attention_weights = outputs.attentions[ + -1 + ] # Get the attention weights from the last layer + sequence_length = inputs["input_ids"].shape[1] + visualize_attention_weights( + attention_weights, sequence_length, attention_heads=12 + ) diff --git a/training/vis/t3.py b/training/vis/t3.py new file mode 100644 index 00000000..2774feed --- /dev/null +++ b/training/vis/t3.py @@ -0,0 +1,103 @@ +import plotly.graph_objects as go +from torchvision import models +import numpy as np + + +def visualize_model_3d(model): + def get_layer_info(module, depth=0, path=""): + layers = [] + for name, child in module.named_children(): + child_path = f"{path}/{name}".lstrip("/") + child_info = { + "name": child_path, + "type": child.__class__.__name__, + "params": sum( + p.numel() for p in child.parameters() if p.requires_grad + ), + "depth": depth, + } + layers.append(child_info) + layers.extend(get_layer_info(child, depth + 1, child_path)) + return layers + + layers = get_layer_info(model) + + fig = go.Figure() + + max_depth = max(layer["depth"] for layer in layers) + max_params = max(layer["params"] for layer in layers) + + # Calculate positions on a sphere + phi = np.linspace(0, np.pi, len(layers)) + theta = np.linspace(0, 2 * np.pi, len(layers)) + + # Create nodes + for i, layer in enumerate(layers): + r = (layer["depth"] + 1) / (max_depth + 1) # Radius based on depth + x = r * np.sin(phi[i]) * np.cos(theta[i]) + y = r * np.sin(phi[i]) * np.sin(theta[i]) + z = r * np.cos(phi[i]) + + size = (layer["params"] / max_params * 20 + 5) if max_params > 0 else 5 + + fig.add_trace( + go.Scatter3d( + x=[x], + y=[y], + z=[z], + mode="markers", + marker=dict( + size=size, + color=layer["depth"], + colorscale="Viridis", + opacity=0.8, + ), + text=f"{layer['name']}
{layer['type']}
Params: {layer['params']}", + hoverinfo="text", + ) + ) + + # Create edges + for i in range(1, len(layers)): + prev_layer = layers[i - 1] + curr_layer = layers[i] + + r_prev = (prev_layer["depth"] + 1) / (max_depth + 1) + r_curr = (curr_layer["depth"] + 1) / (max_depth + 1) + + x_prev = r_prev * np.sin(phi[i - 1]) * np.cos(theta[i - 1]) + y_prev = r_prev * np.sin(phi[i - 1]) * np.sin(theta[i - 1]) + z_prev = r_prev * np.cos(phi[i - 1]) + + x_curr = r_curr * np.sin(phi[i]) * np.cos(theta[i]) + y_curr = r_curr * np.sin(phi[i]) * np.sin(theta[i]) + z_curr = r_curr * np.cos(phi[i]) + + fig.add_trace( + go.Scatter3d( + x=[x_prev, x_curr], + y=[y_prev, y_curr], + z=[z_prev, z_curr], + mode="lines", + line=dict(color="rgba(100,100,100,0.5)", width=2), + hoverinfo="none", + ) + ) + + fig.update_layout( + title="Spherical 3D Model Architecture Visualization", + scene=dict( + xaxis_title="X", yaxis_title="Y", zaxis_title="Z", aspectmode="data" + ), + width=900, + height=700, + margin=dict(r=0, l=0, b=0, t=40), + ) + + return fig + + +# Example usage +model = models.vgg16(pretrained=True) +fig = visualize_model_3d(model) +fig.show() diff --git a/training/vis/t4.py b/training/vis/t4.py new file mode 100644 index 00000000..8fafe18f --- /dev/null +++ b/training/vis/t4.py @@ -0,0 +1,232 @@ +import plotly.graph_objects as go +import numpy as np + + +def visualize_parameters( + model, dim="3d", color_by="mean", group_by=None, animate=False +): + params_info = [] + + for name, param in model.named_parameters(): + if param.requires_grad: + layer_name = name.split(".")[0] if group_by == "layer" else "All" + params_info.append( + { + "name": name, + "layer": layer_name, + "shape": param.shape, + "numel": param.numel(), + "mean": param.data.mean().item(), + "std": param.data.std().item(), + "grad_mean": ( + param.grad.mean().item() + if param.grad is not None + else 0 + ), + "grad_std": ( + param.grad.std().item() if param.grad is not None else 0 + ), + "type": param.__class__.__name__, + } + ) + + params_info.sort(key=lambda x: x["numel"], reverse=True) + + if dim == "3d": + fig = create_3d_plot(params_info, color_by) + else: + fig = create_2d_plot(params_info, color_by) + + if group_by == "layer": + fig = group_by_layer(fig, params_info) + + add_model_summary(fig, model, params_info) + + if animate: + fig = add_animation(fig, model) + + return fig + + +def create_3d_plot(params_info, color_by): + fig = go.Figure() + + n_params = len(params_info) + grid_size = int(np.ceil(np.cbrt(n_params))) + x, y, z = np.mgrid[0:grid_size, 0:grid_size, 0:grid_size] + x, y, z = ( + x.flatten()[:n_params], + y.flatten()[:n_params], + z.flatten()[:n_params], + ) + + max_numel = max(p["numel"] for p in params_info) + + marker_shapes = {"Weight": "circle", "Bias": "square"} + + for i, param in enumerate(params_info): + size = np.log1p(param["numel"]) / np.log1p(max_numel) * 20 + 5 + color = param[color_by] + + fig.add_trace( + go.Scatter3d( + x=[x[i]], + y=[y[i]], + z=[z[i]], + mode="markers", + marker=dict( + size=size, + color=color, + colorscale="Viridis", + opacity=0.8, + symbol=marker_shapes.get(param["type"], "circle"), + colorbar=dict(title=color_by.capitalize()), + ), + text=create_hover_text(param), + hoverinfo="text", + ) + ) + + fig.update_layout( + title="3D Parameter Visualization", + scene=dict( + xaxis_title="X", yaxis_title="Y", zaxis_title="Z", aspectmode="cube" + ), + width=900, + height=700, + margin=dict(r=0, l=0, b=0, t=40), + ) + + return fig + + +def create_2d_plot(params_info, color_by): + fig = go.Figure() + + n_params = len(params_info) + grid_size = int(np.ceil(np.sqrt(n_params))) + x, y = np.mgrid[0:grid_size, 0:grid_size] + x, y = x.flatten()[:n_params], y.flatten()[:n_params] + + max_numel = max(p["numel"] for p in params_info) + + marker_shapes = {"Weight": "circle", "Bias": "square"} + + for i, param in enumerate(params_info): + size = np.log1p(param["numel"]) / np.log1p(max_numel) * 20 + 5 + color = param[color_by] + + fig.add_trace( + go.Scatter( + x=[x[i]], + y=[y[i]], + mode="markers", + marker=dict( + size=size, + color=color, + colorscale="Viridis", + opacity=0.8, + symbol=marker_shapes.get(param["type"], "circle"), + colorbar=dict(title=color_by.capitalize()), + ), + text=create_hover_text(param), + hoverinfo="text", + ) + ) + + fig.update_layout( + title="2D Parameter Visualization", + xaxis_title="X", + yaxis_title="Y", + width=900, + height=700, + margin=dict(r=0, l=0, b=0, t=40), + ) + + return fig + + +def create_hover_text(param): + return ( + f"Name: {param['name']}
" + f"Shape: {param['shape']}
" + f"Elements: {param['numel']}
" + f"Mean: {param['mean']:.4f}
" + f"Std: {param['std']:.4f}
" + f"Grad Mean: {param['grad_mean']:.4f}
" + f"Grad Std: {param['grad_std']:.4f}
" + f"Type: {param['type']}" + ) + + +def group_by_layer(fig, params_info): + layers = sorted(set(p["layer"] for p in params_info)) + fig.update_layout( + updatemenus=[ + { + "buttons": [ + { + "label": layer, + "method": "update", + "args": [ + { + "visible": [ + p["layer"] == layer for p in params_info + ] + } + ], + } + for layer in layers + ] + + [ + { + "label": "All", + "method": "update", + "args": [{"visible": [True] * len(params_info)}], + } + ], + "direction": "down", + "showactive": True, + } + ] + ) + return fig + + +def add_model_summary(fig, model, params_info): + total_params = sum(p["numel"] for p in params_info) + summary = f"Total parameters: {total_params:,}
" + summary += f"Layers: {len(set(p['layer'] for p in params_info))}
" + summary += f"Model type: {model.__class__.__name__}" + + fig.add_annotation( + x=0.95, + y=0.95, + xref="paper", + yref="paper", + text=summary, + showarrow=False, + font=dict(size=12), + align="right", + bgcolor="rgba(255,255,255,0.8)", + bordercolor="black", + borderwidth=1, + ) + return fig + + +def add_animation(fig, model): + # This is a placeholder for animation functionality + # You would need to implement logic to update parameter values over time + return fig + + +# Example usage +if __name__ == "__main__": + from torchvision import models + + model = models.vgg16(pretrained=True) + fig = visualize_parameters( + model, dim="3d", color_by="mean", group_by="layer" + ) + fig.show() diff --git a/training/vis/t5.py b/training/vis/t5.py new file mode 100644 index 00000000..dbb2542a --- /dev/null +++ b/training/vis/t5.py @@ -0,0 +1,318 @@ +import plotly.graph_objects as go +import plotly.subplots as sp +import numpy as np +import plotly.io as pio + + +def visualize_parameters( + model, + dim="3d", + color_by="mean", + group_by=None, + colorscale="Viridis", + compare_model=None, +): + params_info = get_params_info(model) + if compare_model: + params_info_2 = get_params_info(compare_model) + return create_comparison_plot( + params_info, params_info_2, dim, color_by, colorscale + ) + + if dim == "3d": + fig = create_3d_plot(params_info, color_by, colorscale) + elif dim == "2d": + fig = create_2d_plot(params_info, color_by, colorscale) + else: + fig = create_heatmap(params_info, color_by) + + if group_by == "layer": + fig = group_by_layer(fig, params_info) + + add_model_summary(fig, model, params_info) + add_distribution_histogram(fig, params_info) + + return fig + + +def get_params_info(model): + params_info = [] + for name, param in model.named_parameters(): + if param.requires_grad: + layer_name = name.split(".")[0] + params_info.append( + { + "name": name, + "layer": layer_name, + "shape": param.shape, + "numel": param.numel(), + "mean": param.data.mean().item(), + "std": param.data.std().item(), + "grad_mean": ( + param.grad.mean().item() + if param.grad is not None + else 0 + ), + "grad_std": ( + param.grad.std().item() if param.grad is not None else 0 + ), + "type": param.__class__.__name__, + "data": param.data.flatten().tolist(), + } + ) + params_info.sort(key=lambda x: x["numel"], reverse=True) + return params_info + + +def create_3d_plot(params_info, color_by, colorscale): + fig = go.Figure() + + n_params = len(params_info) + grid_size = int(np.ceil(np.cbrt(n_params))) + x, y, z = np.mgrid[0:grid_size, 0:grid_size, 0:grid_size] + x, y, z = ( + x.flatten()[:n_params], + y.flatten()[:n_params], + z.flatten()[:n_params], + ) + + max_numel = max(p["numel"] for p in params_info) + + marker_shapes = {"Weight": "circle", "Bias": "square"} + + for i, param in enumerate(params_info): + size = np.log1p(param["numel"]) / np.log1p(max_numel) * 20 + 5 + color = param[color_by] + + fig.add_trace( + go.Scatter3d( + x=[x[i]], + y=[y[i]], + z=[z[i]], + mode="markers", + marker=dict( + size=size, + color=color, + colorscale=colorscale, + opacity=0.8, + symbol=marker_shapes.get(param["type"], "circle"), + colorbar=dict(title=color_by.capitalize()), + ), + text=create_hover_text(param), + hoverinfo="text", + ) + ) + + fig.update_layout( + title="3D Parameter Visualization", + scene=dict( + xaxis_title="X", yaxis_title="Y", zaxis_title="Z", aspectmode="cube" + ), + width=900, + height=700, + margin=dict(r=0, l=0, b=0, t=40), + coloraxis=dict(colorscale=colorscale), + ) + + return fig + + +def create_2d_plot(params_info, color_by, colorscale): + fig = go.Figure() + + n_params = len(params_info) + grid_size = int(np.ceil(np.sqrt(n_params))) + x, y = np.mgrid[0:grid_size, 0:grid_size] + x, y = x.flatten()[:n_params], y.flatten()[:n_params] + + max_numel = max(p["numel"] for p in params_info) + + marker_shapes = {"Weight": "circle", "Bias": "square"} + + for i, param in enumerate(params_info): + size = np.log1p(param["numel"]) / np.log1p(max_numel) * 20 + 5 + color = param[color_by] + + fig.add_trace( + go.Scatter( + x=[x[i]], + y=[y[i]], + mode="markers", + marker=dict( + size=size, + color=color, + colorscale=colorscale, + opacity=0.8, + symbol=marker_shapes.get(param["type"], "circle"), + colorbar=dict(title=color_by.capitalize()), + ), + text=create_hover_text(param), + hoverinfo="text", + ) + ) + + fig.update_layout( + title="2D Parameter Visualization", + xaxis_title="X", + yaxis_title="Y", + width=900, + height=700, + margin=dict(r=0, l=0, b=0, t=40), + coloraxis=dict(colorscale=colorscale), + ) + + return fig + + +def create_heatmap(params_info, color_by): + data = [param[color_by] for param in params_info] + names = [param["name"] for param in params_info] + + fig = go.Figure( + data=go.Heatmap( + z=[data], y=["Parameters"], x=names, colorscale="Viridis" + ) + ) + + fig.update_layout( + title="Parameter Heatmap", + xaxis_title="Parameter Name", + yaxis_title="", + width=1200, + height=400, + ) + + return fig + + +def create_comparison_plot( + params_info_1, params_info_2, dim, color_by, colorscale +): + fig = sp.make_subplots( + rows=1, cols=2, subplot_titles=("Model 1", "Model 2") + ) + + if dim == "3d": + fig1 = create_3d_plot(params_info_1, color_by, colorscale) + fig2 = create_3d_plot(params_info_2, color_by, colorscale) + elif dim == "2d": + fig1 = create_2d_plot(params_info_1, color_by, colorscale) + fig2 = create_2d_plot(params_info_2, color_by, colorscale) + else: + fig1 = create_heatmap(params_info_1, color_by) + fig2 = create_heatmap(params_info_2, color_by) + + for trace in fig1.data: + fig.add_trace(trace, row=1, col=1) + for trace in fig2.data: + fig.add_trace(trace, row=1, col=2) + + fig.update_layout(height=600, width=1200, title_text="Model Comparison") + return fig + + +def add_distribution_histogram(fig, params_info): + data = [param["data"] for param in params_info] + flat_data = [item for sublist in data for item in sublist] + + histogram = go.Figure(data=[go.Histogram(x=flat_data)]) + histogram.update_layout(title_text="Parameter Distribution") + + fig.add_trace(histogram.data[0]) + return fig + + +def add_gradient_flow(fig, model): + # Placeholder for gradient flow visualization + # This would require tracking gradients during backward pass + pass + + +def create_hover_text(param): + return ( + f"Name: {param['name']}
" + f"Shape: {param['shape']}
" + f"Elements: {param['numel']}
" + f"Mean: {param['mean']:.4f}
" + f"Std: {param['std']:.4f}
" + f"Grad Mean: {param['grad_mean']:.4f}
" + f"Grad Std: {param['grad_std']:.4f}
" + f"Type: {param['type']}" + ) + + +def group_by_layer(fig, params_info): + layers = sorted(set(p["layer"] for p in params_info)) + fig.update_layout( + updatemenus=[ + { + "buttons": [ + { + "label": layer, + "method": "update", + "args": [ + { + "visible": [ + p["layer"] == layer for p in params_info + ] + } + ], + } + for layer in layers + ] + + [ + { + "label": "All", + "method": "update", + "args": [{"visible": [True] * len(params_info)}], + } + ], + "direction": "down", + "showactive": True, + } + ] + ) + return fig + + +def add_model_summary(fig, model, params_info): + total_params = sum(p["numel"] for p in params_info) + summary = f"Total parameters: {total_params:,}
" + summary += f"Layers: {len(set(p['layer'] for p in params_info))}
" + summary += f"Model type: {model.__class__.__name__}" + + fig.add_annotation( + x=0.95, + y=0.95, + xref="paper", + yref="paper", + text=summary, + showarrow=False, + font=dict(size=12), + align="right", + bgcolor="rgba(255,255,255,0.8)", + bordercolor="black", + borderwidth=1, + ) + return fig + + +def export_html(fig, filename="parameter_visualization.html"): + pio.write_html(fig, file=filename, auto_open=True) + + +# Example usage +if __name__ == "__main__": + from torchvision import models + + model = models.resnet18(pretrained=True) + fig = visualize_parameters( + model, dim="3d", color_by="mean", group_by="layer", colorscale="Plasma" + ) + + # Uncomment to compare two models + # model2 = models.resnet34(pretrained=True) + # fig = visualize_parameters(model, dim='3d', color_by='mean', colorscale='Plasma', compare_model=model2) + + fig.show() + export_html(fig) diff --git a/training/vis/transformer.py b/training/vis/transformer.py new file mode 100644 index 00000000..e1c5d32f --- /dev/null +++ b/training/vis/transformer.py @@ -0,0 +1,148 @@ +import plotly.graph_objects as go +import numpy as np +from transformers import BertModel, BertTokenizer + + +# Helper function to convert Cartesian coordinates to spherical coordinates +def spherical_coordinates(num_points, radius): + points = [] + phi = np.pi * (3.0 - np.sqrt(5.0)) # golden angle in radians + + for i in range(num_points): + y = 1 - (i / float(num_points - 1)) * 2 # y goes from 1 to -1 + radius_y = np.sqrt(1 - y * y) # radius at y + theta = phi * i # golden angle increment + + x = np.cos(theta) * radius_y * radius + z = np.sin(theta) * radius_y * radius + points.append([x, y * radius, z]) + + return np.array(points) + + +# Function to extract transformer layers and parameters (specific for BERT here) +def extract_transformer_layers(model): + layers = [] + params_per_layer = [] + + for name, param in model.named_parameters(): + num_params = param.numel() + if num_params > 0: + layers.append( + {"name": name, "num_params": num_params, "shape": param.shape} + ) + params_per_layer.append(num_params) + + return layers, params_per_layer + + +# Function to visualize the transformer model components as a 3D ball structure +def visualize_transformer_as_ball(model, max_points_per_layer=100): + layers, params_per_layer = extract_transformer_layers(model) + + # Visualize each layer as a concentric sphere with points + fig = go.Figure() + max_radius = 10 # Maximum radius for the outer layer + radius_step = max_radius / len(layers) + + # Create spheres and add points to each sphere + for i, layer in enumerate(layers): + radius = radius_step * (i + 1) + num_params = min( + params_per_layer[i], max_points_per_layer + ) # Limit points per layer for clarity + points = spherical_coordinates(num_params, radius) + + fig.add_trace( + go.Scatter3d( + x=points[:, 0], + y=points[:, 1], + z=points[:, 2], + mode="markers", + marker=dict( + size=5, + color=np.linspace(0, 1, num_params), + colorscale="Viridis", + opacity=0.8, + ), + name=f'Layer {i + 1}: {layer["name"]}', + hovertext=[ + f'{layer["name"]}, Shape: {layer["shape"]}' + for j in range(num_params) + ], + ) + ) + + # Configure layout + fig.update_layout( + scene=dict( + xaxis=dict(title="X", showgrid=False, zeroline=False), + yaxis=dict(title="Y", showgrid=False, zeroline=False), + zaxis=dict(title="Z", showgrid=False, zeroline=False), + ), + title="Transformer Model Visualization as Ball", + showlegend=True, + ) + + fig.show() + + +# Visualizing Attention Weights +def visualize_attention_weights(attention_heads, sequence_length): + fig = go.Figure() + + # Create multiple scatter plots for each attention head + for head in range(attention_heads): + points = spherical_coordinates( + sequence_length, radius=5 + head + ) # Slightly different radii for heads + + fig.add_trace( + go.Scatter3d( + x=points[:, 0], + y=points[:, 1], + z=points[:, 2], + mode="markers", + marker=dict( + size=6, + color=np.linspace(0, 1, sequence_length), + colorscale="Plasma", + opacity=0.8, + ), + name=f"Attention Head {head + 1}", + hovertext=[f"Token {i + 1}" for i in range(sequence_length)], + ) + ) + + fig.update_layout( + scene=dict( + xaxis=dict(title="X", showgrid=False, zeroline=False), + yaxis=dict(title="Y", showgrid=False, zeroline=False), + zaxis=dict(title="Z", showgrid=False, zeroline=False), + ), + title="Multi-Head Attention Weights", + showlegend=True, + ) + + fig.show() + + +# Example usage with a pretrained BERT model +if __name__ == "__main__": + # Load a pretrained transformer model (e.g., BERT) from Huggingface + model = BertModel.from_pretrained("bert-base-uncased") + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + + # Example sentence + sentence = "Transformers are powerful models for NLP tasks." + inputs = tokenizer(sentence, return_tensors="pt") + outputs = model(**inputs) + + # Visualize the model structure + visualize_transformer_as_ball(model) + + # Example: Visualize attention heads for a sequence length (BERT has 12 attention heads by default) + sequence_length = inputs["input_ids"].shape[1] + visualize_attention_weights( + attention_heads=12, sequence_length=sequence_length + ) diff --git a/training/yolo_alt/model.py b/training/yolo_alt/model.py new file mode 100644 index 00000000..87a143e9 --- /dev/null +++ b/training/yolo_alt/model.py @@ -0,0 +1,172 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +class DepthwiseSeparableConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride): + super(DepthwiseSeparableConv, self).__init__() + self.depthwise = nn.Conv2d( + in_channels, + in_channels, + kernel_size, + stride, + padding=kernel_size // 2, + groups=in_channels, + bias=False, + ) + self.pointwise = nn.Conv2d( + in_channels, out_channels, 1, 1, 0, bias=False + ) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.depthwise(x) + x = self.pointwise(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class SEBlock(nn.Module): + def __init__(self, channels, reduction=16): + super(SEBlock, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channels, channels // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channels // reduction, channels, bias=False), + nn.Sigmoid(), + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) + + +class GhostModule(nn.Module): + def __init__( + self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True + ): + super(GhostModule, self).__init__() + self.oup = oup + init_channels = math.ceil(oup / ratio) + new_channels = init_channels * (ratio - 1) + + self.primary_conv = nn.Sequential( + nn.Conv2d( + inp, + init_channels, + kernel_size, + stride, + kernel_size // 2, + bias=False, + ), + nn.BatchNorm2d(init_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + + self.cheap_operation = nn.Sequential( + nn.Conv2d( + init_channels, + new_channels, + dw_size, + 1, + dw_size // 2, + groups=init_channels, + bias=False, + ), + nn.BatchNorm2d(new_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + + def forward(self, x): + x1 = self.primary_conv(x) + x2 = self.cheap_operation(x1) + out = torch.cat([x1, x2], dim=1) + return out[:, : self.oup, :, :] + + +class FastDetector(nn.Module): + def __init__(self, num_classes): + super(FastDetector, self).__init__() + + # Lightweight backbone with Ghost modules + self.backbone = nn.Sequential( + GhostModule(3, 16, 3, stride=2), + GhostModule(16, 32, 3, stride=2), + GhostModule(32, 64, 3, stride=2), + GhostModule(64, 128, 3, stride=2), + ) + + # Feature Pyramid Network + self.fpn = nn.ModuleList( + [ + DepthwiseSeparableConv(128, 128, 3, 1), + DepthwiseSeparableConv(128, 64, 3, 1), + DepthwiseSeparableConv(64, 32, 3, 1), + ] + ) + + # SE blocks for each FPN level + self.se_blocks = nn.ModuleList( + [ + SEBlock(128), + SEBlock(128), + SEBlock(64), + SEBlock(32), + ] + ) + + # Detection heads + self.heads = nn.ModuleList( + [ + self._make_head(128, num_classes), + self._make_head(128, num_classes), + self._make_head(64, num_classes), + self._make_head(32, num_classes), + ] + ) + + def _make_head(self, in_channels, num_classes): + return nn.Sequential( + DepthwiseSeparableConv(in_channels, in_channels, 3, 1), + nn.Conv2d( + in_channels, num_classes + 4, 1 + ), # cls + x, y, w, h (anchor-free) + ) + + def forward(self, x): + features = [] + for i, layer in enumerate(self.backbone): + x = layer(x) + features.append(x) + + # FPN + for i in range(len(features) - 1, 0, -1): + features[i - 1] = features[i - 1] + F.interpolate( + self.fpn[i - 1](features[i]), size=features[i - 1].shape[2:] + ) + + # Apply SE blocks and get predictions + outputs = [] + for feature, se_block, head in zip( + features, self.se_blocks, self.heads + ): + feature = se_block(feature) + outputs.append(head(feature).flatten(start_dim=2)) + + return outputs + + +# Example usage +num_classes = 80 # COCO dataset +model = FastDetector(num_classes) +input_tensor = torch.randn(1, 3, 416, 416) +outputs = model(input_tensor) +for i, output in enumerate(outputs): + print(f"Output {i} shape:", output.shape)