From 053cb4489ca4fbbbd3c8d227fcb26a6efcebf71d Mon Sep 17 00:00:00 2001 From: yiling Date: Fri, 27 Dec 2024 17:25:18 +0800 Subject: [PATCH] add prompt handler tests --- api/handler/prompt.go | 54 +++--- api/handler/prompt_test.go | 338 +++++++++++++++++++++++++++++++++++++ 2 files changed, 364 insertions(+), 28 deletions(-) create mode 100644 api/handler/prompt_test.go diff --git a/api/handler/prompt.go b/api/handler/prompt.go index a8defe4b..0bd6c750 100644 --- a/api/handler/prompt.go +++ b/api/handler/prompt.go @@ -17,9 +17,9 @@ import ( ) type PromptHandler struct { - pc component.PromptComponent - sc component.SensitiveComponent - repo component.RepoComponent + prompt component.PromptComponent + sensitive component.SensitiveComponent + repo component.RepoComponent } func NewPromptHandler(cfg *config.Config) (*PromptHandler, error) { @@ -33,13 +33,12 @@ func NewPromptHandler(cfg *config.Config) (*PromptHandler, error) { } repo, err := component.NewRepoComponent(cfg) if err != nil { - return nil, fmt.Errorf("error creating repo component:%w", err) + return nil, fmt.Errorf("failed to create repo component: %w", err) } - return &PromptHandler{ - pc: promptComp, - sc: sc, - repo: repo, + prompt: promptComp, + sensitive: sc, + repo: repo, }, nil } @@ -89,7 +88,7 @@ func (h *PromptHandler) Index(ctx *gin.Context) { return } - prompts, total, err := h.pc.IndexPromptRepo(ctx, filter, per, page) + prompts, total, err := h.prompt.IndexPromptRepo(ctx, filter, per, page) if err != nil { slog.Error("Failed to get prompts dataset", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -125,7 +124,7 @@ func (h *PromptHandler) ListPrompt(ctx *gin.Context) { return } - detail, err := h.pc.Show(ctx, namespace, name, currentUser) + detail, err := h.prompt.Show(ctx, namespace, name, currentUser) if err != nil { if errors.Is(err, component.ErrUnauthorized) { httpbase.UnauthorizedError(ctx, err) @@ -141,7 +140,7 @@ func (h *PromptHandler) ListPrompt(ctx *gin.Context) { Name: name, CurrentUser: currentUser, } - data, err := h.pc.ListPrompt(ctx, req) + data, err := h.prompt.ListPrompt(ctx, req) if err != nil { slog.Error("Failed to list prompts of repo", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -190,7 +189,7 @@ func (h *PromptHandler) GetPrompt(ctx *gin.Context) { CurrentUser: currentUser, Path: convertFilePathFromRoute(filePath), } - data, err := h.pc.GetPrompt(ctx, req) + data, err := h.prompt.GetPrompt(ctx, req) if err != nil { slog.Error("Failed to get prompt of repo", slog.Any("req", req), slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -232,7 +231,7 @@ func (h *PromptHandler) CreatePrompt(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - _, err = h.sc.CheckRequestV2(ctx, body) + _, err = h.sensitive.CheckRequestV2(ctx, body) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) @@ -244,8 +243,7 @@ func (h *PromptHandler) CreatePrompt(ctx *gin.Context) { Name: name, CurrentUser: currentUser, } - - data, err := h.pc.CreatePrompt(ctx, req, body) + data, err := h.prompt.CreatePrompt(ctx, req, body) if err != nil { slog.Error("Failed to create prompt file of repo", slog.Any("req", req), slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -294,7 +292,7 @@ func (h *PromptHandler) UpdatePrompt(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - _, err = h.sc.CheckRequestV2(ctx, body) + _, err = h.sensitive.CheckRequestV2(ctx, body) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) @@ -307,7 +305,7 @@ func (h *PromptHandler) UpdatePrompt(ctx *gin.Context) { CurrentUser: currentUser, Path: convertFilePathFromRoute(filePath), } - data, err := h.pc.UpdatePrompt(ctx, req, body) + data, err := h.prompt.UpdatePrompt(ctx, req, body) if err != nil { slog.Error("Failed to update prompt file of repo", slog.Any("req", req), slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -357,7 +355,7 @@ func (h *PromptHandler) DeletePrompt(ctx *gin.Context) { CurrentUser: currentUser, Path: convertFilePathFromRoute(filePath), } - err = h.pc.DeletePrompt(ctx, req) + err = h.prompt.DeletePrompt(ctx, req) if err != nil { slog.Error("Failed to remove prompt file of repo", slog.Any("req", req), slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -387,7 +385,7 @@ func (h *PromptHandler) Relations(ctx *gin.Context) { return } currentUser := httpbase.GetCurrentUser(ctx) - detail, err := h.pc.Relations(ctx, namespace, name, currentUser) + detail, err := h.prompt.Relations(ctx, namespace, name, currentUser) if err != nil { if errors.Is(err, component.ErrUnauthorized) { httpbase.UnauthorizedError(ctx, err) @@ -439,7 +437,7 @@ func (h *PromptHandler) SetRelations(ctx *gin.Context) { req.Name = name req.CurrentUser = currentUser - err = h.pc.SetRelationModels(ctx, req) + err = h.prompt.SetRelationModels(ctx, req) if err != nil { slog.Error("Failed to set models for prompt", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -486,7 +484,7 @@ func (h *PromptHandler) AddModelRelation(ctx *gin.Context) { req.Name = name req.CurrentUser = currentUser - err = h.pc.AddRelationModel(ctx, req) + err = h.prompt.AddRelationModel(ctx, req) if err != nil { slog.Error("Failed to add model for prompt", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -533,7 +531,7 @@ func (h *PromptHandler) DelModelRelation(ctx *gin.Context) { req.Name = name req.CurrentUser = currentUser - err = h.pc.DelRelationModel(ctx, req) + err = h.prompt.DelRelationModel(ctx, req) if err != nil { slog.Error("Failed to delete dataset for model", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -567,7 +565,7 @@ func (h *PromptHandler) Create(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - _, err := h.sc.CheckRequestV2(ctx, req) + _, err := h.sensitive.CheckRequestV2(ctx, req) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) @@ -575,7 +573,7 @@ func (h *PromptHandler) Create(ctx *gin.Context) { } req.Username = currentUser - prompt, err := h.pc.CreatePromptRepo(ctx, req) + prompt, err := h.prompt.CreatePromptRepo(ctx, req) if err != nil { slog.Error("Failed to create prompt repo", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -616,7 +614,7 @@ func (h *PromptHandler) Update(ctx *gin.Context) { return } - _, err := h.sc.CheckRequestV2(ctx, req) + _, err := h.sensitive.CheckRequestV2(ctx, req) if err != nil { slog.Error("failed to check sensitive request", slog.Any("error", err)) httpbase.BadRequest(ctx, fmt.Errorf("sensitive check failed: %w", err).Error()) @@ -633,7 +631,7 @@ func (h *PromptHandler) Update(ctx *gin.Context) { req.Namespace = namespace req.Name = name - prompt, err := h.pc.UpdatePromptRepo(ctx, req) + prompt, err := h.prompt.UpdatePromptRepo(ctx, req) if err != nil { slog.Error("Failed to update prompt repo", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -669,7 +667,7 @@ func (h *PromptHandler) Delete(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - err = h.pc.RemoveRepo(ctx, namespace, name, currentUser) + err = h.prompt.RemoveRepo(ctx, namespace, name, currentUser) if err != nil { slog.Error("Failed to delete prompt repo", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -778,7 +776,7 @@ func (h *PromptHandler) Tags(ctx *gin.Context) { func (h *PromptHandler) UpdateTags(ctx *gin.Context) { currentUser := httpbase.GetCurrentUser(ctx) if currentUser == "" { - httpbase.UnauthorizedError(ctx, httpbase.ErrorNeedLogin) + httpbase.UnauthorizedError(ctx, component.ErrUserNotFound) return } namespace, name, err := common.GetNamespaceAndNameFromContext(ctx) diff --git a/api/handler/prompt_test.go b/api/handler/prompt_test.go new file mode 100644 index 00000000..3f7ad019 --- /dev/null +++ b/api/handler/prompt_test.go @@ -0,0 +1,338 @@ +package handler + +import ( + "fmt" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + mock_component "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/types" +) + +type PromptTester struct { + *GinTester + handler *PromptHandler + mocks struct { + prompt *mock_component.MockPromptComponent + sensitive *mock_component.MockSensitiveComponent + repo *mock_component.MockRepoComponent + } +} + +func NewPromptTester(t *testing.T) *PromptTester { + tester := &PromptTester{GinTester: NewGinTester()} + tester.mocks.prompt = mock_component.NewMockPromptComponent(t) + tester.mocks.sensitive = mock_component.NewMockSensitiveComponent(t) + tester.mocks.repo = mock_component.NewMockRepoComponent(t) + tester.handler = &PromptHandler{ + prompt: tester.mocks.prompt, sensitive: tester.mocks.sensitive, + repo: tester.mocks.repo, + } + tester.WithParam("name", "r") + tester.WithParam("namespace", "u") + return tester + +} + +func (t *PromptTester) WithHandleFunc(fn func(h *PromptHandler) gin.HandlerFunc) *PromptTester { + t.ginHandler = fn(t.handler) + return t + +} + +func TestPromptHandler_Index(t *testing.T) { + cases := []struct { + sort string + source string + error bool + }{ + {"most_download", "local", false}, + {"foo", "local", true}, + {"most_download", "bar", true}, + } + + for _, c := range cases { + t.Run(fmt.Sprintf("%+v", c), func(t *testing.T) { + + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.Index + }) + + if !c.error { + tester.mocks.prompt.EXPECT().IndexPromptRepo(tester.ctx, &types.RepoFilter{ + Search: "foo", + Sort: c.sort, + Source: c.source, + }, 10, 1).Return([]types.PromptRes{ + {Name: "cc"}, + }, 100, nil) + } + + tester.AddPagination(1, 10).WithQuery("search", "foo"). + WithQuery("sort", c.sort). + WithQuery("source", c.source).Execute() + + if c.error { + require.Equal(t, 400, tester.response.Code) + } else { + tester.ResponseEqSimple(t, 200, gin.H{ + "data": []types.PromptRes{{Name: "cc"}}, + "total": 100, + }) + } + }) + } +} + +func TestPromptHandler_ListPrompt(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.ListPrompt + }) + + tester.WithUser() + tester.mocks.prompt.EXPECT().Show(tester.ctx, "u", "r", "u").Return(&types.PromptRes{Name: "p"}, nil) + tester.mocks.prompt.EXPECT().ListPrompt(tester.ctx, types.PromptReq{ + Namespace: "u", Name: "r", CurrentUser: "u", + }).Return([]types.PromptOutput{{FilePath: "fp"}}, nil) + tester.Execute() + + tester.ResponseEq(t, 200, tester.OKText, gin.H{ + "detail": &types.PromptRes{Name: "p"}, + "prompts": []types.PromptOutput{{FilePath: "fp"}}, + }) +} + +func TestPromptHandler_GetPrompt(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.GetPrompt + }) + + tester.WithUser().WithParam("file_path", "fp") + tester.mocks.prompt.EXPECT().GetPrompt(tester.ctx, types.PromptReq{ + Namespace: "u", Name: "r", CurrentUser: "u", Path: "fp", + }).Return(&types.PromptOutput{FilePath: "fp"}, nil) + tester.Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.PromptOutput{FilePath: "fp"}) +} + +func TestPromptHandler_CreatePrompt(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.CreatePrompt + }) + tester.RequireUser(t) + + req := &types.CreatePromptReq{Prompt: types.Prompt{ + Title: "t", Content: "c", Language: "l", + }} + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, req).Return(true, nil) + tester.mocks.prompt.EXPECT().CreatePrompt(tester.ctx, types.PromptReq{ + Namespace: "u", Name: "r", CurrentUser: "u", + }, req).Return(&types.Prompt{Title: "p"}, nil) + tester.WithBody(t, req).Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.Prompt{Title: "p"}) +} + +func TestPromptHandler_UpdatePrompt(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.UpdatePrompt + }) + tester.RequireUser(t) + + req := &types.UpdatePromptReq{Prompt: types.Prompt{ + Title: "t", Content: "c", Language: "l", + }} + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, req).Return(true, nil) + tester.mocks.prompt.EXPECT().UpdatePrompt(tester.ctx, types.PromptReq{ + Namespace: "u", Name: "r", CurrentUser: "u", Path: "fp", + }, req).Return(&types.Prompt{Title: "p"}, nil) + tester.WithParam("file_path", "fp").WithBody(t, req).Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.Prompt{Title: "p"}) +} + +func TestPromptHandler_DeletePrompt(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.DeletePrompt + }) + tester.RequireUser(t) + + tester.WithUser().WithParam("file_path", "fp") + tester.mocks.prompt.EXPECT().DeletePrompt(tester.ctx, types.PromptReq{ + Namespace: "u", Name: "r", CurrentUser: "u", Path: "fp", + }).Return(nil) + tester.Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestPromptHandler_Relations(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.Relations + }) + + tester.WithUser() + tester.mocks.prompt.EXPECT().Relations(tester.ctx, "u", "r", "u").Return(&types.Relations{}, nil) + tester.Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.Relations{}) +} + +func TestPromptHandler_SetRelations(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.SetRelations + }) + tester.RequireUser(t) + + req := types.RelationModels{Namespace: "u", Name: "r", CurrentUser: "u"} + tester.mocks.prompt.EXPECT().SetRelationModels(tester.ctx, req).Return(nil) + tester.WithBody(t, types.RelationModels{Name: "rm"}).Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestPromptHandler_AddModelRelation(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.AddModelRelation + }) + tester.RequireUser(t) + + req := types.RelationModel{Namespace: "u", Name: "r", CurrentUser: "u"} + tester.mocks.prompt.EXPECT().AddRelationModel(tester.ctx, req).Return(nil) + tester.WithBody(t, types.RelationModels{Name: "rm"}).Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestPromptHandler_DeleteModelRelation(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.AddModelRelation + }) + tester.RequireUser(t) + + req := types.RelationModel{Namespace: "u", Name: "r", CurrentUser: "u"} + tester.mocks.prompt.EXPECT().AddRelationModel(tester.ctx, req).Return(nil) + tester.WithBody(t, types.RelationModels{Name: "rm"}).Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestPromptHandler_CreatePromptRepo(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.Create + }) + tester.RequireUser(t) + + req := &types.CreatePromptRepoReq{CreateRepoReq: types.CreateRepoReq{}} + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, req).Return(true, nil) + reqn := *req + reqn.Username = "u" + tester.mocks.prompt.EXPECT().CreatePromptRepo(tester.ctx, &reqn).Return( + &types.PromptRes{Name: "p"}, nil, + ) + tester.WithBody(t, req).Execute() + + tester.ResponseEqSimple(t, 200, gin.H{ + "data": &types.PromptRes{Name: "p"}, + }) +} + +func TestPromptHandler_UpdatePromptRepo(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.Update + }) + tester.RequireUser(t) + + req := &types.UpdatePromptRepoReq{UpdateRepoReq: types.UpdateRepoReq{}} + tester.mocks.sensitive.EXPECT().CheckRequestV2(tester.ctx, req).Return(true, nil) + reqn := *req + reqn.Namespace = "u" + reqn.Name = "r" + reqn.Username = "u" + tester.mocks.prompt.EXPECT().UpdatePromptRepo(tester.ctx, &reqn).Return( + &types.PromptRes{Name: "p"}, nil, + ) + tester.WithBody(t, req).Execute() + + tester.ResponseEq(t, 200, tester.OKText, &types.PromptRes{Name: "p"}) +} + +func TestPromptHandler_DeletePromptRepo(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.Delete + }) + tester.RequireUser(t) + + tester.mocks.prompt.EXPECT().RemoveRepo(tester.ctx, "u", "r", "u").Return(nil) + tester.Execute() + + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestPromptHandler_Branches(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.Branches + }) + + tester.mocks.repo.EXPECT().Branches(tester.ctx, &types.GetBranchesReq{ + Namespace: "u", + Name: "r", + Per: 10, + Page: 1, + RepoType: types.PromptRepo, + CurrentUser: "u", + }).Return([]types.Branch{{Name: "main"}}, nil) + tester.WithUser().AddPagination(1, 10).Execute() + + tester.ResponseEq(t, 200, tester.OKText, []types.Branch{{Name: "main"}}) +} + +func TestPromptHandler_Tags(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.Tags + }) + + tester.mocks.repo.EXPECT().Tags(tester.ctx, &types.GetTagsReq{ + Namespace: "u", + Name: "r", + RepoType: types.PromptRepo, + CurrentUser: "u", + }).Return([]database.Tag{{Name: "main"}}, nil) + tester.WithUser().AddPagination(1, 10).Execute() + + tester.ResponseEq(t, 200, tester.OKText, []database.Tag{{Name: "main"}}) +} + +func TestPromptHandler_UpdateTags(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.UpdateTags + }) + tester.RequireUser(t) + + req := []string{"a", "b"} + tester.mocks.repo.EXPECT().UpdateTags(tester.ctx, "u", "r", types.PromptRepo, "cat", "u", req).Return(nil) + tester.WithBody(t, req).WithParam("category", "cat").Execute() + tester.ResponseEq(t, 200, tester.OKText, nil) +} + +func TestPromptHandler_UpdateDownloads(t *testing.T) { + tester := NewPromptTester(t).WithHandleFunc(func(h *PromptHandler) gin.HandlerFunc { + return h.UpdateDownloads + }) + + tester.mocks.repo.EXPECT().UpdateDownloads(tester.ctx, &types.UpdateDownloadsReq{ + Namespace: "u", + Name: "r", + RepoType: types.PromptRepo, + Date: time.Date(2012, 12, 12, 0, 0, 0, 0, time.UTC), + ReqDate: "2012-12-12", + }).Return(nil) + tester.WithUser().WithBody(t, &types.UpdateDownloadsReq{ + ReqDate: time.Date(2012, 12, 12, 0, 0, 0, 0, time.UTC).Format("2006-01-02"), + }).WithParam("category", "cat").Execute() + tester.ResponseEq(t, 200, tester.OKText, nil) +}