Skip to content

Commit

Permalink
fix import chain AttributeError: 'NoneType' object has no attribute '…
Browse files Browse the repository at this point in the history
…id' (#1254)

* fix import chain AttributeError: 'NoneType' object has no attribute 'id'

* stick with ids
  • Loading branch information
birdup000 authored Sep 22, 2024
1 parent 739d1b5 commit e334f70
Showing 1 changed file with 24 additions and 15 deletions.
39 changes: 24 additions & 15 deletions agixt/Chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,47 +623,56 @@ def import_chain(self, chain_name: str, steps: dict):
if "prompt_type" not in step_data:
step_data["prompt_type"] = "prompt"
prompt_type = step_data["prompt_type"].lower()
target_id = None
if prompt_type == "prompt":
argument_key = "prompt_name"
prompt_category = prompt.get("prompt_category", "Default")
target_id = (
target = (
session.query(Prompt)
.filter(
Prompt.name == prompt[argument_key],
Prompt.user_id == self.user_id,
Prompt.prompt_category.has(name=prompt_category),
)
.first()
.id
)
target_type = "prompt"
if target:
target_id = target.id
elif prompt_type == "chain":
argument_key = "chain_name"
if "chain" in prompt:
argument_key = "chain"
target_id = (
target = (
session.query(ChainDB)
.filter(
ChainDB.name == prompt[argument_key],
ChainDB.user_id == self.user_id,
)
.first()
.id
)
target_type = "chain"
if target:
target_id = target.id
elif prompt_type == "command":
argument_key = "command_name"
target_id = (
target = (
session.query(Command)
.filter(Command.name == prompt[argument_key])
.first()
.id
)
target_type = "command"
if target:
target_id = target.id
else:
# Handle the case where the argument key is not found
# You can choose to skip this step or raise an exception
# Handle the case where the prompt_type is not recognized
logging.error(f"Unrecognized prompt_type: {prompt_type}")
continue

if target_id is None:
# Handle the case where the target is not found
logging.error(
f"Target not found for {prompt_type}: {prompt[argument_key]}"
)
continue

argument_value = prompt[argument_key]
prompt_arguments = prompt.copy()
del prompt_arguments[argument_key]
Expand All @@ -673,9 +682,9 @@ def import_chain(self, chain_name: str, steps: dict):
agent_id=agent.id,
prompt_type=step_data["prompt_type"],
prompt=argument_value,
target_chain_id=target_id if target_type == "chain" else None,
target_command_id=target_id if target_type == "command" else None,
target_prompt_id=target_id if target_type == "prompt" else None,
target_chain_id=target_id if prompt_type == "chain" else None,
target_command_id=target_id if prompt_type == "command" else None,
target_prompt_id=target_id if prompt_type == "prompt" else None,
)
session.add(chain_step)
session.commit()
Expand All @@ -687,7 +696,7 @@ def import_chain(self, chain_name: str, steps: dict):
)
if not argument:
# Handle the case where argument not found based on argument_name
# You can choose to skip this argument or raise an exception
logging.warning(f"Argument not found: {argument_name}")
continue

chain_step_argument = ChainStepArgument(
Expand Down

0 comments on commit e334f70

Please sign in to comment.