Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Stable diffusion v3 #615

Merged
merged 10 commits into from
Oct 22, 2024
1 change: 1 addition & 0 deletions src/globals.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ export const HEADER_KEYS: Record<string, string> = {
CUSTOM_HOST: `x-${POWERED_BY}-custom-host`,
REQUEST_TIMEOUT: `x-${POWERED_BY}-request-timeout`,
STRICT_OPEN_AI_COMPLIANCE: `x-${POWERED_BY}-strict-open-ai-compliance`,
CONTENT_TYPE: `Content-Type`,
};

export const RESPONSE_HEADER_KEYS: Record<string, string> = {
Expand Down
26 changes: 22 additions & 4 deletions src/handlers/handlerUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import {
OPEN_AI,
AZURE_AI_INFERENCE,
ANTHROPIC,
MULTIPART_FORM_DATA_ENDPOINTS,
CONTENT_TYPES,
HUGGING_FACE,
STABILITY_AI,
} from '../globals';
import Providers from '../providers';
import { ProviderAPIConfig, endpointStrings } from '../providers/types';
Expand Down Expand Up @@ -524,6 +524,7 @@ export async function tryPost(
fn,
transformedRequestBody,
transformedRequestUrl: url,
gatewayRequestBody: params,
});

// Construct the base object for the POST request
Expand All @@ -535,9 +536,10 @@ export async function tryPost(
requestHeaders
);

fetchOptions.body = MULTIPART_FORM_DATA_ENDPOINTS.includes(fn)
? (transformedRequestBody as FormData)
: JSON.stringify(transformedRequestBody);
fetchOptions.body =
headers[HEADER_KEYS.CONTENT_TYPE] === CONTENT_TYPES.MULTIPART_FORM_DATA
? (transformedRequestBody as FormData)
: JSON.stringify(transformedRequestBody);

providerOption.retry = {
attempts: providerOption.retry?.attempts ?? 0,
Expand Down Expand Up @@ -1012,6 +1014,14 @@ export function constructConfigFromRequestHeaders(
azureModelName: requestHeaders[`x-${POWERED_BY}-azure-model-name`],
};

const stabilityAiConfig = {
stabilityClientId: requestHeaders[`x-${POWERED_BY}-stability-client-id`],
stabilityClientUserId:
requestHeaders[`x-${POWERED_BY}-stability-client-user-id`],
stabilityClientVersion:
requestHeaders[`x-${POWERED_BY}-stability-client-version`],
};

const azureAiInferenceConfig = {
azureDeploymentName:
requestHeaders[`x-${POWERED_BY}-azure-deployment-name`],
Expand Down Expand Up @@ -1128,6 +1138,12 @@ export function constructConfigFromRequestHeaders(
...anthropicConfig,
};
}
if (parsedConfigJson.provider === STABILITY_AI) {
parsedConfigJson = {
...parsedConfigJson,
...stabilityAiConfig,
};
}
}
return convertKeysToCamelCase(parsedConfigJson, [
'override_params',
Expand Down Expand Up @@ -1158,6 +1174,8 @@ export function constructConfigFromRequestHeaders(
huggingfaceConfig),
mistralFimCompletion:
requestHeaders[`x-${POWERED_BY}-mistral-fim-completion`],
...(requestHeaders[`x-${POWERED_BY}-provider`] === STABILITY_AI &&
stabilityAiConfig),
};
}

Expand Down
5 changes: 5 additions & 0 deletions src/providers/bedrock/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,8 @@ export const MISTRAL_CONTROL_TOKENS = {
MIDDLE: '[MIDDLE]',
SUFFIX: '[SUFFIX]',
};

export const BEDROCK_STABILITY_V1_MODELS = [
'stable-diffusion-xl-v0',
'stable-diffusion-xl-v1',
];
46 changes: 39 additions & 7 deletions src/providers/bedrock/imageGenerate.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import { BEDROCK } from '../../globals';
import { StabilityAIImageGenerateV2Config } from '../stability-ai/imageGenerateV2';
import { ErrorResponse, ImageGenerateResponse, ProviderConfig } from '../types';
import { generateInvalidProviderResponseError } from '../utils';
import { BedrockErrorResponseTransform } from './chatComplete';
import { BedrockErrorResponse } from './embed';

