Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compilation takes too long #1

Open
ozanarmagan opened this issue Nov 1, 2024 · 3 comments
Open

Compilation takes too long #1

ozanarmagan opened this issue Nov 1, 2024 · 3 comments

Comments

@ozanarmagan
Copy link

Hi,

Currently trying to use SOAP for fine-tuning HF base model, but compilation takes too long. Is this expected?

@ozanarmagan
Copy link
Author

Sample code

import jax
import jax.numpy as jnp
from transformers import FlaxBertForSequenceClassification, BertConfig
from transformers import BertTokenizer
from flax.training.train_state import TrainState
from transformers.models.bert.modeling_flax_bert import FlaxBertModel, FlaxBertModule
import optax
import soap

def create_train_state(model, learning_rate=1e-5):
    """Creates initial `TrainState` for the model."""
    learnin_rate_fn = optax.join_schedules(
        schedules=[
            optax.linear_schedule(0.0, 0.001, 1000),
            optax.linear_schedule(0.001, 0.0, 5000 - 1000),
        ],
        boundaries=[1000],
    )
    
    opt = soap.soap(learning_rate=learnin_rate_fn, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01, precondition_frequency=5)
    module = FlaxBertModule(model.config)
    state = TrainState.create(
        apply_fn=module.apply,
        params=model.params,
        tx=opt
    )
    
    return state

@jax.jit
def train_step(state, batch):
    """Single training step."""
    def loss_fn(params):
        outputs = state.apply_fn(
            {'params': params},
            batch['input_ids'],
            attention_mask=batch['attention_mask'],
            deterministic=False,
            rngs={'dropout': jax.random.PRNGKey(0)}
        )
        logits = outputs.last_hidden_state
        loss = jnp.mean(logits)
        return loss

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    
    # Update parameters
    new_state = state.apply_gradients(grads=grads)
    
    return new_state, loss

# Initialize model and tokenizer
config = BertConfig.from_pretrained(
    'bert-base-uncased'
)
model = FlaxBertModel.from_pretrained(
    'bert-base-uncased',
    config=config
)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Create dummy batch
text = ["This is a positive review!", "This is a negative review!", "This is a neutral review!", "This is a mixed review!", "This is a review!"]
labels = jnp.array([1, 0, 2, 3, 4])

# Tokenize
encoded = tokenizer(
    text,
    padding=True,
    truncation=True,
    max_length=128,
    return_tensors='np'
)

# Create batch
batch = {
    'input_ids': encoded['input_ids'],
    'attention_mask': encoded['attention_mask'],
    'labels': labels
}

# Initialize training state
state = create_train_state(model)

# Perform single training step
new_state, loss = train_step(state, batch)
print(f"Loss: {loss}")

@haydn-jones
Copy link
Owner

Its expected that it takes a while to compile, each parameter update requires a loop over the preconditioners. It's possible that I can use https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html to reduce the compilation time, I'll take a look at that.

Thanks for providing some example code.

@kishorenc
Copy link

The long compile time makes parameter sweeps very difficult. Please consider using scan to avoid the loop.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants