From 2b7947c9ab79fb29e1c8c28ca3483bddb2f82c1b Mon Sep 17 00:00:00 2001 From: SeanHH86 <154984842+SeanHH86@users.noreply.github.com> Date: Fri, 11 Oct 2024 14:22:57 +0800 Subject: [PATCH] fix download model/dataset with specified branch (#136) --- api/handler/dataset.go | 1 + api/handler/repo.go | 4 ++-- common/types/file.go | 1 + component/mirror.go | 8 ++++---- component/model.go | 6 +++--- component/recom.go | 2 +- component/repo.go | 10 +++++++--- 7 files changed, 19 insertions(+), 13 deletions(-) diff --git a/api/handler/dataset.go b/api/handler/dataset.go index 9e856811..243fdcf5 100644 --- a/api/handler/dataset.go +++ b/api/handler/dataset.go @@ -339,6 +339,7 @@ func (h *DatasetHandler) AllFiles(ctx *gin.Context) { req.Name = name req.RepoType = types.DatasetRepo req.CurrentUser = httpbase.GetCurrentUser(ctx) + req.Ref = "" detail, err := h.c.AllFiles(ctx, req) if err != nil { if errors.Is(err, component.ErrUnauthorized) { diff --git a/api/handler/repo.go b/api/handler/repo.go index daf61d4d..6c967d78 100644 --- a/api/handler/repo.go +++ b/api/handler/repo.go @@ -725,8 +725,8 @@ func (h *RepoHandler) SDKListFiles(ctx *gin.Context) { httpbase.BadRequest(ctx, err.Error()) return } - - files, err := h.c.SDKListFiles(ctx, common.RepoTypeFromContext(ctx), namespace, name, currentUser) + ref := ctx.Param("ref") + files, err := h.c.SDKListFiles(ctx, common.RepoTypeFromContext(ctx), namespace, name, ref, currentUser) if err != nil { if errors.Is(err, component.ErrUnauthorized) { slog.Error("permission denied when accessing repo", slog.String("repo_type", string(common.RepoTypeFromContext(ctx))), slog.Any("path", fmt.Sprintf("%s/%s", namespace, name))) diff --git a/common/types/file.go b/common/types/file.go index 41caf4e7..336554ce 100644 --- a/common/types/file.go +++ b/common/types/file.go @@ -133,6 +133,7 @@ type GetAllFilesReq struct { Name string `json:"name"` RepoType RepositoryType `json:"repo_type"` CurrentUser string `json:"current_user"` + Ref string `json:"ref"` } type LFSPointer struct { diff --git a/component/mirror.go b/component/mirror.go index 5e613ae3..2ab6066d 100644 --- a/component/mirror.go +++ b/component/mirror.go @@ -379,13 +379,13 @@ func (c *MirrorComponent) checkAndUpdateMirrorStatus(ctx context.Context, mirror return nil } -func getAllFiles(namespace, repoName, folder string, repoType types.RepositoryType, gsTree func(ctx context.Context, req gitserver.GetRepoInfoByPathReq) ([]*types.File, error)) ([]*types.File, error) { +func getAllFiles(namespace, repoName, folder string, repoType types.RepositoryType, ref string, gsTree func(ctx context.Context, req gitserver.GetRepoInfoByPathReq) ([]*types.File, error)) ([]*types.File, error) { var files []*types.File getRepoFileTree := gitserver.GetRepoInfoByPathReq{ Namespace: namespace, Name: repoName, - Ref: "", + Ref: ref, Path: folder, RepoType: repoType, } @@ -395,7 +395,7 @@ func getAllFiles(namespace, repoName, folder string, repoType types.RepositoryTy } for _, file := range gitFiles { if file.Type == "dir" { - subFiles, err := getAllFiles(namespace, repoName, file.Path, repoType, gsTree) + subFiles, err := getAllFiles(namespace, repoName, file.Path, repoType, ref, gsTree) if err != nil { return files, err } @@ -549,7 +549,7 @@ func (c *MirrorComponent) countMirrorProgress(ctx context.Context, mirror databa namespaceAndName := strings.Split(mirror.Repository.Path, "/") namespace := namespaceAndName[0] name := namespaceAndName[1] - allFiles, err := getAllFiles(namespace, name, "", mirror.Repository.RepositoryType, c.git.GetRepoFileTree) + allFiles, err := getAllFiles(namespace, name, "", mirror.Repository.RepositoryType, "", c.git.GetRepoFileTree) if err != nil { slog.Error("fail to get all files of mirror repository", slog.Int64("mirrorId", mirror.ID), slog.String("namespace", namespace), slog.String("name", name), slog.String("error", err.Error())) diff --git a/component/model.go b/component/model.go index 35bf83ba..ccc15ca7 100644 --- a/component/model.go +++ b/component/model.go @@ -494,7 +494,7 @@ func (c *ModelComponent) SDKModelInfo(ctx context.Context, namespace, name, ref, } } - filePaths, err := getFilePaths(namespace, name, "", types.ModelRepo, c.git.GetRepoFileTree) + filePaths, err := getFilePaths(namespace, name, "", types.ModelRepo, ref, c.git.GetRepoFileTree) if err != nil { return nil, fmt.Errorf("failed to get all %s files, error: %w", types.ModelRepo, err) } @@ -612,9 +612,9 @@ func (c *ModelComponent) getRelations(ctx context.Context, fromRepoID int64, cur return rels, nil } -func getFilePaths(namespace, repoName, folder string, repoType types.RepositoryType, gsTree func(ctx context.Context, req gitserver.GetRepoInfoByPathReq) ([]*types.File, error)) ([]string, error) { +func getFilePaths(namespace, repoName, folder string, repoType types.RepositoryType, ref string, gsTree func(ctx context.Context, req gitserver.GetRepoInfoByPathReq) ([]*types.File, error)) ([]string, error) { var filePaths []string - allFiles, err := getAllFiles(namespace, repoName, folder, repoType, gsTree) + allFiles, err := getAllFiles(namespace, repoName, folder, repoType, ref, gsTree) if err != nil { return nil, err } diff --git a/component/recom.go b/component/recom.go index 255d1179..72b4f0fc 100644 --- a/component/recom.go +++ b/component/recom.go @@ -125,7 +125,7 @@ func (rc *RecomComponent) calcQualityScore(ctx context.Context, repo *database.R score := 0.0 // get file counts from git server namespace, name := repo.NamespaceAndName() - files, err := getFilePaths(namespace, name, "", repo.RepositoryType, rc.gs.GetRepoFileTree) + files, err := getFilePaths(namespace, name, "", repo.RepositoryType, "", rc.gs.GetRepoFileTree) if err != nil { return 0, fmt.Errorf("failed to get repo file tree,%w", err) } diff --git a/component/repo.go b/component/repo.go index 754148c9..b8c33e1c 100644 --- a/component/repo.go +++ b/component/repo.go @@ -993,7 +993,7 @@ func (c *RepoComponent) UploadFile(ctx context.Context, req *types.CreateFileReq return err } -func (c *RepoComponent) SDKListFiles(ctx context.Context, repoType types.RepositoryType, namespace, name, userName string) (*types.SDKFiles, error) { +func (c *RepoComponent) SDKListFiles(ctx context.Context, repoType types.RepositoryType, namespace, name, ref, userName string) (*types.SDKFiles, error) { var sdkFiles []types.SDKFile repo, err := c.repo.FindByPath(ctx, repoType, namespace, name) if err != nil || repo == nil { @@ -1008,7 +1008,11 @@ func (c *RepoComponent) SDKListFiles(ctx context.Context, repoType types.Reposit return nil, ErrUnauthorized } - filePaths, err := getFilePaths(namespace, name, "", repoType, c.git.GetRepoFileTree) + if ref == "" { + ref = repo.DefaultBranch + } + + filePaths, err := getFilePaths(namespace, name, "", repoType, ref, c.git.GetRepoFileTree) if err != nil { return nil, fmt.Errorf("failed to get all %s files, error: %w", repoType, err) } @@ -2424,7 +2428,7 @@ func (c *RepoComponent) AllFiles(ctx context.Context, req types.GetAllFilesReq) return nil, fmt.Errorf("users do not have permission to get all files for this repo") } } - allFiles, err := getAllFiles(req.Namespace, req.Name, "", req.RepoType, c.git.GetRepoFileTree) + allFiles, err := getAllFiles(req.Namespace, req.Name, "", req.RepoType, req.Ref, c.git.GetRepoFileTree) if err != nil { slog.Error("fail to get all files of repository", slog.Any("repoType", req.RepoType), slog.String("namespace", req.Namespace), slog.String("name", req.Name), slog.String("error", err.Error())) return nil, err