Skip to content

Commit

Permalink
Improved dataset generator (#13)
Browse files Browse the repository at this point in the history
* re-balanced dataset + refactored dataset generator

* majority of the floortile data generator

* modified finetune + dataset_generator

* floor-tile correct
  • Loading branch information
maxzuo authored Oct 9, 2024
1 parent e7be47a commit 56ab2ff
Show file tree
Hide file tree
Showing 8 changed files with 29,717 additions and 7,745 deletions.
36,170 changes: 28,636 additions & 7,534 deletions dataset_config.yaml

Large diffs are not rendered by default.

1,216 changes: 1,025 additions & 191 deletions dataset_generator.py

Large diffs are not rendered by default.

39 changes: 33 additions & 6 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def preprocess(
inputs = [
strip(
tokenizer.apply_chat_template(
llmp.PlanningProblem(nl, d, p).apply_template(
llmp.PlanningProblem(nl, d, p).apply_template(
domain_prompt,
problem_prompt,
),
Expand Down Expand Up @@ -204,6 +204,34 @@ def load_model(config: dict) -> tuple[PreTrainedTokenizer, PreTrainedModel]:
return tokenizer, model


def extract_instruct_tokens(tokenizer: PreTrainedTokenizer) -> tuple[str, str]:
"""Extract the instruction tokens from the tokenizer.
Args:
tokenizer (PreTrainedTokenizer): The tokenizer to use.
Returns:
tuple[str, str]: The templates.
"""
placeholder = tokenizer.unk_token

chat_str = tokenizer.apply_chat_template(
[
{"role": "user", "content": placeholder},
{"role": "assistant", "content": placeholder},
],
tokenize=False,
)

if not tokenizer.chat_template:
templates = chat_str.split(f" {placeholder} ")
else:
templates = chat_str.split(placeholder)
templates = [t.replace("<s> ", "").strip() for t in templates]

return templates[:2]


def main(config_path: str):
"""Train a model on a dataset using a given configuration.
Expand All @@ -217,17 +245,16 @@ def main(config_path: str):
# Load dataset
dataset = load_dataset(config["dataset"])

train_config = config["train"]
train_config: dict = config["train"]

# Load model
tokenizer, model = load_model(train_config)

# Create data collator
instr_template, resp_template = extract_instruct_tokens(tokenizer)
data_collator = DataCollatorForCompletionOnlyLM(
tokenizer.encode(
train_config["model"]["response_template"],
add_special_tokens=False,
),
response_template=resp_template,
instruction_template=instr_template,
tokenizer=tokenizer,
)

Expand Down
21 changes: 20 additions & 1 deletion planetarium/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,28 @@
__all__ = ["builder", "downward", "graph", "metric", "oracle", "evaluate"]
import os
from importlib import resources

__all__ = [
"builder",
"downward",
"graph",
"metric",
"oracle",
"evaluate",
"DOMAINS",
]

from . import builder
from . import downward
from . import graph
from . import metric
from . import oracle
from . import domains

DOMAINS = dict()

# load domains
for domain in resources.files(domains).iterdir():
with domain.open() as f:
DOMAINS[os.path.basename(domain).split(".")[0]] = f.read()

from .evaluate import evaluate
2 changes: 1 addition & 1 deletion planetarium/domains/floor-tile.pddl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
;; Modified from: https://github.com/AI-Planning/pddl-generators/blob/main/floortile/domain.pddl

(define (domain floor-tile)
(:requirements :typing :action-costs)
(:requirements :typing)
(:types
robot tile color - object
)
Expand Down
1 change: 0 additions & 1 deletion planetarium/downward.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def _get_best_plan(plan_filepath: str) -> tuple[str | None, float]:
best_plan = None

for plan_fp in glob.glob(f"{plan_filepath}*"):
print(plan_fp)
with open(plan_fp, "r") as f:
*pddl_plan, cost_str = f.readlines()
match = re.search(r"cost = ([-\d\.]+)", cost_str)
Expand Down
12 changes: 2 additions & 10 deletions planetarium/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,11 @@
from pddl.parser.problem import LenientProblemParser
from pddl.formatter import problem_to_string

from planetarium import builder, oracle, metric, downward
from . import domains
from planetarium import builder, oracle, metric, downward, DOMAINS


VALIDATE = os.getenv("VALIDATE", "Validate")
DOWNWARD = os.getenv("DOWNWARD", "downward")
DOMAINS = dict()

# load domains
for domain in resources.files(domains).iterdir():
with domain.open() as f:
DOMAINS[os.path.basename(domain).split(".")[0]] = f.read()


def evaluate(
source_pddl_str: str,
Expand Down Expand Up @@ -55,7 +47,7 @@ def evaluate(
try:
target_graph = builder.build(target_pddl_str)
parseable = True
except Exception:
except Exception as e:
return parseable, solveable, equivalent

clean_pddl_str = problem_to_string(LenientProblemParser()(target_pddl_str))
Expand Down
1 change: 0 additions & 1 deletion planetarium/oracles/floortile.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ def _fixed_color_predicates(
if n.typing in ({"tile"}, {"robot"})
]
subgraph = init.graph.subgraph(subgraph_nodes).to_undirected()
print('subgraph', subgraph.nodes())

for u, v, edge in goal.edges:
if edge.predicate == "painted":
Expand Down

0 comments on commit 56ab2ff

Please sign in to comment.