diff --git a/garak/generators/huggingface.py b/garak/generators/huggingface.py index 4461aa10..fac315e0 100644 --- a/garak/generators/huggingface.py +++ b/garak/generators/huggingface.py @@ -19,7 +19,7 @@ from math import log import re import os -from typing import List +from typing import List, Union import warnings import backoff @@ -151,6 +151,75 @@ def __init__(self, name, do_sample=True, generations=10, device=0): self.deprefix_prompt = True +class ConversationalPipeline(Generator): + """Conversational text generation using HuggingFace pipelines""" + + generator_family_name = "Hugging Face 🤗 pipeline for conversations" + supports_multiple_generations = True + + def __init__(self, name, do_sample=True, generations=10, device=0): + self.fullname, self.name = name, name.split("/")[-1] + + super().__init__(name, generations=generations) + + from transformers import pipeline, set_seed, Conversation + + if _config.run.seed is not None: + set_seed(_config.run.seed) + + import torch.cuda + + if not torch.cuda.is_available(): + logging.debug("Using CPU, torch.cuda.is_available() returned False") + device = -1 + + # Note that with pipeline, in order to access the tokenizer, model, or device, you must get the attribute + # directly from self.generator instead of from the ConversationalPipeline object itself. + self.generator = pipeline( + "conversational", + model=name, + do_sample=do_sample, + device=device, + ) + self.conversation = Conversation() + self.deprefix_prompt = name in models_to_deprefix + if _config.loaded: + if _config.run.deprefix is True: + self.deprefix_prompt = True + + def clear_history(self): + from transformers import Conversation + self.conversation = Conversation() + + def _call_model(self, prompt: Union[str, list[dict]]) -> List[str]: + """Take a conversation as a list of dictionaries and feed it to the model""" + + # If conversation is provided as a list of dicts, create the conversation. + # Otherwise, maintain state in Generator + if isinstance(prompt, str): + self.conversation.add_message({"role": "user", "content": prompt}) + self.conversation = self.generator(self.conversation) + generations = [self.conversation[-1]["content"]] + + elif isinstance(prompt, list): + from transformers import Conversation + + conversation = Conversation() + for item in prompt: + conversation.add_message(item) + + conversation = self.generator(conversation) + + generations = [conversation[-1]["content"]] + else: + raise TypeError(f"Expected list or str, got {type(prompt)}") + + if not self.deprefix_prompt: + return generations + else: + return [re.sub("^" + re.escape(prompt), "", i) for i in generations] + + class InferenceAPI(Generator): """Get text generations from Hugging Face Inference API""" diff --git a/garak/resources/tap/generator_utils.py b/garak/resources/tap/generator_utils.py index 8a4f3e59..532d3bdd 100644 --- a/garak/resources/tap/generator_utils.py +++ b/garak/resources/tap/generator_utils.py @@ -4,7 +4,7 @@ from typing import Union from garak.generators.openai import chat_models, OpenAIGenerator -from garak.generators.huggingface import Model +from garak.generators.huggingface import ConversationalPipeline supported_openai = chat_models supported_huggingface = [ @@ -51,7 +51,9 @@ def load_generator( generations=generations, ) elif model_name in supported_huggingface: - generator = Model(model_name, generations=generations, device=device) + generator = ConversationalPipeline( + model_name, generations=generations, device=device + ) else: msg = ( f"{model_name} is not currently supported for TAP generation. Support is available for the following " @@ -59,7 +61,9 @@ def load_generator( f"Your jailbreaks will *NOT* be saved." ) print(msg) - generator = Model(model_name, generations=generations, device=device) + generator = ConversationalPipeline( + model_name, generations=generations, device=device + ) generator.max_tokens = max_tokens if temperature is not None: diff --git a/garak/resources/tap/tap_main.py b/garak/resources/tap/tap_main.py index 2c87fb99..7eac7766 100644 --- a/garak/resources/tap/tap_main.py +++ b/garak/resources/tap/tap_main.py @@ -121,11 +121,11 @@ def get_attack(self, convs, prompts): continue for full_prompt in full_prompts_subset[left:right]: - outputs_list.append(self.attack_generator.generate(full_prompt)[0]) + outputs_list.append(self.attack_generator.generate(full_prompt)) # Check for valid outputs and update the list new_indices_to_regenerate = [] - for i, full_output in enumerate(outputs_list): + for i, full_output in enumerate([o for o in outputs_list if o is not None]): orig_index = indices_to_regenerate[i] if "gpt" not in self.attack_generator.name: