diff --git a/scripts/determine_params.py b/scripts/determine_params.py index d97746e..799706f 100644 --- a/scripts/determine_params.py +++ b/scripts/determine_params.py @@ -14,7 +14,7 @@ n = len(model.model.layers) r = q_size // k_size h = model.config.num_attention_heads // r -k = q_size // m // h +k = k_size // m // h n_params = sum(x.numel() for x in model.parameters())