Skip to content

Commit

Permalink
Option for user to automatically add Amazon Bedrock integrations (#1819)
Browse files Browse the repository at this point in the history
* Option for user to automatically add Amazon Bedrock integrations when read only instance role (IAM) is set #1759
  • Loading branch information
ramanan-ravi authored Dec 6, 2023
1 parent e7e51b6 commit 6468a55
Show file tree
Hide file tree
Showing 11 changed files with 273 additions and 109 deletions.
6 changes: 3 additions & 3 deletions deepfence_server/apiDocs/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -726,9 +726,9 @@ func (d *OpenAPIDocs) AddIntegrationOperations() {
d.AddOperation("addGenerativeAiIntegrationBedrock", http.MethodPost, "/deepfence/generative-ai-integration/bedrock",
"Add AWS Bedrock Generative AI Integration", "Add a new AWS Bedrock Generative AI Integration",
http.StatusOK, []string{tagGenerativeAi}, bearerToken, new(AddGenerativeAiBedrockIntegration), new(MessageResponse))
d.AddOperation("autoAddGenerativeAiIntegrationBedrock", http.MethodPost, "/deepfence/generative-ai-integration/bedrock/auto-add",
"Automatically add AWS Bedrock Generative AI Integration", "Automatically add AWS Bedrock Generative AI Integrations using IAM role",
http.StatusOK, []string{tagGenerativeAi}, bearerToken, new(AutoAddGenerativeAiBedrockIntegration), new(MessageResponse))
d.AddOperation("autoAddGenerativeAiIntegration", http.MethodPost, "/deepfence/generative-ai-integration/auto-add",
"Automatically add Generative AI Integration", "Automatically add Generative AI Integrations using IAM role",
http.StatusAccepted, []string{tagGenerativeAi}, bearerToken, nil, nil)

d.AddOperation("listGenerativeAiIntegration", http.MethodGet, "/deepfence/generative-ai-integration",
"List Generative AI Integrations", "List all the added Generative AI Integrations",
Expand Down
157 changes: 62 additions & 95 deletions deepfence_server/handler/generative_ai_integration.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package handler

import (
"context"
"database/sql"
"encoding/json"
"errors"
Expand All @@ -16,6 +15,7 @@ import (
"github.com/deepfence/ThreatMapper/deepfence_utils/encryption"
"github.com/deepfence/ThreatMapper/deepfence_utils/log"
postgresqlDb "github.com/deepfence/ThreatMapper/deepfence_utils/postgresql/postgresql-db"
"github.com/deepfence/ThreatMapper/deepfence_utils/utils"
"github.com/go-chi/chi/v5"
httpext "github.com/go-playground/pkg/v5/net/http"
)
Expand All @@ -25,104 +25,86 @@ var (
ErrGenerativeAIIntegrationExists = BadDecoding{
err: errors.New("similar integration already exists"),
}
ErrBedrockNoActiveModel = BadDecoding{
err: bedrock.ErrBedrockNoActiveModel,
}
)

func (h *Handler) AddBedrockIntegrationUsingIAMRole(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
var req model.AutoAddGenerativeAiBedrockIntegration
err := httpext.DecodeJSON(r, httpext.NoQueryParams, MaxPostRequestSize, &req)
if err != nil {
log.Error().Msgf("%v", err)
h.respondError(&BadDecoding{err}, w)
return
}

addModelRequests, err := bedrock.ListBedrockModels(req.AWSRegion)
func (h *Handler) AddGenerativeAIIntegrationUsingIAMRole(w http.ResponseWriter, r *http.Request) {
// Only AWS at the moment
foundModel, err := bedrock.CheckBedrockModelAvailability()
if err != nil {
log.Error().Msgf("%v", err)
h.respondError(&BadDecoding{err}, w)
return
}

if len(addModelRequests) == 0 {
log.Error().Msgf("%v", ErrBedrockNoActiveModel)
h.respondError(&ErrBedrockNoActiveModel, w)
if !foundModel {
h.respondError(&BadDecoding{err: bedrock.ErrBedrockNoActiveModel}, w)
return
}

ctx := r.Context()
pgClient, err := directory.PostgresClient(ctx)
worker, err := directory.Worker(r.Context())
if err != nil {
h.respondError(&InternalServerError{err}, w)
h.respondError(err, w)
return
}

// encrypt secret
aesValue, err := model.GetAESValueForEncryption(ctx, pgClient)
user, statusCode, _, err := h.GetUserFromJWT(r.Context())
if err != nil {
log.Error().Msgf(err.Error())
h.respondError(&InternalServerError{err}, w)
h.respondWithErrorCode(err, w, statusCode)
return
}
aes := encryption.AES{}
err = json.Unmarshal(aesValue, &aes)
data := utils.AutoFetchGenerativeAIIntegrationsParameters{
CloudProvider: "aws",
UserID: user.ID,
}
dataJson, err := json.Marshal(data)

Check failure on line 57 in deepfence_server/handler/generative_ai_integration.go

View workflow job for this annotation

GitHub Actions / lint-server

ST1003: var dataJson should be dataJSON (stylecheck)
if err != nil {
log.Error().Msgf(err.Error())
h.respondError(&InternalServerError{err}, w)
h.respondError(err, w)
return
}
user, statusCode, _, err := h.GetUserFromJWT(ctx)
err = worker.Enqueue(utils.AutoFetchGenerativeAIIntegrations, dataJson, utils.DefaultTaskOpts()...)
if err != nil {
h.respondWithErrorCode(err, w, statusCode)
h.respondError(err, w)
return
}

for _, addModelRequest := range addModelRequests {
err = h.AddGenerativeAiIntegrationHelper(ctx, addModelRequest, aes, user, pgClient)
if err != nil {
log.Warn().Msgf(err.Error())
continue
}
}

httpext.JSON(w, http.StatusOK, model.MessageResponse{Message: api_messages.SuccessIntegrationCreated})
w.WriteHeader(http.StatusAccepted)
}

func (h *Handler) AddOpenAiIntegration(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
var req model.AddGenerativeAiOpenAIIntegration
err := httpext.DecodeJSON(r, httpext.NoQueryParams, MaxPostRequestSize, &req)
if err != nil {
log.Error().Msgf("%v", err)
h.respondError(&BadDecoding{err}, w)
return
}
h.AddGenerativeAiIntegration(req, w, r)
AddGenerativeAiIntegration[model.AddGenerativeAiOpenAIIntegration](w, r, h)
}

func (h *Handler) AddBedrockIntegration(w http.ResponseWriter, r *http.Request) {
AddGenerativeAiIntegration[model.AddGenerativeAiBedrockIntegration](w, r, h)
}

func AddGenerativeAiIntegration[T model.AddGenerativeAiIntegrationRequest](w http.ResponseWriter, r *http.Request, h *Handler) {
defer r.Body.Close()
var req model.AddGenerativeAiBedrockIntegration
var req T
err := httpext.DecodeJSON(r, httpext.NoQueryParams, MaxPostRequestSize, &req)
if err != nil {
log.Error().Msgf("%v", err)
h.respondError(&BadDecoding{err}, w)
return
}
h.AddGenerativeAiIntegration(req, w, r)
}

func (h *Handler) AddGenerativeAiIntegration(req model.AddGenerativeAiIntegrationRequest, w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
pgClient, err := directory.PostgresClient(ctx)
if err != nil {
h.respondError(&InternalServerError{err}, w)
return
}

obj, err := generative_ai_integration.NewGenerativeAiIntegration(ctx, req)
if err != nil {
log.Error().Msgf("%v", err)
h.respondError(&BadDecoding{err}, w)
return
}
err = obj.ValidateConfig(h.Validator)
if err != nil {
h.respondError(&ValidatorError{err: err}, w)
return
}

// encrypt secret
aesValue, err := model.GetAESValueForEncryption(ctx, pgClient)
if err != nil {
Expand All @@ -137,60 +119,37 @@ func (h *Handler) AddGenerativeAiIntegration(req model.AddGenerativeAiIntegratio
h.respondError(&InternalServerError{err}, w)
return
}
user, statusCode, _, err := h.GetUserFromJWT(ctx)
if err != nil {
h.respondWithErrorCode(err, w, statusCode)
return
}

err = h.AddGenerativeAiIntegrationHelper(ctx, req, aes, user, pgClient)
err = obj.EncryptSecret(aes)
if err != nil {
log.Error().Msgf(err.Error())
h.respondError(err, w)
h.respondError(&InternalServerError{err}, w)
return
}

h.AuditUserActivity(r, EventGenerativeAIIntegration, ActionCreate, map[string]interface{}{"integration_type": req.GetIntegrationType()}, true)

err = httpext.JSON(w, http.StatusOK, model.MessageResponse{Message: api_messages.SuccessIntegrationCreated})
if err != nil {
log.Error().Msg(err.Error())
}
}

func (h *Handler) AddGenerativeAiIntegrationHelper(ctx context.Context, req model.AddGenerativeAiIntegrationRequest, aes encryption.AES, user *model.User, pgClient *postgresqlDb.Queries) error {
obj, err := generative_ai_integration.NewGenerativeAiIntegration(ctx, req)
if err != nil {
return &BadDecoding{err}
}
err = obj.ValidateConfig(h.Validator)
if err != nil {
return &ValidatorError{err: err}
}
err = obj.VerifyAuth(ctx)
if err != nil {
return &BadDecoding{err: err}
}

err = obj.EncryptSecret(aes)
if err != nil {
return err
}

// add integration to database
// before that check if integration already exists
integrationExists, err := req.IntegrationExists(ctx, pgClient)
if err != nil {
return err
log.Error().Msgf(err.Error())
h.respondError(&InternalServerError{err}, w)
return
}
if integrationExists {
return &ErrGenerativeAIIntegrationExists
h.respondError(&ErrGenerativeAIIntegrationExists, w)
return
}

user, statusCode, _, err := h.GetUserFromJWT(ctx)
if err != nil {
h.respondWithErrorCode(err, w, statusCode)
return
}

// store the integration in db
bConfig, err := json.Marshal(obj)
if err != nil {
return err
h.respondWithErrorCode(err, w, statusCode)
return
}

arg := postgresqlDb.CreateGenerativeAiIntegrationParams{
Expand All @@ -201,14 +160,22 @@ func (h *Handler) AddGenerativeAiIntegrationHelper(ctx context.Context, req mode
}
dbIntegration, err := pgClient.CreateGenerativeAiIntegration(ctx, arg)
if err != nil {
return err
log.Error().Msgf(err.Error())
h.respondError(&InternalServerError{err}, w)
return
}

h.AuditUserActivity(r, EventGenerativeAIIntegration, ActionCreate, map[string]interface{}{"integration_type": req.GetIntegrationType()}, true)

err = pgClient.UpdateGenerativeAiIntegrationDefault(ctx, dbIntegration.ID)
if err != nil {
log.Warn().Msgf(err.Error())
}
return nil

err = httpext.JSON(w, http.StatusOK, model.MessageResponse{Message: api_messages.SuccessIntegrationCreated})
if err != nil {
log.Error().Msg(err.Error())
}
}

func (h *Handler) GetGenerativeAiIntegrations(w http.ResponseWriter, r *http.Request) {
Expand Down
4 changes: 0 additions & 4 deletions deepfence_server/model/generative_ai_integration.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,3 @@ type GenerativeAiIntegrationListResponse struct {
LastErrorMsg string `json:"last_error_msg"`
DefaultIntegration bool `json:"default_integration"`
}

type AutoAddGenerativeAiBedrockIntegration struct {
AWSRegion string `json:"aws_region" validate:"required,oneof=us-east-1 us-east-2 us-west-1 us-west-2 af-south-1 ap-east-1 ap-south-1 ap-northeast-1 ap-northeast-2 ap-northeast-3 ap-southeast-1 ap-southeast-2 ap-southeast-3 ca-central-1 eu-central-1 eu-west-1 eu-west-2 eu-west-3 eu-south-1 eu-north-1 me-south-1 me-central-1 sa-east-1 us-gov-east-1 us-gov-west-1" required:"true" enum:"us-east-1,us-east-2,us-west-1,us-west-2,af-south-1,ap-east-1,ap-south-1,ap-northeast-1,ap-northeast-2,ap-northeast-3,ap-southeast-1,ap-southeast-2,ap-southeast-3,ca-central-1,eu-central-1,eu-west-1,eu-west-2,eu-west-3,eu-south-1,eu-north-1,me-south-1,me-central-1,sa-east-1,us-gov-east-1,us-gov-west-1"`
}
64 changes: 60 additions & 4 deletions deepfence_server/pkg/generative-ai-integration/bedrock/aimodels.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,45 @@
package bedrock

import (
"encoding/json"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/bedrock"
"github.com/aws/aws-sdk-go/service/bedrockruntime"
"github.com/deepfence/ThreatMapper/deepfence_server/model"
"github.com/deepfence/ThreatMapper/deepfence_utils/log"
)

func CheckBedrockModelAvailability() (bool, error) {
foundModel := false
for _, region := range BedrockRegions {
sess, err := session.NewSession(&aws.Config{
Region: aws.String(region),
})
if err != nil {
return false, err
}
svc := bedrock.New(sess)

models, err := svc.ListFoundationModels(&bedrock.ListFoundationModelsInput{

Check failure on line 25 in deepfence_server/pkg/generative-ai-integration/bedrock/aimodels.go

View workflow job for this annotation

GitHub Actions / lint-server

ineffectual assignment to err (ineffassign)
ByOutputModality: &textModality,
})
for _, modelSummary := range models.ModelSummaries {
if *modelSummary.ModelLifecycle.Status == modelLifecycleActive {
if _, ok := BedrockModelBody[*modelSummary.ModelId]; ok {
foundModel = true
break
}
}
}
if foundModel == true {

Check failure on line 36 in deepfence_server/pkg/generative-ai-integration/bedrock/aimodels.go

View workflow job for this annotation

GitHub Actions / lint-server

S1002: should omit comparison to bool constant, can be simplified to `foundModel` (gosimple)
break
}
}
return foundModel, nil
}

// ListBedrockModels Fetch enabled Bedrock models using IAM roles
func ListBedrockModels(region string) ([]model.AddGenerativeAiIntegrationRequest, error) {
sess, err := session.NewSession(&aws.Config{
Expand All @@ -15,24 +48,47 @@ func ListBedrockModels(region string) ([]model.AddGenerativeAiIntegrationRequest
if err != nil {
return nil, err
}

svc := bedrock.New(sess)
bedrockRuntimeSvc := bedrockruntime.New(sess)

models, err := svc.ListFoundationModels(&bedrock.ListFoundationModelsInput{
ByOutputModality: &textModality,
})
if err != nil {
return nil, err
}
resp := []model.AddGenerativeAiIntegrationRequest{}

var bedrockModels []model.AddGenerativeAiIntegrationRequest
message := "hello"

for _, modelSummary := range models.ModelSummaries {
if *modelSummary.ModelLifecycle.Status == modelLifecycleActive {
if _, ok := BedrockModelBody[*modelSummary.ModelId]; ok {
resp = append(resp, model.AddGenerativeAiBedrockIntegration{
if body, ok := BedrockModelBody[*modelSummary.ModelId]; ok {
body[BedrockModelBodyInputKey[*modelSummary.ModelId]] = body[BedrockModelBodyInputKey[*modelSummary.ModelId]].(string) + message + BedrockModelBodyInputSuffix[*modelSummary.ModelId]
bodyBytes, err := json.Marshal(body)
if err != nil {
log.Warn().Msg(err.Error())
continue
}
_, err = bedrockRuntimeSvc.InvokeModel(&bedrockruntime.InvokeModelInput{
Accept: &acceptHeader,
Body: bodyBytes,
ContentType: &contentTypeHeader,
ModelId: modelSummary.ModelId,
})
if err != nil {
log.Warn().Msg(err.Error())
continue
}

bedrockModels = append(bedrockModels, model.AddGenerativeAiBedrockIntegration{
AWSRegion: region,
UseIAMRole: true,
ModelID: *modelSummary.ModelId,
})
}
}
}
return resp, nil
return bedrockModels, nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ const (
)

var (
BedrockRegions = []string{"us-east-1", "us-west-2", "ap-southeast-1", "ap-northeast-1", "eu-central-1"}

textModality = "TEXT"

contentTypeHeader = "application/json"
Expand Down
2 changes: 1 addition & 1 deletion deepfence_server/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ func SetupRoutes(r *chi.Mux, serverPort string, serveOpenapiDocs bool, ingestC c
r.Route("/generative-ai-integration", func(r chi.Router) {
r.Post("/openai", dfHandler.AuthHandler(ResourceIntegration, PermissionWrite, dfHandler.AddOpenAiIntegration))
r.Post("/bedrock", dfHandler.AuthHandler(ResourceIntegration, PermissionWrite, dfHandler.AddBedrockIntegration))
r.Post("/bedrock/auto-add", dfHandler.AuthHandler(ResourceIntegration, PermissionWrite, dfHandler.AddBedrockIntegrationUsingIAMRole))
r.Post("/auto-add", dfHandler.AuthHandler(ResourceIntegration, PermissionWrite, dfHandler.AddGenerativeAIIntegrationUsingIAMRole))

r.Get("/", dfHandler.AuthHandler(ResourceIntegration, PermissionRead, dfHandler.GetGenerativeAiIntegrations))
r.Route("/{integration_id}", func(r chi.Router) {
Expand Down
Loading

0 comments on commit 6468a55

Please sign in to comment.