Skip to content

Commit

Permalink
config
Browse files Browse the repository at this point in the history
  • Loading branch information
krichard1212 committed Dec 2, 2024
1 parent 63fad1c commit c1630eb
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 22 deletions.
29 changes: 29 additions & 0 deletions config_example.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
rules:




input:
- name: "language_detection"
type: "language_detection"
Expand Down Expand Up @@ -62,6 +66,31 @@ rules:
action:
type: "block"
# - type: "monitoring" # logging
- name: "llamaguard_check"
type: "llama_guard"
enabled: true
order_number: 4
config:
plugin_name: "llama_guard"
threshold: 0.5
relation: ">"
# can be left empty, in that case every categgory is included.
#categories: ["S1","S7"]

action:
type: "block"

- name: "PromptGuard Injection Detection"
type: "prompt_guard"
enabled: true
order_number: 5
config:
plugin_name: "prompt_guard"
threshold: 0.7
relation: ">"
temperature: 3.0
action:
type: "block"
output:
- name: "pii_example"
type: "pii_filter"
Expand Down
1 change: 1 addition & 0 deletions lib/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ type Config struct {
Url string `mapstructure:"url,omitempty"`
ApiKey string `mapstructure:"api_key,omitempty"`
PIIService interface{} `mapstructure:"piiservice,omitempty"`
Categories []string `mapstructure:"categories,omitempty"`
}

// Add this function to set default values
Expand Down
58 changes: 52 additions & 6 deletions lib/rules/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,61 @@ func genericHandler(inputConfig lib.Rule, rule RuleResult) (bool, string, error)
return false, fmt.Sprintf(`{"status": "non_blocked", "rule_type": "%s"}`, inputConfig.Type), nil
}
func handleLlamaGuardAction(inputConfig lib.Rule, rule RuleResult) (bool, string, error) {
log.Printf("%s detection result: Match=%v, Score=%f", inputConfig.Type, rule.Match, rule.Inspection.Score)
log.Printf("LlamaGuard detection result: Match=%v, Score=%f", rule.Match, rule.Inspection.Score)

// Log which categories we're checking
if len(inputConfig.Config.Categories) > 0 {
log.Printf("Checking specific categories: %v", inputConfig.Config.Categories)
} else {
log.Println("Checking all default categories")
}

if rule.Match {
if inputConfig.Action.Type == "block" {
log.Println("Blocking request due to LlamaGuard detection.")
return true, fmt.Sprintf(`{"status": "blocked", "rule_type": "%s"}`, inputConfig.Type), nil
details := rule.Inspection.Details
if details != nil {
if rawAnalysis, ok := details["raw_analysis"].(string); ok {
log.Printf("LlamaGuard analysis: %s", rawAnalysis)
}

if violatedCategories, ok := details["violated_categories"].([]interface{}); ok {
categories := make([]string, len(violatedCategories))
for i, v := range violatedCategories {
categories[i] = v.(string)
}

relevantViolations := []string{}
configuredCategories := inputConfig.Config.Categories

if len(configuredCategories) > 0 {

for _, violation := range categories {
for _, configured := range configuredCategories {
if violation == configured {
relevantViolations = append(relevantViolations, violation)
break
}
}
}
} else {

relevantViolations = categories
}

if len(relevantViolations) > 0 {
log.Printf("Violated categories (after filtering): %v", relevantViolations)
if inputConfig.Action.Type == "block" {
log.Printf("Blocking request due to LlamaGuard detection in categories: %v", relevantViolations)
return true, fmt.Sprintf(`{"status": "blocked", "rule_type": "%s", "violated_categories": %v}`,
inputConfig.Type, relevantViolations), nil
}
log.Printf("Monitoring request due to LlamaGuard detection in categories: %v", relevantViolations)
return false, fmt.Sprintf(`{"status": "non_blocked", "rule_type": "%s", "violated_categories": %v}`,
inputConfig.Type, relevantViolations), nil
}
}
}
log.Println("Monitoring request due to LlamaGuard detection.")
return false, fmt.Sprintf(`{"status": "non_blocked", "rule_type": "%s"}`, inputConfig.Type), nil
}

log.Println("LlamaGuard Rule Not Matched")
return false, fmt.Sprintf(`{"status": "non_blocked", "rule_type": "%s"}`, inputConfig.Type), nil
}
Expand Down
33 changes: 17 additions & 16 deletions services/rule/src/plugins/llama_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import logging
from typing import Dict, Any, List, Optional
import torch
import accelerate
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login, HfApi

Expand All @@ -17,6 +16,8 @@
)
logger = logging.getLogger(__name__)

