Skip to content

Commit

Permalink
Make use of instruction field optional (#123)
Browse files Browse the repository at this point in the history
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
  • Loading branch information
alex-jw-brooks authored Apr 16, 2024
1 parent de6f272 commit aae11d7
Showing 1 changed file with 28 additions and 12 deletions.
40 changes: 28 additions & 12 deletions scripts/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,12 @@ def parse_and_validate_args():
"--eos_token",
help="EOS token emitted by the model; will recursively remove the token if present",
)
parser.add_argument(
"--use_instruction",
help="Indicates whether or not the instruction field should be used in formatting",
action="store_true",
)
parser.add_argument("--purge_results", action=argparse.BooleanOptionalAction)

parsed_args = parser.parse_args()

print(f"Multiclass / multioutput delimiter: {parsed_args.delimiter}")
Expand Down Expand Up @@ -86,29 +90,37 @@ def parse_and_validate_args():
}


def get_formatted_example(example: dict[str, str]) -> dict[str, str]:
def get_formatted_example(
example: dict[str, str], use_instruction: bool
) -> dict[str, str]:
"""Given a single example, format it based on whether or not we have an input provided.
Args:
example: dict[str, str]
Dictionary containing the keys for instruction / input / output, i.e., Alpaca formatted
data.
use_instruction: bool
Indicates whether or not the instruction field will be used.
Returns:
dict[str, str]
Dictionary containing the following:
"input" - the formatted text to run the prediction on.
"output" - the target text we aim to generate.
"""
prompt_input, prompt_no_input = (
PROMPT_DICT["prompt_input"],
PROMPT_DICT["prompt_no_input"],
)
formatted_input = (
prompt_input.format_map(example)
if example.get("input", "") != ""
else prompt_no_input.format_map(example)
)
# NOTE: Currently we ignore the instruction field due to the type of tasks we're tuning against
if use_instruction:
prompt_input, prompt_no_input = (
PROMPT_DICT["prompt_input"],
PROMPT_DICT["prompt_no_input"],
)
formatted_input = (
prompt_input.format_map(example)
if example.get("input", "") != ""
else prompt_no_input.format_map(example)
)
else:
formatted_input = f"Input: \n{example.get('input')}\n\n### Response:"
return {
# Text to run the prediction on
"input": formatted_input,
Expand All @@ -122,6 +134,7 @@ def get_prediction_results(
model: TunedCausalLM,
data: datasets.arrow_dataset.Dataset,
max_new_tokens: int,
use_instruction: bool,
delimiter: Optional[str],
eos_token: Optional[str],
) -> tuple[list]:
Expand All @@ -135,6 +148,8 @@ def get_prediction_results(
HF dataset to be processed for evaluation.
max_new_tokens: int
Max number of tokens to be used for generation.
use_instruction: bool
Indicates whether or not the instruction field should be used.
delimiter: Optional[str]
Delimiter to be used for splitting apart multioutput instances.
eos_token: Optional[str]
Expand All @@ -151,7 +166,7 @@ def get_prediction_results(
model_pred_info = []
for datum in tqdm(data):
# Format the alpaca example
formatted_datum = get_formatted_example(datum)
formatted_datum = get_formatted_example(datum, use_instruction)
# Run the formatted text through the model, and only save the newly generated text strings
prediction = model.run(
formatted_datum["input"],
Expand Down Expand Up @@ -434,6 +449,7 @@ def export_experiment_info(
tuned_model,
eval_data,
args.max_new_tokens,
args.use_instruction,
args.delimiter,
args.eos_token,
)
Expand Down

0 comments on commit aae11d7

Please sign in to comment.