Skip to content

Commit

Permalink
Merge branch 'main' into ft-delete-chat
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaskoepf authored May 3, 2023
2 parents 64aeed8 + 26acf95 commit d46d2ea
Show file tree
Hide file tree
Showing 58 changed files with 2,516 additions and 356 deletions.
20 changes: 15 additions & 5 deletions backend/oasst_backend/tree_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
Task,
TextLabels,
User,
UserStats,
UserStatsTimeFrame,
message_tree_state,
)
from oasst_backend.prompt_repository import PromptRepository
Expand Down Expand Up @@ -269,10 +271,13 @@ def _prompt_lottery(self, lang: str, max_activate: int = 1) -> int:
def activate_one(db: Session) -> int:
# select among distinct users
authors_qry = (
db.query(Message.user_id)
db.query(Message.user_id, func.coalesce(UserStats.reply_ranked_1, 0).label("reply_ranked_1"))
.select_from(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.id)
.join(User, Message.user_id == User.id)
.outerjoin(
UserStats, and_(UserStats.user_id == User.id, UserStats.time_frame == UserStatsTimeFrame.month)
)
.filter(
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING,
Message.lang == lang,
Expand All @@ -284,16 +289,21 @@ def activate_one(db: Session) -> int:
.distinct(Message.user_id)
)

author_ids = authors_qry.all()
if len(author_ids) == 0:
author_data = authors_qry.all()
if len(author_data) == 0:
logger.info(
f"No prompts for prompt lottery available ({num_missing_growing=}, trees missing for {lang=})."
)
return False

author_ids = [data["user_id"] for data in author_data]
# add one to avoid any scenario where all weights are 0
# this also means inactive users can still occasionally be selected
weights = [data["reply_ranked_1"] + 1 for data in author_data]

# first select an author
prompt_author_id: UUID = random.choice(author_ids)["user_id"]
logger.info(f"Selected random prompt author {prompt_author_id} among {len(author_ids)} candidates.")
prompt_author_id: UUID = random.choices(author_ids, weights=weights)[0]
logger.info(f"Selected random prompt author {prompt_author_id} among {len(author_data)} candidates.")

# select random prompt of author
qry = (
Expand Down
2 changes: 1 addition & 1 deletion data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"oa_dolly_15k": "OllieStanley/oa_dolly_15k",
"poetry_instruction": "checkai/instruction-poems",
"oa_stackexchange": "donfu/oa-stackexchange",
"multilingual_wikihow_qa_8k": "0x22almostEvil/multilingual-wikihow-qa-8k",
"multilingual_wikihow_qa_16k": "0x22almostEvil/multilingual-wikihow-qa-16k",
"stable_diffusion_instructional_dataset": "MadVoyager/stable_diffusion_instructional_dataset",
}

Expand Down
40 changes: 36 additions & 4 deletions data/datasets/oa_leet10k/oa_leet10k.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,11 @@
"import random\n",
"from IPython.display import display\n",
"from datasets import Dataset\n",
"import requests\n",
"\n",
"data_source = \"https://www.kaggle.com/datasets/erichartford/leetcode-solutions\"\n",
"lc_contests_data_source = \"https://github.com/Nan-Do/LeetCodeContestsDataset/raw/main/submissions.json\"\n",
"\n",
"output_dir = \"data\"\n",
"os.makedirs(output_dir, exist_ok=True)"
]
Expand All @@ -54,7 +57,12 @@
"metadata": {},
"outputs": [],
"source": [
"kaggle.api.dataset_download_files(\"erichartford/leetcode-solutions\", \"data\", unzip=True)"
"kaggle.api.dataset_download_files(\"erichartford/leetcode-solutions\", \"data\", unzip=True)\n",
"r = requests.get(lc_contests_data_source, allow_redirects=True)\n",
"with open(\"data/lc_contests.json\", \"wb\") as f:\n",
" for chunk in r.iter_content(chunk_size=1024):\n",
" if chunk:\n",
" f.write(chunk)"
]
},
{
Expand All @@ -64,6 +72,7 @@
"outputs": [],
"source": [
"leetcode_solutions = pd.read_json(\"data/leetcode-solutions.jsonl\", lines=True)\n",
"leetcode_contests = pd.read_json(\"data/lc_contests.json\")\n",
"\n",
"# Create dataframe with columns INSTRUCTION, RESPONSE, SOURCE\n",
"# The INSTRUCTION a random choice from ONE_STEP_TEMPLATES with the language and content filled in\n",
Expand All @@ -83,7 +92,21 @@
" \"SOURCE\": data_source,\n",
" }\n",
" )\n",
"\n",
"oa_leetcode_contests = []\n",
"for index, row in leetcode_contests.iterrows():\n",
" oa_leetcode_contests.append(\n",
" {\n",
" \"INSTRUCTION\": row[\"instruction\"] + \"\\n\" + row[\"input\"],\n",
" \"RESPONSE\": row[\"output\"],\n",
" \"SOURCE\": \"https://github.com/Nan-Do/LeetCodeContestsDataset\",\n",
" }\n",
" )\n",
"\n",
"oa_leet10k = pd.DataFrame(oa_leet10k)\n",
"oa_leetcode_contests = pd.DataFrame(oa_leetcode_contests)\n",
"\n",
"print(f\"oa_leet10k: {oa_leet10k.shape[0]}, oa_leetcode_contests: {oa_leetcode_contests.shape[0]}\")\n",
"\n",
"# Print the first 5 rows of the dataframe with full width and newline characters correctly displayed in the RESPONSE column\n",
"with pd.option_context(\"display.max_colwidth\", 80):\n",
Expand All @@ -94,7 +117,13 @@
" \"text-align\": \"left\",\n",
" \"white-space\": \"pre-wrap\",\n",
" }\n",
" )\n",
" ),\n",
" oa_leetcode_contests.head(5).style.set_properties(\n",
" **{\n",
" \"text-align\": \"left\",\n",
" \"white-space\": \"pre-wrap\",\n",
" }\n",
" ),\n",
" )"
]
},
Expand All @@ -106,9 +135,12 @@
"source": [
"# Upload dataset to HF\n",
"oa_leet10k.to_parquet(\"oa_leet10k.parquet\", row_group_size=100, engine=\"pyarrow\")\n",
"ds = Dataset.from_parquet(\"oa_leet10k.parquet\")\n",
"ds_leet10k = Dataset.from_parquet(\"oa_leet10k.parquet\")\n",
"oa_leetcode_contests.to_parquet(\"oa_leetcode_contests.parquet\", row_group_size=100, engine=\"pyarrow\")\n",
"ds_leetcode_contests = Dataset.from_parquet(\"oa_leetcode_contests.parquet\")\n",
"# Uncomment to push dataset to HF\n",
"# ds.push_to_hub(\"ehartford/oa_leet10k\")"
"# ds_leet10k.push_to_hub(\"ehartford/oa_leet10k\")\n",
"# ds_leetcode_contests.push_to_hub(\"ehartford/oa_leet10k\")"
]
}
],
Expand Down
6 changes: 4 additions & 2 deletions inference/full-dev-setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ else
INFERENCE_TAG=latest
fi

