Skip to content

Commit

Permalink
fix download model/dataset with specified branch (#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanHH86 authored Oct 11, 2024
1 parent ea76c6d commit 2b7947c
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 13 deletions.
1 change: 1 addition & 0 deletions api/handler/dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ func (h *DatasetHandler) AllFiles(ctx *gin.Context) {
req.Name = name
req.RepoType = types.DatasetRepo
req.CurrentUser = httpbase.GetCurrentUser(ctx)
req.Ref = ""
detail, err := h.c.AllFiles(ctx, req)
if err != nil {
if errors.Is(err, component.ErrUnauthorized) {
Expand Down
4 changes: 2 additions & 2 deletions api/handler/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -725,8 +725,8 @@ func (h *RepoHandler) SDKListFiles(ctx *gin.Context) {
httpbase.BadRequest(ctx, err.Error())
return
}

files, err := h.c.SDKListFiles(ctx, common.RepoTypeFromContext(ctx), namespace, name, currentUser)
ref := ctx.Param("ref")
files, err := h.c.SDKListFiles(ctx, common.RepoTypeFromContext(ctx), namespace, name, ref, currentUser)
if err != nil {
if errors.Is(err, component.ErrUnauthorized) {
slog.Error("permission denied when accessing repo", slog.String("repo_type", string(common.RepoTypeFromContext(ctx))), slog.Any("path", fmt.Sprintf("%s/%s", namespace, name)))
Expand Down
1 change: 1 addition & 0 deletions common/types/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ type GetAllFilesReq struct {
Name string `json:"name"`
RepoType RepositoryType `json:"repo_type"`
CurrentUser string `json:"current_user"`
Ref string `json:"ref"`
}

type LFSPointer struct {
Expand Down
8 changes: 4 additions & 4 deletions component/mirror.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,13 +379,13 @@ func (c *MirrorComponent) checkAndUpdateMirrorStatus(ctx context.Context, mirror
return nil
}

func getAllFiles(namespace, repoName, folder string, repoType types.RepositoryType, gsTree func(ctx context.Context, req gitserver.GetRepoInfoByPathReq) ([]*types.File, error)) ([]*types.File, error) {
func getAllFiles(namespace, repoName, folder string, repoType types.RepositoryType, ref string, gsTree func(ctx context.Context, req gitserver.GetRepoInfoByPathReq) ([]*types.File, error)) ([]*types.File, error) {
var files []*types.File

getRepoFileTree := gitserver.GetRepoInfoByPathReq{
Namespace: namespace,
Name: repoName,
Ref: "",
Ref: ref,
Path: folder,
RepoType: repoType,
}
Expand All @@ -395,7 +395,7 @@ func getAllFiles(namespace, repoName, folder string, repoType types.RepositoryTy
}
for _, file := range gitFiles {
if file.Type == "dir" {
subFiles, err := getAllFiles(namespace, repoName, file.Path, repoType, gsTree)
subFiles, err := getAllFiles(namespace, repoName, file.Path, repoType, ref, gsTree)
if err != nil {
return files, err
}
Expand Down Expand Up @@ -549,7 +549,7 @@ func (c *MirrorComponent) countMirrorProgress(ctx context.Context, mirror databa
namespaceAndName := strings.Split(mirror.Repository.Path, "/")
namespace := namespaceAndName[0]
name := namespaceAndName[1]
allFiles, err := getAllFiles(namespace, name, "", mirror.Repository.RepositoryType, c.git.GetRepoFileTree)
allFiles, err := getAllFiles(namespace, name, "", mirror.Repository.RepositoryType, "", c.git.GetRepoFileTree)

if err != nil {
slog.Error("fail to get all files of mirror repository", slog.Int64("mirrorId", mirror.ID), slog.String("namespace", namespace), slog.String("name", name), slog.String("error", err.Error()))
Expand Down
6 changes: 3 additions & 3 deletions component/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ func (c *ModelComponent) SDKModelInfo(ctx context.Context, namespace, name, ref,
}
}

filePaths, err := getFilePaths(namespace, name, "", types.ModelRepo, c.git.GetRepoFileTree)
filePaths, err := getFilePaths(namespace, name, "", types.ModelRepo, ref, c.git.GetRepoFileTree)
if err != nil {
return nil, fmt.Errorf("failed to get all %s files, error: %w", types.ModelRepo, err)
}
Expand Down Expand Up @@ -612,9 +612,9 @@ func (c *ModelComponent) getRelations(ctx context.Context, fromRepoID int64, cur
return rels, nil
}

func getFilePaths(namespace, repoName, folder string, repoType types.RepositoryType, gsTree func(ctx context.Context, req gitserver.GetRepoInfoByPathReq) ([]*types.File, error)) ([]string, error) {
func getFilePaths(namespace, repoName, folder string, repoType types.RepositoryType, ref string, gsTree func(ctx context.Context, req gitserver.GetRepoInfoByPathReq) ([]*types.File, error)) ([]string, error) {
var filePaths []string
allFiles, err := getAllFiles(namespace, repoName, folder, repoType, gsTree)
allFiles, err := getAllFiles(namespace, repoName, folder, repoType, ref, gsTree)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion component/recom.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func (rc *RecomComponent) calcQualityScore(ctx context.Context, repo *database.R
score := 0.0
// get file counts from git server
namespace, name := repo.NamespaceAndName()
files, err := getFilePaths(namespace, name, "", repo.RepositoryType, rc.gs.GetRepoFileTree)
files, err := getFilePaths(namespace, name, "", repo.RepositoryType, "", rc.gs.GetRepoFileTree)
if err != nil {
return 0, fmt.Errorf("failed to get repo file tree,%w", err)
}
Expand Down
10 changes: 7 additions & 3 deletions component/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ func (c *RepoComponent) UploadFile(ctx context.Context, req *types.CreateFileReq
return err
}

func (c *RepoComponent) SDKListFiles(ctx context.Context, repoType types.RepositoryType, namespace, name, userName string) (*types.SDKFiles, error) {
func (c *RepoComponent) SDKListFiles(ctx context.Context, repoType types.RepositoryType, namespace, name, ref, userName string) (*types.SDKFiles, error) {
var sdkFiles []types.SDKFile
repo, err := c.repo.FindByPath(ctx, repoType, namespace, name)
if err != nil || repo == nil {
Expand All @@ -1008,7 +1008,11 @@ func (c *RepoComponent) SDKListFiles(ctx context.Context, repoType types.Reposit
return nil, ErrUnauthorized
}

filePaths, err := getFilePaths(namespace, name, "", repoType, c.git.GetRepoFileTree)
if ref == "" {
ref = repo.DefaultBranch
}

filePaths, err := getFilePaths(namespace, name, "", repoType, ref, c.git.GetRepoFileTree)
if err != nil {
return nil, fmt.Errorf("failed to get all %s files, error: %w", repoType, err)
}
Expand Down Expand Up @@ -2424,7 +2428,7 @@ func (c *RepoComponent) AllFiles(ctx context.Context, req types.GetAllFilesReq)
return nil, fmt.Errorf("users do not have permission to get all files for this repo")
}
}
allFiles, err := getAllFiles(req.Namespace, req.Name, "", req.RepoType, c.git.GetRepoFileTree)
allFiles, err := getAllFiles(req.Namespace, req.Name, "", req.RepoType, req.Ref, c.git.GetRepoFileTree)
if err != nil {
slog.Error("fail to get all files of repository", slog.Any("repoType", req.RepoType), slog.String("namespace", req.Namespace), slog.String("name", req.Name), slog.String("error", err.Error()))
return nil, err
Expand Down

0 comments on commit 2b7947c

Please sign in to comment.