diff --git a/packages/pg-pool/index.js b/packages/pg-pool/index.js index f53a85ab1..0d54be906 100644 --- a/packages/pg-pool/index.js +++ b/packages/pg-pool/index.js @@ -277,30 +277,67 @@ class Pool extends EventEmitter { } else { this.log('new client connected') - if (this.options.maxLifetimeSeconds !== 0) { - const maxLifetimeTimeout = setTimeout(() => { - this.log('ending client due to expired lifetime') - this._expired.add(client) - const idleIndex = this._idle.findIndex((idleItem) => idleItem.client === client) - if (idleIndex !== -1) { - this._acquireClient( - client, - new PendingItem((err, client, clientRelease) => clientRelease()), - idleListener, - false - ) - } - }, this.options.maxLifetimeSeconds * 1000) - - maxLifetimeTimeout.unref() - client.once('end', () => clearTimeout(maxLifetimeTimeout)) + if (this.options.onConnect) { + let hookResult + try { + hookResult = this.options.onConnect(client) + } catch (hookErr) { + this._clients = this._clients.filter((c) => c !== client) + client.end(() => { + this._pulseQueue() + if (!pendingItem.timedOut) { + pendingItem.callback(hookErr, undefined, NOOP) + } + }) + return + } + if (hookResult && typeof hookResult.then === 'function') { + hookResult.then( + () => { + this._afterConnect(client, pendingItem, idleListener) + }, + (hookErr) => { + this._clients = this._clients.filter((c) => c !== client) + client.end(() => { + this._pulseQueue() + if (!pendingItem.timedOut) { + pendingItem.callback(hookErr, undefined, NOOP) + } + }) + } + ) + return + } } - return this._acquireClient(client, pendingItem, idleListener, true) + return this._afterConnect(client, pendingItem, idleListener) } }) } + _afterConnect(client, pendingItem, idleListener) { + if (this.options.maxLifetimeSeconds !== 0) { + const maxLifetimeTimeout = setTimeout(() => { + this.log('ending client due to expired lifetime') + this._expired.add(client) + const idleIndex = this._idle.findIndex((idleItem) => idleItem.client === client) + if (idleIndex !== -1) { + this._acquireClient( + client, + new PendingItem((err, client, clientRelease) => clientRelease()), + idleListener, + false + ) + } + }, this.options.maxLifetimeSeconds * 1000) + + maxLifetimeTimeout.unref() + client.once('end', () => clearTimeout(maxLifetimeTimeout)) + } + + return this._acquireClient(client, pendingItem, idleListener, true) + } + // acquire a client for a pending work item _acquireClient(client, pendingItem, idleListener, isNew) { if (isNew) { diff --git a/packages/pg-pool/test/lifecycle-hooks.js b/packages/pg-pool/test/lifecycle-hooks.js new file mode 100644 index 000000000..05a706d95 --- /dev/null +++ b/packages/pg-pool/test/lifecycle-hooks.js @@ -0,0 +1,164 @@ +const describe = require('mocha').describe +const it = require('mocha').it +const expect = require('expect.js') + +const Pool = require('..') + +describe('lifecycle hooks', () => { + it('are called on connect', async () => { + const pool = new Pool({ + onConnect: (client) => { + client.HOOK_CONNECT_COUNT = (client.HOOK_CONNECT_COUNT || 0) + 1 + }, + }) + const client = await pool.connect() + expect(client.HOOK_CONNECT_COUNT).to.equal(1) + client.release() + const client2 = await pool.connect() + expect(client).to.equal(client2) + expect(client2.HOOK_CONNECT_COUNT).to.equal(1) + client.release() + await pool.end() + }) + + it('are called on connect with an async hook', async () => { + const pool = new Pool({ + onConnect: async (client) => { + const res = await client.query('SELECT 1 AS num') + client.HOOK_CONNECT_RESULT = res.rows[0].num + }, + }) + const client = await pool.connect() + expect(client.HOOK_CONNECT_RESULT).to.equal(1) + const res = await client.query('SELECT 1 AS num') + expect(res.rows[0].num).to.equal(1) + client.release() + const client2 = await pool.connect() + expect(client).to.equal(client2) + expect(client2.HOOK_CONNECT_RESULT).to.equal(1) + client.release() + await pool.end() + }) + + it('errors out the connect call if the async connect hook rejects', async () => { + const pool = new Pool({ + onConnect: async (client) => { + await client.query('SELECT INVALID HERE') + }, + }) + try { + await pool.connect() + throw new Error('Expected connect to throw') + } catch (err) { + expect(err.message).to.contain('invalid') + } + await pool.end() + }) + + it('calls onConnect when using pool.query', async () => { + const pool = new Pool({ + onConnect: async (client) => { + const res = await client.query('SELECT 1 AS num') + client.HOOK_CONNECT_RESULT = res.rows[0].num + }, + }) + const res = await pool.query('SELECT $1::text AS name', ['brianc']) + expect(res.rows[0].name).to.equal('brianc') + const client = await pool.connect() + expect(client.HOOK_CONNECT_RESULT).to.equal(1) + client.release() + await pool.end() + }) + + it('recovers after a hook error', async () => { + let shouldError = true + const pool = new Pool({ + onConnect: () => { + if (shouldError) { + throw new Error('connect hook error') + } + }, + }) + try { + await pool.connect() + throw new Error('Expected connect to throw') + } catch (err) { + expect(err.message).to.equal('connect hook error') + } + shouldError = false + const client = await pool.connect() + const res = await client.query('SELECT 1 AS num') + expect(res.rows[0].num).to.equal(1) + client.release() + await pool.end() + }) + + it('calls onConnect for each new client', async () => { + let connectCount = 0 + const pool = new Pool({ + max: 2, + onConnect: async (client) => { + connectCount++ + await client.query('SELECT 1') + }, + }) + const client1 = await pool.connect() + const client2 = await pool.connect() + expect(connectCount).to.equal(2) + expect(client1).to.not.equal(client2) + client1.release() + client2.release() + await pool.end() + }) + + it('cleans up clients after repeated hook failures', async () => { + let errorCount = 0 + const pool = new Pool({ + max: 2, + onConnect: () => { + if (errorCount < 10) { + errorCount++ + throw new Error('connect hook error') + } + }, + }) + for (let i = 0; i < 10; i++) { + let threw = false + try { + await pool.connect() + } catch (err) { + threw = true + expect(err.message).to.equal('connect hook error') + } + expect(threw).to.equal(true) + } + expect(errorCount).to.equal(10) + expect(pool.totalCount).to.equal(0) + expect(pool.idleCount).to.equal(0) + const client1 = await pool.connect() + const res1 = await client1.query('SELECT 1 AS num') + expect(res1.rows[0].num).to.equal(1) + const client2 = await pool.connect() + const res2 = await client2.query('SELECT 2 AS num') + expect(res2.rows[0].num).to.equal(2) + expect(pool.totalCount).to.equal(2) + client1.release() + client2.release() + await pool.end() + }) + + it('errors out the connect call if the connect hook throws', async () => { + const pool = new Pool({ + onConnect: () => { + throw new Error('connect hook error') + }, + }) + try { + await pool.connect() + throw new Error('Expected connect to throw') + } catch (err) { + expect(err.message).to.equal('connect hook error') + } + await pool.end() + }) +}) diff --git a/packages/pg-pool/test/timeout.js b/packages/pg-pool/test/timeout.js deleted file mode 100644 index e69de29bb..000000000