Skip to content

Commit

Permalink
Merge pull request #11 from Adibvafa/prediction
Browse files Browse the repository at this point in the history
Add support for top_p in non-deterministic generation
  • Loading branch information
Adibvafa authored Sep 21, 2024
2 parents ca24214 + dfa53da commit bca7eef
Show file tree
Hide file tree
Showing 5 changed files with 420 additions and 25 deletions.
11 changes: 6 additions & 5 deletions CodonTransformer/CodonData.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
START_CODONS,
STOP_CODONS,
STOP_SYMBOL,
STOP_SYMBOLS,
find_pattern_in_fasta,
get_taxonomy_id,
sort_amino2codon_skeleton,
Expand Down Expand Up @@ -177,13 +178,13 @@ def preprocess_protein_sequence(protein: str) -> str:
)

# Check for sequence validity
if any(
aminoacid not in AMINO_ACIDS + ["*", STOP_SYMBOL] for aminoacid in protein[:-1]
):
if any(aminoacid not in AMINO_ACIDS + STOP_SYMBOLS for aminoacid in protein):
raise ValueError("Invalid characters in protein sequence.")

if protein[-1] not in AMINO_ACIDS + ["*", STOP_SYMBOL]:
raise ValueError("Protein sequence must end with *, or _, or an amino acid.")
if protein[-1] not in AMINO_ACIDS + STOP_SYMBOLS:
raise ValueError(
"Protein sequence must end with `*`, or `_`, or an amino acid."
)

