diff --git a/builder/git/gitserver/gitaly/commit.go b/builder/git/gitserver/gitaly/commit.go index 672d0035..0af84b63 100644 --- a/builder/git/gitserver/gitaly/commit.go +++ b/builder/git/gitserver/gitaly/commit.go @@ -2,6 +2,7 @@ package gitaly import ( "context" + "errors" "fmt" "io" "math" @@ -142,7 +143,7 @@ func (c *Client) GetSingleCommit(ctx context.Context, req gitserver.GetRepoLastC if err != nil { return nil, err } - if commitResp != nil { + if commitResp != nil && commitResp.Commit != nil { commit = types.Commit{ ID: string(commitResp.Commit.Id), CommitterName: string(commitResp.Commit.Committer.Name), @@ -160,6 +161,8 @@ func (c *Client) GetSingleCommit(ctx context.Context, req gitserver.GetRepoLastC }) } + } else { + return nil, errors.New("commit not found") } result = types.CommitResponse{ Commit: &commit, diff --git a/cmd/csghub-server/cmd/start/start.go b/cmd/csghub-server/cmd/start/start.go index 628ef60a..39d4d85c 100644 --- a/cmd/csghub-server/cmd/start/start.go +++ b/cmd/csghub-server/cmd/start/start.go @@ -14,7 +14,6 @@ import ( func init() { Cmd.AddCommand(serverCmd) Cmd.AddCommand(rproxyCmd) - Cmd.AddCommand(syncServerCmd) } var Cmd = &cobra.Command{ diff --git a/cmd/csghub-server/cmd/start/syncserver.go b/cmd/csghub-server/cmd/start/syncserver.go deleted file mode 100644 index 4e422704..00000000 --- a/cmd/csghub-server/cmd/start/syncserver.go +++ /dev/null @@ -1,49 +0,0 @@ -package start - -import ( - "fmt" - - "github.com/spf13/cobra" - "opencsg.com/csghub-server/api/httpbase" - "opencsg.com/csghub-server/builder/store/database" - "opencsg.com/csghub-server/common/config" - "opencsg.com/csghub-server/multisync/router" -) - -var syncServerCmd = &cobra.Command{ - Use: "sync-server", - Short: "Start the multi source sync server", - Example: rproxyExample(), - RunE: func(*cobra.Command, []string) (err error) { - cfg, err := config.LoadConfig() - if err != nil { - return err - } - - dbConfig := database.DBConfig{ - Dialect: database.DatabaseDialect(cfg.Database.Driver), - DSN: cfg.Database.DSN, - } - database.InitDB(dbConfig) - r, err := router.NewRouter(cfg) - if err != nil { - return fmt.Errorf("failed to init router: %w", err) - } - server := httpbase.NewGracefulServer( - httpbase.GraceServerOpt{ - Port: cfg.Mirror.Port, - }, - r, - ) - server.Run() - - return nil - }, -} - -func syncServerExample() string { - return ` -# for development -csghub-server start sync-server -` -} diff --git a/component/tagparser/nameparser.go b/component/tagparser/nameparser.go index 0065fb6a..f6dd7061 100644 --- a/component/tagparser/nameparser.go +++ b/component/tagparser/nameparser.go @@ -36,14 +36,14 @@ func LibraryTag(filePath string) string { } func isPytorch(filename string) bool { - return strings.HasPrefix(filename, "pytorch_model") && strings.HasSuffix(filename, ".bin") + return (strings.HasPrefix(filename, "pytorch_model") && strings.HasSuffix(filename, ".bin")) || strings.HasSuffix(filename, ".pt") } func isTensorflow(filename string) bool { return strings.HasPrefix(filename, "tf_model") && strings.HasSuffix(filename, ".h5") } func isSafetensors(filename string) bool { - return strings.HasPrefix(filename, "model") && strings.HasSuffix(filename, ".safetensors") + return strings.HasSuffix(filename, ".safetensors") } func isJAX(filename string) bool { return strings.HasPrefix(filename, "flax_model") && strings.HasSuffix(filename, ".msgpack") diff --git a/component/tagparser/nameparser_test.go b/component/tagparser/nameparser_test.go index e0cd19df..fdeaea29 100644 --- a/component/tagparser/nameparser_test.go +++ b/component/tagparser/nameparser_test.go @@ -14,6 +14,7 @@ func TestLibraryTag(t *testing.T) { {name: "case insensitive", args: args{filePath: "Pytorch_model.Bin"}, want: "pytorch"}, {name: "pytorch", args: args{filePath: "pytorch_model.bin"}, want: "pytorch"}, + {name: "pytorch", args: args{filePath: "model.pt"}, want: "pytorch"}, {name: "pytorch", args: args{filePath: "pytorch_model_001.bin"}, want: "pytorch"}, {name: "not pytorch", args: args{filePath: "1-pytorch_model_001.bin"}, want: ""}, {name: "not pytorch", args: args{filePath: "pytorch_model-bin"}, want: ""}, @@ -25,8 +26,9 @@ func TestLibraryTag(t *testing.T) { {name: "safetensors", args: args{filePath: "model.safetensors"}, want: "safetensors"}, {name: "safetensors", args: args{filePath: "model_001.safetensors"}, want: "safetensors"}, - {name: "not safetensors", args: args{filePath: "1-model.safetensors"}, want: ""}, - {name: "not safetensors", args: args{filePath: "model-safetensors"}, want: ""}, + {name: "safetensors", args: args{filePath: "adpter_model.safetensors"}, want: "safetensors"}, + {name: "not safetensors", args: args{filePath: "1-test.safeten"}, want: ""}, + {name: "not safetensors", args: args{filePath: "test-safetensors"}, want: ""}, {name: "flax_model", args: args{filePath: "flax_model.msgpack"}, want: "jax"}, {name: "flax_model", args: args{filePath: "flax_model-001.msgpack"}, want: "jax"}, diff --git a/multisync/accounting/aync_quota_statement.go b/multisync/accounting/aync_quota_statement.go deleted file mode 100644 index 41386325..00000000 --- a/multisync/accounting/aync_quota_statement.go +++ /dev/null @@ -1,53 +0,0 @@ -package accounting - -import ( - "bytes" - "encoding/json" - "fmt" - "net/http" - "time" -) - -type SyncQuotaStatement struct { - ID int64 `json:"id"` - UserID int64 `json:"user_id"` - RepoPath string `json:"repo_path"` - RepoType string `json:"repo_type"` - CreatedAt time.Time `json:"created_at"` -} - -type SyncQuotaStatementRes struct { - Message string `json:"msg"` - Data SyncQuotaStatement `json:"data"` -} - -type GetSyncQuotaStatementsReq struct { - RepoPath string `json:"repo_path"` - RepoType string `json:"repo_type"` - AccessToken string `json:"-"` -} - -type CreateSyncQuotaStatementReq = GetSyncQuotaStatementsReq - -func (c *AccountingClient) CreateSyncQuotaStatement(opt *CreateSyncQuotaStatementReq) (*Response, error) { - header := http.Header{"content-type": []string{"application/json"}} - body, err := json.Marshal(&opt) - if err != nil { - return nil, err - } - if opt.AccessToken != "" { - header.Add("Authorization", "Bearer "+opt.AccessToken) - } - _, resp, err := c.getResponse("POST", "/accounting/multisync/downloads", header, bytes.NewReader(body)) - return resp, err -} - -func (c *AccountingClient) GetSyncQuotaStatement(opt *GetSyncQuotaStatementsReq) (*SyncQuotaStatement, *Response, error) { - s := new(SyncQuotaStatementRes) - header := http.Header{} - if opt.AccessToken != "" { - header.Add("Authorization", "Bearer "+opt.AccessToken) - } - resp, err := c.getParsedResponse("GET", fmt.Sprintf("/accounting/multisync/download?repo_path=%s&repo_type=%s", opt.RepoPath, opt.RepoType), header, nil, s) - return &s.Data, resp, err -} diff --git a/multisync/accounting/client.go b/multisync/accounting/client.go deleted file mode 100644 index db3db65d..00000000 --- a/multisync/accounting/client.go +++ /dev/null @@ -1,135 +0,0 @@ -package accounting - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "sync" - "time" - - "opencsg.com/csghub-server/common/config" -) - -type AccountingClient struct { - baseURL string - httpClient *http.Client - mutex sync.RWMutex - ctx context.Context -} - -type Response struct { - *http.Response -} - -func NewAccountingClient(config *config.Config) (*AccountingClient, error) { - if config.Accounting.Host == "" { - return nil, fmt.Errorf("accounting host should be configured") - } - - if config.Accounting.Port == 0 { - return nil, fmt.Errorf("accounting port should be configured") - } - if config.APIToken == "" { - return nil, fmt.Errorf("api token should be configured") - } - - return &AccountingClient{ - baseURL: fmt.Sprintf("%s:%d", config.Accounting.Host, config.Accounting.Port), - httpClient: &http.Client{ - Timeout: time.Second * 5, - }, - ctx: context.Background(), - }, nil -} - -func (c *AccountingClient) getParsedResponse(method, path string, header http.Header, body io.Reader, obj interface{}) (*Response, error) { - data, resp, err := c.getResponse(method, path, header, body) - if err != nil { - return resp, err - } - return resp, json.Unmarshal(data, obj) -} - -func (c *AccountingClient) getResponse(method, path string, header http.Header, body io.Reader) ([]byte, *Response, error) { - resp, err := c.doRequest(method, path, header, body) - if err != nil { - return nil, resp, err - } - defer resp.Body.Close() - - // check for errors - data, err := statusCodeToErr(resp) - if err != nil { - return data, resp, err - } - - // success (2XX), read body - data, err = io.ReadAll(resp.Body) - if err != nil { - return nil, resp, err - } - - return data, resp, nil -} - -// Converts a response for a HTTP status code indicating an error condition -// (non-2XX) to a well-known error value and response body. For non-problematic -// (2XX) status codes nil will be returned. Note that on a non-2XX response, the -// response body stream will have been read and, hence, is closed on return. -func statusCodeToErr(resp *Response) (body []byte, err error) { - // no error - if resp.StatusCode/100 == 2 { - return nil, nil - } - - // - // error: body will be read for details - // - defer resp.Body.Close() - data, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("body read on HTTP error %d: %v", resp.StatusCode, err) - } - - // Try to unmarshal and get an error message - errMap := make(map[string]interface{}) - if err = json.Unmarshal(data, &errMap); err != nil { - // when the JSON can't be parsed, data was probably empty or a - // plain string, so we try to return a helpful error anyway - path := resp.Request.URL.Path - method := resp.Request.Method - header := resp.Request.Header - return data, fmt.Errorf("unknown API Error: %d\nRequest: '%s' with '%s' method '%s' header and '%s' body", resp.StatusCode, path, method, header, string(data)) - } - - if msg, ok := errMap["message"]; ok { - return data, fmt.Errorf("%v", msg) - } - - // If no error message, at least give status and data - return data, fmt.Errorf("%s: %s", resp.Status, string(data)) -} - -func (c *AccountingClient) doRequest(method, path string, header http.Header, body io.Reader) (*Response, error) { - c.mutex.RLock() - req, err := http.NewRequestWithContext(c.ctx, method, c.baseURL+"/api/v1"+path, body) - if err != nil { - c.mutex.RUnlock() - return nil, err - } - - client := c.httpClient - c.mutex.RUnlock() - - for k, v := range header { - req.Header[k] = v - } - - resp, err := client.Do(req) - if err != nil { - return nil, err - } - return &Response{resp}, nil -} diff --git a/multisync/accounting/sync_quota.go b/multisync/accounting/sync_quota.go deleted file mode 100644 index 77b6860f..00000000 --- a/multisync/accounting/sync_quota.go +++ /dev/null @@ -1,52 +0,0 @@ -package accounting - -import ( - "bytes" - "encoding/json" - "net/http" -) - -type GetSyncQuotaReq struct { - AccessToken string `json:"access_token"` -} - -type SyncQuota struct { - RepoCountLimit int64 `json:"repo_count_limit"` - TrafficLimit int64 `json:"traffic_limit"` - AccessToken string `json:"-"` - RepoCountUsed int64 `json:"repo_count_used"` - SpeedLimit int64 `json:"speed_limit"` - TrafficUsed int64 `json:"traffic_used"` -} - -type SyncQuotaRes struct { - Message string `json:"msg"` - Data SyncQuota `json:"data"` -} - -type CreateSyncQuotaReq = SyncQuota - -type UpdateSyncQuotaReq = SyncQuota - -func (c *AccountingClient) CreateOrUpdateSyncQuota(opt *CreateSyncQuotaReq) (*Response, error) { - header := http.Header{"content-type": []string{"application/json"}} - body, err := json.Marshal(&opt) - if err != nil { - return nil, err - } - if opt.AccessToken != "" { - header.Add("Authorization", "Bearer "+opt.AccessToken) - } - _, resp, err := c.getResponse("POST", "/accounting/multisync/quotas", header, bytes.NewReader(body)) - return resp, err -} - -func (c *AccountingClient) GetSyncQuota(opt *GetSyncQuotaReq) (*SyncQuota, *Response, error) { - s := new(SyncQuotaRes) - header := http.Header{} - if opt.AccessToken != "" { - header.Add("Authorization", "Bearer "+opt.AccessToken) - } - resp, err := c.getParsedResponse("GET", "/accounting/multisync/quota", header, nil, s) - return &s.Data, resp, err -} diff --git a/multisync/component/mirror_proxy.go b/multisync/component/mirror_proxy.go deleted file mode 100644 index 5935e3b1..00000000 --- a/multisync/component/mirror_proxy.go +++ /dev/null @@ -1,77 +0,0 @@ -package component - -import ( - "context" - "fmt" - "net/http" - "strconv" - - "github.com/gin-gonic/gin" - "opencsg.com/csghub-server/builder/store/database" - "opencsg.com/csghub-server/common/config" - "opencsg.com/csghub-server/multisync/accounting" - "opencsg.com/csghub-server/multisync/types" -) - -type MirrorProxyComponent struct { - ac *accounting.AccountingClient - user *database.UserStore -} - -func NewMirrorProxyComponent(config *config.Config) (*MirrorProxyComponent, error) { - ac, err := accounting.NewAccountingClient(config) - if err != nil { - return nil, err - } - return &MirrorProxyComponent{ - ac: ac, - user: database.NewUserStore(), - }, nil -} - -func (c *MirrorProxyComponent) Serve(ctx context.Context, req *types.GetSyncQuotaStatementReq) error { - sq, _, err := c.ac.GetSyncQuota(&accounting.GetSyncQuotaReq{ - AccessToken: req.AccessToken, - }) - if err != nil { - return fmt.Errorf("error getting sync quota: %v", err) - } - if sq.RepoCountLimit <= sq.RepoCountUsed { - return fmt.Errorf("sync repository count limit exceeded") - } - sqs, _, err := c.ac.GetSyncQuotaStatement(&accounting.GetSyncQuotaStatementsReq{ - AccessToken: req.AccessToken, - RepoPath: req.RepoPath, - RepoType: req.RepoType, - }) - if err != nil { - return fmt.Errorf("error getting sync quota statement: %v", err) - } - if sqs.ID != 0 { - return nil - } - resp, err := c.ac.CreateSyncQuotaStatement(&accounting.CreateSyncQuotaStatementReq{ - AccessToken: req.AccessToken, - RepoPath: req.RepoPath, - RepoType: req.RepoType, - }) - if err != nil { - return fmt.Errorf("error creating sync quota statement: %v", err) - } - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("error creating sync quota statement") - } - return nil -} - -func (c *MirrorProxyComponent) LfsDownload(ctx *gin.Context, token string) error { - sq, _, err := c.ac.GetSyncQuota(&accounting.GetSyncQuotaReq{ - AccessToken: token, - }) - if err != nil { - return fmt.Errorf("error getting sync quota: %v", err) - } - - ctx.Request.Header.Add("X-OPENCSG-Speed-Limit", strconv.FormatInt(sq.SpeedLimit, 10)) - return nil -} diff --git a/multisync/handler/mirror_proxy.go b/multisync/handler/mirror_proxy.go deleted file mode 100644 index 70f13b19..00000000 --- a/multisync/handler/mirror_proxy.go +++ /dev/null @@ -1,85 +0,0 @@ -package handler - -import ( - "fmt" - "log/slog" - "strings" - - "github.com/gin-gonic/gin" - "opencsg.com/csghub-server/api/httpbase" - "opencsg.com/csghub-server/builder/proxy" - "opencsg.com/csghub-server/common/config" - "opencsg.com/csghub-server/multisync/component" - "opencsg.com/csghub-server/multisync/types" -) - -const MirrorTokenHeaderKey = "X-OPENCSG-Sync-Token" - -type MirrorProxyHandler struct { - gitServerURL string - mpComp *component.MirrorProxyComponent -} - -func NewMirrorProxyHandler(config *config.Config) (*MirrorProxyHandler, error) { - mpComp, err := component.NewMirrorProxyComponent(config) - if err != nil { - return nil, fmt.Errorf("failed to create repo component,%w", err) - } - - return &MirrorProxyHandler{ - mpComp: mpComp, - gitServerURL: config.GitServer.URL, - }, nil -} - -func (r *MirrorProxyHandler) Serve(ctx *gin.Context) { - var req types.GetSyncQuotaStatementReq - token := getMirrorTokenFromContext(ctx) - repoType := ctx.Param("repo_type") - namespace := ctx.Param("namespace") - name := ctx.Param("name") - name, _ = strings.CutSuffix(name, ".git") - req.RepoPath = fmt.Sprintf("%s/%s", namespace, name) - req.RepoType = strings.TrimSuffix(repoType, "s") - req.AccessToken = token - - if strings.HasSuffix(ctx.Request.URL.Path, "git-upload-pack") { - err := r.mpComp.Serve(ctx, &req) - if err != nil { - slog.Error("failed to serve git upload pack request:", slog.Any("err", err)) - httpbase.BadRequest(ctx, err.Error()) - return - } - } - - path := strings.Replace(ctx.Request.URL.Path, fmt.Sprintf("%s/", repoType), fmt.Sprintf("%s_", repoType), 1) - rp, _ := proxy.NewReverseProxy(r.gitServerURL) - rp.ServeHTTP(ctx.Writer, ctx.Request, path) -} - -func (r *MirrorProxyHandler) ServeLFS(ctx *gin.Context) { - var req types.GetSyncQuotaStatementReq - token := getMirrorTokenFromContext(ctx) - repoType := ctx.Param("repo_type") - namespace := ctx.Param("namespace") - name := ctx.Param("name") - name, _ = strings.CutSuffix(name, ".git") - req.RepoPath = fmt.Sprintf("%s/%s", namespace, name) - req.RepoType = strings.TrimSuffix(repoType, "s") - req.AccessToken = token - - err := r.mpComp.LfsDownload(ctx, token) - if err != nil { - slog.Error("failed to serve lfs download request:", slog.Any("err", err)) - httpbase.BadRequest(ctx, err.Error()) - return - } - - path := strings.Replace(ctx.Request.URL.Path, fmt.Sprintf("%s/", repoType), fmt.Sprintf("%s_", repoType), 1) - rp, _ := proxy.NewReverseProxy(r.gitServerURL) - rp.ServeHTTP(ctx.Writer, ctx.Request, path) -} - -func getMirrorTokenFromContext(ctx *gin.Context) string { - return ctx.GetHeader(MirrorTokenHeaderKey) -} diff --git a/multisync/router/api.go b/multisync/router/api.go deleted file mode 100644 index a384fc8e..00000000 --- a/multisync/router/api.go +++ /dev/null @@ -1,47 +0,0 @@ -package router - -import ( - "fmt" - - "github.com/gin-gonic/gin" - "opencsg.com/csghub-server/api/middleware" - "opencsg.com/csghub-server/common/config" - "opencsg.com/csghub-server/multisync/handler" -) - -func NewRouter(config *config.Config) (*gin.Engine, error) { - r := gin.New() - r.Use(gin.Recovery()) - r.Use(middleware.Log()) - // store := cookie.NewStore([]byte(config.Mirror.SessionSecretKey)) - // store.Options(sessions.Options{ - // SameSite: http.SameSiteNoneMode, - // Secure: config.EnableHTTPS, - // }) - // r.Use(sessions.Sessions("jwt_session", store)) - // r.Use(middleware.BuildJwtSession(config.JWT.SigningKey)) - - mpHandler, err := handler.NewMirrorProxyHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating rproxy handler:%w", err) - } - rGroup := r.Group("/:repo_type/:namespace/:name") - { - rGroup.POST("/git-upload-pack", mpHandler.Serve) - rGroup.POST("/git-receive-pack", mpHandler.Serve) - rGroup.GET("/info/refs", mpHandler.Serve) - rGroup.GET("/HEAD", mpHandler.Serve) - rGroup.GET("/objects/info/alternates", mpHandler.Serve) - rGroup.GET("/objects/info/http-alternates", mpHandler.Serve) - rGroup.GET("/objects/info/packs", mpHandler.Serve) - rGroup.GET("/objects/info/:file", mpHandler.Serve) - rGroup.GET("/objects/:head/:hash", mpHandler.Serve) - rGroup.GET("/objects/pack/pack-:file", mpHandler.Serve) - rGroup.POST("/info/lfs/objects/batch", mpHandler.ServeLFS) - rGroup.GET("/info/lfs/objects/:oid", mpHandler.ServeLFS) - } - - // r.Any("/*api", handler.Serve) - - return r, nil -} diff --git a/multisync/types/mirror_proxy.go b/multisync/types/mirror_proxy.go deleted file mode 100644 index 21c9ef8d..00000000 --- a/multisync/types/mirror_proxy.go +++ /dev/null @@ -1,7 +0,0 @@ -package types - -type GetSyncQuotaStatementReq struct { - RepoPath string `json:"repo_path"` - RepoType string `json:"repo_type"` - AccessToken string `json:"access_token"` -}