diff --git a/integrations/pull_update_test.go b/integrations/pull_update_test.go index 2dc966316e6c0..9d04eeafe67f6 100644 --- a/integrations/pull_update_test.go +++ b/integrations/pull_update_test.go @@ -47,6 +47,34 @@ func TestAPIPullUpdate(t *testing.T) { }) } +func TestAPIPullUpdateByRebase(t *testing.T) { + onGiteaRun(t, func(t *testing.T, giteaURL *url.URL) { + //Create PR to test + user := models.AssertExistsAndLoadBean(t, &models.User{ID: 2}).(*models.User) + org26 := models.AssertExistsAndLoadBean(t, &models.User{ID: 26}).(*models.User) + pr := createOutdatedPR(t, user, org26) + + //Test GetDiverging + diffCount, err := pull_service.GetDiverging(pr) + assert.NoError(t, err) + assert.EqualValues(t, 1, diffCount.Behind) + assert.EqualValues(t, 1, diffCount.Ahead) + assert.NoError(t, pr.LoadBaseRepo()) + assert.NoError(t, pr.LoadIssue()) + + session := loginUser(t, "user2") + token := getTokenForLoggedInUser(t, session) + req := NewRequestf(t, "POST", "/api/v1/repos/%s/%s/pulls/%d/update?style=rebase&token="+token, pr.BaseRepo.OwnerName, pr.BaseRepo.Name, pr.Issue.Index) + session.MakeRequest(t, req, http.StatusOK) + + //Test GetDiverging after update + diffCount, err = pull_service.GetDiverging(pr) + assert.NoError(t, err) + assert.EqualValues(t, 0, diffCount.Behind) + assert.EqualValues(t, 1, diffCount.Ahead) + }) +} + func createOutdatedPR(t *testing.T, actor, forkOrg *models.User) *models.PullRequest { baseRepo, err := repo_service.CreateRepository(actor, actor, models.CreateRepoOptions{ Name: "repo-pr-update", diff --git a/routers/api/v1/repo/pull.go b/routers/api/v1/repo/pull.go index eff998ee996a1..bb7503cd4658b 100644 --- a/routers/api/v1/repo/pull.go +++ b/routers/api/v1/repo/pull.go @@ -1030,6 +1030,11 @@ func UpdatePullRequest(ctx *context.APIContext) { // type: integer // format: int64 // required: true + // - name: style + // in: query + // description: how to update pull request + // type: string + // enum: [merge, rebase] // responses: // "200": // "$ref": "#/responses/empty" @@ -1076,7 +1081,9 @@ func UpdatePullRequest(ctx *context.APIContext) { return } - allowedUpdate, err := pull_service.IsUserAllowedToUpdate(pr, ctx.User) + rebase := ctx.Query("style") == "rebase" + + allowedUpdate, err := pull_service.IsUserAllowedToUpdate(pr, ctx.User, rebase) if err != nil { ctx.Error(http.StatusInternalServerError, "IsUserAllowedToMerge", err) return @@ -1090,7 +1097,7 @@ func UpdatePullRequest(ctx *context.APIContext) { // default merge commit message message := fmt.Sprintf("Merge branch '%s' into %s", pr.BaseBranch, pr.HeadBranch) - if err = pull_service.Update(pr, ctx.User, message); err != nil { + if err = pull_service.Update(pr, ctx.User, message, rebase); err != nil { if models.IsErrMergeConflicts(err) { ctx.Error(http.StatusConflict, "Update", "merge failed because of conflict") return diff --git a/routers/web/repo/pull.go b/routers/web/repo/pull.go index 28f94c841701e..035e3374ebb63 100644 --- a/routers/web/repo/pull.go +++ b/routers/web/repo/pull.go @@ -439,7 +439,13 @@ func PrepareViewPullInfo(ctx *context.Context, issue *models.Issue) *git.Compare } if headBranchExist { - ctx.Data["UpdateAllowed"], err = pull_service.IsUserAllowedToUpdate(pull, ctx.User) + b, err := models.GetProtectedBranchBy(pull.HeadRepoID, pull.HeadBranch) + if err != nil { + ctx.ServerError("GetProtectedBranchBy", err) + return nil + } + ctx.Data["UpdateByRebaseNotAllowed"] = b != nil + ctx.Data["UpdateAllowed"], err = pull_service.IsUserAllowedToUpdate(pull, ctx.User, false) if err != nil { ctx.ServerError("IsUserAllowedToUpdate", err) return nil @@ -712,6 +718,8 @@ func UpdatePullRequest(ctx *context.Context) { return } + rebase := ctx.Query("style") == "rebase" + if err := issue.PullRequest.LoadBaseRepo(); err != nil { ctx.ServerError("LoadBaseRepo", err) return @@ -721,7 +729,7 @@ func UpdatePullRequest(ctx *context.Context) { return } - allowedUpdate, err := pull_service.IsUserAllowedToUpdate(issue.PullRequest, ctx.User) + allowedUpdate, err := pull_service.IsUserAllowedToUpdate(issue.PullRequest, ctx.User, rebase) if err != nil { ctx.ServerError("IsUserAllowedToMerge", err) return @@ -737,7 +745,7 @@ func UpdatePullRequest(ctx *context.Context) { // default merge commit message message := fmt.Sprintf("Merge branch '%s' into %s", issue.PullRequest.BaseBranch, issue.PullRequest.HeadBranch) - if err = pull_service.Update(issue.PullRequest, ctx.User, message); err != nil { + if err = pull_service.Update(issue.PullRequest, ctx.User, message, rebase); err != nil { if models.IsErrMergeConflicts(err) { conflictError := err.(models.ErrMergeConflicts) flashError, err := ctx.HTMLString(string(tplAlertDetails), map[string]interface{}{ diff --git a/services/pull/update.go b/services/pull/update.go index f4f7859a49ec1..e2b3a7c616267 100644 --- a/services/pull/update.go +++ b/services/pull/update.go @@ -5,7 +5,11 @@ package pull import ( + "errors" "fmt" + "os" + "strings" + "time" "code.gitea.io/gitea/models" "code.gitea.io/gitea/modules/git" @@ -13,7 +17,7 @@ import ( ) // Update updates pull request with base branch. -func Update(pull *models.PullRequest, doer *models.User, message string) error { +func Update(pull *models.PullRequest, doer *models.User, message string, rebase bool) error { //use merge functions but switch repo's and branch's pr := &models.PullRequest{ HeadRepoID: pull.BaseRepoID, @@ -37,7 +41,11 @@ func Update(pull *models.PullRequest, doer *models.User, message string) error { return fmt.Errorf("HeadBranch of PR %d is up to date", pull.Index) } - _, err = rawMerge(pr, doer, models.MergeStyleMerge, message) + if rebase { + err = doRebase(pr, doer) + } else { + _, err = rawMerge(pr, doer, models.MergeStyleMerge, message) + } defer func() { go AddTestPullRequestTask(doer, pr.HeadRepo.ID, pr.HeadBranch, false, "", "") @@ -47,7 +55,7 @@ func Update(pull *models.PullRequest, doer *models.User, message string) error { } // IsUserAllowedToUpdate check if user is allowed to update PR with given permissions and branch protections -func IsUserAllowedToUpdate(pull *models.PullRequest, user *models.User) (bool, error) { +func IsUserAllowedToUpdate(pull *models.PullRequest, user *models.User, rebase bool) (bool, error) { if user == nil { return false, nil } @@ -68,6 +76,11 @@ func IsUserAllowedToUpdate(pull *models.PullRequest, user *models.User) (bool, e return false, err } + // can't do rebase on protected branch because need force push + if rebase && pr.ProtectedBranch != nil { + return false, err + } + // Update function need push permission if pr.ProtectedBranch != nil && !pr.ProtectedBranch.CanUserPush(user.ID) { return false, nil @@ -100,3 +113,83 @@ func GetDiverging(pr *models.PullRequest) (*git.DivergeObject, error) { diff, err := git.GetDivergingCommits(tmpRepo, "base", "tracking") return &diff, err } + +func doRebase(pr *models.PullRequest, doer *models.User) error { + // 1. Clone base repo. + tmpBasePath, err := createTemporaryRepo(pr) + if err != nil { + log.Error("CreateTemporaryPath: %v", err) + return err + } + defer func() { + if err := models.RemoveTemporaryPath(tmpBasePath); err != nil { + log.Error("Update-By-Rebase: RemoveTemporaryPath: %s", err) + } + }() + + baseBranch := "base" + trackingBranch := "tracking" + + // 2. checkout base branch + msg, err := git.NewCommand("checkout", baseBranch).RunInDir(tmpBasePath) + if err != nil { + return errors.New(msg) + } + + // 3. do rebase to ttacking branch + sig := doer.NewGitSig() + committer := sig + + // Determine if we should sign + signArg := "" + if git.CheckGitVersionAtLeast("1.7.9") == nil { + sign, keyID, signer, _ := pr.SignMerge(doer, tmpBasePath, "HEAD", trackingBranch) + if sign { + signArg = "-S" + keyID + if pr.BaseRepo.GetTrustModel() == models.CommitterTrustModel || pr.BaseRepo.GetTrustModel() == models.CollaboratorCommitterTrustModel { + committer = signer + } + } else if git.CheckGitVersionAtLeast("2.0.0") == nil { + signArg = "--no-gpg-sign" + } + } + + commitTimeStr := time.Now().Format(time.RFC3339) + + // Because this may call hooks we should pass in the environment + env := append(os.Environ(), + "GIT_AUTHOR_NAME="+sig.Name, + "GIT_AUTHOR_EMAIL="+sig.Email, + "GIT_AUTHOR_DATE="+commitTimeStr, + "GIT_COMMITTER_NAME="+committer.Name, + "GIT_COMMITTER_EMAIL="+committer.Email, + "GIT_COMMITTER_DATE="+commitTimeStr, + ) + + var outbuf, errbuf strings.Builder + err = git.NewCommand("rebase", trackingBranch, signArg).RunInDirTimeoutEnvFullPipeline(env, -1, tmpBasePath, &outbuf, &errbuf, nil) + if err != nil { + log.Error("git rebase [%s:%s -> %s:%s]: %v\n%s\n%s", pr.BaseRepo.FullName(), pr.BaseBranch, pr.HeadRepo.FullName(), pr.HeadBranch, err, outbuf.String(), errbuf.String()) + return fmt.Errorf("git rebase [%s:%s -> %s:%s]: %v\n%s\n%s", pr.BaseRepo.FullName(), pr.BaseBranch, pr.HeadRepo.FullName(), pr.HeadBranch, err, outbuf.String(), errbuf.String()) + } + + // 4. force push to base branch + env = models.FullPushingEnvironment(doer, doer, pr.BaseRepo, pr.BaseRepo.Name, pr.ID) + + outbuf.Reset() + errbuf.Reset() + if err := git.NewCommand("push", "-f", "origin", baseBranch+":refs/heads/"+pr.BaseBranch).RunInDirTimeoutEnvPipeline(env, -1, tmpBasePath, &outbuf, &errbuf); err != nil { + if strings.Contains(errbuf.String(), "! [remote rejected]") { + err := &git.ErrPushRejected{ + StdOut: outbuf.String(), + StdErr: errbuf.String(), + Err: err, + } + err.GenerateMessage() + return err + } + return fmt.Errorf("git force push: %s", errbuf.String()) + } + + return nil +} diff --git a/templates/repo/issue/view_content/pull.tmpl b/templates/repo/issue/view_content/pull.tmpl index 3bdec4becb02e..60642fc6dd358 100644 --- a/templates/repo/issue/view_content/pull.tmpl +++ b/templates/repo/issue/view_content/pull.tmpl @@ -281,7 +281,22 @@ {{$.i18n.Tr "repo.pulls.outdated_with_base_branch"}}