Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor CSRF token #32216

Merged
merged 3 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions routers/web/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func autoSignIn(ctx *context.Context) (bool, error) {
return false, err
}

ctx.Csrf.DeleteCookie(ctx)
ctx.Csrf.PrepareForSessionUser(ctx)
return true, nil
}

Expand Down Expand Up @@ -359,8 +359,8 @@ func handleSignInFull(ctx *context.Context, u *user_model.User, remember, obeyRe
ctx.Locale = middleware.Locale(ctx.Resp, ctx.Req)
}

// Clear whatever CSRF cookie has right now, force to generate a new one
ctx.Csrf.DeleteCookie(ctx)
// force to generate a new CSRF token
ctx.Csrf.PrepareForSessionUser(ctx)

// Register last login
if err := user_service.UpdateUser(ctx, u, &user_service.UpdateOptions{SetLastLogin: true}); err != nil {
Expand Down Expand Up @@ -804,6 +804,8 @@ func handleAccountActivation(ctx *context.Context, user *user_model.User) {
return
}

ctx.Csrf.PrepareForSessionUser(ctx)

if err := resetLocale(ctx, user); err != nil {
ctx.ServerError("resetLocale", err)
return
Expand Down
4 changes: 2 additions & 2 deletions routers/web/auth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,8 @@ func handleOAuth2SignIn(ctx *context.Context, source *auth.Source, u *user_model
return
}

// Clear whatever CSRF cookie has right now, force to generate a new one
ctx.Csrf.DeleteCookie(ctx)
// force to generate a new CSRF token
ctx.Csrf.PrepareForSessionUser(ctx)

if err := resetLocale(ctx, u); err != nil {
ctx.ServerError("resetLocale", err)
Expand Down
4 changes: 2 additions & 2 deletions services/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ func handleSignIn(resp http.ResponseWriter, req *http.Request, sess SessionStore

middleware.SetLocaleCookie(resp, user.Language, 0)

// Clear whatever CSRF has right now, force to generate a new one
// force to generate a new CSRF token
if ctx := gitea_context.GetWebContext(req); ctx != nil {
ctx.Csrf.DeleteCookie(ctx)
ctx.Csrf.PrepareForSessionUser(ctx)
}
}
4 changes: 1 addition & 3 deletions services/context/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,8 @@ func (c *csrfProtector) PrepareForSessionUser(ctx *Context) {
}

if needsNew {
// FIXME: actionId.
c.token = GenerateCsrfToken(c.opt.Secret, c.id, "POST", time.Now())
cookie := newCsrfCookie(&c.opt, c.token)
ctx.Resp.Header().Add("Set-Cookie", cookie.String())
ctx.Resp.Header().Add("Set-Cookie", newCsrfCookie(&c.opt, c.token).String())
}

ctx.Data["CsrfToken"] = c.token
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/admin_user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func testSuccessfullEdit(t *testing.T, formData user_model.User) {

func makeRequest(t *testing.T, formData user_model.User, headerCode int) {
session := loginUser(t, "user1")
csrf := GetCSRF(t, session, "/admin/users/"+strconv.Itoa(int(formData.ID))+"/edit")
csrf := GetUserCSRFToken(t, session)
req := NewRequestWithValues(t, "POST", "/admin/users/"+strconv.Itoa(int(formData.ID))+"/edit", map[string]string{
"_csrf": csrf,
"user_name": formData.Name,
Expand All @@ -72,7 +72,7 @@ func TestAdminDeleteUser(t *testing.T) {

session := loginUser(t, "user1")

csrf := GetCSRF(t, session, "/admin/users/8/edit")
csrf := GetUserCSRFToken(t, session)
req := NewRequestWithValues(t, "POST", "/admin/users/8/delete", map[string]string{
"_csrf": csrf,
})
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/api_httpsig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func TestHTTPSigCert(t *testing.T) {
defer tests.PrepareTestEnv(t)()
session := loginUser(t, "user1")

csrf := GetCSRF(t, session, "/user/settings/keys")
csrf := GetUserCSRFToken(t, session)
req := NewRequestWithValues(t, "POST", "/user/settings/keys", map[string]string{
"_csrf": csrf,
"content": "user1",
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/api_packages_container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ func TestPackageContainer(t *testing.T) {
newOwnerName := "newUsername"

req := NewRequestWithValues(t, "POST", "/user/settings", map[string]string{
"_csrf": GetCSRF(t, session, "/user/settings"),
"_csrf": GetUserCSRFToken(t, session),
"name": newOwnerName,
"email": "user2@example.com",
"language": "en-US",
Expand All @@ -794,7 +794,7 @@ func TestPackageContainer(t *testing.T) {
t.Run(fmt.Sprintf("Catalog[%s]", newOwnerName), checkCatalog(newOwnerName))

req = NewRequestWithValues(t, "POST", "/user/settings", map[string]string{
"_csrf": GetCSRF(t, session, "/user/settings"),
"_csrf": GetUserCSRFToken(t, session),
"name": user.Name,
"email": "user2@example.com",
"language": "en-US",
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/attachment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ func createAttachment(t *testing.T, session *TestSession, csrf, repoURL, filenam
func TestCreateAnonymousAttachment(t *testing.T) {
defer tests.PrepareTestEnv(t)()
session := emptyTestSession(t)
createAttachment(t, session, GetCSRF(t, session, "/user/login"), "user2/repo1", "image.png", generateImg(), http.StatusSeeOther)
createAttachment(t, session, GetAnonymousCSRFToken(t, session), "user2/repo1", "image.png", generateImg(), http.StatusSeeOther)
}

func TestCreateIssueAttachment(t *testing.T) {
defer tests.PrepareTestEnv(t)()
const repoURL = "user2/repo1"
session := loginUser(t, "user2")
uuid := createAttachment(t, session, GetCSRF(t, session, repoURL), repoURL, "image.png", generateImg(), http.StatusOK)
uuid := createAttachment(t, session, GetUserCSRFToken(t, session), repoURL, "image.png", generateImg(), http.StatusOK)

req := NewRequest(t, "GET", repoURL+"/issues/new")
resp := session.MakeRequest(t, req, http.StatusOK)
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/auth_ldap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func addAuthSourceLDAP(t *testing.T, sshKeyAttribute, groupFilter string, groupM
groupTeamMap = groupMapParams[1]
}
session := loginUser(t, "user1")
csrf := GetCSRF(t, session, "/admin/auths/new")
csrf := GetUserCSRFToken(t, session)
req := NewRequestWithValues(t, "POST", "/admin/auths/new", buildAuthSourceLDAPPayload(csrf, sshKeyAttribute, groupFilter, groupTeamMap, groupTeamMapRemoval))
session.MakeRequest(t, req, http.StatusSeeOther)
}
Expand Down Expand Up @@ -252,7 +252,7 @@ func TestLDAPUserSyncWithEmptyUsernameAttribute(t *testing.T) {
defer tests.PrepareTestEnv(t)()

session := loginUser(t, "user1")
csrf := GetCSRF(t, session, "/admin/auths/new")
csrf := GetUserCSRFToken(t, session)
payload := buildAuthSourceLDAPPayload(csrf, "", "", "", "")
payload["attribute_username"] = ""
req := NewRequestWithValues(t, "POST", "/admin/auths/new", payload)
Expand Down Expand Up @@ -487,7 +487,7 @@ func TestLDAPPreventInvalidGroupTeamMap(t *testing.T) {
defer tests.PrepareTestEnv(t)()

session := loginUser(t, "user1")
csrf := GetCSRF(t, session, "/admin/auths/new")
csrf := GetUserCSRFToken(t, session)
req := NewRequestWithValues(t, "POST", "/admin/auths/new", buildAuthSourceLDAPPayload(csrf, "", "", `{"NOT_A_VALID_JSON"["MISSING_DOUBLE_POINT"]}`, "off"))
session.MakeRequest(t, req, http.StatusOK) // StatusOK = failed, StatusSeeOther = ok
}
4 changes: 2 additions & 2 deletions tests/integration/change_default_branch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ func TestChangeDefaultBranch(t *testing.T) {
session := loginUser(t, owner.Name)
branchesURL := fmt.Sprintf("/%s/%s/settings/branches", owner.Name, repo.Name)

csrf := GetCSRF(t, session, branchesURL)
csrf := GetUserCSRFToken(t, session)
req := NewRequestWithValues(t, "POST", branchesURL, map[string]string{
"_csrf": csrf,
"action": "default_branch",
"branch": "DefaultBranch",
})
session.MakeRequest(t, req, http.StatusSeeOther)

csrf = GetCSRF(t, session, branchesURL)
csrf = GetUserCSRFToken(t, session)
req = NewRequestWithValues(t, "POST", branchesURL, map[string]string{
"_csrf": csrf,
"action": "default_branch",
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/delete_user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestUserDeleteAccount(t *testing.T) {
defer tests.PrepareTestEnv(t)()

session := loginUser(t, "user8")
csrf := GetCSRF(t, session, "/user/settings/account")
csrf := GetUserCSRFToken(t, session)
urlStr := fmt.Sprintf("/user/settings/account/delete?password=%s", userPassword)
req := NewRequestWithValues(t, "POST", urlStr, map[string]string{
"_csrf": csrf,
Expand All @@ -48,7 +48,7 @@ func TestUserDeleteAccountStillOwnRepos(t *testing.T) {
defer tests.PrepareTestEnv(t)()

session := loginUser(t, "user2")
csrf := GetCSRF(t, session, "/user/settings/account")
csrf := GetUserCSRFToken(t, session)
urlStr := fmt.Sprintf("/user/settings/account/delete?password=%s", userPassword)
req := NewRequestWithValues(t, "POST", urlStr, map[string]string{
"_csrf": csrf,
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/editor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestCreateFileOnProtectedBranch(t *testing.T) {
onGiteaRun(t, func(t *testing.T, u *url.URL) {
session := loginUser(t, "user2")

csrf := GetCSRF(t, session, "/user2/repo1/settings/branches")
csrf := GetUserCSRFToken(t, session)
// Change master branch to protected
req := NewRequestWithValues(t, "POST", "/user2/repo1/settings/branches/edit", map[string]string{
"_csrf": csrf,
Expand Down Expand Up @@ -84,7 +84,7 @@ func TestCreateFileOnProtectedBranch(t *testing.T) {
assert.Contains(t, resp.Body.String(), "Cannot commit to protected branch "master".")

// remove the protected branch
csrf = GetCSRF(t, session, "/user2/repo1/settings/branches")
csrf = GetUserCSRFToken(t, session)

// Change master branch to protected
req = NewRequestWithValues(t, "POST", "/user2/repo1/settings/branches/1/delete", map[string]string{
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/empty_repo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
func testAPINewFile(t *testing.T, session *TestSession, user, repo, branch, treePath, content string) *httptest.ResponseRecorder {
url := fmt.Sprintf("/%s/%s/_new/%s", user, repo, branch)
req := NewRequestWithValues(t, "POST", url, map[string]string{
"_csrf": GetCSRF(t, session, "/user/settings"),
"_csrf": GetUserCSRFToken(t, session),
"commit_choice": "direct",
"tree_path": treePath,
"content": content,
Expand Down Expand Up @@ -63,7 +63,7 @@ func TestEmptyRepoAddFile(t *testing.T) {
doc := NewHTMLParser(t, resp.Body).Find(`input[name="commit_choice"]`)
assert.Empty(t, doc.AttrOr("checked", "_no_"))
req = NewRequestWithValues(t, "POST", "/user30/empty/_new/"+setting.Repository.DefaultBranch, map[string]string{
"_csrf": GetCSRF(t, session, "/user/settings"),
"_csrf": GetUserCSRFToken(t, session),
"commit_choice": "direct",
"tree_path": "test-file.md",
"content": "newly-added-test-file",
Expand All @@ -89,7 +89,7 @@ func TestEmptyRepoUploadFile(t *testing.T) {

body := &bytes.Buffer{}
mpForm := multipart.NewWriter(body)
_ = mpForm.WriteField("_csrf", GetCSRF(t, session, "/user/settings"))
_ = mpForm.WriteField("_csrf", GetUserCSRFToken(t, session))
file, _ := mpForm.CreateFormFile("file", "uploaded-file.txt")
_, _ = io.Copy(file, bytes.NewBufferString("newly-uploaded-test-file"))
_ = mpForm.Close()
Expand All @@ -101,7 +101,7 @@ func TestEmptyRepoUploadFile(t *testing.T) {
assert.NoError(t, json.Unmarshal(resp.Body.Bytes(), &respMap))

req = NewRequestWithValues(t, "POST", "/user30/empty/_upload/"+setting.Repository.DefaultBranch, map[string]string{
"_csrf": GetCSRF(t, session, "/user/settings"),
"_csrf": GetUserCSRFToken(t, session),
"commit_choice": "direct",
"files": respMap["uuid"],
"tree_path": "",
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/git_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ func doBranchProtectPRMerge(baseCtx *APITestContext, dstPath string) func(t *tes
func doProtectBranch(ctx APITestContext, branch, userToWhitelistPush, userToWhitelistForcePush, unprotectedFilePatterns string) func(t *testing.T) {
// We are going to just use the owner to set the protection.
return func(t *testing.T) {
csrf := GetCSRF(t, ctx.Session, fmt.Sprintf("/%s/%s/settings/branches", url.PathEscape(ctx.Username), url.PathEscape(ctx.Reponame)))
csrf := GetUserCSRFToken(t, ctx.Session)

formData := map[string]string{
"_csrf": csrf,
Expand Down Expand Up @@ -644,7 +644,7 @@ func doPushCreate(ctx APITestContext, u *url.URL) func(t *testing.T) {

func doBranchDelete(ctx APITestContext, owner, repo, branch string) func(*testing.T) {
return func(t *testing.T) {
csrf := GetCSRF(t, ctx.Session, fmt.Sprintf("/%s/%s/branches", url.PathEscape(owner), url.PathEscape(repo)))
csrf := GetUserCSRFToken(t, ctx.Session)

req := NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/branches/delete?name=%s", url.PathEscape(owner), url.PathEscape(repo), url.QueryEscape(branch)), map[string]string{
"_csrf": csrf,
Expand Down
26 changes: 11 additions & 15 deletions tests/integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -486,23 +486,19 @@ func VerifyJSONSchema(t testing.TB, resp *httptest.ResponseRecorder, schemaFile
assert.True(t, result.Valid())
}

// GetCSRF returns CSRF token from body
// If it fails, it means the CSRF token is not found in the response body returned by the url with the given session.
// In this case, you should find a better url to get it.
func GetCSRF(t testing.TB, session *TestSession, urlStr string) string {
// GetUserCSRFToken returns CSRF token for current user
func GetUserCSRFToken(t testing.TB, session *TestSession) string {
t.Helper()
req := NewRequest(t, "GET", urlStr)
resp := session.MakeRequest(t, req, http.StatusOK)
doc := NewHTMLParser(t, resp.Body)
csrf := doc.GetCSRF()
require.NotEmpty(t, csrf)
return csrf
cookie := session.GetCookie("_csrf")
require.NotEmpty(t, cookie)
return cookie.Value
}

// GetCSRFFrom returns CSRF token from body
func GetCSRFFromCookie(t testing.TB, session *TestSession, urlStr string) string {
// GetUserCSRFToken returns CSRF token for anonymous user (not logged in)
func GetAnonymousCSRFToken(t testing.TB, session *TestSession) string {
t.Helper()
req := NewRequest(t, "GET", urlStr)
session.MakeRequest(t, req, http.StatusOK)
return session.GetCookie("_csrf").Value
resp := session.MakeRequest(t, NewRequest(t, "GET", "/user/login"), http.StatusOK)
csrfToken := NewHTMLParser(t, resp.Body).GetCSRF()
require.NotEmpty(t, csrfToken)
return csrfToken
}
20 changes: 10 additions & 10 deletions tests/integration/issue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,21 +197,21 @@ func TestEditIssue(t *testing.T) {
issueURL := testNewIssue(t, session, "user2", "repo1", "Title", "Description")

req := NewRequestWithValues(t, "POST", fmt.Sprintf("%s/content", issueURL), map[string]string{
"_csrf": GetCSRF(t, session, issueURL),
"_csrf": GetUserCSRFToken(t, session),
"content": "modified content",
"context": fmt.Sprintf("/%s/%s", "user2", "repo1"),
})
session.MakeRequest(t, req, http.StatusOK)

req = NewRequestWithValues(t, "POST", fmt.Sprintf("%s/content", issueURL), map[string]string{
"_csrf": GetCSRF(t, session, issueURL),
"_csrf": GetUserCSRFToken(t, session),
"content": "modified content",
"context": fmt.Sprintf("/%s/%s", "user2", "repo1"),
})
session.MakeRequest(t, req, http.StatusBadRequest)

req = NewRequestWithValues(t, "POST", fmt.Sprintf("%s/content", issueURL), map[string]string{
"_csrf": GetCSRF(t, session, issueURL),
"_csrf": GetUserCSRFToken(t, session),
"content": "modified content",
"content_version": "1",
"context": fmt.Sprintf("/%s/%s", "user2", "repo1"),
Expand Down Expand Up @@ -246,11 +246,11 @@ func TestIssueCommentDelete(t *testing.T) {

// Using the ID of a comment that does not belong to the repository must fail
req := NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/comments/%d/delete", "user5", "repo4", commentID), map[string]string{
"_csrf": GetCSRF(t, session, issueURL),
"_csrf": GetUserCSRFToken(t, session),
})
session.MakeRequest(t, req, http.StatusNotFound)
req = NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/comments/%d/delete", "user2", "repo1", commentID), map[string]string{
"_csrf": GetCSRF(t, session, issueURL),
"_csrf": GetUserCSRFToken(t, session),
})
session.MakeRequest(t, req, http.StatusOK)
unittest.AssertNotExistsBean(t, &issues_model.Comment{ID: commentID})
Expand All @@ -270,13 +270,13 @@ func TestIssueCommentUpdate(t *testing.T) {

// Using the ID of a comment that does not belong to the repository must fail
req := NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/comments/%d", "user5", "repo4", commentID), map[string]string{
"_csrf": GetCSRF(t, session, issueURL),
"_csrf": GetUserCSRFToken(t, session),
"content": modifiedContent,
})
session.MakeRequest(t, req, http.StatusNotFound)

req = NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/comments/%d", "user2", "repo1", commentID), map[string]string{
"_csrf": GetCSRF(t, session, issueURL),
"_csrf": GetUserCSRFToken(t, session),
"content": modifiedContent,
})
session.MakeRequest(t, req, http.StatusOK)
Expand All @@ -298,21 +298,21 @@ func TestIssueCommentUpdateSimultaneously(t *testing.T) {
modifiedContent := comment.Content + "MODIFIED"

req := NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/comments/%d", "user2", "repo1", commentID), map[string]string{
"_csrf": GetCSRF(t, session, issueURL),
"_csrf": GetUserCSRFToken(t, session),
"content": modifiedContent,
})
session.MakeRequest(t, req, http.StatusOK)

modifiedContent = comment.Content + "2"

req = NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/comments/%d", "user2", "repo1", commentID), map[string]string{
"_csrf": GetCSRF(t, session, issueURL),
"_csrf": GetUserCSRFToken(t, session),
"content": modifiedContent,
})
session.MakeRequest(t, req, http.StatusBadRequest)

req = NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/comments/%d", "user2", "repo1", commentID), map[string]string{
"_csrf": GetCSRF(t, session, issueURL),
"_csrf": GetUserCSRFToken(t, session),
"content": modifiedContent,
"content_version": "1",
})
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/mirror_push_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func testMirrorPush(t *testing.T, u *url.URL) {

func doCreatePushMirror(ctx APITestContext, address, username, password string) func(t *testing.T) {
return func(t *testing.T) {
csrf := GetCSRF(t, ctx.Session, fmt.Sprintf("/%s/%s/settings", url.PathEscape(ctx.Username), url.PathEscape(ctx.Reponame)))
csrf := GetUserCSRFToken(t, ctx.Session)

req := NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/settings", url.PathEscape(ctx.Username), url.PathEscape(ctx.Reponame)), map[string]string{
"_csrf": csrf,
Expand All @@ -101,7 +101,7 @@ func doCreatePushMirror(ctx APITestContext, address, username, password string)

func doRemovePushMirror(ctx APITestContext, address, username, password string, pushMirrorID int) func(t *testing.T) {
return func(t *testing.T) {
csrf := GetCSRF(t, ctx.Session, fmt.Sprintf("/%s/%s/settings", url.PathEscape(ctx.Username), url.PathEscape(ctx.Reponame)))
csrf := GetUserCSRFToken(t, ctx.Session)

req := NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/settings", url.PathEscape(ctx.Username), url.PathEscape(ctx.Reponame)), map[string]string{
"_csrf": csrf,
Expand Down
Loading
Loading