DEFAULT_CATEGORIES = ["S1", "S2", "S3", "S4", "S5", "S6", "S7",
"S8", "S9", "S10", "S11", "S12", "S13"]

def get_huggingface_token():
token = os.getenv("HUGGINGFACE_TOKEN") or os.getenv("HUGGINGFACE_API_KEY")
Expand All @@ -25,7 +26,6 @@ def get_huggingface_token():
return None
return token


class LlamaGuardAnalyzer:
def __init__(self):
self.token = get_huggingface_token()
Expand All @@ -48,15 +48,13 @@ def __init__(self):
model_id = "meta-llama/Llama-Guard-3-1B"

if torch.cuda.is_available():
logger.info("Using GPU for model loading")
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
token=self.token
)
else:
logger.info("Using CPU for model loading")
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32,
Expand Down Expand Up @@ -87,8 +85,7 @@ def clean_analysis_output(self, text: str) -> str:
def analyze_content(
self,
text: str,
categories: Optional[List[str]] = None,
excluded_categories: Optional[List[str]] = None
categories: Optional[List[str]] = None
) -> str:
try:
logger.info(f"Analyzing text: '{text[:100]}{'...' if len(text) > 100 else ''}'")
Expand All @@ -106,11 +103,17 @@ def analyze_content(
]

kwargs = {"return_tensors": "pt"}

if categories:
# Convert categories to the format expected by the model
cats_dict = {cat: cat for cat in categories}
kwargs["categories"] = cats_dict
if excluded_categories:
kwargs["excluded_category_keys"] = excluded_categories
logger.info(f"Using specified categories: {cats_dict}")
else:
# Use all default categories if none specified
cats_dict = {cat: cat for cat in DEFAULT_CATEGORIES}
kwargs["categories"] = cats_dict
logger.info("Using all default categories")

input_ids = self.tokenizer.apply_chat_template(
conversation,
Expand Down Expand Up @@ -138,7 +141,6 @@ def analyze_content(
logger.error(f"Error during analysis: {e}")
raise


analyzer = None
try:
logger.info("Initializing LlamaGuard analyzer...")
Expand All @@ -147,28 +149,27 @@ def analyze_content(
except Exception as e:
logger.error(f"Failed to initialize LlamaGuard analyzer: {str(e)}")


def handler(text: str, threshold: float, config: Dict[str, Any]) -> Dict[str, Any]:
try:
if analyzer is None:
raise RuntimeError("LlamaGuard analyzer not initialized")

# Extract categories from config
categories = config.get('categories', [])
excluded_categories = config.get('excluded_categories', [])

analysis = analyzer.analyze_content(
text,
categories=categories,
excluded_categories=excluded_categories
categories=categories if categories else None
)

is_unsafe = not analysis.lower().startswith('safe')
score = 1.0 if is_unsafe else 0.0

violated_categories = []
if is_unsafe:
for category in ["S1", "S2", "S3", "S4", "S5", "S6", "S7",
"S8", "S9", "S10", "S11", "S12", "S13"]:
# Look for category violations in the analysis text
check_categories = categories if categories else DEFAULT_CATEGORIES
for category in check_categories:
if category in analysis:
violated_categories.append(category)

Expand All @@ -188,4 +189,4 @@ def handler(text: str, threshold: float, config: Dict[str, Any]) -> Dict[str, An
"check_result": False,
"score": 0.0,
"details": {"error": str(e)}
}
}

0 comments on commit c1630eb

Please sign in to comment.