Skip to content

Commit

Permalink
Update on "ppo chess with llm draft"
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
mikaylagawarecki committed Feb 5, 2025
1 parent 9f4f291 commit 797486b
Showing 1 changed file with 29 additions and 36 deletions.
65 changes: 29 additions & 36 deletions examples/agents/ppo-chess-llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,23 +329,14 @@ def remove_logits(td):
return_log_prob=True,
)

class AggregateProb(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
log_prob = x.sum(dim=-1)
return log_prob

actor_llm_policy = ProbSeq(
Mod(
LLMWrapper(llm, tokenizer, mode="policy"),
in_keys=["obs_tokens"],
out_keys=["logits", "hidden"],
),
prob_module,
# if use lambda: 'function' object has no attribute 'in_keys'
Mod(AggregateProb(), in_keys=["sample_log_prob"], out_keys=["sample_log_prob"]),
# using return_compsite=True so aggregate_probabilities is set
return_composite=True,
)

Expand All @@ -372,24 +363,38 @@ def play(env, data_llm_policy, actor_llm_policy, tokenizer):
batch_size=32,
sampler=SliceSamplerWithoutReplacement(slice_len=8, end_key=("next", "done")),
)
# layout=torch.jagged errors with Qwen
# obs_tokens in layout=torch.jagged errors with Qwen
# File "/home/mg1998/.conda/envs/rl/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 859, in forward
# cache_position = torch.arange(
# past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
# )
# AttributeError: 'ConstantIntNode' object has no attribute 'add'
# After attempting to fix this there were other issues like NJT can't unsqueeze dim=0
# rb.append_transform(lambda td: td.densify(layout=torch.jagged))
rb.append_transform(lambda td: td.densify(layout=torch.jagged))
rb.append_transform(
lambda td: td.set(
"obs_tokens", td.get("obs_tokens").to_padded_tensor(tokenizer.pad_token_id)
)
)

def foo(td):
td["next"].set(
"obs_tokens",
td["next", "obs_tokens"].to_padded_tensor(tokenizer.pad_token_id),
)
return td

rb.append_transform(foo)

# rb.append_transform(
# lambda td: td.set(
# "obs_tokens", td.get("obs_tokens").to_padded_tensor(tokenizer.pad_token_id)
# "tokenized_action",
# td.get("tokenized_action").to_padded_tensor(tokenizer.pad_token_id),
# )
# )

# rb.append_transform(lambda td: td.set("tokenized_action", td.get("tokenized_action").to_padded_tensor(tokenizer.pad_token_id)))

collector = SyncDataCollector(
env, data_llm_policy, frames_per_batch=100, total_frames=10000
env, data_llm_policy, frames_per_batch=20, total_frames=10000
)
loss_module = ClipPPOLoss(
actor_network=actor_llm_policy,
Expand Down Expand Up @@ -431,29 +436,17 @@ def pad_tensors_to_same_shape(tensor1, tensor2):
for data in tqdm(collector):
# FIXME: reward seems to be getting wrongly propagated (e.g. sunfish's win gets reflected as llm's win)
rb.empty()
# FIXME: what is the right way to do this?
data = data.densify(layout=torch.jagged)
data.set(
"obs_tokens",
data.get("obs_tokens").to_padded_tensor(tokenizer.pad_token_id),
)
data.get("next").set(
"obs_tokens",
data.get("next").get("obs_tokens").to_padded_tensor(tokenizer.pad_token_id),
)

obs_tokens = data.get("obs_tokens")
next_obs_tokens = data.get("next").get("obs_tokens")
obs_tokens, next_obs_tokens = pad_tensors_to_same_shape(
obs_tokens, next_obs_tokens
)
data.set("obs_tokens", obs_tokens)
data.get("next").set("obs_tokens", next_obs_tokens)

data = gae(data)
rb.extend(data)

for data in tqdm(rb):
obs_tokens = data["obs_tokens"]
next_obs_tokens = data["next", "obs_tokens"]
obs_tokens, next_obs_tokens = pad_tensors_to_same_shape(
obs_tokens, next_obs_tokens
)
data["obs_tokens"] = obs_tokens
data["next", "obs_tokens"] = next_obs_tokens
data = gae(data)
loss = loss_module(data)
loss.sum(reduce=True).backward()
torch.nn.utils.clip_grad_norm_(loss_module.parameters(), 100.0)
Expand Down

0 comments on commit 797486b

Please sign in to comment.