Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Models fixes #1126

Merged
merged 13 commits into from
Aug 28, 2023
24 changes: 19 additions & 5 deletions gui/pages/Content/Agents/AgentCreate.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import {
updateExecution,
uploadFile,
getAgentDetails, addAgentRun, fetchModels,
getAgentWorkflows
getAgentWorkflows, validateOrAddModels
} from "@/pages/api/DashboardService";
import {
formatBytes,
Expand Down Expand Up @@ -56,7 +56,7 @@ export default function AgentCreate({
const [searchValue, setSearchValue] = useState('');
const [showButton, setShowButton] = useState(false);
const [showPlaceholder, setShowPlaceholder] = useState(true);
const [modelsArray, setModelsArray] = useState([]);
const [modelsArray, setModelsArray] = useState(['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k']);

const constraintsArray = [
"If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.",
Expand All @@ -69,7 +69,7 @@ export default function AgentCreate({
const [goals, setGoals] = useState(['Describe the agent goals here']);
const [instructions, setInstructions] = useState(['']);

const models = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k', 'google-palm-bison-001', 'replicate-llama13b-v2-chat']
const models = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k']
const [model, setModel] = useState(models[1]);
const modelRef = useRef(null);
const [modelDropdown, setModelDropdown] = useState(false);
Expand Down Expand Up @@ -155,7 +155,7 @@ export default function AgentCreate({
.then((response) => {
const models = response.data.map(model => model.name) || [];
const selected_model = localStorage.getItem("agent_model_" + String(internalId)) || '';
setModelsArray(models);
setModelsArray(prevModels => Array.from(new Set([...prevModels, ...models])));
if (models.length > 0 && !selected_model) {
setLocalStorageValue("agent_model_" + String(internalId), models[0], setModel);
} else {
Expand Down Expand Up @@ -494,7 +494,21 @@ export default function AgentCreate({
return true;
}

const handleAddAgent = () => {
const validateModel = async () => {
const response = await validateOrAddModels(model)
if (response.data.error) {
toast.error(response.data.error, {autoClose: 1800});
return false;
}
return true;
}

const handleAddAgent = async () => {
if(env === 'DEV') {
const bool = await validateModel()
if(!bool) return;
}

if (!validateAgentData(true)) {
return;
}
Expand Down
5 changes: 5 additions & 0 deletions gui/pages/api/DashboardService.js
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,8 @@ export const fetchMarketPlaceModel = () => {
return api.get(`/models_controller/get/list`)
}

export const validateOrAddModels = (model) => {
return api.get(`/models_controller/validate_or_add_gpt_models`, {
params: { model }
});
}
9 changes: 9 additions & 0 deletions superagi/controllers/models_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ async def fetch_data(request: ModelName, organisation=Depends(get_user_organisat
raise HTTPException(status_code=500, detail="Internal Server Error")


@router.get("/validate_or_add_gpt_models", status_code=200)
async def validate_or_add_gpt_models(model: str = None, organisation=Depends(get_user_organisation)):
try:
return Models.validate_model_in_db(db.session, organisation.id, model)
except Exception as e:
logging.error(f"Error Validating or Adding GPT Models: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Server Error")


@router.get("/get/list", status_code=200)
def get_knowledge_list(page: int = 0, organisation=Depends(get_user_organisation)):
"""
Expand Down
46 changes: 46 additions & 0 deletions superagi/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from sqlalchemy.sql import func
from typing import List, Dict, Union
from superagi.models.base_model import DBBaseModel
from superagi.helper.encyption_helper import encrypt_data, decrypt_data
import requests, logging

# marketplace_url = "https://app.superagi.com/api"
Expand Down Expand Up @@ -201,3 +202,48 @@ def fetch_model_details(cls, session, organisation_id, model_id: int) -> Dict[st
except Exception as e:
logging.error(f"Unexpected Error Occured: {e}")
return {"error": "Unexpected Error Occured"}

@classmethod
def validate_model_in_db(cls, session, organisation_id, model):
try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move the imports to top of the file.

from superagi.models.models_config import ModelsConfig
from superagi.models.configuration import Configuration

models = {"gpt-3.5-turbo-0301": 4032, "gpt-4-0314": 8092, "gpt-3.5-turbo": 4032,
"gpt-4": 8092, "gpt-3.5-turbo-16k": 16184, "gpt-4-32k": 32768}

model_config = session.query(Models).filter(Models.model_name == model,
Models.org_id == organisation_id).first()
if model_config is None:
model_provider = session.query(ModelsConfig).filter(ModelsConfig.provider == "OpenAI",
ModelsConfig.org_id == organisation_id).first()

if model_provider is None:
configurations = session.query(Configuration).filter(Configuration.key == 'model_api_key',
Configuration.organisation_id == organisation_id).first()
model_api_key = decrypt_data(configurations.value)

if configurations is None:
return {"error": "Model not found and the API Key is missing"}

model_details = ModelsConfig.store_api_key(session, organisation_id, "OpenAI", model_api_key)

# Get 'model_provider_id'
model_provider_id = model_details.get('model_provider_id')

result = cls.store_model_details(session, organisation_id, model, model, '',
model_provider_id, models[model], 'Custom', '')
if result is not None:
return {"success": "Model was not Installed, so I have dont it for you"}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle using HTTP code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use HTTP Exception


else:
result = cls.store_model_details(session, organisation_id, model, model, '',
model_provider.id, models[model], 'Custom', '')
if result is not None:
return {"success": "Model was not Installed, so I have dont it for you"}

else:
return {"success": "Model is found"}

except Exception as e:
logging.error(f"Unexpected Error occurred while Validating GPT Models: {e}")
11 changes: 8 additions & 3 deletions superagi/models/models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,17 @@ def store_api_key(cls, session, organisation_id, model_provider, model_api_key):
ModelsConfig.provider == model_provider)).first()
if existing_entry:
existing_entry.api_key = encrypt_data(model_api_key)
session.commit()
result = {'message': 'The API key was successfully updated'}
else:
new_entry = ModelsConfig(org_id=organisation_id, provider=model_provider,
api_key=encrypt_data(model_api_key))
session.add(new_entry)
session.commit()
session.flush()
result = {'message': 'The API key was successfully stored', 'model_provider_id': new_entry.id}

session.commit()

return {'message': 'The API key was successfully stored'}
return result

@classmethod
def fetch_api_keys(cls, session, organisation_id):
Expand All @@ -107,6 +110,8 @@ def fetch_api_key(cls, session, organisation_id, model_provider):
if api_key_data is None:
return []
else:
print("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove print

print(decrypt_data(api_key_data.api_key))
api_key = [{'id': api_key_data.id, 'provider': api_key_data.provider,
'api_key': decrypt_data(api_key_data.api_key)}]
return api_key
Expand Down
12 changes: 12 additions & 0 deletions tests/unit_tests/models/test_models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,16 @@ def test_fetch_model_by_id(mock_session):

# Call the method
model = ModelsConfig.fetch_model_by_id(mock_session, organisation_id, model_provider_id)
assert model == {"provider": "some_provider"}

def test_fetch_model_by_id_marketplace(mock_session):
# Arrange
model_provider_id = 1
# Mock model
mock_model = MagicMock()
mock_model.provider = 'some_provider'
mock_session.query.return_value.filter.return_value.first.return_value = mock_model

# Call the method
model = ModelsConfig.fetch_model_by_id_marketplace(mock_session, model_provider_id)
assert model == {"provider": "some_provider"}