export const BedrockStabilityAIImageGenerateConfig: ProviderConfig = {
export const BedrockStabilityAIImageGenerateV1Config: ProviderConfig = {
prompt: {
param: 'text_prompts',
required: true,
Expand Down Expand Up @@ -47,29 +48,60 @@ interface ImageArtifact {
seed: number;
}

interface BedrockStabilityAIImageGenerateResponse {
interface BedrockStabilityAIImageGenerateV1Response {
result: string;
artifacts: ImageArtifact[];
}

export const BedrockStabilityAIImageGenerateResponseTransform: (
response: BedrockStabilityAIImageGenerateResponse | BedrockErrorResponse,
export const BedrockStabilityAIImageGenerateV1ResponseTransform: (
response: BedrockStabilityAIImageGenerateV1Response | BedrockErrorResponse,
responseStatus: number
) => ImageGenerateResponse | ErrorResponse = (response, responseStatus) => {
if (responseStatus !== 200) {
const errorResposne = BedrockErrorResponseTransform(
const errorResponse = BedrockErrorResponseTransform(
response as BedrockErrorResponse
);
if (errorResposne) return errorResposne;
if (errorResponse) return errorResponse;
}

if ('artifacts' in response) {
return {
created: `${new Date().getTime()}`,
created: Math.floor(Date.now() / 1000),
data: response.artifacts.map((art) => ({ b64_json: art.base64 })),
provider: BEDROCK,
};
}

return generateInvalidProviderResponseError(response, BEDROCK);
};

interface BedrockStabilityAIImageGenerateV2Response {
seeds: number[];
finish_reasons: string[];
images: string[];
}

export const BedrockStabilityAIImageGenerateV2Config =
StabilityAIImageGenerateV2Config;

export const BedrockStabilityAIImageGenerateV2ResponseTransform: (
response: BedrockStabilityAIImageGenerateV2Response | BedrockErrorResponse,
responseStatus: number
) => ImageGenerateResponse | ErrorResponse = (response, responseStatus) => {
if (responseStatus !== 200) {
const errorResponse = BedrockErrorResponseTransform(
response as BedrockErrorResponse
);
if (errorResponse) return errorResponse;
}

if ('images' in response) {
return {
created: Math.floor(Date.now() / 1000),
data: response.images.map((image) => ({ b64_json: image })),
provider: BEDROCK,
};
}

return generateInvalidProviderResponseError(response, BEDROCK);
};
25 changes: 19 additions & 6 deletions src/providers/bedrock/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,18 @@ import {
BedrockTitanCompleteResponseTransform,
BedrockTitanCompleteStreamChunkTransform,
} from './complete';
import { BEDROCK_STABILITY_V1_MODELS } from './constants';
import {
BedrockCohereEmbedConfig,
BedrockCohereEmbedResponseTransform,
BedrockTitanEmbedConfig,
BedrockTitanEmbedResponseTransform,
} from './embed';
import {
BedrockStabilityAIImageGenerateConfig,
BedrockStabilityAIImageGenerateResponseTransform,
BedrockStabilityAIImageGenerateV1Config,
BedrockStabilityAIImageGenerateV1ResponseTransform,
BedrockStabilityAIImageGenerateV2Config,
BedrockStabilityAIImageGenerateV2ResponseTransform,
} from './imageGenerate';

const BedrockConfig: ProviderConfigs = {
Expand All @@ -63,8 +66,9 @@ const BedrockConfig: ProviderConfigs = {
// To remove the region in case its a cross-region inference profile ID
// https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-support.html
const providerModel = params.model.replace(/^(us\.|eu\.)/, '');
const provider = providerModel?.split('.')[0];
const model = providerModel?.split('.')[1];
const providerModelArray = providerModel.split('.');
const provider = providerModelArray[0];
const model = providerModelArray.slice(1).join('.');
switch (provider) {
case ANTHROPIC:
return {
Expand Down Expand Up @@ -148,11 +152,20 @@ const BedrockConfig: ProviderConfigs = {
},
};
case 'stability':
if (model && BEDROCK_STABILITY_V1_MODELS.includes(model)) {
return {
imageGenerate: BedrockStabilityAIImageGenerateV1Config,
api: BedrockAPIConfig,
responseTransforms: {
imageGenerate: BedrockStabilityAIImageGenerateV1ResponseTransform,
},
};
}
return {
imageGenerate: BedrockStabilityAIImageGenerateConfig,
imageGenerate: BedrockStabilityAIImageGenerateV2Config,
api: BedrockAPIConfig,
responseTransforms: {
imageGenerate: BedrockStabilityAIImageGenerateResponseTransform,
imageGenerate: BedrockStabilityAIImageGenerateV2ResponseTransform,
},
};
default:
Expand Down
2 changes: 1 addition & 1 deletion src/providers/fireworks-ai/imageGenerate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ export const FireworksAIImageGenerateResponseTransform: (
}
if (response instanceof Array) {
return {
created: `${new Date().getTime()}`, // Corrected method call
created: Math.floor(Date.now() / 1000), // Corrected method call
data: response?.map((r) => ({
b64_json: r.base64,
seed: r.seed,
Expand Down
2 changes: 1 addition & 1 deletion src/providers/segmind/imageGenerate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ export const SegmindImageGenerateResponseTransform: (
let dataObj: object[] = imageArr.map((img) => ({ b64_json: img }));

return {
created: `${new Date().getTime()}`, // Corrected method call
created: Math.floor(Date.now() / 1000), // Corrected method call
data: dataObj,
provider: SEGMIND,
} as ImageGenerateResponse;
Expand Down
22 changes: 18 additions & 4 deletions src/providers/stability-ai/api.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import { CONTENT_TYPES } from '../../globals';
import { ProviderAPIConfig } from '../types';
import { isStabilityV1Model } from './utils';

const StabilityAIAPIConfig: ProviderAPIConfig = {
getBaseURL: () => 'https://api.stability.ai/v1',
headers: ({ providerOptions }) => {
return { Authorization: `Bearer ${providerOptions.apiKey}` };
getBaseURL: () => 'https://api.stability.ai',
headers: ({ providerOptions, gatewayRequestBody }) => {
const headers: Record<string, string> = {
Authorization: `Bearer ${providerOptions.apiKey}`,
};
if (isStabilityV1Model(gatewayRequestBody?.model)) return headers;
headers['Content-Type'] = CONTENT_TYPES.MULTIPART_FORM_DATA;
headers['Accept'] = CONTENT_TYPES.APPLICATION_JSON;
return headers;
},
getEndpoint: ({ fn, gatewayRequestBody, providerOptions }) => {
let mappedFn = fn;
Expand All @@ -18,12 +26,18 @@ const StabilityAIAPIConfig: ProviderAPIConfig = {

switch (mappedFn) {
case 'imageGenerate': {
return `/generation/${gatewayRequestBody.model}/text-to-image`;
if (isStabilityV1Model(gatewayRequestBody.model))
return `/v1/generation/${gatewayRequestBody.model}/text-to-image`;
return `/v2beta/stable-image/generate/${gatewayRequestBody.model}`;
}
default:
return '';
}
},
transformToFormData: ({ gatewayRequestBody }) => {
if (isStabilityV1Model(gatewayRequestBody.model)) return false;
return true;
},
};

export default StabilityAIAPIConfig;
4 changes: 4 additions & 0 deletions src/providers/stability-ai/constants.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
export const STABILITY_V1_MODELS = [
'stable-diffusion-xl-1024-v1-0',
'stable-diffusion-v1-6',
];
18 changes: 7 additions & 11 deletions src/providers/stability-ai/imageGenerate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
generateInvalidProviderResponseError,
} from '../utils';

export const StabilityAIImageGenerateConfig: ProviderConfig = {
export const StabilityAIImageGenerateV1Config: ProviderConfig = {
prompt: {
param: 'text_prompts',
required: true,
Expand Down Expand Up @@ -60,15 +60,11 @@ export const StabilityAIImageGenerateConfig: ProviderConfig = {
},
};

interface StabilityAIImageGenerateResponse extends ImageGenerateResponse {
interface StabilityAIImageGenerateV1Response extends ImageGenerateResponse {
artifacts: ImageArtifact[];
}

interface StabilityAIImageGenerateResponse extends ImageGenerateResponse {
artifacts: ImageArtifact[];
}

interface StabilityAIImageGenerateErrorResponse {
interface StabilityAIImageGenerateV1ErrorResponse {
id: string;
name: string;
message: string;
Expand All @@ -80,10 +76,10 @@ interface ImageArtifact {
seed: number; // The seed associated with this image
}

export const StabilityAIImageGenerateResponseTransform: (
export const StabilityAIImageGenerateV1ResponseTransform: (
response:
| StabilityAIImageGenerateResponse
| StabilityAIImageGenerateErrorResponse,
| StabilityAIImageGenerateV1Response
| StabilityAIImageGenerateV1ErrorResponse,
responseStatus: number
) => ImageGenerateResponse | ErrorResponse = (response, responseStatus) => {
if (responseStatus !== 200 && 'message' in response) {
Expand All @@ -100,7 +96,7 @@ export const StabilityAIImageGenerateResponseTransform: (

if ('artifacts' in response) {
return {
created: `${new Date().getTime()}`, // Corrected method call
created: Math.floor(Date.now() / 1000), // Corrected method call
data: response.artifacts.map((art) => ({ b64_json: art.base64 })), // Corrected object creation within map
provider: STABILITY_AI,
};
Expand Down
Loading
Loading