Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/add api key support #1735

Merged
merged 4 commits into from
Dec 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 53 additions & 6 deletions py/cli/command_group.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,42 @@
# from .main import load_config
import json
import types
from functools import wraps
from rich.console import Console
from pathlib import Path
from typing import Any, Never

import asyncclick as click
from rich import box
from rich.table import Table
from asyncclick import pass_context
from asyncclick.exceptions import Exit
import types
from typing import Any, Never
from rich import box
from rich.console import Console
from rich.table import Table

from sdk import R2RAsyncClient

console = Console()

CONFIG_DIR = Path.home() / ".r2r"
CONFIG_FILE = CONFIG_DIR / "config.json"


def load_config() -> dict[str, Any]:
"""
Load the CLI config from ~/.r2r/config.json.
Returns an empty dict if the file doesn't exist or is invalid.
"""
if not CONFIG_FILE.is_file():
return {}
try:
with open(CONFIG_FILE, "r", encoding="utf-8") as f:
data = json.load(f)
# Ensure we always have a dict
if not isinstance(data, dict):
return {}
return data
except (IOError, json.JSONDecodeError):
return {}


def silent_exit(ctx, code=0):
if code != 0:
Expand Down Expand Up @@ -118,11 +143,33 @@ def exit(self, code: int = 0) -> Never:
raise SystemExit(code)


def initialize_client(base_url: str) -> R2RAsyncClient:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The base_url parameter is not used when initializing R2RAsyncClient. Consider passing base_url to R2RAsyncClient to ensure the client connects to the correct URL.

"""Initialize R2R client with API key from config if available."""
client = R2RAsyncClient()

try:
config = load_config()
if api_key := config.get("api_key"):
client.set_api_key(api_key)
if not client.api_key:
console.print(
"[yellow]Warning: API key not properly set in client[/yellow]"
)

except Exception as e:
console.print(
"[yellow]Warning: Failed to load API key from config[/yellow]"
)
console.print_exception()

return client


@click.group(cls=CustomGroup)
@click.option(
"--base-url", default="http://localhost:7272", help="Base URL for the API"
)
@pass_context
async def cli(ctx: click.Context, base_url: str) -> None:
"""R2R CLI for all core operations."""
ctx.obj = R2RAsyncClient(base_url=base_url)
ctx.obj = initialize_client(base_url)
8 changes: 4 additions & 4 deletions py/cli/commands/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from rich.console import Console
from rich.box import ROUNDED
from rich.table import Table
from pathlib import Path
import configparser
from pathlib import Path

import asyncclick as click
from rich.box import ROUNDED
from rich.console import Console
from rich.table import Table

console = Console()

Expand Down
1 change: 0 additions & 1 deletion py/cli/commands/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ async def upgrade(schema, revision):
click.echo(
f"Running database upgrade for schema {schema or 'default'}..."
)
print(f"Upgrading revision = {revision}")
command = f"upgrade {revision}" if revision else "upgrade"
result = await run_alembic_command(command, schema_name=schema)

Expand Down
9 changes: 4 additions & 5 deletions py/cli/commands/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,22 @@
import os
import tempfile
import uuid
from typing import Any, Sequence, Optional
from builtins import list as _list
from typing import Any, Optional, Sequence
from urllib.parse import urlparse
from uuid import UUID

import asyncclick as click
import requests
from asyncclick import pass_context
from rich.box import ROUNDED
from rich.console import Console
from rich.table import Table

from cli.utils.param_types import JSON
from cli.utils.timer import timer
from r2r import R2RAsyncClient, R2RException

from rich.console import Console
from rich.box import ROUNDED
from rich.table import Table

console = Console()


Expand Down
12 changes: 6 additions & 6 deletions py/cli/commands/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,13 +373,13 @@ def generate_report():
except subprocess.CalledProcessError as e:
report["docker_error"] = f"Error running Docker command: {e}"
except FileNotFoundError:
report["docker_error"] = (
"Docker command not found. Is Docker installed and in PATH?"
)
report[
"docker_error"
] = "Docker command not found. Is Docker installed and in PATH?"
except subprocess.TimeoutExpired:
report["docker_error"] = (
"Docker command timed out. Docker might be unresponsive."
)
report[
"docker_error"
] = "Docker command timed out. Docker might be unresponsive."

