Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix download model/dataset with specified branch #136

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/handler/dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ func (h *DatasetHandler) AllFiles(ctx *gin.Context) {
req.Name = name
req.RepoType = types.DatasetRepo
req.CurrentUser = httpbase.GetCurrentUser(ctx)
req.Ref = ""
detail, err := h.c.AllFiles(ctx, req)
if err != nil {
if errors.Is(err, component.ErrUnauthorized) {
Expand Down
4 changes: 2 additions & 2 deletions api/handler/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -725,8 +725,8 @@ func (h *RepoHandler) SDKListFiles(ctx *gin.Context) {
httpbase.BadRequest(ctx, err.Error())
return
}

files, err := h.c.SDKListFiles(ctx, common.RepoTypeFromContext(ctx), namespace, name, currentUser)
ref := ctx.Param("ref")
files, err := h.c.SDKListFiles(ctx, common.RepoTypeFromContext(ctx), namespace, name, ref, currentUser)
if err != nil {
if errors.Is(err, component.ErrUnauthorized) {
slog.Error("permission denied when accessing repo", slog.String("repo_type", string(common.RepoTypeFromContext(ctx))), slog.Any("path", fmt.Sprintf("%s/%s", namespace, name)))
Expand Down
1 change: 1 addition & 0 deletions common/types/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ type GetAllFilesReq struct {
Name string `json:"name"`
RepoType RepositoryType `json:"repo_type"`
CurrentUser string `json:"current_user"`
Ref string `json:"ref"`
}

type LFSPointer struct {
Expand Down
8 changes: 4 additions & 4 deletions component/mirror.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,13 +379,13 @@ func (c *MirrorComponent) checkAndUpdateMirrorStatus(ctx context.Context, mirror
return nil
}

func getAllFiles(namespace, repoName, folder string, repoType types.RepositoryType, gsTree func(ctx context.Context, req gitserver.GetRepoInfoByPathReq) ([]*types.File, error)) ([]*types.File, error) {
func getAllFiles(namespace, repoName, folder string, repoType types.RepositoryType, ref string, gsTree func(ctx context.Context, req gitserver.GetRepoInfoByPathReq) ([]*types.File, error)) ([]*types.File, error) {
var files []*types.File

getRepoFileTree := gitserver.GetRepoInfoByPathReq{
Namespace: namespace,
Name: repoName,
Ref: "",
Ref: ref,
Path: folder,
RepoType: repoType,
}
Expand All @@ -395,7 +395,7 @@ func getAllFiles(namespace, repoName, folder string, repoType types.RepositoryTy
}
for _, file := range gitFiles {
if file.Type == "dir" {
subFiles, err := getAllFiles(namespace, repoName, file.Path, repoType, gsTree)
subFiles, err := getAllFiles(namespace, repoName, file.Path, repoType, ref, gsTree)
if err != nil {
return files, err
}
Expand Down Expand Up @@ -549,7 +549,7 @@ func (c *MirrorComponent) countMirrorProgress(ctx context.Context, mirror databa
namespaceAndName := strings.Split(mirror.Repository.Path, "/")
namespace := namespaceAndName[0]
name := namespaceAndName[1]
allFiles, err := getAllFiles(namespace, name, "", mirror.Repository.RepositoryType, c.git.GetRepoFileTree)
allFiles, err := getAllFiles(namespace, name, "", mirror.Repository.RepositoryType, "", c.git.GetRepoFileTree)

if err != nil {
slog.Error("fail to get all files of mirror repository", slog.Int64("mirrorId", mirror.ID), slog.String("namespace", namespace), slog.String("name", name), slog.String("error", err.Error()))
Expand Down
6 changes: 3 additions & 3 deletions component/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ func (c *ModelComponent) SDKModelInfo(ctx context.Context, namespace, name, ref,
}
}

filePaths, err := getFilePaths(namespace, name, "", types.ModelRepo, c.git.GetRepoFileTree)
filePaths, err := getFilePaths(namespace, name, "", types.ModelRepo, ref, c.git.GetRepoFileTree)
if err != nil {
return nil, fmt.Errorf("failed to get all %s files, error: %w", types.ModelRepo, err)
}
Expand Down Expand Up @@ -612,9 +612,9 @@ func (c *ModelComponent) getRelations(ctx context.Context, fromRepoID int64, cur
return rels, nil
}

func getFilePaths(namespace, repoName, folder string, repoType types.RepositoryType, gsTree func(ctx context.Context, req gitserver.GetRepoInfoByPathReq) ([]*types.File, error)) ([]string, error) {
func getFilePaths(namespace, repoName, folder string, repoType types.RepositoryType, ref string, gsTree func(ctx context.Context, req gitserver.GetRepoInfoByPathReq) ([]*types.File, error)) ([]string, error) {
var filePaths []string
allFiles, err := getAllFiles(namespace, repoName, folder, repoType, gsTree)
allFiles, err := getAllFiles(namespace, repoName, folder, repoType, ref, gsTree)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion component/recom.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func (rc *RecomComponent) calcQualityScore(ctx context.Context, repo *database.R
score := 0.0
// get file counts from git server
namespace, name := repo.NamespaceAndName()
files, err := getFilePaths(namespace, name, "", repo.RepositoryType, rc.gs.GetRepoFileTree)
files, err := getFilePaths(namespace, name, "", repo.RepositoryType, "", rc.gs.GetRepoFileTree)
if err != nil {
return 0, fmt.Errorf("failed to get repo file tree,%w", err)
}
Expand Down
10 changes: 7 additions & 3 deletions component/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ func (c *RepoComponent) UploadFile(ctx context.Context, req *types.CreateFileReq
return err
}

func (c *RepoComponent) SDKListFiles(ctx context.Context, repoType types.RepositoryType, namespace, name, userName string) (*types.SDKFiles, error) {
func (c *RepoComponent) SDKListFiles(ctx context.Context, repoType types.RepositoryType, namespace, name, ref, userName string) (*types.SDKFiles, error) {
var sdkFiles []types.SDKFile
repo, err := c.repo.FindByPath(ctx, repoType, namespace, name)
if err != nil || repo == nil {
Expand All @@ -1008,7 +1008,11 @@ func (c *RepoComponent) SDKListFiles(ctx context.Context, repoType types.Reposit
return nil, ErrUnauthorized
}

filePaths, err := getFilePaths(namespace, name, "", repoType, c.git.GetRepoFileTree)
if ref == "" {
ref = repo.DefaultBranch
}

filePaths, err := getFilePaths(namespace, name, "", repoType, ref, c.git.GetRepoFileTree)
if err != nil {
return nil, fmt.Errorf("failed to get all %s files, error: %w", repoType, err)
}
Expand Down Expand Up @@ -2424,7 +2428,7 @@ func (c *RepoComponent) AllFiles(ctx context.Context, req types.GetAllFilesReq)
return nil, fmt.Errorf("users do not have permission to get all files for this repo")
}
}
allFiles, err := getAllFiles(req.Namespace, req.Name, "", req.RepoType, c.git.GetRepoFileTree)
allFiles, err := getAllFiles(req.Namespace, req.Name, "", req.RepoType, req.Ref, c.git.GetRepoFileTree)
if err != nil {
slog.Error("fail to get all files of repository", slog.Any("repoType", req.RepoType), slog.String("namespace", req.Namespace), slog.String("name", req.Name), slog.String("error", err.Error()))
return nil, err
Expand Down