Skip to content

Commit

Permalink
modified finetune + dataset_generator
Browse files Browse the repository at this point in the history
  • Loading branch information
maxzuo committed Oct 8, 2024
1 parent d9927b3 commit 63be527
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 9 deletions.
102 changes: 99 additions & 3 deletions dataset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,7 +1329,9 @@ def grid(

for robot, data in zip(robots, robot_data):
predicates.append(Predicate("robot-has", robot, colors[data["color"]]))
predicates.append(Predicate("robot-at", robot, grid[data["pos"][0]][data["pos"][1]]))
predicates.append(
Predicate("robot-at", robot, grid[data["pos"][0]][data["pos"][1]])
)

for color in colors:
predicates.append(Predicate("available-color", color))
Expand Down Expand Up @@ -1424,7 +1426,9 @@ def checkerboard(
assert len(tiles) == grid_size_x * grid_size_y

tiles_iter = iter(tiles)
grid = [[next(tiles_iter) for _ in range(grid_size_x)] for _ in range(grid_size_y)]
grid = [
[next(tiles_iter) for _ in range(grid_size_x)] for _ in range(grid_size_y)
]

predicates = []

Expand Down Expand Up @@ -1457,6 +1461,46 @@ def all_different(

return predicates

def paint_x(
self,
robots: list[Constant],
tiles: list[Constant],
colors: list[Constant],
goal: bool = False,
grid_size: list[int] = None,
**kwargs,
) -> list[Predicate]:
if not grid_size or grid_size[0] != grid_size[1]:
raise ValueError("Paint x task requires a square grid size")
if not goal:
raise ValueError("Paint x task is only supported as a goal state")
if not (0 < len(colors) <= 2):
raise ValueError("Paint x task requires exactly one color")

tiles_iter = iter(tiles)

grid = [
[next(tiles_iter) for _ in range(grid_size[0])] for _ in range(grid_size[1])
]

predicates = []
painted = set()
for i in range(grid_size[0]):
predicates.append(Predicate("painted", grid[i][i], colors[0]))
painted.add((i, i))
predicates.append(
Predicate("painted", grid[i][grid_size[0] - i - 1], colors[0])
)
painted.add((i, grid_size[0] - i - 1))

if len(colors) == 2:
for i in range(grid_size[0]):
for j in range(grid_size[0]):
if (i, j) not in painted:
predicates.append(Predicate("painted", grid[i][j], colors[1]))

return predicates

def paint_all(
self,
robots: list[Constant],
Expand Down Expand Up @@ -1567,6 +1611,12 @@ def get_robot_grid_string(
case ("paint_all", False):
return "Your goal is to paint all the tiles with the same color."

case ("paint_x", False):
if n_colors == 1:
return "Your goal is to paint an 'X' shape on the tiles with a single color."
elif n_colors == 2:
return "Your goal is to paint an 'X' shape on the tiles, and every other tile should be painted with a different color."

case ("checkerboard", False):
return "Your goal is to paint the tiles in a checkerboard pattern."

Expand Down Expand Up @@ -1982,7 +2032,7 @@ def insert_split(
Args:
conn (sqlite3.Connection): SQLite database connection.
name (str): Name of the split.
name (str): Name of the split.d
split (dict[int | str, list[int]]): Split data.
"""
cursor = conn.cursor()
Expand All @@ -2007,6 +2057,19 @@ def split(
random.seed(config.get("random_seed", 42))
conn = sqlite3.connect(database_path)
cursor = conn.cursor()
# drop splits table if exists
cursor.execute("DROP TABLE IF EXISTS splits")
cursor.execute(
"""
CREATE TABLE splits (
problem_id INTEGER NOT NULL,
split_type TEXT NOT NULL,
split TEXT NOT NULL,
PRIMARY KEY (problem_id, split),
FOREIGN KEY (problem_id) REFERENCES problems (id)
)
"""
)

# split by domain
cursor.execute(
Expand Down Expand Up @@ -2095,11 +2158,44 @@ def split(
f"{i}": problems[i::num_random_splits] for i in range(num_random_splits)
}

# split by held-out problems
heldout_splits = {"train": [], "test": []}
heldout_config = config.get("heldout", {})

heldout_test_ids = set()
for d, queries in heldout_config.items():
if d != domain:
continue
for query in queries:
init_query, goal_query = query['init'], query['goal']
args = [domain]
query_str = ""
if init_query:
query_str += f" AND init = ?"
args.append(init_query)
if goal_query:
query_str += f" AND goal = ?"
args.append(goal_query)
cursor.execute(
f"""
SELECT id
FROM problems
WHERE domain = ?{query_str}
""",
args,
)

heldout_test_ids.update([row[0] for row in cursor.fetchall()])

heldout_splits["test"] = list(heldout_test_ids)
heldout_splits["train"] = list(set(problems) - heldout_test_ids)

splits = {
"abstraction": abstraction_splits,
"size": size_splits,
"placeholder": strict_splits,
"random": random_splits,
"heldout": heldout_splits,
}
domain_splits[domain] = splits
pbar.update()
Expand Down
39 changes: 33 additions & 6 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def preprocess(
inputs = [
strip(
tokenizer.apply_chat_template(
llmp.PlanningProblem(nl, d, p).apply_template(
llmp.PlanningProblem(nl, d, p).apply_template(
domain_prompt,
problem_prompt,
),
Expand Down Expand Up @@ -204,6 +204,34 @@ def load_model(config: dict) -> tuple[PreTrainedTokenizer, PreTrainedModel]:
return tokenizer, model


def extract_instruct_tokens(tokenizer: PreTrainedTokenizer) -> tuple[str, str]:
"""Extract the instruction tokens from the tokenizer.
Args:
tokenizer (PreTrainedTokenizer): The tokenizer to use.
Returns:
tuple[str, str]: The templates.
"""
placeholder = tokenizer.unk_token

chat_str = tokenizer.apply_chat_template(
[
{"role": "user", "content": placeholder},
{"role": "assistant", "content": placeholder},
],
tokenize=False,
)

if not tokenizer.chat_template:
templates = chat_str.split(f" {placeholder} ")
else:
templates = chat_str.split(placeholder)
templates = [t.replace("<s> ", "").strip() for t in templates]

return templates[:2]


def main(config_path: str):
"""Train a model on a dataset using a given configuration.
Expand All @@ -217,17 +245,16 @@ def main(config_path: str):
# Load dataset
dataset = load_dataset(config["dataset"])

train_config = config["train"]
train_config: dict = config["train"]

# Load model
tokenizer, model = load_model(train_config)

# Create data collator
instr_template, resp_template = extract_instruct_tokens(tokenizer)
data_collator = DataCollatorForCompletionOnlyLM(
tokenizer.encode(
train_config["model"]["response_template"],
add_special_tokens=False,
),
response_template=resp_template,
instruction_template=instr_template,
tokenizer=tokenizer,
)

Expand Down

0 comments on commit 63be527

Please sign in to comment.