-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compilation takes too long #1
Comments
Sample code import jax
import jax.numpy as jnp
from transformers import FlaxBertForSequenceClassification, BertConfig
from transformers import BertTokenizer
from flax.training.train_state import TrainState
from transformers.models.bert.modeling_flax_bert import FlaxBertModel, FlaxBertModule
import optax
import soap
def create_train_state(model, learning_rate=1e-5):
"""Creates initial `TrainState` for the model."""
learnin_rate_fn = optax.join_schedules(
schedules=[
optax.linear_schedule(0.0, 0.001, 1000),
optax.linear_schedule(0.001, 0.0, 5000 - 1000),
],
boundaries=[1000],
)
opt = soap.soap(learning_rate=learnin_rate_fn, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01, precondition_frequency=5)
module = FlaxBertModule(model.config)
state = TrainState.create(
apply_fn=module.apply,
params=model.params,
tx=opt
)
return state
@jax.jit
def train_step(state, batch):
"""Single training step."""
def loss_fn(params):
outputs = state.apply_fn(
{'params': params},
batch['input_ids'],
attention_mask=batch['attention_mask'],
deterministic=False,
rngs={'dropout': jax.random.PRNGKey(0)}
)
logits = outputs.last_hidden_state
loss = jnp.mean(logits)
return loss
grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(state.params)
# Update parameters
new_state = state.apply_gradients(grads=grads)
return new_state, loss
# Initialize model and tokenizer
config = BertConfig.from_pretrained(
'bert-base-uncased'
)
model = FlaxBertModel.from_pretrained(
'bert-base-uncased',
config=config
)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Create dummy batch
text = ["This is a positive review!", "This is a negative review!", "This is a neutral review!", "This is a mixed review!", "This is a review!"]
labels = jnp.array([1, 0, 2, 3, 4])
# Tokenize
encoded = tokenizer(
text,
padding=True,
truncation=True,
max_length=128,
return_tensors='np'
)
# Create batch
batch = {
'input_ids': encoded['input_ids'],
'attention_mask': encoded['attention_mask'],
'labels': labels
}
# Initialize training state
state = create_train_state(model)
# Perform single training step
new_state, loss = train_step(state, batch)
print(f"Loss: {loss}") |
Its expected that it takes a while to compile, each parameter update requires a loop over the preconditioners. It's possible that I can use https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html to reduce the compilation time, I'll take a look at that. Thanks for providing some example code. |
The long compile time makes parameter sweeps very difficult. Please consider using scan to avoid the loop. |
Hi,
Currently trying to use SOAP for fine-tuning HF base model, but compilation takes too long. Is this expected?
The text was updated successfully, but these errors were encountered: