Skip to content

Commit

Permalink
feat: use one transaction for /admin/recovery/code (#4225)
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr authored Nov 27, 2024
1 parent 5e26610 commit 3e87e0c
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 18 deletions.
1 change: 1 addition & 0 deletions selfservice/strategy/code/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ type (
x.WriterProvider
x.LoggingProvider
x.TracingProvider
x.TransactionPersistenceProvider

config.Provider

Expand Down
31 changes: 19 additions & 12 deletions selfservice/strategy/code/strategy_recovery_admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
package code

import (
"context"
"net/http"
"net/url"
"time"

"github.com/gobuffalo/pop/v6"

"github.com/gofrs/uuid"
"github.com/julienschmidt/httprouter"
"github.com/pkg/errors"
Expand Down Expand Up @@ -184,16 +187,12 @@ func (s *Strategy) createRecoveryCodeForIdentity(w http.ResponseWriter, r *http.
})).
WithMetaLabel(text.NewInfoNodeLabelRecoveryCode()),
)
rawCode := GenerateCode()

recoveryFlow.UI.Nodes.
Append(node.NewInputField("method", s.RecoveryStrategyID(), node.CodeGroup, node.InputAttributeTypeSubmit).
WithMetaLabel(text.NewInfoNodeLabelContinue()))

if err := s.deps.RecoveryFlowPersister().CreateRecoveryFlow(ctx, recoveryFlow); err != nil {
s.deps.Writer().WriteError(w, r, err)
return
}

id, err := s.deps.IdentityPool().GetIdentity(ctx, p.IdentityID, identity.ExpandDefault)
if notFoundErr := sqlcon.ErrNoRows; errors.As(err, &notFoundErr) {
s.deps.Writer().WriteError(w, r, notFoundErr.WithReasonf("could not find identity"))
Expand All @@ -203,14 +202,22 @@ func (s *Strategy) createRecoveryCodeForIdentity(w http.ResponseWriter, r *http.
return
}

rawCode := GenerateCode()
if err := s.deps.TransactionalPersisterProvider().Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
if err := s.deps.RecoveryFlowPersister().CreateRecoveryFlow(ctx, recoveryFlow); err != nil {
return err
}

if _, err := s.deps.RecoveryCodePersister().CreateRecoveryCode(ctx, &CreateRecoveryCodeParams{
RawCode: rawCode,
CodeType: RecoveryCodeTypeAdmin,
ExpiresIn: expiresIn,
FlowID: recoveryFlow.ID,
IdentityID: id.ID,
}); err != nil {
return err
}

if _, err := s.deps.RecoveryCodePersister().CreateRecoveryCode(ctx, &CreateRecoveryCodeParams{
RawCode: rawCode,
CodeType: RecoveryCodeTypeAdmin,
ExpiresIn: expiresIn,
FlowID: recoveryFlow.ID,
IdentityID: id.ID,
return nil
}); err != nil {
s.deps.Writer().WriteError(w, r, err)
return
Expand Down
1 change: 1 addition & 0 deletions selfservice/strategy/link/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type (
x.WriterProvider
x.LoggingProvider
x.TracingProvider
x.TransactionPersistenceProvider

config.Provider

Expand Down
15 changes: 9 additions & 6 deletions selfservice/strategy/link/strategy_recovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"net/url"
"time"

"github.com/gobuffalo/pop/v6"

"github.com/gofrs/uuid"
"github.com/julienschmidt/httprouter"
"github.com/pkg/errors"
Expand Down Expand Up @@ -171,11 +173,6 @@ func (s *Strategy) createRecoveryLinkForIdentity(w http.ResponseWriter, r *http.
return
}

if err := s.d.RecoveryFlowPersister().CreateRecoveryFlow(r.Context(), req); err != nil {
s.d.Writer().WriteError(w, r, err)
return
}

id, err := s.d.IdentityPool().GetIdentity(r.Context(), p.IdentityID, identity.ExpandDefault)
if errors.Is(err, sqlcon.ErrNoRows) {
s.d.Writer().WriteError(w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("The requested identity id does not exist.").WithWrap(err)))
Expand All @@ -186,7 +183,13 @@ func (s *Strategy) createRecoveryLinkForIdentity(w http.ResponseWriter, r *http.
}

token := NewAdminRecoveryToken(id.ID, req.ID, expiresIn)
if err := s.d.RecoveryTokenPersister().CreateRecoveryToken(r.Context(), token); err != nil {
if err := s.d.TransactionalPersisterProvider().Transaction(r.Context(), func(ctx context.Context, c *pop.Connection) error {
if err := s.d.RecoveryFlowPersister().CreateRecoveryFlow(ctx, req); err != nil {
return err
}

return s.d.RecoveryTokenPersister().CreateRecoveryToken(ctx, token)
}); err != nil {
s.d.Writer().WriteError(w, r, err)
return
}
Expand Down

0 comments on commit 3e87e0c

Please sign in to comment.