POSTGRES_PORT=${POSTGRES_PORT:-5732}

# Creates a tmux window with splits for the individual services

tmux new-session -d -s "inference-dev-setup"
tmux send-keys "docker run --rm -it -p 5732:5432 -e POSTGRES_PASSWORD=postgres --name postgres postgres" C-m
tmux send-keys "docker run --rm -it -p $POSTGRES_PORT:5432 -e POSTGRES_PASSWORD=postgres --name postgres postgres" C-m
tmux split-window -h
tmux send-keys "docker run --rm -it -p 6779:6379 --name redis redis" C-m

Expand All @@ -30,7 +32,7 @@ fi

tmux split-window -h
tmux send-keys "cd server" C-m
tmux send-keys "LOGURU_LEVEL=$LOGLEVEL POSTGRES_PORT=5732 REDIS_PORT=6779 DEBUG_API_KEYS='0000,0001' ALLOW_DEBUG_AUTH=True TRUSTED_CLIENT_KEYS=6969 uvicorn main:app" C-m
tmux send-keys "LOGURU_LEVEL=$LOGLEVEL POSTGRES_PORT=$POSTGRES_PORT REDIS_PORT=6779 DEBUG_API_KEYS='0000,0001' ALLOW_DEBUG_AUTH=True TRUSTED_CLIENT_KEYS=6969 uvicorn main:app" C-m
tmux split-window -h
tmux send-keys "cd text-client" C-m
tmux send-keys "sleep 5" C-m
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""added used plugin to message
Revision ID: 5b4211625a9f
Revises: ea19bbc743f9
Create Date: 2023-05-01 22:53:16.297495
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "5b4211625a9f"
down_revision = "ea19bbc743f9"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("message", sa.Column("used_plugin", postgresql.JSONB(astext_type=sa.Text()), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("message", "used_plugin")
# ### end Alembic commands ###
3 changes: 2 additions & 1 deletion inference/server/oasst_inference_server/chat_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,13 @@ async def abort_work(self, message_id: str, reason: str) -> models.DbMessage:
await self.session.refresh(message)
return message

async def complete_work(self, message_id: str, content: str) -> models.DbMessage:
async def complete_work(self, message_id: str, content: str, used_plugin: inference.PluginUsed) -> models.DbMessage:
logger.debug(f"Completing work on message {message_id}")
message = await self.get_assistant_message_by_id(message_id)
message.state = inference.MessageState.complete
message.work_end_at = datetime.datetime.utcnow()
message.content = content
message.used_plugin = used_plugin
await self.session.commit()
logger.debug(f"Completed work on message {message_id}")
await self.session.refresh(message)
Expand Down
2 changes: 2 additions & 0 deletions inference/server/oasst_inference_server/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def custom_json_deserializer(s):
return chat_schema.CreateMessageRequest.parse_obj(d)
case "WorkRequest":
return inference.WorkRequest.parse_obj(d)
case "PluginUsed":
return inference.PluginUsed.parse_obj(d)
case None:
return d
case _:
Expand Down
3 changes: 3 additions & 0 deletions inference/server/oasst_inference_server/models/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class DbMessage(SQLModel, table=True):
safety_label: str | None = Field(None)
safety_rots: str | None = Field(None)

used_plugin: inference.PluginUsed | None = Field(None, sa_column=sa.Column(pg.JSONB))

state: inference.MessageState = Field(inference.MessageState.manual)
work_parameters: inference.WorkParameters = Field(None, sa_column=sa.Column(pg.JSONB))
work_begin_at: datetime.datetime | None = Field(None)
Expand Down Expand Up @@ -68,6 +70,7 @@ def to_read(self) -> inference.MessageRead:
safety_level=self.safety_level,
safety_label=self.safety_label,
safety_rots=self.safety_rots,
used_plugin=self.used_plugin,
)


Expand Down
1 change: 1 addition & 0 deletions inference/server/oasst_inference_server/routes/chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ async def create_assistant_message(
work_parameters = inference.WorkParameters(
model_config=model_config,
sampling_parameters=request.sampling_parameters,
plugins=request.plugins,
)
assistant_message = await ucr.initiate_assistant_message(
parent_id=request.parent_id,
Expand Down
91 changes: 91 additions & 0 deletions inference/server/oasst_inference_server/routes/configs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import asyncio

import aiohttp
import fastapi
import pydantic
import yaml
from aiohttp.client_exceptions import ClientConnectorError, ServerTimeoutError
from fastapi import HTTPException
from loguru import logger
from oasst_inference_server.settings import settings
from oasst_shared import model_configs
from oasst_shared.schemas import inference

# NOTE: Populate this with plugins that we will provide out of the box
OA_PLUGINS = []

router = fastapi.APIRouter(
prefix="/configs",
tags=["configs"],
Expand Down Expand Up @@ -63,6 +73,16 @@ class ModelConfigInfo(pydantic.BaseModel):
repetition_penalty=1.2,
),
),
ParameterConfig(
name="k50-Plugins",
description="Top-k sampling with k=50 and temperature=0.35",
sampling_parameters=inference.SamplingParameters(
max_new_tokens=1024,
temperature=0.35,
top_k=50,
repetition_penalty=(1 / 0.90),
),
),
ParameterConfig(
name="nucleus9",
description="Nucleus sampling with p=0.9",
Expand Down Expand Up @@ -93,6 +113,44 @@ class ModelConfigInfo(pydantic.BaseModel):
]


async def fetch_plugin(url: str, retries: int = 3, timeout: float = 5.0) -> inference.PluginConfig:
async with aiohttp.ClientSession() as session:
for attempt in range(retries):
try:
async with session.get(url, timeout=timeout) as response:
content_type = response.headers.get("Content-Type")

if response.status == 200:
if "application/json" in content_type or url.endswith(".json"):
config = await response.json()
elif (
"application/yaml" in content_type
or "application/x-yaml" in content_type
or url.endswith(".yaml")
or url.endswith(".yml")
):
config = yaml.safe_load(await response.text())
else:
raise HTTPException(
status_code=400,
detail=f"Unsupported content type: {content_type}. Only JSON and YAML are supported.",
)

return inference.PluginConfig(**config)
elif response.status == 404:
raise HTTPException(status_code=404, detail="Plugin not found")
else:
raise HTTPException(status_code=response.status, detail="Unexpected status code")
except (ClientConnectorError, ServerTimeoutError) as e:
if attempt == retries - 1: # last attempt
raise HTTPException(status_code=500, detail=f"Request failed after {retries} retries: {e}")
await asyncio.sleep(2**attempt) # exponential backoff

except aiohttp.ClientError as e:
raise HTTPException(status_code=500, detail=f"Request failed: {e}")
raise HTTPException(status_code=500, detail="Failed to fetch plugin")


@router.get("/model_configs")
async def get_model_configs() -> list[ModelConfigInfo]:
return [
Expand All @@ -103,3 +161,36 @@ async def get_model_configs() -> list[ModelConfigInfo]:
for model_config_name in model_configs.MODEL_CONFIGS
if (settings.allowed_model_config_names == "*" or model_config_name in settings.allowed_model_config_names_list)
]


@router.post("/plugin_config")
async def get_plugin_config(plugin: inference.PluginEntry) -> inference.PluginEntry:
try:
plugin_config = await fetch_plugin(plugin.url)
except HTTPException as e:
logger.warning(f"Failed to fetch plugin config from {plugin.url}: {e.detail}")
raise fastapi.HTTPException(status_code=e.status_code, detail=e.detail)

return inference.PluginEntry(url=plugin.url, enabled=plugin.enabled, plugin_config=plugin_config)


@router.get("/builtin_plugins")
async def get_builtin_plugins() -> list[inference.PluginEntry]:
plugins = []

for plugin in OA_PLUGINS:
try:
plugin_config = await fetch_plugin(plugin.url)
except HTTPException as e:
logger.warning(f"Failed to fetch plugin config from {plugin.url}: {e.detail}")
continue

final_plugin: inference.PluginEntry = inference.PluginEntry(
url=plugin.url,
enabled=plugin.enabled,
trusted=plugin.trusted,
plugin_config=plugin_config,
)
plugins.append(final_plugin)

return plugins
1 change: 1 addition & 0 deletions inference/server/oasst_inference_server/routes/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ async def handle_generated_text_response(
message = await cr.complete_work(
message_id=message_id,
content=response.text,
used_plugin=response.used_plugin,
)
logger.info(f"Completed work for {message_id=}")
message_packet = inference.InternalFinishedMessageResponse(
Expand Down
2 changes: 2 additions & 0 deletions inference/server/oasst_inference_server/schemas/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ class CreateAssistantMessageRequest(pydantic.BaseModel):
parent_id: str
model_config_name: str
sampling_parameters: inference.SamplingParameters = pydantic.Field(default_factory=inference.SamplingParameters)
plugins: list[inference.PluginEntry] = pydantic.Field(default_factory=list[inference.PluginEntry])
used_plugin: inference.PluginUsed | None = None


class PendingResponseEvent(pydantic.BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ async def add_prompter_message(self, chat_id: str, parent_id: str | None, conten
if parent_id is None:
if len(chat.messages) > 0:
raise fastapi.HTTPException(status_code=400, detail="Trying to add first message to non-empty chat")
chat.title = content
if chat.title is None:
chat.title = content
else:
msg_dict = chat.get_msg_dict()
if parent_id not in msg_dict:
Expand Down
Loading

0 comments on commit d46d2ea

Please sign in to comment.