diff --git a/README.md b/README.md index 1f5226d22c..fec7a74565 100644 --- a/README.md +++ b/README.md @@ -261,7 +261,7 @@ to a single process. - Overwrite this method to generate your custom socket id. - **Parameters** - `http.IncomingMessage`: a node request object - - **Returns** A socket id for connected client. + - `Function`: a callback method which contains an error (if there is) object and the generated id value

diff --git a/lib/server.js b/lib/server.js index d52120a587..d48c4ad793 100644 --- a/lib/server.js +++ b/lib/server.js @@ -281,8 +281,8 @@ function sendErrorMessage (req, res, code) { * @api public */ -Server.prototype.generateId = function (req) { - return base64id.generateId(); +Server.prototype.generateId = function (req, callback) { + callback(null, base64id.generateId()); }; /** @@ -294,52 +294,56 @@ Server.prototype.generateId = function (req) { */ Server.prototype.handshake = function (transportName, req) { - var id = this.generateId(req); - - debug('handshaking client "%s"', id); - - try { - var transport = new transports[transportName](req); - if ('polling' === transportName) { - transport.maxHttpBufferSize = this.maxHttpBufferSize; - transport.httpCompression = this.httpCompression; - } else if ('websocket' === transportName) { - transport.perMessageDeflate = this.perMessageDeflate; + var self = this; + this.generateId(req, function (err, id) { + if (err) { + sendErrorMessage(req, req.res, Server.errors.BAD_REQUEST); + return; } + debug('handshaking client "%s"', id); + + try { + var transport = new transports[transportName](req); + if ('polling' === transportName) { + transport.maxHttpBufferSize = self.maxHttpBufferSize; + transport.httpCompression = self.httpCompression; + } else if ('websocket' === transportName) { + transport.perMessageDeflate = self.perMessageDeflate; + } - if (req._query && req._query.b64) { - transport.supportsBinary = false; - } else { - transport.supportsBinary = true; + if (req._query && req._query.b64) { + transport.supportsBinary = false; + } else { + transport.supportsBinary = true; + } + } catch (e) { + sendErrorMessage(req, req.res, Server.errors.BAD_REQUEST); + return; + } + var socket = new Socket(id, self, transport, req); + + if (false !== self.cookie) { + transport.on('headers', function (headers) { + headers['Set-Cookie'] = cookieMod.serialize(self.cookie, id, + { + path: self.cookiePath, + httpOnly: self.cookiePath ? self.cookieHttpOnly : false + }); + }); } - } catch (e) { - sendErrorMessage(req, req.res, Server.errors.BAD_REQUEST); - return; - } - var socket = new Socket(id, this, transport, req); - var self = this; - if (false !== this.cookie) { - transport.on('headers', function (headers) { - headers['Set-Cookie'] = cookieMod.serialize(self.cookie, id, - { - path: self.cookiePath, - httpOnly: self.cookiePath ? self.cookieHttpOnly : false - }); - }); - } + transport.onRequest(req); - transport.onRequest(req); + self.clients[id] = socket; + self.clientsCount++; - this.clients[id] = socket; - this.clientsCount++; + socket.once('close', function () { + delete self.clients[id]; + self.clientsCount--; + }); - socket.once('close', function () { - delete self.clients[id]; - self.clientsCount--; + self.emit('connection', socket); }); - - this.emit('connection', socket); }; /** diff --git a/test/server.js b/test/server.js index 59bb41d2ff..c081503fe3 100644 --- a/test/server.js +++ b/test/server.js @@ -1,3 +1,4 @@ +'use strict'; /* eslint-disable standard/no-callback-literal */ /** @@ -234,8 +235,8 @@ describe('server', function () { var customId = 'CustomId' + Date.now(); - engine.generateId = function (req) { - return customId; + engine.generateId = function (req, callback) { + callback(null, customId); }; var socket = new eioc.Socket('ws://localhost:%d'.s(port)); @@ -249,6 +250,18 @@ describe('server', function () { }); }); + it('should disallow connection when custom id cannot be generated', function (done) { + let engine = listen({ allowUpgrades: false }, port => { + engine.generateId = (req, callback) => { + callback(new Error('no ID found')); + }; + + let socket = new eioc.Socket('ws://localhost:%d'.s(port)); + socket.on('open', () => done(new Error('should not be able to connect'))); + socket.on('error', () => done()); + }); + }); + it('should exchange handshake data', function (done) { listen({ allowUpgrades: false }, function (port) { var socket = new eioc.Socket('ws://localhost:%d'.s(port));