# Replace '*' at the end of protein with STOP_SYMBOL if present
if protein[-1] == "*":
Expand Down
95 changes: 77 additions & 18 deletions CodonTransformer/CodonPrediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from CodonTransformer.CodonData import get_merged_seq
from CodonTransformer.CodonUtils import (
AMINO_ACIDS,
INDEX2TOKEN,
NUM_ORGANISMS,
ORGANISM2ID,
Expand All @@ -40,6 +39,7 @@ def predict_dna_sequence(
attention_type: str = "original_full",
deterministic: bool = True,
temperature: float = 0.2,
top_p: float = 0.95,
) -> DNASequencePrediction:
"""
Predict the DNA sequence for a given protein using the CodonTransformer model.
Expand Down Expand Up @@ -76,6 +76,10 @@ def predict_dna_sequence(
- Medium randomness: 0.5
- High randomness: 0.8
The temperature must be a positive float. Defaults to 0.2.
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
Tokens with cumulative probability up to `top_p` are considered for sampling.
This parameter helps balance diversity and coherence in the predicted DNA sequences.
The value must be a float between 0 and 1. Defaults to 0.95.
Returns:
DNASequencePrediction: An object containing the prediction results:
Expand All @@ -86,7 +90,7 @@ def predict_dna_sequence(
Raises:
ValueError: If the protein sequence is empty, if the organism is invalid,
or if the temperature is not a positive float.
if the temperature is not a positive float, or if `top_p` is not between 0 and 1.
Note:
This function uses `ORGANISM2ID` and `INDEX2TOKEN` dictionaries imported from
Expand Down Expand Up @@ -123,7 +127,7 @@ def predict_dna_sequence(
... deterministic=True
... )
>>>
>>> # Predict DNA sequence with low randomness
>>> # Predict DNA sequence with low randomness and top_p sampling
>>> output_random = predict_dna_sequence(
... protein=protein,
... organism=organism,
Expand All @@ -132,7 +136,8 @@ def predict_dna_sequence(
... model=model,
... attention_type="original_full",
... deterministic=False,
... temperature=0.2
... temperature=0.2,
... top_p=0.95
... )
>>>
>>> print(format_model_output(output))
Expand All @@ -141,14 +146,14 @@ def predict_dna_sequence(
if not protein:
raise ValueError("Protein sequence cannot be empty.")

# Ensure the protein sequence contains only valid amino acids
if not all(aminoacid in AMINO_ACIDS for aminoacid in protein):
raise ValueError("Invalid amino acid found in protein sequence.")

# Validate temperature
if not isinstance(temperature, (float, int)) or temperature <= 0:
raise ValueError("Temperature must be a positive float.")

# Validate top_p
if not isinstance(top_p, (float, int)) or not 0 < top_p <= 1.0:
raise ValueError("top_p must be a float between 0 and 1.")

# Load tokenizer
if not isinstance(tokenizer, PreTrainedTokenizerFast):
tokenizer = load_tokenizer(tokenizer)
Expand Down Expand Up @@ -181,18 +186,10 @@ def predict_dna_sequence(

# Decode the predicted DNA sequence from the model output
if deterministic:
# Select the most probable tokens (argmax)
predicted_indices = logits.argmax(dim=-1).squeeze().tolist()
else:
# Sample tokens according to their probability distribution
# Apply temperature scaling and convert logits to probabilities
logits = logits / temperature
probabilities = torch.softmax(logits, dim=-1)

# Sample from the probability distribution at each position
probabilities = probabilities.squeeze(0) # Shape: [seq_len, vocab_size]
predicted_indices = (
torch.multinomial(probabilities, num_samples=1).squeeze(-1).tolist()
predicted_indices = sample_non_deterministic(
logits=logits, temperature=temperature, top_p=top_p
)

predicted_dna = list(map(INDEX2TOKEN.__getitem__, predicted_indices))
Expand All @@ -210,6 +207,68 @@ def predict_dna_sequence(
)


def sample_non_deterministic(
logits: torch.Tensor,
temperature: float = 0.2,
top_p: float = 0.95,
) -> List[int]:
"""
Sample token indices from logits using temperature scaling and nucleus (top-p) sampling.
This function applies temperature scaling to the logits, computes probabilities,
and then performs nucleus sampling to select token indices. It is used for
non-deterministic decoding in language models to introduce randomness while
maintaining coherence in the generated sequences.
Args:
logits (torch.Tensor): The logits output from the model of shape
[seq_len, vocab_size] or [batch_size, seq_len, vocab_size].
temperature (float, optional): Temperature value for scaling logits.
Must be a positive float. Defaults to 1.0.
top_p (float, optional): Cumulative probability threshold for nucleus sampling.
Must be a float between 0 and 1. Tokens with cumulative probability up to
`top_p` are considered for sampling. Defaults to 0.95.
Returns:
List[int]: A list of sampled token indices corresponding to the predicted tokens.
Raises:
ValueError: If `temperature` is not a positive float or if `top_p` is not between 0 and 1.
Example:
>>> logits = model_output.logits # Assume logits is a tensor of shape [seq_len, vocab_size]
>>> predicted_indices = sample_non_deterministic(logits, temperature=0.7, top_p=0.9)
"""
if not isinstance(temperature, (float, int)) or temperature <= 0:
raise ValueError("Temperature must be a positive float.")

if not isinstance(top_p, (float, int)) or not 0 < top_p <= 1.0:
raise ValueError("top_p must be a float between 0 and 1.")

# Compute probabilities using temperature scaling
logits /= temperature
probs = torch.softmax(logits, dim=-1)

# Remove batch dimension if present
if probs.dim() == 3:
probs = probs.squeeze(0) # Shape: [seq_len, vocab_size]

# Sort probabilities in descending order
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > top_p

# Zero out probabilities for tokens beyond the top-p threshold
probs_sort[mask] = 0.0

# Renormalize the probabilities
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
predicted_indices = torch.gather(probs_idx, -1, next_token).squeeze(-1)

return predicted_indices.tolist()


def load_model(
model_path: Optional[str] = None,
device: torch.device = None,
Expand Down
1 change: 1 addition & 0 deletions CodonTransformer/CodonUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"W", # Tryptophan
"Y", # Tyrosine
]
STOP_SYMBOLS = ["_", "*"] # Stop codon symbols

# Dictionary ambiguous amino acids to standard amino acids
AMBIGUOUS_AMINOACID_MAP: Dict[str, str] = {
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,10 @@ This subpackage contains functions and classes that handle the core prediction f

Predict the DNA sequence for a given protein using the CodonTransformer model.

- `sample_non_deterministic(logits: torch.Tensor, temperature: float = 0.2, top_p: float = 0.95) -> List[int]`

Sample token indices from logits using temperature scaling and nucleus (top-p) sampling.

- `load_model(path: str, device: torch.device = None, num_organisms: int = None, remove_prefix: bool = True, attention_type: str = "original_full") -> torch.nn.Module`

Load a BigBirdForMaskedLM model from a file or checkpoint.
Expand Down Expand Up @@ -383,6 +387,7 @@ The CodonUtils subpackage contains constants and helper functions essential for
#### Constants

- `AMINO_ACIDS`: List of all standard amino acids
- `STOP_SYMBOLS`: List of possible stop symbols to end the protein with
- `AMBIGUOUS_AMINOACID_MAP`: Mapping of ambiguous amino acids to standard amino acids
- `START_CODONS` and `STOP_CODONS`: Lists of start and stop codons
- `TOKEN2INDEX` and `INDEX2TOKEN`: Mappings between tokens and their indices
Expand Down
Loading

0 comments on commit bca7eef

Please sign in to comment.