Skip to content

Commit

Permalink
feat: add support for extra request parameters (headers, body, query)
Browse files Browse the repository at this point in the history
Add support for specifying extra_headers, extra_body, and extra_query
arguments in OpenAI function calls:

- ai.openai_list_models
- ai.openai_embed
- ai.openai_chat_complete
- ai.openai_moderate

These parameters are passed through to the
underlying OpenAI library methods that support them, enabling more
flexible request customization.

This change allows users to:

- Add custom HTTP headers via extra_headers
- Include additional body parameters via extra_body
- Specify extra query parameters via extra_query

All parameters are passed unchanged to the corresponding OpenAI library
methods.
  • Loading branch information
alejandrodnm committed Feb 4, 2025
1 parent 4aad292 commit f74defd
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 105 deletions.
11 changes: 11 additions & 0 deletions docs/openai.md
Original file line number Diff line number Diff line change
Expand Up @@ -509,3 +509,14 @@ from
) x
;
```
### Undocumented request params
If you want to include additional parameters in a request, you can do so using
the extra_query, extra_body, and extra_headers options.
These values will be passed down to the corresponding method in the OpenAI
Python library and used in the same way as described in the [OpenAI Python
library documentation for Undocumented request params][undocumented-params].
[undocumented-params]: https://openai.com/docs/api-reference/python#undocumented-request-params
44 changes: 37 additions & 7 deletions projects/extension/ai/openai.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from collections.abc import Generator
from datetime import datetime

Expand Down Expand Up @@ -25,15 +26,36 @@ def make_client(
return openai.Client(api_key=api_key, base_url=base_url)


def str_arg_to_dict(arg: str | None) -> dict | None:
return json.loads(arg) if arg is not None else None


def create_kwargs(**kwargs) -> dict:
kwargs_ = {}
for k, v in kwargs.items():
if v is not None:
kwargs_[k] = v
return kwargs_


def list_models(
plpy,
api_key: str,
base_url: str | None = None,
extra_headers: str | None = None,
extra_query: str | None = None,
timeout: float | None = None,
) -> Generator[tuple[str, datetime, str], None, None]:
client = make_client(plpy, api_key, base_url)
from datetime import datetime, timezone

for model in client.models.list():
kwargs = create_kwargs(
extra_headers=str_arg_to_dict(extra_headers),
extra_query=str_arg_to_dict(extra_query),
timeout=timeout,
)

for model in client.models.list(**kwargs):
created = datetime.fromtimestamp(model.created, timezone.utc)
yield model.id, created, model.owned_by

Expand All @@ -46,14 +68,22 @@ def embed(
base_url: str | None = None,
dimensions: int | None = None,
user: str | None = None,
extra_headers: str | None = None,
extra_query: str | None = None,
extra_body: str | None = None,
timeout: float | None = None,
) -> Generator[tuple[int, list[float]], None, None]:
client = make_client(plpy, api_key, base_url)
args = {}
if dimensions is not None:
args["dimensions"] = dimensions
if user is not None:
args["user"] = user
response = client.embeddings.create(input=input, model=model, **args)

kwargs = create_kwargs(
dimensions=dimensions,
user=user,
extra_headers=str_arg_to_dict(extra_headers),
extra_query=str_arg_to_dict(extra_query),
extra_body=str_arg_to_dict(extra_body),
timeout=timeout,
)
response = client.embeddings.create(input=input, model=model, **kwargs)
if not hasattr(response, "data"):
return None
for obj in response.data:
Expand Down
160 changes: 107 additions & 53 deletions projects/extension/sql/idempotent/001-openai.sql
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,14 @@ set search_path to pg_catalog, pg_temp
-- openai_list_models
-- list models supported on the openai platform
-- https://platform.openai.com/docs/api-reference/models/list
create or replace function ai.openai_list_models(api_key text default null, api_key_name text default null, base_url text default null)
create or replace function ai.openai_list_models
( api_key text default null
, api_key_name text default null
, base_url text default null
, extra_headers jsonb default null
, extra_query jsonb default null
, timeout float8 default null
)
returns table
( id text
, created timestamptz
Expand All @@ -46,7 +53,14 @@ as $python$
import ai.openai
import ai.secrets
api_key_resolved = ai.secrets.get_secret(plpy, api_key, api_key_name, ai.openai.DEFAULT_KEY_NAME, SD)
for tup in ai.openai.list_models(plpy, api_key_resolved, base_url):
models = ai.openai.list_models(
plpy,
api_key_resolved,
base_url,
extra_headers,
extra_query,
timeout)
for tup in models:
yield tup
$python$
language plpython3u volatile parallel safe security invoker
Expand All @@ -65,13 +79,29 @@ create or replace function ai.openai_embed
, base_url text default null
, dimensions int default null
, openai_user text default null
, extra_headers jsonb default null
, extra_query jsonb default null
, extra_body jsonb default null
, timeout float8 default null
) returns @extschema:vector@.vector
as $python$
#ADD-PYTHON-LIB-DIR
import ai.openai
import ai.secrets
api_key_resolved = ai.secrets.get_secret(plpy, api_key, api_key_name, ai.openai.DEFAULT_KEY_NAME, SD)
for tup in ai.openai.embed(plpy, model, input_text, api_key=api_key_resolved, base_url=base_url, dimensions=dimensions, user=openai_user):
embeddings = ai.openai.embed(
plpy,
model,
input_text,
api_key_resolved,
base_url,
dimensions,
openai_user,
extra_headers,
extra_query,
extra_body,
timeout)
for tup in embeddings:
return tup[1]
$python$
language plpython3u immutable parallel safe security invoker
Expand All @@ -90,6 +120,10 @@ create or replace function ai.openai_embed
, base_url text default null
, dimensions int default null
, openai_user text default null
, extra_headers jsonb default null
, extra_query jsonb default null
, extra_body jsonb default null
, timeout float8 default null
) returns table
( "index" int
, embedding @extschema:vector@.vector
Expand All @@ -99,7 +133,20 @@ as $python$
import ai.openai
import ai.secrets
api_key_resolved = ai.secrets.get_secret(plpy, api_key, api_key_name, ai.openai.DEFAULT_KEY_NAME, SD)
for tup in ai.openai.embed(plpy, model, input_texts, api_key=api_key_resolved, base_url=base_url, dimensions=dimensions, user=openai_user):

embeddings = ai.openai.embed(
plpy,
model,
input_texts,
api_key_resolved,
base_url,
dimensions,
openai_user,
extra_headers,
extra_query,
extra_body,
timeout)
for tup in embeddings:
yield tup
$python$
language plpython3u immutable parallel safe security invoker
Expand All @@ -118,13 +165,30 @@ create or replace function ai.openai_embed
, base_url text default null
, dimensions int default null
, openai_user text default null
, extra_headers jsonb default null
, extra_query jsonb default null
, extra_body jsonb default null
, timeout float8 default null
) returns @extschema:vector@.vector
as $python$
#ADD-PYTHON-LIB-DIR
import ai.openai
import ai.secrets
api_key_resolved = ai.secrets.get_secret(plpy, api_key, api_key_name, ai.openai.DEFAULT_KEY_NAME, SD)
for tup in ai.openai.embed(plpy, model, input_tokens, api_key=api_key_resolved, base_url=base_url, dimensions=dimensions, user=openai_user):

embeddings = ai.openai.embed(
plpy,
model,
input_tokens,
api_key_resolved,
base_url,
dimensions,
openai_user,
extra_headers,
extra_query,
extra_body,
timeout)
for tup in embeddings:
return tup[1]
$python$
language plpython3u immutable parallel safe security invoker
Expand Down Expand Up @@ -156,6 +220,10 @@ create or replace function ai.openai_chat_complete
, tools jsonb default null
, tool_choice text default null
, openai_user text default null
, extra_headers jsonb default null
, extra_query jsonb default null
, extra_body jsonb default null
, timeout float8 default null
) returns jsonb
as $python$
#ADD-PYTHON-LIB-DIR
Expand All @@ -169,58 +237,32 @@ as $python$
if not isinstance(messages_1, list):
plpy.error("messages is not an array")

args = {}

if frequency_penalty is not None:
args["frequency_penalty"] = frequency_penalty

if logit_bias is not None:
args["logit_bias"] = json.loads(logit_bias)

if logprobs is not None:
args["logprobs"] = logprobs

if top_logprobs is not None:
args["top_logprobs"] = top_logprobs

if max_tokens is not None:
args["max_tokens"] = max_tokens

if n is not None:
args["n"] = n

if presence_penalty is not None:
args["presence_penalty"] = presence_penalty

if response_format is not None:
args["response_format"] = json.loads(response_format)

if seed is not None:
args["seed"] = seed

if stop is not None:
args["stop"] = stop

if temperature is not None:
args["temperature"] = temperature

if top_p is not None:
args["top_p"] = top_p

if tools is not None:
args["tools"] = json.loads(tools)

if tool_choice is not None:
args["tool_choice"] = tool_choice if tool_choice in {'auto', 'none', 'required'} else json.loads(tool_choice)

if openai_user is not None:
args["user"] = openai_user
kwargs = ai.openai.create_kwargs(
frequency_penalty=frequency_penalty,
logit_bias=ai.openai.str_arg_to_dict(logit_bias),
logprobs=logprobs,
top_logprobs=top_logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
response_format=ai.openai.str_arg_to_dict(response_format),
seed=seed,
stop=stop,
temperature=temperature,
top_p=top_p,
tools=ai.openai.str_arg_to_dict(tools),
tool_choice=tool_choice if tool_choice in {'auto', 'none', 'required'} else ai.openai.str_arg_to_dict(tool_choice),
user=openai_user,
extra_headers=ai.openai.str_arg_to_dict(extra_headers),
extra_query=ai.openai.str_arg_to_dict(extra_query),
extra_body=ai.openai.str_arg_to_dict(extra_body),
timeout=timeout)

response = client.chat.completions.create(
model=model
, messages=messages_1
, stream=False
, **args
, **kwargs
)

return response.model_dump_json()
Expand Down Expand Up @@ -266,14 +308,26 @@ create or replace function ai.openai_moderate
, api_key text default null
, api_key_name text default null
, base_url text default null
, extra_headers jsonb default null
, extra_query jsonb default null
, extra_body jsonb default null
, timeout float8 default null
) returns jsonb
as $python$
#ADD-PYTHON-LIB-DIR
import ai.openai
import ai.secrets
api_key_resolved = ai.secrets.get_secret(plpy, api_key, api_key_name, ai.openai.DEFAULT_KEY_NAME, SD)
client = ai.openai.make_client(plpy, api_key_resolved, base_url)
moderation = client.moderations.create(input=input_text, model=model)
kwargs = ai.openai.create_kwargs(
extra_headers=ai.openai.str_arg_to_dict(extra_headers),
extra_query=ai.openai.str_arg_to_dict(extra_query),
extra_body=ai.openai.str_arg_to_dict(extra_body),
timeout=timeout)
moderation = client.moderations.create(
input=input_text,
model=model,
**kwargs)
return moderation.model_dump_json()
$python$
language plpython3u immutable parallel safe security invoker
Expand Down
27 changes: 27 additions & 0 deletions projects/extension/sql/incremental/014-extra-openai-args.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
drop function if exists ai.openai_list_models(text, text, text);
drop function if exists ai.openai_embed(text, text, text, text, text, int, text);
drop function if exists ai.openai_embed(text, text [], text, text, text, int, text);
drop function if exists ai.openai_embed(text, int [], text, text, text, int, text);
drop function if exists ai.openai_chat_complete(
text,
jsonb,
text,
text,
text,
float8,
jsonb,
boolean,
int,
int,
int,
float8,
jsonb,
int,
text,
float8,
float8,
jsonb,
text,
text
);
drop function if exists ai.openai_moderate(text, text, text, text, text);
18 changes: 9 additions & 9 deletions projects/extension/tests/contents/output16.expected
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
DROP DATABASE
CREATE DATABASE
CREATE EXTENSION
Objects in extension "ai"
Object description
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Objects in extension "ai"
Object description
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
event trigger _vectorizer_handle_drops
function ai.anthropic_generate(text,jsonb,integer,text,text,text,double precision,integer,text,text,text[],double precision,jsonb,jsonb,integer,double precision)
function ai.anthropic_list_models(text,text,text)
Expand Down Expand Up @@ -46,13 +46,13 @@ CREATE EXTENSION
function ai.ollama_list_models(text)
function ai.ollama_ps(text)
function ai.openai_chat_complete_simple(text,text,text)
function ai.openai_chat_complete(text,jsonb,text,text,text,double precision,jsonb,boolean,integer,integer,integer,double precision,jsonb,integer,text,double precision,double precision,jsonb,text,text)
function ai.openai_chat_complete(text,jsonb,text,text,text,double precision,jsonb,boolean,integer,integer,integer,double precision,jsonb,integer,text,double precision,double precision,jsonb,text,text,jsonb,jsonb,jsonb,double precision)
function ai.openai_detokenize(text,integer[])
function ai.openai_embed(text,integer[],text,text,text,integer,text)
function ai.openai_embed(text,text,text,text,text,integer,text)
function ai.openai_embed(text,text[],text,text,text,integer,text)
function ai.openai_list_models(text,text,text)
function ai.openai_moderate(text,text,text,text,text)
function ai.openai_embed(text,integer[],text,text,text,integer,text,jsonb,jsonb,jsonb,double precision)
function ai.openai_embed(text,text,text,text,text,integer,text,jsonb,jsonb,jsonb,double precision)
function ai.openai_embed(text,text[],text,text,text,integer,text,jsonb,jsonb,jsonb,double precision)
function ai.openai_list_models(text,text,text,jsonb,jsonb,double precision)
function ai.openai_moderate(text,text,text,text,text,jsonb,jsonb,jsonb,double precision)
function ai.openai_tokenize(text,text)
function ai.processing_default(integer,integer)
function ai._resolve_indexing_default()
Expand Down
Loading

0 comments on commit f74defd

Please sign in to comment.