From b8af7cf8a26487644eeb1df8024f3d8453538413 Mon Sep 17 00:00:00 2001 From: Thomas Bruyelle Date: Mon, 11 Jun 2018 13:09:39 +0200 Subject: [PATCH] Add support for ssl dial string (#184) * Add support for ssl dial string * Ensure we dont override user settings * update examples * update ssl value parsing * PingSsl test * skip test requiring system certificates --- example_test.go | 16 +++++++++++++++- session.go | 19 +++++++++++++++++++ session_test.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 1 deletion(-) diff --git a/example_test.go b/example_test.go index d176d5f5c..9775ba9e1 100644 --- a/example_test.go +++ b/example_test.go @@ -137,7 +137,21 @@ func ExampleSession_concurrency() { func ExampleDial_usingSSL() { // To connect via TLS/SSL (enforced for MongoDB Atlas for example) requires - // configuring the dialer to use a TLS connection: + // to set the ssl query param to true. + url := "mongodb://localhost:40003?ssl=true" + + session, err := Dial(url) + if err != nil { + panic(err) + } + + // Use session as normal + session.Close() +} + +func ExampleDial_tlsConfig() { + // You can define a custom tlsConfig, this one enables TLS, like if you have + // ssl=true in the connection string. url := "mongodb://localhost:40003" tlsConfig := &tls.Config{ diff --git a/session.go b/session.go index c053dba39..cd2a53e19 100644 --- a/session.go +++ b/session.go @@ -28,6 +28,7 @@ package mgo import ( "crypto/md5" + "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" @@ -294,6 +295,12 @@ const ( // The identifier of this client application. This parameter is used to // annotate logs / profiler output and cannot exceed 128 bytes. // +// ssl= +// +// true: Initiate the connection with TLS/SSL. +// false: Initiate the connection without TLS/SSL. +// The default value is false. +// // Relevant documentation: // // http://docs.mongodb.org/manual/reference/connection-string/ @@ -331,6 +338,7 @@ func ParseURL(url string) (*DialInfo, error) { if err != nil { return nil, err } + ssl := false direct := false mechanism := "" service := "" @@ -345,6 +353,10 @@ func ParseURL(url string) (*DialInfo, error) { safe := Safe{} for _, opt := range uinfo.options { switch opt.key { + case "ssl": + if v, err := strconv.ParseBool(opt.value); err == nil && v { + ssl = true + } case "authSource": source = opt.value case "authMechanism": @@ -460,6 +472,13 @@ func ParseURL(url string) (*DialInfo, error) { MinPoolSize: minPoolSize, MaxIdleTimeMS: maxIdleTimeMS, } + if ssl && info.DialServer == nil { + // Set DialServer only if nil, we don't want to override user's settings. + info.DialServer = func(addr *ServerAddr) (net.Conn, error) { + conn, err := tls.Dial("tcp", addr.String(), &tls.Config{}) + return conn, err + } + } return &info, nil } diff --git a/session_test.go b/session_test.go index f3ac2da92..0a897b61d 100644 --- a/session_test.go +++ b/session_test.go @@ -87,6 +87,15 @@ func (s *S) TestPing(c *C) { c.Assert(stats.ReceivedOps, Equals, 1) } +func (s *S) TestPingSsl(c *C) { + c.Skip("this test requires the usage of the system provided certificates") + session, err := mgo.Dial("localhost:40001?ssl=true") + c.Assert(err, IsNil) + defer session.Close() + + c.Assert(session.Ping(), IsNil) +} + func (s *S) TestDialIPAddress(c *C) { session, err := mgo.Dial("127.0.0.1:40001") c.Assert(err, IsNil) @@ -135,6 +144,25 @@ func (s *S) TestURLParsing(c *C) { } } +func (s *S) TestURLSsl(c *C) { + type test struct { + url string + nilDialServer bool + } + + tests := []test{ + {"localhost:40001", true}, + {"localhost:40001?ssl=false", true}, + {"localhost:40001?ssl=true", false}, + } + + for _, test := range tests { + info, err := mgo.ParseURL(test.url) + c.Assert(err, IsNil) + c.Assert(info.DialServer == nil, Equals, test.nilDialServer) + } +} + func (s *S) TestURLReadPreference(c *C) { type test struct { url string