From 9424cb5150815e22e0db95fbb02f1b4ef9969c88 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Thu, 4 Apr 2024 13:56:44 -0600 Subject: [PATCH] GODRIVER-2800 Remove the Session interface --- internal/docexamples/examples.go | 26 +-- internal/integration/client_test.go | 3 +- internal/integration/crud_helpers_test.go | 60 +++---- internal/integration/mongos_pinning_test.go | 27 ++-- internal/integration/sessions_test.go | 13 +- internal/integration/unified/entity.go | 6 +- .../unified/testrunner_operation.go | 4 +- internal/integration/unified_spec_test.go | 22 ++- mongo/client.go | 12 +- mongo/client_test.go | 2 +- mongo/crud_examples_test.go | 9 +- mongo/session.go | 151 +++++------------- mongo/with_transactions_test.go | 2 +- 13 files changed, 133 insertions(+), 204 deletions(-) diff --git a/internal/docexamples/examples.go b/internal/docexamples/examples.go index b08447c15c..7e43919cb6 100644 --- a/internal/docexamples/examples.go +++ b/internal/docexamples/examples.go @@ -1760,7 +1760,9 @@ func UpdateEmployeeInfo(ctx context.Context, client *mongo.Client) error { events := client.Database("reporting").Collection("events") return client.UseSession(ctx, func(sctx mongo.SessionContext) error { - err := sctx.StartTransaction(options.Transaction(). + sess := mongo.SessionFromContext(sctx) + + err := sess.StartTransaction(options.Transaction(). SetReadConcern(readconcern.Snapshot()). SetWriteConcern(writeconcern.Majority()), ) @@ -1770,19 +1772,19 @@ func UpdateEmployeeInfo(ctx context.Context, client *mongo.Client) error { _, err = employees.UpdateOne(sctx, bson.D{{"employee", 3}}, bson.D{{"$set", bson.D{{"status", "Inactive"}}}}) if err != nil { - sctx.AbortTransaction(sctx) + sess.AbortTransaction(sctx) log.Println("caught exception during transaction, aborting.") return err } _, err = events.InsertOne(sctx, bson.D{{"employee", 3}, {"status", bson.D{{"new", "Inactive"}, {"old", "Active"}}}}) if err != nil { - sctx.AbortTransaction(sctx) + sess.AbortTransaction(sctx) log.Println("caught exception during transaction, aborting.") return err } for { - err = sctx.CommitTransaction(sctx) + err = sess.CommitTransaction(sctx) switch e := err.(type) { case nil: return nil @@ -1830,8 +1832,10 @@ func RunTransactionWithRetry(sctx mongo.SessionContext, txnFn func(mongo.Session // CommitWithRetry is an example function demonstrating transaction commit with retry logic. func CommitWithRetry(sctx mongo.SessionContext) error { + sess := mongo.SessionFromContext(sctx) + for { - err := sctx.CommitTransaction(sctx) + err := sess.CommitTransaction(sctx) switch e := err.(type) { case nil: log.Println("Transaction committed.") @@ -1892,8 +1896,10 @@ func TransactionsExamples(ctx context.Context, client *mongo.Client) error { } commitWithRetry := func(sctx mongo.SessionContext) error { + sess := mongo.SessionFromContext(sctx) + for { - err := sctx.CommitTransaction(sctx) + err := sess.CommitTransaction(sctx) switch e := err.(type) { case nil: log.Println("Transaction committed.") @@ -1918,7 +1924,9 @@ func TransactionsExamples(ctx context.Context, client *mongo.Client) error { employees := client.Database("hr").Collection("employees") events := client.Database("reporting").Collection("events") - err := sctx.StartTransaction(options.Transaction(). + sess := mongo.SessionFromContext(sctx) + + err := sess.StartTransaction(options.Transaction(). SetReadConcern(readconcern.Snapshot()). SetWriteConcern(writeconcern.Majority()), ) @@ -1928,13 +1936,13 @@ func TransactionsExamples(ctx context.Context, client *mongo.Client) error { _, err = employees.UpdateOne(sctx, bson.D{{"employee", 3}}, bson.D{{"$set", bson.D{{"status", "Inactive"}}}}) if err != nil { - sctx.AbortTransaction(sctx) + sess.AbortTransaction(sctx) log.Println("caught exception during transaction, aborting.") return err } _, err = events.InsertOne(sctx, bson.D{{"employee", 3}, {"status", bson.D{{"new", "Inactive"}, {"old", "Active"}}}}) if err != nil { - sctx.AbortTransaction(sctx) + sess.AbortTransaction(sctx) log.Println("caught exception during transaction, aborting.") return err } diff --git a/internal/integration/client_test.go b/internal/integration/client_test.go index 8350db58e0..1aa50a5bf5 100644 --- a/internal/integration/client_test.go +++ b/internal/integration/client_test.go @@ -374,8 +374,7 @@ func TestClient(t *testing.T) { sess, err := mt.Client.StartSession(tc.opts) assert.Nil(mt, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - xs := sess.(mongo.XSession) - consistent := xs.ClientSession().Consistent + consistent := sess.ClientSession().Consistent assert.Equal(mt, tc.consistent, consistent, "expected consistent to be %v, got %v", tc.consistent, consistent) }) } diff --git a/internal/integration/crud_helpers_test.go b/internal/integration/crud_helpers_test.go index 80abf29231..0677e6489d 100644 --- a/internal/integration/crud_helpers_test.go +++ b/internal/integration/crud_helpers_test.go @@ -158,7 +158,7 @@ type watcher interface { Watch(context.Context, interface{}, ...*options.ChangeStreamOptions) (*mongo.ChangeStream, error) } -func executeAggregate(mt *mtest.T, agg aggregator, sess mongo.Session, args bson.Raw) (*mongo.Cursor, error) { +func executeAggregate(mt *mtest.T, agg aggregator, sess *mongo.Session, args bson.Raw) (*mongo.Cursor, error) { mt.Helper() var pipeline []interface{} @@ -198,7 +198,7 @@ func executeAggregate(mt *mtest.T, agg aggregator, sess mongo.Session, args bson return agg.Aggregate(context.Background(), pipeline, opts) } -func executeWatch(mt *mtest.T, w watcher, sess mongo.Session, args bson.Raw) (*mongo.ChangeStream, error) { +func executeWatch(mt *mtest.T, w watcher, sess *mongo.Session, args bson.Raw) (*mongo.ChangeStream, error) { mt.Helper() pipeline := []interface{}{} @@ -227,7 +227,7 @@ func executeWatch(mt *mtest.T, w watcher, sess mongo.Session, args bson.Raw) (*m return w.Watch(context.Background(), pipeline) } -func executeCountDocuments(mt *mtest.T, sess mongo.Session, args bson.Raw) (int64, error) { +func executeCountDocuments(mt *mtest.T, sess *mongo.Session, args bson.Raw) (int64, error) { mt.Helper() filter := emptyDoc @@ -265,7 +265,7 @@ func executeCountDocuments(mt *mtest.T, sess mongo.Session, args bson.Raw) (int6 return mt.Coll.CountDocuments(context.Background(), filter, opts) } -func executeInsertOne(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.InsertOneResult, error) { +func executeInsertOne(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.InsertOneResult, error) { mt.Helper() doc := emptyDoc @@ -299,7 +299,7 @@ func executeInsertOne(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.In return mt.Coll.InsertOne(context.Background(), doc, opts) } -func executeInsertMany(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.InsertManyResult, error) { +func executeInsertMany(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.InsertManyResult, error) { mt.Helper() var docs []interface{} @@ -362,7 +362,7 @@ func setFindModifiers(modifiersDoc bson.Raw, opts *options.FindOptions) { } } -func executeFind(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.Cursor, error) { +func executeFind(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.Cursor, error) { mt.Helper() filter := emptyDoc @@ -410,7 +410,7 @@ func executeFind(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.Cursor, return mt.Coll.Find(context.Background(), filter, opts) } -func executeRunCommand(mt *mtest.T, sess mongo.Session, args bson.Raw) *mongo.SingleResult { +func executeRunCommand(mt *mtest.T, sess *mongo.Session, args bson.Raw) *mongo.SingleResult { mt.Helper() cmd := emptyDoc @@ -443,7 +443,7 @@ func executeRunCommand(mt *mtest.T, sess mongo.Session, args bson.Raw) *mongo.Si return mt.DB.RunCommand(context.Background(), cmd, opts) } -func executeListCollections(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.Cursor, error) { +func executeListCollections(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.Cursor, error) { mt.Helper() filter := emptyDoc @@ -472,7 +472,7 @@ func executeListCollections(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mo return mt.DB.ListCollections(context.Background(), filter) } -func executeListCollectionNames(mt *mtest.T, sess mongo.Session, args bson.Raw) ([]string, error) { +func executeListCollectionNames(mt *mtest.T, sess *mongo.Session, args bson.Raw) ([]string, error) { mt.Helper() filter := emptyDoc @@ -501,7 +501,7 @@ func executeListCollectionNames(mt *mtest.T, sess mongo.Session, args bson.Raw) return mt.DB.ListCollectionNames(context.Background(), filter) } -func executeListDatabaseNames(mt *mtest.T, sess mongo.Session, args bson.Raw) ([]string, error) { +func executeListDatabaseNames(mt *mtest.T, sess *mongo.Session, args bson.Raw) ([]string, error) { mt.Helper() filter := emptyDoc @@ -530,7 +530,7 @@ func executeListDatabaseNames(mt *mtest.T, sess mongo.Session, args bson.Raw) ([ return mt.Client.ListDatabaseNames(context.Background(), filter) } -func executeListDatabases(mt *mtest.T, sess mongo.Session, args bson.Raw) (mongo.ListDatabasesResult, error) { +func executeListDatabases(mt *mtest.T, sess *mongo.Session, args bson.Raw) (mongo.ListDatabasesResult, error) { mt.Helper() filter := emptyDoc @@ -559,7 +559,7 @@ func executeListDatabases(mt *mtest.T, sess mongo.Session, args bson.Raw) (mongo return mt.Client.ListDatabases(context.Background(), filter) } -func executeFindOne(mt *mtest.T, sess mongo.Session, args bson.Raw) *mongo.SingleResult { +func executeFindOne(mt *mtest.T, sess *mongo.Session, args bson.Raw) *mongo.SingleResult { mt.Helper() filter := emptyDoc @@ -587,7 +587,7 @@ func executeFindOne(mt *mtest.T, sess mongo.Session, args bson.Raw) *mongo.Singl return mt.Coll.FindOne(context.Background(), filter) } -func executeListIndexes(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.Cursor, error) { +func executeListIndexes(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.Cursor, error) { mt.Helper() // no arguments expected. add a Fatal in case arguments are added in the future @@ -604,7 +604,7 @@ func executeListIndexes(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo. return mt.Coll.Indexes().List(context.Background()) } -func executeDistinct(mt *mtest.T, sess mongo.Session, args bson.Raw) ([]interface{}, error) { +func executeDistinct(mt *mtest.T, sess *mongo.Session, args bson.Raw) ([]interface{}, error) { mt.Helper() var fieldName string @@ -641,7 +641,7 @@ func executeDistinct(mt *mtest.T, sess mongo.Session, args bson.Raw) ([]interfac return mt.Coll.Distinct(context.Background(), fieldName, filter, opts) } -func executeFindOneAndDelete(mt *mtest.T, sess mongo.Session, args bson.Raw) *mongo.SingleResult { +func executeFindOneAndDelete(mt *mtest.T, sess *mongo.Session, args bson.Raw) *mongo.SingleResult { mt.Helper() filter := emptyDoc @@ -680,7 +680,7 @@ func executeFindOneAndDelete(mt *mtest.T, sess mongo.Session, args bson.Raw) *mo return mt.Coll.FindOneAndDelete(context.Background(), filter, opts) } -func executeFindOneAndUpdate(mt *mtest.T, sess mongo.Session, args bson.Raw) *mongo.SingleResult { +func executeFindOneAndUpdate(mt *mtest.T, sess *mongo.Session, args bson.Raw) *mongo.SingleResult { mt.Helper() filter := emptyDoc @@ -737,7 +737,7 @@ func executeFindOneAndUpdate(mt *mtest.T, sess mongo.Session, args bson.Raw) *mo return mt.Coll.FindOneAndUpdate(context.Background(), filter, update, opts) } -func executeFindOneAndReplace(mt *mtest.T, sess mongo.Session, args bson.Raw) *mongo.SingleResult { +func executeFindOneAndReplace(mt *mtest.T, sess *mongo.Session, args bson.Raw) *mongo.SingleResult { mt.Helper() filter := emptyDoc @@ -790,7 +790,7 @@ func executeFindOneAndReplace(mt *mtest.T, sess mongo.Session, args bson.Raw) *m return mt.Coll.FindOneAndReplace(context.Background(), filter, replacement, opts) } -func executeDeleteOne(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.DeleteResult, error) { +func executeDeleteOne(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.DeleteResult, error) { mt.Helper() filter := emptyDoc @@ -826,7 +826,7 @@ func executeDeleteOne(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.De return mt.Coll.DeleteOne(context.Background(), filter, opts) } -func executeDeleteMany(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.DeleteResult, error) { +func executeDeleteMany(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.DeleteResult, error) { mt.Helper() filter := emptyDoc @@ -862,7 +862,7 @@ func executeDeleteMany(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.D return mt.Coll.DeleteMany(context.Background(), filter, opts) } -func executeUpdateOne(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.UpdateResult, error) { +func executeUpdateOne(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.UpdateResult, error) { mt.Helper() filter := emptyDoc @@ -910,7 +910,7 @@ func executeUpdateOne(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.Up return mt.Coll.UpdateOne(context.Background(), filter, update, opts) } -func executeUpdateMany(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.UpdateResult, error) { +func executeUpdateMany(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.UpdateResult, error) { mt.Helper() filter := emptyDoc @@ -958,7 +958,7 @@ func executeUpdateMany(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.U return mt.Coll.UpdateMany(context.Background(), filter, update, opts) } -func executeReplaceOne(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.UpdateResult, error) { +func executeReplaceOne(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.UpdateResult, error) { mt.Helper() filter := emptyDoc @@ -1009,7 +1009,7 @@ type withTransactionArgs struct { Options bson.Raw `bson:"options"` } -func runWithTransactionOperations(mt *mtest.T, operations []*operation, sess mongo.Session) error { +func runWithTransactionOperations(mt *mtest.T, operations []*operation, sess *mongo.Session) error { mt.Helper() for _, op := range operations { @@ -1037,7 +1037,7 @@ func runWithTransactionOperations(mt *mtest.T, operations []*operation, sess mon return nil } -func executeWithTransaction(mt *mtest.T, sess mongo.Session, args bson.Raw) error { +func executeWithTransaction(mt *mtest.T, sess *mongo.Session, args bson.Raw) error { mt.Helper() var testArgs withTransactionArgs @@ -1052,7 +1052,7 @@ func executeWithTransaction(mt *mtest.T, sess mongo.Session, args bson.Raw) erro return err } -func executeBulkWrite(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.BulkWriteResult, error) { +func executeBulkWrite(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.BulkWriteResult, error) { mt.Helper() models := createBulkWriteModels(mt, args.Lookup("requests").Array()) @@ -1196,7 +1196,7 @@ func createBulkWriteModel(mt *mtest.T, rawModel bson.Raw) mongo.WriteModel { return nil } -func executeEstimatedDocumentCount(mt *mtest.T, sess mongo.Session, args bson.Raw) (int64, error) { +func executeEstimatedDocumentCount(mt *mtest.T, sess *mongo.Session, args bson.Raw) (int64, error) { mt.Helper() // no arguments expected. add a Fatal in case arguments are added in the future @@ -1255,7 +1255,7 @@ func executeGridFSDownloadByName(mt *mtest.T, bucket *mongo.GridFSBucket, args b return bucket.DownloadToStreamByName(context.Background(), file, new(bytes.Buffer)) } -func executeCreateIndex(mt *mtest.T, sess mongo.Session, args bson.Raw) (string, error) { +func executeCreateIndex(mt *mtest.T, sess *mongo.Session, args bson.Raw) (string, error) { mt.Helper() model := mongo.IndexModel{ @@ -1289,7 +1289,7 @@ func executeCreateIndex(mt *mtest.T, sess mongo.Session, args bson.Raw) (string, return mt.Coll.Indexes().CreateOne(context.Background(), model) } -func executeDropIndex(mt *mtest.T, sess mongo.Session, args bson.Raw) (bson.Raw, error) { +func executeDropIndex(mt *mtest.T, sess *mongo.Session, args bson.Raw) (bson.Raw, error) { mt.Helper() var name string @@ -1318,7 +1318,7 @@ func executeDropIndex(mt *mtest.T, sess mongo.Session, args bson.Raw) (bson.Raw, return mt.Coll.Indexes().DropOne(context.Background(), name) } -func executeDropCollection(mt *mtest.T, sess mongo.Session, args bson.Raw) error { +func executeDropCollection(mt *mtest.T, sess *mongo.Session, args bson.Raw) error { mt.Helper() var collName string @@ -1348,7 +1348,7 @@ func executeDropCollection(mt *mtest.T, sess mongo.Session, args bson.Raw) error return coll.Drop(context.Background(), dco) } -func executeCreateCollection(mt *mtest.T, sess mongo.Session, args bson.Raw) error { +func executeCreateCollection(mt *mtest.T, sess *mongo.Session, args bson.Raw) error { mt.Helper() cco := options.CreateCollection() diff --git a/internal/integration/mongos_pinning_test.go b/internal/integration/mongos_pinning_test.go index 06b31762c9..f91f16018e 100644 --- a/internal/integration/mongos_pinning_test.go +++ b/internal/integration/mongos_pinning_test.go @@ -31,21 +31,22 @@ func TestMongosPinning(t *testing.T) { mt.Run("unpin for next transaction", func(mt *mtest.T) { addresses := map[string]struct{}{} - _ = mt.Client.UseSession(context.Background(), func(sc mongo.SessionContext) error { + _ = mt.Client.UseSession(context.Background(), func(sctx mongo.SessionContext) error { + sess := mongo.SessionFromContext(sctx) // Insert a document in a transaction to pin session to a mongos - err := sc.StartTransaction() + err := sess.StartTransaction() assert.Nil(mt, err, "StartTransaction error: %v", err) - _, err = mt.Coll.InsertOne(sc, bson.D{{"x", 1}}) + _, err = mt.Coll.InsertOne(sctx, bson.D{{"x", 1}}) assert.Nil(mt, err, "InsertOne error: %v", err) - err = sc.CommitTransaction(sc) + err = sess.CommitTransaction(sctx) assert.Nil(mt, err, "CommitTransaction error: %v", err) for i := 0; i < 50; i++ { // Call Find in a new transaction to unpin from the old mongos and select a new one - err = sc.StartTransaction() + err = sess.StartTransaction() assert.Nil(mt, err, iterationErrmsg("StartTransaction", i, err)) - cursor, err := mt.Coll.Find(sc, bson.D{}) + cursor, err := mt.Coll.Find(sctx, bson.D{}) assert.Nil(mt, err, iterationErrmsg("Find", i, err)) assert.True(mt, cursor.Next(context.Background()), "Next returned false on iteration %v", i) @@ -55,7 +56,7 @@ func TestMongosPinning(t *testing.T) { err = descConn.Close() assert.Nil(mt, err, iterationErrmsg("connection Close", i, err)) - err = sc.CommitTransaction(sc) + err = sess.CommitTransaction(sctx) assert.Nil(mt, err, iterationErrmsg("CommitTransaction", i, err)) } return nil @@ -64,18 +65,20 @@ func TestMongosPinning(t *testing.T) { }) mt.Run("unpin for non transaction operation", func(mt *mtest.T) { addresses := map[string]struct{}{} - _ = mt.Client.UseSession(context.Background(), func(sc mongo.SessionContext) error { + _ = mt.Client.UseSession(context.Background(), func(sctx mongo.SessionContext) error { + sess := mongo.SessionFromContext(sctx) + // Insert a document in a transaction to pin session to a mongos - err := sc.StartTransaction() + err := sess.StartTransaction() assert.Nil(mt, err, "StartTransaction error: %v", err) - _, err = mt.Coll.InsertOne(sc, bson.D{{"x", 1}}) + _, err = mt.Coll.InsertOne(sctx, bson.D{{"x", 1}}) assert.Nil(mt, err, "InsertOne error: %v", err) - err = sc.CommitTransaction(sc) + err = sess.CommitTransaction(sctx) assert.Nil(mt, err, "CommitTransaction error: %v", err) for i := 0; i < 50; i++ { // Call Find with the session but outside of a transaction - cursor, err := mt.Coll.Find(sc, bson.D{}) + cursor, err := mt.Coll.Find(sctx, bson.D{}) assert.Nil(mt, err, iterationErrmsg("Find", i, err)) assert.True(mt, cursor.Next(context.Background()), "Next returned false on iteration %v", i) diff --git a/internal/integration/sessions_test.go b/internal/integration/sessions_test.go index a9ae56cba9..a95af5b15c 100644 --- a/internal/integration/sessions_test.go +++ b/internal/integration/sessions_test.go @@ -35,14 +35,14 @@ func TestSessionPool(t *testing.T) { sess, err := mt.Client.StartSession() assert.Nil(mt, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - initialLastUsedTime := getSessionLastUsedTime(mt, sess) + initialLastUsedTime := sess.ClientSession().LastUsed err = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error { return mt.Client.Ping(sc, readpref.Primary()) }) assert.Nil(mt, err, "WithSession error: %v", err) - newLastUsedTime := getSessionLastUsedTime(mt, sess) + newLastUsedTime := sess.ClientSession().LastUsed assert.True(mt, newLastUsedTime.After(initialLastUsedTime), "last used time %s is not after the initial last used time %s", newLastUsedTime, initialLastUsedTime) }) @@ -63,7 +63,6 @@ func TestSessions(t *testing.T) { defer sess.EndSession(context.Background()) ctx := mongo.NewSessionContext(context.Background(), sess) - assert.Equal(mt, sess.ID(), ctx.ID(), "expected Session ID %v, got %v", sess.ID(), ctx.ID()) gotSess := mongo.SessionFromContext(ctx) assert.NotNil(mt, gotSess, "expected SessionFromContext to return non-nil value, got nil") @@ -513,7 +512,7 @@ type sessionFunction struct { params []interface{} // should not include context } -func (sf sessionFunction) execute(mt *mtest.T, sess mongo.Session) error { +func (sf sessionFunction) execute(mt *mtest.T, sess *mongo.Session) error { var target reflect.Value switch sf.target { case "client": @@ -639,9 +638,3 @@ func extractSentSessionID(mt *mtest.T) []byte { _, data := lsid.Document().Lookup("id").Binary() return data } - -func getSessionLastUsedTime(mt *mtest.T, sess mongo.Session) time.Time { - xsess, ok := sess.(mongo.XSession) - assert.True(mt, ok, "expected session to implement mongo.XSession, but got %T", sess) - return xsess.ClientSession().LastUsed -} diff --git a/internal/integration/unified/entity.go b/internal/integration/unified/entity.go index 75bbee6035..873a828da7 100644 --- a/internal/integration/unified/entity.go +++ b/internal/integration/unified/entity.go @@ -191,7 +191,7 @@ type EntityMap struct { clientEntities map[string]*clientEntity dbEntites map[string]*mongo.Database collEntities map[string]*mongo.Collection - sessions map[string]mongo.Session + sessions map[string]*mongo.Session gridfsBuckets map[string]*mongo.GridFSBucket bsonValues map[string]bson.RawValue eventListEntities map[string][]bson.Raw @@ -225,7 +225,7 @@ func newEntityMap() *EntityMap { clientEntities: make(map[string]*clientEntity), collEntities: make(map[string]*mongo.Collection), dbEntites: make(map[string]*mongo.Database), - sessions: make(map[string]mongo.Session), + sessions: make(map[string]*mongo.Session), eventListEntities: make(map[string][]bson.Raw), bsonArrayEntities: make(map[string][]bson.Raw), successValues: make(map[string]int32), @@ -422,7 +422,7 @@ func (em *EntityMap) database(id string) (*mongo.Database, error) { return db, nil } -func (em *EntityMap) session(id string) (mongo.Session, error) { +func (em *EntityMap) session(id string) (*mongo.Session, error) { sess, ok := em.sessions[id] if !ok { return nil, newEntityNotFoundError("session", id) diff --git a/internal/integration/unified/testrunner_operation.go b/internal/integration/unified/testrunner_operation.go index 38f81bfed3..1079f33840 100644 --- a/internal/integration/unified/testrunner_operation.go +++ b/internal/integration/unified/testrunner_operation.go @@ -443,8 +443,8 @@ func waitForEvent(ctx context.Context, args waitForEventArguments) error { } } -func extractClientSession(sess mongo.Session) *session.Client { - return sess.(mongo.XSession).ClientSession() +func extractClientSession(sess *mongo.Session) *session.Client { + return sess.ClientSession() } func verifySessionPinnedState(ctx context.Context, sessionID string, expectedPinned bool) error { diff --git a/internal/integration/unified_spec_test.go b/internal/integration/unified_spec_test.go index c62de30698..97f854f7e7 100644 --- a/internal/integration/unified_spec_test.go +++ b/internal/integration/unified_spec_test.go @@ -366,12 +366,12 @@ func createBucket(mt *mtest.T, testFile testFile, testCase *testCase) { testCase.bucket = mt.DB.GridFSBucket(bucketOpts) } -func runOperation(mt *mtest.T, testCase *testCase, op *operation, sess0, sess1 mongo.Session) error { +func runOperation(mt *mtest.T, testCase *testCase, op *operation, sess0, sess1 *mongo.Session) error { if op.Name == "count" { mt.Skip("count has been deprecated") } - var sess mongo.Session + var sess *mongo.Session if sessVal, err := op.Arguments.LookupErr("session"); err == nil { sessStr := sessVal.StringValue() switch sessStr { @@ -442,14 +442,10 @@ func executeGridFSOperation(mt *mtest.T, bucket *mongo.GridFSBucket, op *operati return nil } -func executeTestRunnerOperation(mt *mtest.T, testCase *testCase, op *operation, sess mongo.Session) error { +func executeTestRunnerOperation(mt *mtest.T, testCase *testCase, op *operation, sess *mongo.Session) error { var clientSession *session.Client if sess != nil { - xsess, ok := sess.(mongo.XSession) - if !ok { - return fmt.Errorf("expected session type %T to implement mongo.XSession", sess) - } - clientSession = xsess.ClientSession() + clientSession = sess.ClientSession() } switch op.Name { @@ -635,7 +631,7 @@ func lastTwoIDs(mt *mtest.T) (bson.RawValue, bson.RawValue) { return first, second } -func executeSessionOperation(mt *mtest.T, op *operation, sess mongo.Session) error { +func executeSessionOperation(mt *mtest.T, op *operation, sess *mongo.Session) error { switch op.Name { case "startTransaction": var txnOpts *options.TransactionOptions @@ -654,7 +650,7 @@ func executeSessionOperation(mt *mtest.T, op *operation, sess mongo.Session) err } } -func executeCollectionOperation(mt *mtest.T, op *operation, sess mongo.Session) error { +func executeCollectionOperation(mt *mtest.T, op *operation, sess *mongo.Session) error { switch op.Name { case "countDocuments": // no results to verify with count @@ -798,7 +794,7 @@ func executeCollectionOperation(mt *mtest.T, op *operation, sess mongo.Session) return nil } -func executeDatabaseOperation(mt *mtest.T, op *operation, sess mongo.Session) error { +func executeDatabaseOperation(mt *mtest.T, op *operation, sess *mongo.Session) error { switch op.Name { case "runCommand": res := executeRunCommand(mt, sess, op.Arguments) @@ -853,7 +849,7 @@ func executeDatabaseOperation(mt *mtest.T, op *operation, sess mongo.Session) er return nil } -func executeClientOperation(mt *mtest.T, op *operation, sess mongo.Session) error { +func executeClientOperation(mt *mtest.T, op *operation, sess *mongo.Session) error { switch op.Name { case "listDatabaseNames": _, err := executeListDatabaseNames(mt, sess, op.Arguments) @@ -882,7 +878,7 @@ func executeClientOperation(mt *mtest.T, op *operation, sess mongo.Session) erro return nil } -func setupSessions(mt *mtest.T, test *testCase) (mongo.Session, mongo.Session) { +func setupSessions(mt *mtest.T, test *testCase) (*mongo.Session, *mongo.Session) { mt.Helper() var sess0Opts, sess1Opts *options.SessionOptions diff --git a/mongo/client.go b/mongo/client.go index 40c0b3c411..0bae485a22 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -374,7 +374,7 @@ func (c *Client) Ping(ctx context.Context, rp *readpref.ReadPref) error { // // If the DefaultReadConcern, DefaultWriteConcern, or DefaultReadPreference options are not set, the client's read // concern, write concern, or read preference will be used, respectively. -func (c *Client) StartSession(opts ...*options.SessionOptions) (Session, error) { +func (c *Client) StartSession(opts ...*options.SessionOptions) (*Session, error) { if c.sessionPool == nil { return nil, ErrClientDisconnected } @@ -439,7 +439,7 @@ func (c *Client) StartSession(opts ...*options.SessionOptions) (Session, error) sess.RetryWrite = false sess.RetryRead = c.retryReads - return &sessionImpl{ + return &Session{ clientSession: sess, client: c, deployment: c.deployment, @@ -786,7 +786,7 @@ func (c *Client) ListDatabaseNames(ctx context.Context, filter interface{}, opts // If the ctx parameter already contains a Session, that Session will be replaced with the one provided. // // Any error returned by the fn callback will be returned without any modifications. -func WithSession(ctx context.Context, sess Session, fn func(SessionContext) error) error { +func WithSession(ctx context.Context, sess *Session, fn func(SessionContext) error) error { return fn(NewSessionContext(ctx, sess)) } @@ -809,7 +809,11 @@ func (c *Client) UseSession(ctx context.Context, fn func(SessionContext) error) // // UseSessionWithOptions is safe to call from multiple goroutines concurrently. However, the SessionContext passed to // the UseSessionWithOptions callback function is not safe for concurrent use by multiple goroutines. -func (c *Client) UseSessionWithOptions(ctx context.Context, opts *options.SessionOptions, fn func(SessionContext) error) error { +func (c *Client) UseSessionWithOptions( + ctx context.Context, + opts *options.SessionOptions, + fn func(SessionContext) error, +) error { defaultSess, err := c.StartSession(opts) if err != nil { return err diff --git a/mongo/client_test.go b/mongo/client_test.go index ddba3062fe..e5d08642b3 100644 --- a/mongo/client_test.go +++ b/mongo/client_test.go @@ -380,7 +380,7 @@ func TestClient(t *testing.T) { // Do an application operation and create the number of sessions specified by the test. _, err = coll.CountDocuments(bgCtx, bson.D{}) assert.Nil(t, err, "CountDocuments error: %v", err) - var sessions []Session + var sessions []*Session for i := 0; i < tc.numSessions; i++ { sess, err := client.StartSession() assert.Nil(t, err, "StartSession error at index %d: %v", i, err) diff --git a/mongo/crud_examples_test.go b/mongo/crud_examples_test.go index 47127006f3..a92b7509fe 100644 --- a/mongo/crud_examples_test.go +++ b/mongo/crud_examples_test.go @@ -684,11 +684,12 @@ func ExampleClient_UseSessionWithOptions() { context.TODO(), opts, func(ctx mongo.SessionContext) error { + sess := mongo.SessionFromContext(ctx) // Use the mongo.SessionContext as the Context parameter for // InsertOne and FindOne so both operations are run under the new // Session. - if err := ctx.StartTransaction(); err != nil { + if err := sess.StartTransaction(); err != nil { return err } @@ -699,7 +700,7 @@ func ExampleClient_UseSessionWithOptions() { // context.Background() to ensure that the abort can complete // successfully even if the context passed to mongo.WithSession // is changed to have a timeout. - _ = ctx.AbortTransaction(context.Background()) + _ = sess.AbortTransaction(context.Background()) return err } @@ -713,7 +714,7 @@ func ExampleClient_UseSessionWithOptions() { // context.Background() to ensure that the abort can complete // successfully even if the context passed to mongo.WithSession // is changed to have a timeout. - _ = ctx.AbortTransaction(context.Background()) + _ = sess.AbortTransaction(context.Background()) return err } fmt.Println(result) @@ -721,7 +722,7 @@ func ExampleClient_UseSessionWithOptions() { // Use context.Background() to ensure that the commit can complete // successfully even if the context passed to mongo.WithSession is // changed to have a timeout. - return ctx.CommitTransaction(context.Background()) + return sess.CommitTransaction(context.Background()) }) if err != nil { log.Fatal(err) diff --git a/mongo/session.go b/mongo/session.go index dcd83f650c..45f224e5bd 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -27,6 +27,25 @@ var ErrWrongClient = errors.New("session was not created by this client") var withTransactionTimeout = 120 * time.Second +// Session is a MongoDB logical session. Sessions can be used to enable causal +// consistency for a group of operations or to execute operations in an ACID +// transaction. A new Session can be created from a Client instance. A Session +// created from a Client must only be used to execute operations using that +// Client or a Database or Collection created from that Client. For more +// information about sessions, and their use cases, see +// https://www.mongodb.com/docs/manual/reference/server-sessions/, +// https://www.mongodb.com/docs/manual/core/read-isolation-consistency-recency/#causal-consistency, and +// https://www.mongodb.com/docs/manual/core/transactions/. +// +// Implementations of Session are not safe for concurrent use by multiple +// goroutines. +type Session struct { + clientSession *session.Client + client *Client + deployment driver.Deployment + didCommitAfterStart bool // true if commit was called after start with no other operations +} + // SessionContext combines the context.Context and mongo.Session interfaces. It should be used as the Context arguments // to operations that should be executed in a session. // @@ -37,35 +56,31 @@ var withTransactionTimeout = 120 * time.Second // the provided callback. The other is to use NewSessionContext to explicitly create a SessionContext. type SessionContext interface { context.Context - Session } -type sessionContext struct { +type sessionCtx struct { context.Context - Session } -type sessionKey struct { -} +type sessionKey struct{} // NewSessionContext creates a new SessionContext associated with the given Context and Session parameters. -func NewSessionContext(ctx context.Context, sess Session) SessionContext { - return &sessionContext{ +func NewSessionContext(ctx context.Context, sess *Session) SessionContext { + return &sessionCtx{ Context: context.WithValue(ctx, sessionKey{}, sess), - Session: sess, } } // SessionFromContext extracts the mongo.Session object stored in a Context. This can be used on a SessionContext that // was created implicitly through one of the callback-based session APIs or explicitly by calling NewSessionContext. If // there is no Session stored in the provided Context, nil is returned. -func SessionFromContext(ctx context.Context) Session { +func SessionFromContext(ctx context.Context) *Session { val := ctx.Value(sessionKey{}) if val == nil { return nil } - sess, ok := val.(Session) + sess, ok := val.(*Session) if !ok { return nil } @@ -73,104 +88,18 @@ func SessionFromContext(ctx context.Context) Session { return sess } -// Session is an interface that represents a MongoDB logical session. Sessions can be used to enable causal consistency -// for a group of operations or to execute operations in an ACID transaction. A new Session can be created from a Client -// instance. A Session created from a Client must only be used to execute operations using that Client or a Database or -// Collection created from that Client. Custom implementations of this interface should not be used in production. For -// more information about sessions, and their use cases, see -// https://www.mongodb.com/docs/manual/reference/server-sessions/, -// https://www.mongodb.com/docs/manual/core/read-isolation-consistency-recency/#causal-consistency, and -// https://www.mongodb.com/docs/manual/core/transactions/. -// -// Implementations of Session are not safe for concurrent use by multiple goroutines. -type Session interface { - // StartTransaction starts a new transaction, configured with the given options, on this - // session. This method returns an error if there is already a transaction in-progress for this - // session. - StartTransaction(...*options.TransactionOptions) error - - // AbortTransaction aborts the active transaction for this session. This method returns an error - // if there is no active transaction for this session or if the transaction has been committed - // or aborted. - AbortTransaction(context.Context) error - - // CommitTransaction commits the active transaction for this session. This method returns an - // error if there is no active transaction for this session or if the transaction has been - // aborted. - CommitTransaction(context.Context) error - - // WithTransaction starts a transaction on this session and runs the fn callback. Errors with - // the TransientTransactionError and UnknownTransactionCommitResult labels are retried for up to - // 120 seconds. Inside the callback, the SessionContext must be used as the Context parameter - // for any operations that should be part of the transaction. If the ctx parameter already has a - // Session attached to it, it will be replaced by this session. The fn callback may be run - // multiple times during WithTransaction due to retry attempts, so it must be idempotent. - // Non-retryable operation errors or any operation errors that occur after the timeout expires - // will be returned without retrying. If the callback fails, the driver will call - // AbortTransaction. Because this method must succeed to ensure that server-side resources are - // properly cleaned up, context deadlines and cancellations will not be respected during this - // call. For a usage example, see the Client.StartSession method documentation. - WithTransaction(ctx context.Context, fn func(ctx SessionContext) (interface{}, error), - opts ...*options.TransactionOptions) (interface{}, error) - - // EndSession aborts any existing transactions and close the session. - EndSession(context.Context) - - // ClusterTime returns the current cluster time document associated with the session. - ClusterTime() bson.Raw - - // OperationTime returns the current operation time document associated with the session. - OperationTime() *primitive.Timestamp - - // Client the Client associated with the session. - Client() *Client - - // ID returns the current ID document associated with the session. The ID document is in the - // form {"id": }. - ID() bson.Raw - - // AdvanceClusterTime advances the cluster time for a session. This method returns an error if - // the session has ended. - AdvanceClusterTime(bson.Raw) error - - // AdvanceOperationTime advances the operation time for a session. This method returns an error - // if the session has ended. - AdvanceOperationTime(*primitive.Timestamp) error - - session() -} - -// XSession is an unstable interface for internal use only. -// -// Deprecated: This interface is unstable because it provides access to a session.Client object, which exists in the -// "x" package. It should not be used by applications and may be changed or removed in any release. -type XSession interface { - ClientSession() *session.Client -} - -// sessionImpl represents a set of sequential operations executed by an application that are related in some way. -type sessionImpl struct { - clientSession *session.Client - client *Client - deployment driver.Deployment - didCommitAfterStart bool // true if commit was called after start with no other operations -} - -var _ Session = &sessionImpl{} -var _ XSession = &sessionImpl{} - // ClientSession implements the XSession interface. -func (s *sessionImpl) ClientSession() *session.Client { +func (s *Session) ClientSession() *session.Client { return s.clientSession } // ID implements the Session interface. -func (s *sessionImpl) ID() bson.Raw { +func (s *Session) ID() bson.Raw { return bson.Raw(s.clientSession.SessionID) } // EndSession implements the Session interface. -func (s *sessionImpl) EndSession(ctx context.Context) { +func (s *Session) EndSession(ctx context.Context) { if s.clientSession.TransactionInProgress() { // ignore all errors aborting during an end session _ = s.AbortTransaction(ctx) @@ -179,7 +108,7 @@ func (s *sessionImpl) EndSession(ctx context.Context) { } // WithTransaction implements the Session interface. -func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(ctx SessionContext) (interface{}, error), +func (s *Session) WithTransaction(ctx context.Context, fn func(ctx SessionContext) (interface{}, error), opts ...*options.TransactionOptions) (interface{}, error) { timeout := time.NewTimer(withTransactionTimeout) defer timeout.Stop() @@ -259,7 +188,7 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(ctx SessionCo } // StartTransaction implements the Session interface. -func (s *sessionImpl) StartTransaction(opts ...*options.TransactionOptions) error { +func (s *Session) StartTransaction(opts ...*options.TransactionOptions) error { err := s.clientSession.CheckStartTransaction() if err != nil { return err @@ -296,7 +225,7 @@ func (s *sessionImpl) StartTransaction(opts ...*options.TransactionOptions) erro } // AbortTransaction implements the Session interface. -func (s *sessionImpl) AbortTransaction(ctx context.Context) error { +func (s *Session) AbortTransaction(ctx context.Context) error { err := s.clientSession.CheckAbortTransaction() if err != nil { return err @@ -322,7 +251,7 @@ func (s *sessionImpl) AbortTransaction(ctx context.Context) error { } // CommitTransaction implements the Session interface. -func (s *sessionImpl) CommitTransaction(ctx context.Context) error { +func (s *Session) CommitTransaction(ctx context.Context) error { err := s.clientSession.CheckCommitTransaction() if err != nil { return err @@ -366,39 +295,35 @@ func (s *sessionImpl) CommitTransaction(ctx context.Context) error { } // ClusterTime implements the Session interface. -func (s *sessionImpl) ClusterTime() bson.Raw { +func (s *Session) ClusterTime() bson.Raw { return s.clientSession.ClusterTime } // AdvanceClusterTime implements the Session interface. -func (s *sessionImpl) AdvanceClusterTime(d bson.Raw) error { +func (s *Session) AdvanceClusterTime(d bson.Raw) error { return s.clientSession.AdvanceClusterTime(d) } // OperationTime implements the Session interface. -func (s *sessionImpl) OperationTime() *primitive.Timestamp { +func (s *Session) OperationTime() *primitive.Timestamp { return s.clientSession.OperationTime } // AdvanceOperationTime implements the Session interface. -func (s *sessionImpl) AdvanceOperationTime(ts *primitive.Timestamp) error { +func (s *Session) AdvanceOperationTime(ts *primitive.Timestamp) error { return s.clientSession.AdvanceOperationTime(ts) } // Client implements the Session interface. -func (s *sessionImpl) Client() *Client { +func (s *Session) Client() *Client { return s.client } -// session implements the Session interface. -func (*sessionImpl) session() { -} - // sessionFromContext checks for a sessionImpl in the argued context and returns the session if it // exists func sessionFromContext(ctx context.Context) *session.Client { s := ctx.Value(sessionKey{}) - if ses, ok := s.(*sessionImpl); ses != nil && ok { + if ses, ok := s.(*Session); ses != nil && ok { return ses.clientSession } diff --git a/mongo/with_transactions_test.go b/mongo/with_transactions_test.go index f65ba7b4f1..eacc12d864 100644 --- a/mongo/with_transactions_test.go +++ b/mongo/with_transactions_test.go @@ -325,7 +325,7 @@ func TestConvenientTransactions(t *testing.T) { "expected timeout error error; got %v", commitErr) // Assert session state is not Committed. - clientSession := session.(XSession).ClientSession() + clientSession := session.ClientSession() assert.False(t, clientSession.TransactionCommitted(), "expected session state to not be Committed") // AbortTransaction without error.