# Get OS information
report["os_info"] = {
Expand Down
4 changes: 2 additions & 2 deletions py/cli/commands/users.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import json
from builtins import list as _list
from uuid import UUID

import asyncclick as click
from asyncclick import pass_context
from builtins import list as _list
from uuid import UUID

from cli.utils.timer import timer
from r2r import R2RAsyncClient, R2RException
Expand Down
73 changes: 70 additions & 3 deletions py/cli/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import json
from typing import Any, Dict

import asyncclick as click
from rich.console import Console

from cli.command_group import cli
from cli.commands import (
collections,
conversations,
config,
conversations,
database,
documents,
graphs,
Expand All @@ -12,9 +18,10 @@
system,
users,
)

from rich.console import Console
from cli.utils.telemetry import posthog, telemetry
from r2r import R2RAsyncClient

from .command_group import CONFIG_DIR, CONFIG_FILE, load_config

console = Console()

Expand Down Expand Up @@ -61,5 +68,65 @@ def main():
posthog.shutdown()


def _ensure_config_dir_exists() -> None:
"""Ensure that the ~/.r2r/ directory exists."""
CONFIG_DIR.mkdir(parents=True, exist_ok=True)


def save_config(config_data: Dict[str, Any]) -> None:
"""
Persist the given config data to ~/.r2r/config.json.
"""
_ensure_config_dir_exists()
with open(CONFIG_FILE, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2)


@cli.command("set-api-key", short_help="Set your R2R API key")
@click.argument("api_key", required=True, type=str)
@click.pass_context
async def set_api_key(ctx, api_key: str):
"""
Store your R2R API key locally so you don’t have to pass it on every command.
Example usage:
r2r set-api sk-1234abcd
"""
try:
# 1) Load existing config
config = load_config()

# 2) Overwrite or add the API key
config["api_key"] = api_key

# 3) Save changes
save_config(config)

console.print("[green]API key set successfully![/green]")
except Exception as e:
console.print("[red]Failed to set API key:[/red]", str(e))


@cli.command("get-api", short_help="Get your stored R2R API key")
@click.pass_context
async def get_api(ctx):
"""
Display your stored R2R API key.
Example usage:
r2r get-api
"""
try:
config = load_config()
api_key = config.get("api_key")

if api_key:
console.print(f"API Key: {api_key}")
else:
console.print(
"[yellow]No API key found. Set one using 'r2r set-api <key>'[/yellow]"
)
except Exception as e:
console.print("[red]Failed to retrieve API key:[/red]", str(e))


if __name__ == "__main__":
main()
1 change: 0 additions & 1 deletion py/core/base/parsers/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


class AsyncParser(ABC, Generic[T]):

@abstractmethod
async def ingest(self, data: T, **kwargs) -> AsyncGenerator[str, None]:
pass
1 change: 0 additions & 1 deletion py/core/base/providers/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ class Config:


class IngestionProvider(Provider, ABC):

config: IngestionConfig
database_provider: "PostgresDatabaseProvider"
llm_provider: CompletionProvider
Expand Down
1 change: 0 additions & 1 deletion py/core/database/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def build(self):


class PostgresConnectionManager(DatabaseConnectionManager):

def __init__(self):
self.pool: Optional[SemaphoreConnectionPool] = None

Expand Down
8 changes: 3 additions & 5 deletions py/core/database/chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,6 @@ async def semantic_search(
async def full_text_search(
self, query_text: str, search_settings: SearchSettings
) -> list[ChunkSearchResult]:

conditions = []
params: list[str | int | bytes] = [query_text]

Expand Down Expand Up @@ -506,9 +505,9 @@ async def hybrid_search(
semantic_results: list[ChunkSearchResult] = await self.semantic_search(
query_vector, semantic_settings
)
full_text_results: list[ChunkSearchResult] = (
await self.full_text_search(query_text, full_text_settings)
)
full_text_results: list[
ChunkSearchResult
] = await self.full_text_search(query_text, full_text_settings)

semantic_limit = search_settings.limit
full_text_limit = search_settings.hybrid_settings.full_text_limit
Expand Down Expand Up @@ -1049,7 +1048,6 @@ async def get_semantic_neighbors(
id: UUID,
similarity_threshold: float = 0.5,
) -> list[dict[str, Any]]:

table_name = self._get_table_name(PostgresChunksHandler.TABLE_NAME)
query = f"""
WITH target_vector AS (
Expand Down
1 change: 0 additions & 1 deletion py/core/database/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ async def create_collection(
description: str = "",
collection_id: Optional[UUID] = None,
) -> CollectionResponse:

if not name and not collection_id:
name = self.config.default_collection_name
collection_id = generate_default_user_collection_id(owner_id)
Expand Down
7 changes: 3 additions & 4 deletions py/core/database/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ async def upsert_documents_overview(
else:
new_attempt_number = db_version

db_entry["ingestion_attempt_number"] = (
new_attempt_number
)
db_entry[
"ingestion_attempt_number"
] = new_attempt_number

update_query = f"""
UPDATE {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
Expand Down Expand Up @@ -152,7 +152,6 @@ async def upsert_documents_overview(
document.id,
)
else:

insert_query = f"""
INSERT INTO {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
(id, collection_ids, owner_id, type, metadata, title, version,
Expand Down
6 changes: 3 additions & 3 deletions py/core/database/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ def __init__(
else:
self.top_level_columns = set(top_level_columns)
self.json_column = json_column
self.params: list[Any] = (
params # params are mutated during construction
)
self.params: list[
Any
] = params # params are mutated during construction
self.mode = mode

def build(self, expr: FilterExpression) -> Tuple[str, list[Any]]:
Expand Down
Loading
Loading