From 936f36b9fc1679df41bebc1ce9b56fb6ba19fcbc Mon Sep 17 00:00:00 2001 From: Pierre-Marie Padiou Date: Wed, 31 Mar 2021 16:12:06 +0200 Subject: [PATCH] Refactor Postgres code (#1743) More symmetry between postgres and sqlite init. * define SqliteDatabases and PostgresDatabase Those classes implement traits `Databases`, `FileBackup` and `ExclusiveLock`. The goal is to have access to backend-specific attributes, particularly in tests. It arguably makes the `Databases` cleaner and simpler, with a nice symmetry between the `apply methods`. * replace 5s lock timeout by NOLOCK * use chaindir instead of datadir for jdbcurl file It is more consistent with sqlite, and makes sense because we don't want to mix up testnet and mainnet databases. * add tests on locks and jdbc url check --- .../main/scala/fr/acinq/eclair/Setup.scala | 2 +- .../scala/fr/acinq/eclair/db/Databases.scala | 300 +++++++-------- .../fr/acinq/eclair/db/pg/PgChannelsDb.scala | 4 +- .../fr/acinq/eclair/db/pg/PgPaymentsDb.scala | 11 +- .../fr/acinq/eclair/db/pg/PgPeersDb.scala | 7 +- .../acinq/eclair/db/pg/PgPendingRelayDb.scala | 2 +- .../scala/fr/acinq/eclair/db/pg/PgUtils.scala | 362 ++++++++++-------- .../scala/fr/acinq/eclair/StartupSpec.scala | 2 +- .../scala/fr/acinq/eclair/TestConstants.scala | 106 +---- .../scala/fr/acinq/eclair/TestDatabases.scala | 81 ++++ .../db/sqlite/SqliteWalletDbSpec.scala | 6 +- .../blockchain/fee/DbFeeProviderSpec.scala | 4 +- ...iteAuditDbSpec.scala => AuditDbSpec.scala} | 18 +- ...nnelsDbSpec.scala => ChannelsDbSpec.scala} | 35 +- ...ratesDbSpec.scala => FeeratesDbSpec.scala} | 10 +- .../eclair/db/FileBackupHandlerSpec.scala | 11 +- ...etworkDbSpec.scala => NetworkDbSpec.scala} | 24 +- ...mentsDbSpec.scala => PaymentsDbSpec.scala} | 22 +- ...itePeersDbSpec.scala => PeersDbSpec.scala} | 14 +- ...yDbSpec.scala => PendingRelayDbSpec.scala} | 14 +- .../fr/acinq/eclair/db/PgUtilsSpec.scala | 114 ++++++ .../fr/acinq/eclair/db/SqliteUtilsSpec.scala | 5 +- 22 files changed, 649 insertions(+), 505 deletions(-) create mode 100644 eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala rename eclair-core/src/test/scala/fr/acinq/eclair/db/{SqliteAuditDbSpec.scala => AuditDbSpec.scala} (98%) rename eclair-core/src/test/scala/fr/acinq/eclair/db/{SqliteChannelsDbSpec.scala => ChannelsDbSpec.scala} (90%) rename eclair-core/src/test/scala/fr/acinq/eclair/db/{SqliteFeeratesDbSpec.scala => FeeratesDbSpec.scala} (93%) rename eclair-core/src/test/scala/fr/acinq/eclair/db/{SqliteNetworkDbSpec.scala => NetworkDbSpec.scala} (95%) rename eclair-core/src/test/scala/fr/acinq/eclair/db/{SqlitePaymentsDbSpec.scala => PaymentsDbSpec.scala} (98%) rename eclair-core/src/test/scala/fr/acinq/eclair/db/{SqlitePeersDbSpec.scala => PeersDbSpec.scala} (89%) rename eclair-core/src/test/scala/fr/acinq/eclair/db/{SqlitePendingRelayDbSpec.scala => PendingRelayDbSpec.scala} (89%) create mode 100644 eclair-core/src/test/scala/fr/acinq/eclair/db/PgUtilsSpec.scala diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala index 70642371df..c22304c24e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala @@ -98,7 +98,7 @@ class Setup(datadir: File, logger.info(s"instanceid=$instanceId") - val databases = Databases.init(config.getConfig("db"), instanceId, datadir, chaindir, db) + val databases = Databases.init(config.getConfig("db"), instanceId, chaindir, db) /** * This counter holds the current blockchain height. diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/Databases.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/Databases.scala index 773aafdf79..2e0f52aef2 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/Databases.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/Databases.scala @@ -16,89 +16,146 @@ package fr.acinq.eclair.db -import java.io.File -import java.nio.file._ -import java.sql.{Connection, DriverManager} -import java.util.UUID - import akka.actor.ActorSystem import com.typesafe.config.Config -import fr.acinq.eclair.db.pg.PgUtils.LockType.LockType +import com.zaxxer.hikari.{HikariConfig, HikariDataSource} +import fr.acinq.eclair.db.pg.PgUtils.PgLock.LockFailureHandler import fr.acinq.eclair.db.pg.PgUtils._ import fr.acinq.eclair.db.pg._ import fr.acinq.eclair.db.sqlite._ import grizzled.slf4j.Logging -import javax.sql.DataSource +import java.io.File +import java.nio.file._ +import java.sql.{Connection, DriverManager} +import java.util.UUID import scala.concurrent.duration._ trait Databases { + //@formatter:off + def network: NetworkDb + def audit: AuditDb + def channels: ChannelsDb + def peers: PeersDb + def payments: PaymentsDb + def pendingRelay: PendingRelayDb + //@formatter:on +} - val network: NetworkDb +object Databases extends Logging { - val audit: AuditDb + trait FileBackup { + this: Databases => + def backup(backupFile: File): Unit + } - val channels: ChannelsDb + trait ExclusiveLock { + this: Databases => + def obtainExclusiveLock(): Unit + } - val peers: PeersDb + case class SqliteDatabases private (network: SqliteNetworkDb, + audit: SqliteAuditDb, + channels: SqliteChannelsDb, + peers: SqlitePeersDb, + payments: SqlitePaymentsDb, + pendingRelay: SqlitePendingRelayDb, + private val backupConnection: Connection) extends Databases with FileBackup { + override def backup(backupFile: File): Unit = SqliteUtils.using(backupConnection.createStatement()) { + statement => { + statement.executeUpdate(s"backup to ${backupFile.getAbsolutePath}") + } + } + } - val payments: PaymentsDb + object SqliteDatabases { + def apply(auditJdbc: Connection, networkJdbc: Connection, eclairJdbc: Connection): Databases = SqliteDatabases( + network = new SqliteNetworkDb(networkJdbc), + audit = new SqliteAuditDb(auditJdbc), + channels = new SqliteChannelsDb(eclairJdbc), + peers = new SqlitePeersDb(eclairJdbc), + payments = new SqlitePaymentsDb(eclairJdbc), + pendingRelay = new SqlitePendingRelayDb(eclairJdbc), + backupConnection = eclairJdbc + ) + } - val pendingRelay: PendingRelayDb -} + case class PostgresDatabases private (network: PgNetworkDb, + audit: PgAuditDb, + channels: PgChannelsDb, + peers: PgPeersDb, + payments: PgPaymentsDb, + pendingRelay: PgPendingRelayDb, + dataSource: HikariDataSource, + lock: PgLock) extends Databases with ExclusiveLock { + override def obtainExclusiveLock(): Unit = lock.obtainExclusiveLock(dataSource) + } -object Databases extends Logging { + object PostgresDatabases { + def apply(hikariConfig: HikariConfig, + instanceId: UUID, + lock: PgLock = PgLock.NoLock, + jdbcUrlFile_opt: Option[File])(implicit system: ActorSystem): PostgresDatabases = { + + jdbcUrlFile_opt.foreach(jdbcUrlFile => checkIfDatabaseUrlIsUnchanged(hikariConfig.getJdbcUrl, jdbcUrlFile)) + + implicit val ds: HikariDataSource = new HikariDataSource(hikariConfig) + implicit val implicitLock: PgLock = lock + + val databases = PostgresDatabases( + network = new PgNetworkDb, + audit = new PgAuditDb, + channels = new PgChannelsDb, + peers = new PgPeersDb, + payments = new PgPaymentsDb, + pendingRelay = new PgPendingRelayDb, + dataSource = ds, + lock = lock) + + lock match { + case PgLock.NoLock => () + case l: PgLock.LeaseLock => + // we obtain a lock right now... + databases.obtainExclusiveLock() + // ...and renew the lease regularly + import system.dispatcher + system.scheduler.scheduleWithFixedDelay(l.leaseRenewInterval, l.leaseRenewInterval)(() => databases.obtainExclusiveLock()) + } - trait FileBackup { this: Databases => - def backup(backupFile: File): Unit - } + databases + } - trait ExclusiveLock { this: Databases => - def obtainExclusiveLock(): Unit + private def checkIfDatabaseUrlIsUnchanged(url: String, urlFile: File): Unit = { + def readString(path: Path): String = Files.readAllLines(path).get(0) + + def writeString(path: Path, string: String): Unit = Files.write(path, java.util.Arrays.asList(string)) + + if (urlFile.exists()) { + val oldUrl = readString(urlFile.toPath) + if (oldUrl != url) + throw JdbcUrlChanged(oldUrl, url) + } else { + writeString(urlFile.toPath, url) + } + } } - def init(dbConfig: Config, instanceId: UUID, datadir: File, chaindir: File, db: Option[Databases] = None)(implicit system: ActorSystem): Databases = { + def init(dbConfig: Config, instanceId: UUID, chaindir: File, db: Option[Databases] = None)(implicit system: ActorSystem): Databases = { db match { case Some(d) => d case None => dbConfig.getString("driver") match { - case "sqlite" => Databases.sqliteJDBC(chaindir) - case "postgres" => - val pg = Databases.setupPgDatabases(dbConfig, instanceId, datadir, { ex => - logger.error("fatal error: Cannot obtain lock on the database.\n", ex) - sys.exit(-2) - }) - if (LockType(dbConfig.getString("postgres.lock-type")) == LockType.LEASE) { - val dbLockLeaseRenewInterval = dbConfig.getDuration("postgres.lease.renew-interval").toSeconds.seconds - val dbLockLeaseInterval = dbConfig.getDuration("postgres.lease.interval").toSeconds.seconds - if (dbLockLeaseInterval <= dbLockLeaseRenewInterval) - throw new RuntimeException("Invalid configuration: `db.postgres.lease.interval` must be greater than `db.postgres.lease.renew-interval`") - import system.dispatcher - system.scheduler.scheduleWithFixedDelay(dbLockLeaseRenewInterval, dbLockLeaseRenewInterval)(new Runnable { - override def run(): Unit = { - try { - pg.obtainExclusiveLock() - } catch { - case e: Throwable => - logger.error("fatal error: Cannot obtain the database lease.\n", e) - sys.exit(-1) - } - } - }) - } - pg - case driver => throw new RuntimeException(s"Unknown database driver `$driver`") + case "sqlite" => Databases.sqlite(chaindir) + case "postgres" => Databases.postgres(dbConfig, instanceId, chaindir) + case driver => throw new RuntimeException(s"unknown database driver `$driver`") } } } /** - * Given a parent folder it creates or loads all the databases from a JDBC connection - * - * @param dbdir - * @return - */ - def sqliteJDBC(dbdir: File): Databases = { + * Given a parent folder it creates or loads all the databases from a JDBC connection + */ + def sqlite(dbdir: File): Databases = { dbdir.mkdir() var sqliteEclair: Connection = null var sqliteNetwork: Connection = null @@ -109,127 +166,52 @@ object Databases extends Logging { sqliteAudit = DriverManager.getConnection(s"jdbc:sqlite:${new File(dbdir, "audit.sqlite")}") SqliteUtils.obtainExclusiveLock(sqliteEclair) // there should only be one process writing to this file logger.info("successful lock on eclair.sqlite") - sqliteDatabaseByConnections(sqliteAudit, sqliteNetwork, sqliteEclair) + SqliteDatabases(sqliteAudit, sqliteNetwork, sqliteEclair) } catch { - case t: Throwable => { + case t: Throwable => logger.error("could not create connection to sqlite databases: ", t) if (sqliteEclair != null) sqliteEclair.close() if (sqliteNetwork != null) sqliteNetwork.close() if (sqliteAudit != null) sqliteAudit.close() throw t - } - } - } - - def postgresJDBC(database: String, host: String, port: Int, - username: Option[String], password: Option[String], - poolProperties: Map[String, Long], - instanceId: UUID, - databaseLeaseInterval: FiniteDuration, - lockExceptionHandler: LockExceptionHandler = { _ => () }, - lockType: LockType = LockType.NONE, datadir: File): Databases with ExclusiveLock = { - val url = s"jdbc:postgresql://${host}:${port}/${database}" - - checkIfDatabaseUrlIsUnchanged(url, datadir) - - implicit val lock: DatabaseLock = lockType match { - case LockType.NONE => NoLock - case LockType.LEASE => LeaseLock(instanceId, databaseLeaseInterval, lockExceptionHandler) - case _ => throw new RuntimeException(s"Unknown postgres lock type: `$lockType`") - } - - import com.zaxxer.hikari.{HikariConfig, HikariDataSource} - - val config = new HikariConfig() - config.setJdbcUrl(url) - username.foreach(config.setUsername) - password.foreach(config.setPassword) - poolProperties.get("max-size").foreach(x => config.setMaximumPoolSize(x.toInt)) - poolProperties.get("connection-timeout").foreach(config.setConnectionTimeout) - poolProperties.get("idle-timeout").foreach(config.setIdleTimeout) - poolProperties.get("max-life-time").foreach(config.setMaxLifetime) - - implicit val ds: DataSource = new HikariDataSource(config) - - val databases = new Databases with ExclusiveLock { - override val network = new PgNetworkDb - override val audit = new PgAuditDb - override val channels = new PgChannelsDb - override val peers = new PgPeersDb - override val payments = new PgPaymentsDb - override val pendingRelay = new PgPendingRelayDb - override def obtainExclusiveLock(): Unit = lock.obtainExclusiveLock - } - databases.obtainExclusiveLock() - databases - } - - def sqliteDatabaseByConnections(auditJdbc: Connection, networkJdbc: Connection, eclairJdbc: Connection): Databases = new Databases with FileBackup { - override val network = new SqliteNetworkDb(networkJdbc) - override val audit = new SqliteAuditDb(auditJdbc) - override val channels = new SqliteChannelsDb(eclairJdbc) - override val peers = new SqlitePeersDb(eclairJdbc) - override val payments = new SqlitePaymentsDb(eclairJdbc) - override val pendingRelay = new SqlitePendingRelayDb(eclairJdbc) - override def backup(backupFile: File): Unit = { - - SqliteUtils.using(eclairJdbc.createStatement()) { - statement => { - statement.executeUpdate(s"backup to ${backupFile.getAbsolutePath}") - } - } - } } - def setupPgDatabases(dbConfig: Config, instanceId: UUID, datadir: File, lockExceptionHandler: LockExceptionHandler): Databases with ExclusiveLock = { + def postgres(dbConfig: Config, instanceId: UUID, dbdir: File, lockExceptionHandler: LockFailureHandler = LockFailureHandler.logAndStop)(implicit system: ActorSystem): PostgresDatabases = { val database = dbConfig.getString("postgres.database") val host = dbConfig.getString("postgres.host") val port = dbConfig.getInt("postgres.port") - val username = if (dbConfig.getIsNull("postgres.username") || dbConfig.getString("postgres.username").isEmpty) - None - else - Some(dbConfig.getString("postgres.username")) - val password = if (dbConfig.getIsNull("postgres.password") || dbConfig.getString("postgres.password").isEmpty) - None - else - Some(dbConfig.getString("postgres.password")) - val properties = { - val poolConfig = dbConfig.getConfig("postgres.pool") - Map.empty - .updated("max-size", poolConfig.getInt("max-size").toLong) - .updated("connection-timeout", poolConfig.getDuration("connection-timeout").toMillis) - .updated("idle-timeout", poolConfig.getDuration("idle-timeout").toMillis) - .updated("max-life-time", poolConfig.getDuration("max-life-time").toMillis) - + val username = if (dbConfig.getIsNull("postgres.username") || dbConfig.getString("postgres.username").isEmpty) None else Some(dbConfig.getString("postgres.username")) + val password = if (dbConfig.getIsNull("postgres.password") || dbConfig.getString("postgres.password").isEmpty) None else Some(dbConfig.getString("postgres.password")) + + val hikariConfig = new HikariConfig() + hikariConfig.setJdbcUrl(s"jdbc:postgresql://$host:$port/$database") + username.foreach(hikariConfig.setUsername) + password.foreach(hikariConfig.setPassword) + val poolConfig = dbConfig.getConfig("postgres.pool") + hikariConfig.setMaximumPoolSize(poolConfig.getInt("max-size")) + hikariConfig.setConnectionTimeout(poolConfig.getDuration("connection-timeout").toMillis) + hikariConfig.setIdleTimeout(poolConfig.getDuration("idle-timeout").toMillis) + hikariConfig.setMaxLifetime(poolConfig.getDuration("max-life-time").toMillis) + + val lock = dbConfig.getString("postgres.lock-type") match { + case "none" => PgLock.NoLock + case "lease" => + val leaseInterval = dbConfig.getDuration("postgres.lease.interval").toSeconds.seconds + val leaseRenewInterval = dbConfig.getDuration("postgres.lease.renew-interval").toSeconds.seconds + require(leaseInterval > leaseRenewInterval, "invalid configuration: `db.postgres.lease.interval` must be greater than `db.postgres.lease.renew-interval`") + PgLock.LeaseLock(instanceId, leaseInterval, leaseRenewInterval, lockExceptionHandler) + case unknownLock => throw new RuntimeException(s"unknown postgres lock type: `$unknownLock`") } - val lockType = LockType(dbConfig.getString("postgres.lock-type")) - val leaseInterval = dbConfig.getDuration("postgres.lease.interval").toSeconds.seconds - Databases.postgresJDBC( - database = database, host = host, port = port, - username = username, password = password, - poolProperties = properties, + val jdbcUrlFile = new File(dbdir, "last_jdbcurl") + + Databases.PostgresDatabases( + hikariConfig = hikariConfig, instanceId = instanceId, - databaseLeaseInterval = leaseInterval, - lockExceptionHandler = lockExceptionHandler, lockType = lockType, datadir = datadir + lock = lock, + jdbcUrlFile_opt = Some(jdbcUrlFile) ) } - private def checkIfDatabaseUrlIsUnchanged(url: String, datadir: File ): Unit = { - val urlFile = new File(datadir, "last_jdbcurl") - - def readString(path: Path): String = Files.readAllLines(path).get(0) - - def writeString(path: Path, string: String): Unit = Files.write(path, java.util.Arrays.asList(string)) - - if (urlFile.exists()) { - val oldUrl = readString(urlFile.toPath) - if (oldUrl != url) - throw new RuntimeException(s"The database URL has changed since the last start. It was `$oldUrl`, now it's `$url`") - } else { - writeString(urlFile.toPath, url) - } - } - -} \ No newline at end of file +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala index 0f60b7c8c4..aa28ea16aa 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala @@ -22,7 +22,7 @@ import fr.acinq.eclair.channel.HasCommitments import fr.acinq.eclair.db.ChannelsDb import fr.acinq.eclair.db.DbEventHandler.ChannelEvent import fr.acinq.eclair.db.Monitoring.Metrics.withMetrics -import fr.acinq.eclair.db.pg.PgUtils.DatabaseLock +import fr.acinq.eclair.db.pg.PgUtils.PgLock import fr.acinq.eclair.wire.internal.channel.ChannelCodecs.stateDataCodec import grizzled.slf4j.Logging @@ -30,7 +30,7 @@ import java.sql.Statement import javax.sql.DataSource import scala.collection.immutable.Queue -class PgChannelsDb(implicit ds: DataSource, lock: DatabaseLock) extends ChannelsDb with Logging { +class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb with Logging { import PgUtils.ExtendedResultSet._ import PgUtils._ diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala index 778af2bff3..e5f134fff6 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala @@ -16,27 +16,26 @@ package fr.acinq.eclair.db.pg -import java.sql.ResultSet -import java.util.UUID - import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.MilliSatoshi import fr.acinq.eclair.db.Monitoring.Metrics.withMetrics -import fr.acinq.eclair.db.pg.PgUtils.DatabaseLock import fr.acinq.eclair.db._ +import fr.acinq.eclair.db.pg.PgUtils.PgLock import fr.acinq.eclair.payment.{PaymentFailed, PaymentRequest, PaymentSent} import fr.acinq.eclair.wire.protocol.CommonCodecs import grizzled.slf4j.Logging -import javax.sql.DataSource import scodec.Attempt import scodec.bits.BitVector import scodec.codecs._ +import java.sql.ResultSet +import java.util.UUID +import javax.sql.DataSource import scala.collection.immutable.Queue import scala.concurrent.duration._ -class PgPaymentsDb(implicit ds: DataSource, lock: DatabaseLock) extends PaymentsDb with Logging { +class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb with Logging { import PgUtils.ExtendedResultSet._ import PgUtils._ diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala index 6e8be88100..9bae575e67 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala @@ -20,12 +20,13 @@ import fr.acinq.bitcoin.Crypto import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.db.Monitoring.Metrics.withMetrics import fr.acinq.eclair.db.PeersDb -import fr.acinq.eclair.db.pg.PgUtils.DatabaseLock +import fr.acinq.eclair.db.pg.PgUtils.PgLock import fr.acinq.eclair.wire.protocol._ -import javax.sql.DataSource import scodec.bits.BitVector -class PgPeersDb(implicit ds: DataSource, lock: DatabaseLock) extends PeersDb { +import javax.sql.DataSource + +class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb { import PgUtils.ExtendedResultSet._ import PgUtils._ diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPendingRelayDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPendingRelayDb.scala index db6aae3e37..66754af103 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPendingRelayDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPendingRelayDb.scala @@ -27,7 +27,7 @@ import javax.sql.DataSource import scala.collection.immutable.Queue -class PgPendingRelayDb(implicit ds: DataSource, lock: DatabaseLock) extends PendingRelayDb { +class PgPendingRelayDb(implicit ds: DataSource, lock: PgLock) extends PendingRelayDb { import PgUtils.ExtendedResultSet._ import PgUtils._ diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgUtils.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgUtils.scala index b73031cfdc..1bbeb25f7b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgUtils.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgUtils.scala @@ -16,92 +16,230 @@ package fr.acinq.eclair.db.pg -import java.sql.{Connection, Statement, Timestamp} -import java.util.UUID - import fr.acinq.eclair.db.jdbc.JdbcUtils +import fr.acinq.eclair.db.pg.PgUtils.PgLock.LockFailureHandler.LockException import grizzled.slf4j.Logging -import javax.sql.DataSource import org.postgresql.util.{PGInterval, PSQLException} +import java.sql.{Connection, Statement, Timestamp} +import java.util.UUID +import javax.sql.DataSource import scala.concurrent.duration._ -import scala.util.{Failure, Success, Try} -object PgUtils extends JdbcUtils with Logging { +object PgUtils extends JdbcUtils { - val LeaseTable = "lease" + /** We raise this exception when the jdbc url changes, to prevent using a different server involuntarily. */ + case class JdbcUrlChanged(before: String, after: String) extends RuntimeException(s"The database URL has changed since the last start. It was `$before`, now it's `$after`") - val LockTimeout = 5 seconds + sealed trait PgLock { + def obtainExclusiveLock(implicit ds: DataSource): Unit - val TransactionIsolationLevel = Connection.TRANSACTION_SERIALIZABLE + def withLock[T](f: Connection => T)(implicit ds: DataSource): T + } - object LockType extends Enumeration { - type LockType = Value + object PgLock extends Logging { - val NONE, LEASE = Value + // @formatter:off + sealed trait LockFailure + object LockFailure { + case object TooManyLockAttempts extends LockFailure + case class AlreadyLocked(lockedBy: UUID) extends LockFailure + case object LeaseExpired extends LockFailure + case object NoLeaseInfo extends LockFailure + case class GeneralLockException(cause: Throwable) extends LockFailure + } + // @formatter:on - def apply(s: String): LockType = s match { - case "none" => NONE - case "lease" => LEASE - case _ => throw new RuntimeException(s"Unknown postgres lock type: `$s`") + type LockFailureHandler = LockFailure => Unit + + object LockFailureHandler { + def log: LockFailureHandler = { + case LockFailure.GeneralLockException(cause) => + logger.error("cannot obtain lock on the database.\n", cause) + case other => + logger.error(s"cannot obtain lock on the database ($other).") + } + + case class LockException(lockFailure: LockFailure) extends RuntimeException("a lock exception occurred") + + /** + * This handler is useful in tests + */ + def logAndThrow: LockFailureHandler = { ex => + log(ex) + throw LockException(ex) + } + + /** + * This is the recommended handler in production + */ + def logAndStop: LockFailureHandler = { ex => + log(ex) + logger.error("db locking error is a fatal error") + sys.exit(-2) + } } - } - case class LockLease(expiresAt: Timestamp, instanceId: UUID, expired: Boolean) - // @formatter:off - class TooManyLockAttempts(msg: String) extends RuntimeException(msg) - class UninitializedLockTable(msg: String) extends RuntimeException(msg) - class LockException(msg: String, cause: Option[Throwable] = None) extends RuntimeException(msg, cause.orNull) - class LeaseException(msg: String) extends RuntimeException(msg) - // @formatter:on + case object NoLock extends PgLock { + override def obtainExclusiveLock(implicit ds: DataSource): Unit = () - type LockExceptionHandler = LockException => Unit + override def withLock[T](f: Connection => T)(implicit ds: DataSource): T = + inTransaction(f) + } - sealed trait DatabaseLock { - def obtainExclusiveLock(implicit ds: DataSource): Unit + /** + * This class represents a lease based locking mechanism [[https://en.wikipedia.org/wiki/Lease_(computer_science]]. + * It allows only one process to access the database at a time. + * + * `obtainExclusiveLock` method updates the record in `lease` table with the instance id and the expiration date + * calculated as the current time plus the lease duration. If the current lease is not expired or it belongs to + * another instance `obtainExclusiveLock` throws an exception. + * + * withLock method executes its `f` function and reads the record from lease table to checks if this instance still + * holds the lease and it's not expired. If so, the database transaction gets committed, otherwise en exception is thrown. + * + * `lockExceptionHandler` provides a lock exception handler to customize the behavior when locking errors occur. + */ + case class LeaseLock(instanceId: UUID, leaseDuration: FiniteDuration, leaseRenewInterval: FiniteDuration, lockFailureHandler: LockFailureHandler) extends PgLock { + + import LeaseLock._ + + override def obtainExclusiveLock(implicit ds: DataSource): Unit = { + obtainDatabaseLease(instanceId, leaseDuration) match { + case Right(_) => () + case Left(ex) => lockFailureHandler(ex) + } + } - def withLock[T](f: Connection => T)(implicit ds: DataSource): T - } + override def withLock[T](f: Connection => T)(implicit ds: DataSource): T = { + inTransaction { connection => + val res = f(connection) + checkDatabaseLease(connection, instanceId) match { + case Right(_) => () + case Left(ex) => + lockFailureHandler(ex) + // at this point, a sane failure handler would have either thrown an exception or stopped the app + // but we can't be careful enough so we make sure we throw here + throw LockException(ex) + } + res + } + } + } - case object NoLock extends DatabaseLock { - override def obtainExclusiveLock(implicit ds: DataSource): Unit = () + object LeaseLock { + + private val LeaseTable: String = "lease" + + /** We use a [[LeaseLock]] mechanism to get a [[LockLease]]. */ + case class LockLease(expiresAt: Timestamp, instanceId: UUID, expired: Boolean) + + private def obtainDatabaseLease(instanceId: UUID, leaseDuration: FiniteDuration, attempt: Int = 1)(implicit ds: DataSource): Either[LockFailure, LockLease] = synchronized { + logger.debug(s"trying to acquire database lease (attempt #$attempt) instance ID=$instanceId") + + // this is a recursive method, we need to make sure we don't enter an infinite loop + if (attempt > 3) return Left(LockFailure.TooManyLockAttempts) + + try { + inTransaction { implicit connection => + acquireExclusiveTableLock() + logger.debug("database lease was successfully acquired") + checkDatabaseLease(connection, instanceId) match { + case Right(_) => + Right(updateLease(instanceId, leaseDuration)) + case Left(LockFailure.LeaseExpired) => + // the previous lease has expired, we can take over + // this happens if we have stopped the app, waited some the previous lease to expire, and then restarted + Right(updateLease(instanceId, leaseDuration)) + case Left(LockFailure.NoLeaseInfo) => + // this is the first lock we ever put on the table + Right(updateLease(instanceId, leaseDuration, insertNew = true)) + case otherFailure => otherFailure + } + } + } catch { + case e: PSQLException if e.getServerErrorMessage != null && e.getServerErrorMessage.getSQLState == "42P01" => + withConnection { + connection => + logger.warn(s"table $LeaseTable does not exist, trying to create it") + initializeLeaseTable(connection) + obtainDatabaseLease(instanceId, leaseDuration, attempt + 1) + } + case t: Throwable => Left(LockFailure.GeneralLockException(t)) + } + } - override def withLock[T](f: Connection => T)(implicit ds: DataSource): T = - inTransaction(f) - } + private def initializeLeaseTable(implicit connection: Connection): Unit = { + using(connection.createStatement()) { + statement => + // allow only one row in the ownership lease table + statement.executeUpdate(s"CREATE TABLE IF NOT EXISTS $LeaseTable (id INTEGER PRIMARY KEY default(1), expires_at TIMESTAMP NOT NULL, instance VARCHAR NOT NULL, CONSTRAINT one_row CHECK (id = 1))") + } + } - /** - * This class represents a lease based locking mechanism [[https://en.wikipedia.org/wiki/Lease_(computer_science]]. - * It allows only one process to access the database at a time. - * - * `obtainExclusiveLock` method updates the record in `lease` table with the instance id and the expiration date - * calculated as the current time plus the lease duration. If the current lease is not expired or it belongs to - * another instance `obtainExclusiveLock` throws an exception. - * - * withLock method executes its `f` function and reads the record from lease table to checks if this instance still - * holds the lease and it's not expired. If so, the database transaction gets committed, otherwise en exception is thrown. - * - * `lockExceptionHandler` provides a lock exception handler to customize the behavior when locking errors occur. - */ - case class LeaseLock(instanceId: UUID, leaseDuration: FiniteDuration, lockExceptionHandler: LockExceptionHandler) extends DatabaseLock { - override def obtainExclusiveLock(implicit ds: DataSource): Unit = - obtainDatabaseLease(instanceId, leaseDuration) - - override def withLock[T](f: Connection => T)(implicit ds: DataSource): T = { - inTransaction { connection => - val res = f(connection) - checkDatabaseLease(connection, instanceId, lockExceptionHandler) - res + private def acquireExclusiveTableLock()(implicit connection: Connection): Unit = { + using(connection.createStatement()) { + statement => + statement.executeUpdate(s"LOCK TABLE $LeaseTable IN ACCESS EXCLUSIVE MODE NOWAIT") + } + } + + private def checkDatabaseLease(connection: Connection, instanceId: UUID): Either[LockFailure, LockLease] = { + try { + getCurrentLease(connection) match { + case Some(lease) => + if (lease.expired) { + Left(LockFailure.LeaseExpired) + } else if (lease.instanceId != instanceId) { + Left(LockFailure.AlreadyLocked(lease.instanceId)) + } else { + Right(lease) + } + case None => + Left(LockFailure.NoLeaseInfo) + } + } catch { + case t: Throwable => Left(LockFailure.GeneralLockException(t)) + } + } + + private def getCurrentLease(implicit connection: Connection): Option[LockLease] = { + using(connection.createStatement()) { + statement => + val rs = statement.executeQuery(s"SELECT expires_at, instance, now() > expires_at AS expired FROM $LeaseTable WHERE id = 1") + if (rs.next()) + Some(LockLease( + expiresAt = rs.getTimestamp("expires_at"), + instanceId = UUID.fromString(rs.getString("instance")), + expired = rs.getBoolean("expired"))) + else + None + } + } + + private def updateLease(instanceId: UUID, leaseDuration: FiniteDuration, insertNew: Boolean = false)(implicit connection: Connection): LockLease = { + val sql = if (insertNew) + s"INSERT INTO $LeaseTable (expires_at, instance) VALUES (now() + ?, ?)" + else + s"UPDATE $LeaseTable SET expires_at = now() + ?, instance = ? WHERE id = 1" + using(connection.prepareStatement(sql)) { + statement => + statement.setObject(1, new PGInterval(s"${leaseDuration.toSeconds} seconds")) + statement.setString(2, instanceId.toString) + statement.executeUpdate() + } + getCurrentLease.get // TODO: improve that (do INSERT/UPDATE+SELECT?) } } + } def inTransaction[T](connection: Connection)(f: Connection => T): T = { val autoCommit = connection.getAutoCommit connection.setAutoCommit(false) val isolationLevel = connection.getTransactionIsolation - connection.setTransactionIsolation(TransactionIsolationLevel) + connection.setTransactionIsolation(Connection.TRANSACTION_SERIALIZABLE) try { val res = f(connection) connection.commit() @@ -123,10 +261,10 @@ object PgUtils extends JdbcUtils with Logging { } /** - * Several logical databases (channels, network, peers) may be stored in the same physical postgres database. - * We keep track of their respective version using a dedicated table. The version entry will be created if - * there is none but will never be updated here (use setVersion to do that). - */ + * Several logical databases (channels, network, peers) may be stored in the same physical postgres database. + * We keep track of their respective version using a dedicated table. The version entry will be created if + * there is none but will never be updated here (use setVersion to do that). + */ def getVersion(statement: Statement, db_name: String, currentVersion: Int): Int = { statement.executeUpdate("CREATE TABLE IF NOT EXISTS versions (db_name TEXT NOT NULL PRIMARY KEY, version INTEGER NOT NULL)") // if there was no version for the current db, then insert the current version @@ -138,110 +276,12 @@ object PgUtils extends JdbcUtils with Logging { } /** - * Updates the version for a particular logical database, it will overwrite the previous version. - */ + * Updates the version for a particular logical database, it will overwrite the previous version. + */ def setVersion(statement: Statement, db_name: String, newVersion: Int): Unit = { statement.executeUpdate("CREATE TABLE IF NOT EXISTS versions (db_name TEXT NOT NULL PRIMARY KEY, version INTEGER NOT NULL)") // overwrite the existing version statement.executeUpdate(s"UPDATE versions SET version=$newVersion WHERE db_name='$db_name'") } - private def obtainDatabaseLease(instanceId: UUID, leaseDuration: FiniteDuration, attempt: Int = 1)(implicit ds: DataSource): Unit = synchronized { - logger.debug(s"trying to acquire database lease (attempt #$attempt) instance ID=${instanceId}") - - if (attempt > 3) throw new TooManyLockAttempts("Too many attempts to acquire database lease") - - try { - inTransaction { implicit connection => - acquireExclusiveTableLock() - getCurrentLease match { - case Some(lease) => - if (lease.instanceId == instanceId || lease.expired) - updateLease(instanceId, leaseDuration) - else - throw new LeaseException(s"The database is locked by instance ID=${lease.instanceId}") - case None => - updateLease(instanceId, leaseDuration, insertNew = true) - } - } - logger.debug("database lease was successfully acquired") - } catch { - case e: PSQLException if (e.getServerErrorMessage != null && e.getServerErrorMessage.getSQLState == "42P01") => - withConnection { - connection => - logger.warn(s"table $LeaseTable does not exist, trying to recreate it") - initializeLeaseTable(connection) - obtainDatabaseLease(instanceId, leaseDuration, attempt + 1) - } - } - } - - private def initializeLeaseTable(implicit connection: Connection): Unit = { - using(connection.createStatement()) { - statement => - // allow only one row in the ownership lease table - statement.executeUpdate(s"CREATE TABLE IF NOT EXISTS $LeaseTable (id INTEGER PRIMARY KEY default(1), expires_at TIMESTAMP NOT NULL, instance VARCHAR NOT NULL, CONSTRAINT one_row CHECK (id = 1))") - } - } - - private def acquireExclusiveTableLock()(implicit connection: Connection): Unit = { - using(connection.createStatement()) { - statement => - statement.executeUpdate(s"SET lock_timeout TO '${LockTimeout.toSeconds}s'") - statement.executeUpdate(s"LOCK TABLE $LeaseTable IN ACCESS EXCLUSIVE MODE") - } - } - - private def checkDatabaseLease(connection: Connection, instanceId: UUID, lockExceptionHandler: LockExceptionHandler): Unit = { - Try { - getCurrentLease(connection) match { - case Some(lease) => - if (!(lease.instanceId == instanceId) || lease.expired) { - logger.info(s"database lease: $lease") - throw new LockException("This Eclair instance is not a database owner") - } - case None => - throw new LockException("No database lease info") - } - } match { - case Success(_) => () - case Failure(ex) => - val lex = ex match { - case e: LockException => e - case t: Throwable => new LockException("Cannot check database lease", Some(t)) - } - lockExceptionHandler(lex) - throw lex - } - } - - private def getCurrentLease(implicit connection: Connection): Option[LockLease] = { - using(connection.createStatement()) { - statement => - val rs = statement.executeQuery(s"SELECT expires_at, instance, now() > expires_at AS expired FROM $LeaseTable WHERE id = 1") - if (rs.next()) - Some(LockLease( - expiresAt = rs.getTimestamp("expires_at"), - instanceId = UUID.fromString(rs.getString("instance")), - expired = rs.getBoolean("expired"))) - else - None - } - } - - private def updateLease(instanceId: UUID, leaseDuration: FiniteDuration, insertNew: Boolean = false)(implicit connection: Connection): Unit = { - val sql = if (insertNew) - s"INSERT INTO $LeaseTable (expires_at, instance) VALUES (now() + ?, ?)" - else - s"UPDATE $LeaseTable SET expires_at = now() + ?, instance = ? WHERE id = 1" - using(connection.prepareStatement(sql)) { - statement => - statement.setObject(1, new PGInterval(s"${ - leaseDuration.toSeconds - } seconds")) - statement.setString(2, instanceId.toString) - statement.executeUpdate() - } - } - } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/StartupSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/StartupSpec.scala index d111d389f2..b091a833c7 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/StartupSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/StartupSpec.scala @@ -40,7 +40,7 @@ class StartupSpec extends AnyFunSuite { val nodeKeyManager = new LocalNodeKeyManager(randomBytes32, chainHash = Block.TestnetGenesisBlock.hash) val channelKeyManager = new LocalChannelKeyManager(randomBytes32, chainHash = Block.TestnetGenesisBlock.hash) val feeEstimator = new TestConstants.TestFeeEstimator - val db = TestConstants.inMemoryDb() + val db = TestDatabases.inMemoryDb() NodeParams.makeNodeParams(conf, UUID.fromString("01234567-0123-4567-89ab-0123456789ab"), nodeKeyManager, channelKeyManager, None, db, blockCount, feeEstimator) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala index 6eda03435a..a22856a89c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala @@ -16,26 +16,20 @@ package fr.acinq.eclair -import com.opentable.db.postgres.embedded.EmbeddedPostgres import fr.acinq.bitcoin.Crypto.PrivateKey -import fr.acinq.bitcoin.{Block, ByteVector32, SatoshiLong, Script} +import fr.acinq.bitcoin.{Block, ByteVector32, Satoshi, SatoshiLong, Script} import fr.acinq.eclair.FeatureSupport.Optional import fr.acinq.eclair.Features._ import fr.acinq.eclair.NodeParams.BITCOIND -import fr.acinq.eclair.blockchain.fee._ +import fr.acinq.eclair.blockchain.fee.{FeeEstimator, FeeTargets, FeeratesPerKw, OnChainFeeConf, _} +import fr.acinq.eclair.channel.LocalParams import fr.acinq.eclair.crypto.keymanager.{LocalChannelKeyManager, LocalNodeKeyManager} -import fr.acinq.eclair.db._ -import fr.acinq.eclair.db.pg.PgUtils.NoLock -import fr.acinq.eclair.db.pg._ -import fr.acinq.eclair.db.sqlite._ import fr.acinq.eclair.io.{Peer, PeerConnection} import fr.acinq.eclair.router.Router.RouterConf -import fr.acinq.eclair.wire.protocol -import fr.acinq.eclair.wire.protocol.{Color, EncodingType, NodeAddress} +import fr.acinq.eclair.wire.protocol.{Color, EncodingType, NodeAddress, OnionRoutingPacket} import org.scalatest.Tag import scodec.bits.ByteVector -import java.sql.{Connection, DriverManager, Statement} import java.util.UUID import java.util.concurrent.atomic.AtomicLong import scala.concurrent.duration._ @@ -46,11 +40,11 @@ import scala.concurrent.duration._ object TestConstants { val defaultBlockHeight = 400000 - val fundingSatoshis = 1000000L sat - val pushMsat = 200000000L msat - val feeratePerKw = FeeratePerKw(10000 sat) - val anchorOutputsFeeratePerKw = FeeratePerKw(2500 sat) - val emptyOnionPacket = protocol.OnionRoutingPacket(0, ByteVector.fill(33)(0), ByteVector.fill(1300)(0), ByteVector32.Zeroes) + val fundingSatoshis: Satoshi = 1000000L sat + val pushMsat: MilliSatoshi = 200000000L msat + val feeratePerKw: FeeratePerKw = FeeratePerKw(10000 sat) + val anchorOutputsFeeratePerKw: FeeratePerKw = FeeratePerKw(2500 sat) + val emptyOnionPacket: OnionRoutingPacket = OnionRoutingPacket(0, ByteVector.fill(33)(0), ByteVector.fill(1300)(0), ByteVector32.Zeroes) class TestFeeEstimator extends FeeEstimator { private var currentFeerates = FeeratesPerKw.single(feeratePerKw) @@ -64,76 +58,12 @@ object TestConstants { } } - sealed trait TestDatabases { - // @formatter:off - val connection: Connection - def network(): NetworkDb - def audit(): AuditDb - def channels(): ChannelsDb - def peers(): PeersDb - def payments(): PaymentsDb - def pendingRelay(): PendingRelayDb - def getVersion(statement: Statement, db_name: String, currentVersion: Int): Int - def close(): Unit - // @formatter:on - } - - case class TestSqliteDatabases(connection: Connection = sqliteInMemory()) extends TestDatabases { - // @formatter:off - override def network(): NetworkDb = new SqliteNetworkDb(connection) - override def audit(): AuditDb = new SqliteAuditDb(connection) - override def channels(): ChannelsDb = new SqliteChannelsDb(connection) - override def peers(): PeersDb = new SqlitePeersDb(connection) - override def payments(): PaymentsDb = new SqlitePaymentsDb(connection) - override def pendingRelay(): PendingRelayDb = new SqlitePendingRelayDb(connection) - override def getVersion(statement: Statement, db_name: String, currentVersion: Int): Int = SqliteUtils.getVersion(statement, db_name, currentVersion) - override def close(): Unit = () - // @formatter:on - } - - case class TestPgDatabases() extends TestDatabases { - private val pg = EmbeddedPostgres.start() - - override val connection: Connection = pg.getPostgresDatabase.getConnection - - import com.zaxxer.hikari.{HikariConfig, HikariDataSource} - - val config = new HikariConfig - config.setDataSource(pg.getPostgresDatabase) - - implicit val ds = new HikariDataSource(config) - implicit val lock = NoLock - - // @formatter:off - override def network(): NetworkDb = new PgNetworkDb - override def audit(): AuditDb = new PgAuditDb - override def channels(): ChannelsDb = new PgChannelsDb - override def peers(): PeersDb = new PgPeersDb - override def payments(): PaymentsDb = new PgPaymentsDb - override def pendingRelay(): PendingRelayDb = new PgPendingRelayDb - override def getVersion(statement: Statement, db_name: String, currentVersion: Int): Int = PgUtils.getVersion(statement, db_name, currentVersion) - override def close(): Unit = pg.close() - // @formatter:on - } - - def sqliteInMemory(): Connection = DriverManager.getConnection("jdbc:sqlite::memory:") - - def forAllDbs(f: TestDatabases => Unit): Unit = { - // @formatter:off - def using(dbs: TestDatabases)(g: TestDatabases => Unit): Unit = try g(dbs) finally dbs.close() - using(TestSqliteDatabases())(f) - using(TestPgDatabases())(f) - // @formatter:on - } - - def inMemoryDb(connection: Connection = sqliteInMemory()): Databases = Databases.sqliteDatabaseByConnections(connection, connection, connection) - case object TestFeature extends Feature { val rfcName = "test_feature" val mandatory = 50000 } - val pluginParams = new CustomFeaturePlugin { + val pluginParams: CustomFeaturePlugin = new CustomFeaturePlugin { // @formatter:off override def messageTags: Set[Int] = Set(60003) override def feature: Feature = TestFeature @@ -142,12 +72,12 @@ object TestConstants { } object Alice { - val seed = ByteVector32(ByteVector.fill(32)(1)) + val seed: ByteVector32 = ByteVector32(ByteVector.fill(32)(1)) val nodeKeyManager = new LocalNodeKeyManager(seed, Block.RegtestGenesisBlock.hash) val channelKeyManager = new LocalChannelKeyManager(seed, Block.RegtestGenesisBlock.hash) // This is a function, and not a val! When called will return a new NodeParams - def nodeParams = NodeParams( + def nodeParams: NodeParams = NodeParams( nodeKeyManager, channelKeyManager, blockCount = new AtomicLong(defaultBlockHeight), @@ -191,7 +121,7 @@ object TestConstants { feeProportionalMillionth = 10, reserveToFundingRatio = 0.01, // note: not used (overridden below) maxReserveToFundingRatio = 0.05, - db = inMemoryDb(sqliteInMemory()), + db = TestDatabases.inMemoryDb(), revocationTimeout = 20 seconds, autoReconnect = false, initialRandomReconnectDelay = 5 seconds, @@ -238,7 +168,7 @@ object TestConstants { instanceId = UUID.fromString("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") ) - def channelParams = Peer.makeChannelParams( + def channelParams: LocalParams = Peer.makeChannelParams( nodeParams, nodeParams.features, Script.write(Script.pay2wpkh(PrivateKey(randomBytes32).publicKey)), @@ -251,11 +181,11 @@ object TestConstants { } object Bob { - val seed = ByteVector32(ByteVector.fill(32)(2)) + val seed: ByteVector32 = ByteVector32(ByteVector.fill(32)(2)) val nodeKeyManager = new LocalNodeKeyManager(seed, Block.RegtestGenesisBlock.hash) val channelKeyManager = new LocalChannelKeyManager(seed, Block.RegtestGenesisBlock.hash) - def nodeParams = NodeParams( + def nodeParams: NodeParams = NodeParams( nodeKeyManager, channelKeyManager, blockCount = new AtomicLong(defaultBlockHeight), @@ -296,7 +226,7 @@ object TestConstants { feeProportionalMillionth = 10, reserveToFundingRatio = 0.01, // note: not used (overridden below) maxReserveToFundingRatio = 0.05, - db = inMemoryDb(sqliteInMemory()), + db = TestDatabases.inMemoryDb(), revocationTimeout = 20 seconds, autoReconnect = false, initialRandomReconnectDelay = 5 seconds, @@ -343,7 +273,7 @@ object TestConstants { instanceId = UUID.fromString("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb") ) - def channelParams = Peer.makeChannelParams( + def channelParams: LocalParams = Peer.makeChannelParams( nodeParams, nodeParams.features, Script.write(Script.pay2wpkh(PrivateKey(randomBytes32).publicKey)), diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala new file mode 100644 index 0000000000..84922a1e79 --- /dev/null +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala @@ -0,0 +1,81 @@ +package fr.acinq.eclair + +import akka.actor.ActorSystem +import com.opentable.db.postgres.embedded.EmbeddedPostgres +import fr.acinq.eclair.db.pg.PgUtils +import fr.acinq.eclair.db.pg.PgUtils.PgLock +import fr.acinq.eclair.db.sqlite.SqliteUtils +import fr.acinq.eclair.db._ +import fr.acinq.eclair.db.pg.PgUtils.PgLock.LockFailureHandler + +import java.io.File +import java.sql.{Connection, DriverManager, Statement} +import java.util.UUID +import scala.concurrent.duration._ + + +/** + * Extends the regular [[fr.acinq.eclair.db.Databases]] trait with test-specific methods + */ +sealed trait TestDatabases extends Databases { + // @formatter:off + val connection: Connection + val db: Databases + override def network: NetworkDb = db.network + override def audit: AuditDb = db.audit + override def channels: ChannelsDb = db.channels + override def peers: PeersDb = db.peers + override def payments: PaymentsDb = db.payments + override def pendingRelay: PendingRelayDb = db.pendingRelay + def getVersion(statement: Statement, db_name: String, currentVersion: Int): Int + def close(): Unit + // @formatter:on +} + +object TestDatabases { + + def sqliteInMemory(): Connection = DriverManager.getConnection("jdbc:sqlite::memory:") + + def inMemoryDb(connection: Connection = sqliteInMemory()): Databases = Databases.SqliteDatabases(connection, connection, connection) + + case class TestSqliteDatabases() extends TestDatabases { + // @formatter:off + override val connection: Connection = sqliteInMemory() + override lazy val db: Databases = Databases.SqliteDatabases(connection, connection, connection) + override def getVersion(statement: Statement, db_name: String, currentVersion: Int): Int = SqliteUtils.getVersion(statement, db_name, currentVersion) + override def close(): Unit = () + // @formatter:on + } + + case class TestPgDatabases() extends TestDatabases { + private val pg = EmbeddedPostgres.start() + + import com.zaxxer.hikari.HikariConfig + + val hikariConfig = new HikariConfig + hikariConfig.setDataSource(pg.getPostgresDatabase) + + val lock: PgLock.LeaseLock = PgLock.LeaseLock(UUID.randomUUID(), 10 minutes, 8 minute, LockFailureHandler.logAndThrow) + + val jdbcUrlFile: File = new File(sys.props("tmp.dir"), s"jdbcUrlFile_${UUID.randomUUID()}.tmp") + jdbcUrlFile.deleteOnExit() + + implicit val system: ActorSystem = ActorSystem() + + // @formatter:off + override val connection: Connection = pg.getPostgresDatabase.getConnection + override lazy val db: Databases = Databases.PostgresDatabases(hikariConfig, UUID.randomUUID(), lock, jdbcUrlFile_opt = Some(jdbcUrlFile)) + override def getVersion(statement: Statement, db_name: String, currentVersion: Int): Int = PgUtils.getVersion(statement, db_name, currentVersion) + override def close(): Unit = pg.close() + // @formatter:on + } + + def forAllDbs(f: TestDatabases => Unit): Unit = { + // @formatter:off + def using(dbs: TestDatabases)(g: TestDatabases => Unit): Unit = try g(dbs) finally dbs.close() + using(TestSqliteDatabases())(f) + using(TestPgDatabases())(f) + // @formatter:on + } + +} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/blockchain/electrum/db/sqlite/SqliteWalletDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/blockchain/electrum/db/sqlite/SqliteWalletDbSpec.scala index bb6c91d384..21cca52fd8 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/blockchain/electrum/db/sqlite/SqliteWalletDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/blockchain/electrum/db/sqlite/SqliteWalletDbSpec.scala @@ -20,7 +20,7 @@ import fr.acinq.bitcoin.{Block, BlockHeader, OutPoint, Satoshi, Transaction, TxI import fr.acinq.eclair.blockchain.electrum.ElectrumClient import fr.acinq.eclair.blockchain.electrum.ElectrumClient.GetMerkleResponse import fr.acinq.eclair.blockchain.electrum.ElectrumWallet.PersistentData -import fr.acinq.eclair.{TestConstants, randomBytes, randomBytes32} +import fr.acinq.eclair.{TestDatabases, randomBytes, randomBytes32} import org.scalatest.funsuite.AnyFunSuite import scodec.Codec import scodec.bits.BitVector @@ -67,7 +67,7 @@ class SqliteWalletDbSpec extends AnyFunSuite { } test("add/get/list headers") { - val db = new SqliteWalletDb(TestConstants.sqliteInMemory()) + val db = new SqliteWalletDb(TestDatabases.sqliteInMemory()) val headers = makeHeaders(100) db.addHeaders(2016, headers) @@ -90,7 +90,7 @@ class SqliteWalletDbSpec extends AnyFunSuite { } test("serialize persistent data") { - val db = new SqliteWalletDb(TestConstants.sqliteInMemory()) + val db = new SqliteWalletDb(TestDatabases.sqliteInMemory()) assert(db.readPersistentData() == None) for (i <- 0 until 50) { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/blockchain/fee/DbFeeProviderSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/blockchain/fee/DbFeeProviderSpec.scala index acb38d7016..4315b44f59 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/blockchain/fee/DbFeeProviderSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/blockchain/fee/DbFeeProviderSpec.scala @@ -18,7 +18,7 @@ package fr.acinq.eclair.blockchain.fee import akka.util.Timeout import fr.acinq.bitcoin.SatoshiLong -import fr.acinq.eclair.TestConstants +import fr.acinq.eclair.TestDatabases import fr.acinq.eclair.db.sqlite.SqliteFeeratesDb import org.scalatest.funsuite.AnyFunSuite @@ -31,7 +31,7 @@ class DbFeeProviderSpec extends AnyFunSuite { val feerates1: FeeratesPerKB = FeeratesPerKB(FeeratePerKB(800 sat), FeeratePerKB(100 sat), FeeratePerKB(200 sat), FeeratePerKB(300 sat), FeeratePerKB(400 sat), FeeratePerKB(500 sat), FeeratePerKB(600 sat), FeeratePerKB(700 sat), FeeratePerKB(800 sat)) test("db fee provider saves feerates in database") { - val sqlite = TestConstants.sqliteInMemory() + val sqlite = TestDatabases.sqliteInMemory() val db = new SqliteFeeratesDb(sqlite) val provider = new DbFeeProvider(db, new ConstantFeeProvider(feerates1)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala similarity index 98% rename from eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala rename to eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala index 7229e873e8..9d2f83b20b 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteAuditDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala @@ -18,7 +18,7 @@ package fr.acinq.eclair.db import fr.acinq.bitcoin.Crypto.PrivateKey import fr.acinq.bitcoin.{ByteVector32, SatoshiLong, Transaction} -import fr.acinq.eclair.TestConstants.{TestPgDatabases, TestSqliteDatabases, forAllDbs} +import fr.acinq.eclair.TestDatabases.{TestPgDatabases, TestSqliteDatabases, forAllDbs} import fr.acinq.eclair._ import fr.acinq.eclair.channel.Helpers.Closing.MutualClose import fr.acinq.eclair.channel.{ChannelErrorOccurred, LocalError, NetworkFeePaid, RemoteError} @@ -35,20 +35,20 @@ import java.util.UUID import scala.concurrent.duration._ import scala.util.Random -class SqliteAuditDbSpec extends AnyFunSuite { +class AuditDbSpec extends AnyFunSuite { val ZERO_UUID: UUID = UUID.fromString("00000000-0000-0000-0000-000000000000") - test("init sqlite 2 times in a row") { + test("init database 2 times in a row") { forAllDbs { dbs => - val db1 = dbs.audit() - val db2 = dbs.audit() + val db1 = dbs.audit + val db2 = dbs.audit } } test("add/list events") { forAllDbs { dbs => - val db = dbs.audit() + val db = dbs.audit val e1 = PaymentSent(ZERO_UUID, randomBytes32, randomBytes32, 40000 msat, randomKey.publicKey, PaymentSent.PartialPayment(ZERO_UUID, 42000 msat, 1000 msat, randomBytes32, None) :: Nil) val pp2a = PaymentReceived.PartialPayment(42000 msat, randomBytes32) @@ -94,7 +94,7 @@ class SqliteAuditDbSpec extends AnyFunSuite { test("stats") { forAllDbs { dbs => - val db = dbs.audit() + val db = dbs.audit val n2 = randomKey.publicKey val n3 = randomKey.publicKey @@ -140,7 +140,7 @@ class SqliteAuditDbSpec extends AnyFunSuite { ignore("relay stats performance", Tag("perf")) { forAllDbs { dbs => - val db = dbs.audit() + val db = dbs.audit val nodeCount = 100 val channelCount = 1000 val eventCount = 100000 @@ -389,7 +389,7 @@ class SqliteAuditDbSpec extends AnyFunSuite { test("ignore invalid values in the DB") { forAllDbs { dbs => - val db = dbs.audit() + val db = dbs.audit val sqlite = dbs.connection val isPg = dbs.isInstanceOf[TestPgDatabases] diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteChannelsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala similarity index 90% rename from eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteChannelsDbSpec.scala rename to eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala index 09ee1b690a..8353e22c39 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteChannelsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala @@ -18,11 +18,10 @@ package fr.acinq.eclair.db import com.softwaremill.quicklens._ import fr.acinq.bitcoin.ByteVector32 -import fr.acinq.eclair.TestConstants.{TestPgDatabases, TestSqliteDatabases, forAllDbs} +import fr.acinq.eclair.TestDatabases.{TestPgDatabases, TestSqliteDatabases, forAllDbs} import fr.acinq.eclair.db.DbEventHandler.ChannelEvent import fr.acinq.eclair.db.jdbc.JdbcUtils.using -import fr.acinq.eclair.db.pg.PgUtils -import fr.acinq.eclair.db.sqlite.{SqliteChannelsDb, SqliteUtils} +import fr.acinq.eclair.db.sqlite.SqliteChannelsDb import fr.acinq.eclair.db.sqlite.SqliteUtils.ExtendedResultSet._ import fr.acinq.eclair.wire.internal.channel.ChannelCodecs.stateDataCodec import fr.acinq.eclair.wire.internal.channel.ChannelCodecsSpec @@ -32,19 +31,19 @@ import scodec.bits.ByteVector import java.sql.SQLException -class SqliteChannelsDbSpec extends AnyFunSuite { +class ChannelsDbSpec extends AnyFunSuite { - test("init sqlite 2 times in a row") { + test("init database 2 times in a row") { forAllDbs { dbs => - val db1 = dbs.channels() - val db2 = dbs.channels() + val db1 = dbs.channels + val db2 = dbs.channels } } test("add/remove/list channels") { forAllDbs { dbs => - val db = dbs.channels() - dbs.pendingRelay() // needed by db.removeChannel + val db = dbs.channels + dbs.pendingRelay // needed by db.removeChannel val channel = ChannelCodecsSpec.normal @@ -75,7 +74,7 @@ class SqliteChannelsDbSpec extends AnyFunSuite { test("channel metadata") { forAllDbs { dbs => - val db = dbs.channels() + val db = dbs.channels val connection = dbs.connection val channel1 = ChannelCodecsSpec.normal @@ -137,7 +136,7 @@ class SqliteChannelsDbSpec extends AnyFunSuite { // create a v1 channels database using(sqlite.createStatement()) { statement => - SqliteUtils.getVersion(statement, "channels", 1) + dbs.getVersion(statement, "channels", 1) statement.execute("PRAGMA foreign_keys = ON") statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL)") statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id BLOB NOT NULL, commitment_number BLOB NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") @@ -156,7 +155,7 @@ class SqliteChannelsDbSpec extends AnyFunSuite { // check that db migration works val db = new SqliteChannelsDb(sqlite) using(sqlite.createStatement()) { statement => - assert(SqliteUtils.getVersion(statement, "channels", 1) == 3) // version changed from 1 -> 3 + assert(dbs.getVersion(statement, "channels", 1) == 3) // version changed from 1 -> 3 } assert(db.listLocalChannels() === List(channel)) db.updateChannelMeta(channel.channelId, ChannelEvent.EventType.Created) // this call must not fail @@ -170,7 +169,7 @@ class SqliteChannelsDbSpec extends AnyFunSuite { // create a v2 channels database using(pg.createStatement()) { statement => - PgUtils.getVersion(statement, "channels", 2) + dbs.getVersion(statement, "channels", 2) statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE)") statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id TEXT NOT NULL, commitment_number TEXT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)") @@ -187,9 +186,9 @@ class SqliteChannelsDbSpec extends AnyFunSuite { } // check that db migration works - val db = dbs.channels() + val db = dbs.channels using(pg.createStatement()) { statement => - assert(PgUtils.getVersion(statement, "channels", 2) == 3) // version changed from 2 -> 3 + assert(dbs.getVersion(statement, "channels", 2) == 3) // version changed from 2 -> 3 } assert(db.listLocalChannels() === List(channel)) db.updateChannelMeta(channel.channelId, ChannelEvent.EventType.Created) // this call must not fail @@ -199,7 +198,7 @@ class SqliteChannelsDbSpec extends AnyFunSuite { // create a v2 channels database using(sqlite.createStatement()) { statement => - SqliteUtils.getVersion(statement, "channels", 2) + dbs.getVersion(statement, "channels", 2) statement.execute("PRAGMA foreign_keys = ON") statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT 0)") statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id BLOB NOT NULL, commitment_number BLOB NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") @@ -217,9 +216,9 @@ class SqliteChannelsDbSpec extends AnyFunSuite { } // check that db migration works - val db = dbs.channels() + val db = dbs.channels using(sqlite.createStatement()) { statement => - assert(SqliteUtils.getVersion(statement, "channels", 2) == 3) // version changed from 2 -> 3 + assert(dbs.getVersion(statement, "channels", 2) == 3) // version changed from 2 -> 3 } assert(db.listLocalChannels() === List(channel)) db.updateChannelMeta(channel.channelId, ChannelEvent.EventType.Created) // this call must not fail diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteFeeratesDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/FeeratesDbSpec.scala similarity index 93% rename from eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteFeeratesDbSpec.scala rename to eclair-core/src/test/scala/fr/acinq/eclair/db/FeeratesDbSpec.scala index 019f285a87..cf25249f30 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteFeeratesDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/FeeratesDbSpec.scala @@ -23,7 +23,7 @@ import fr.acinq.eclair.db.sqlite.SqliteFeeratesDb import fr.acinq.eclair.db.sqlite.SqliteUtils.{getVersion, using} import org.scalatest.funsuite.AnyFunSuite -class SqliteFeeratesDbSpec extends AnyFunSuite { +class FeeratesDbSpec extends AnyFunSuite { val feerate = FeeratesPerKB( mempoolMinFee = FeeratePerKB(10000 sat), @@ -36,14 +36,14 @@ class SqliteFeeratesDbSpec extends AnyFunSuite { blocks_144 = FeeratePerKB(20000 sat), blocks_1008 = FeeratePerKB(10000 sat)) - test("init sqlite 2 times in a row") { - val sqlite = TestConstants.sqliteInMemory() + test("init database 2 times in a row") { + val sqlite = TestDatabases.sqliteInMemory() val db1 = new SqliteFeeratesDb(sqlite) val db2 = new SqliteFeeratesDb(sqlite) } test("add/get feerates") { - val sqlite = TestConstants.sqliteInMemory() + val sqlite = TestDatabases.sqliteInMemory() val db = new SqliteFeeratesDb(sqlite) db.addOrUpdateFeerates(feerate) @@ -51,7 +51,7 @@ class SqliteFeeratesDbSpec extends AnyFunSuite { } test("migration 1->2") { - val sqlite = TestConstants.sqliteInMemory() + val sqlite = TestDatabases.sqliteInMemory() using(sqlite.createStatement()) { statement => getVersion(statement, "feerates", 1) // this will set version to 1 diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/FileBackupHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/FileBackupHandlerSpec.scala index e5d52a1f52..79bd361610 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/FileBackupHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/FileBackupHandlerSpec.scala @@ -16,21 +16,22 @@ package fr.acinq.eclair.db -import java.io.File -import java.sql.DriverManager -import java.util.UUID import akka.testkit.TestProbe import fr.acinq.eclair.channel.ChannelPersisted import fr.acinq.eclair.db.Databases.FileBackup import fr.acinq.eclair.db.sqlite.SqliteChannelsDb import fr.acinq.eclair.wire.internal.channel.ChannelCodecsSpec -import fr.acinq.eclair.{TestConstants, TestKitBaseClass, TestUtils, randomBytes32} +import fr.acinq.eclair.{TestConstants, TestDatabases, TestKitBaseClass, TestUtils, randomBytes32} import org.scalatest.funsuite.AnyFunSuiteLike +import java.io.File +import java.sql.DriverManager +import java.util.UUID + class FileBackupHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike { test("process backups") { - val db = TestConstants.inMemoryDb() + val db = TestDatabases.inMemoryDb() val wip = new File(TestUtils.BUILD_DIRECTORY, s"wip-${UUID.randomUUID()}") val dest = new File(TestUtils.BUILD_DIRECTORY, s"backup-${UUID.randomUUID()}") wip.deleteOnExit() diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteNetworkDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/NetworkDbSpec.scala similarity index 95% rename from eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteNetworkDbSpec.scala rename to eclair-core/src/test/scala/fr/acinq/eclair/db/NetworkDbSpec.scala index 7f175adb60..e6e78a769a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteNetworkDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/NetworkDbSpec.scala @@ -20,26 +20,24 @@ import fr.acinq.bitcoin.Crypto.PrivateKey import fr.acinq.bitcoin.{Block, ByteVector32, ByteVector64, Crypto, Satoshi, SatoshiLong} import fr.acinq.eclair.FeatureSupport.Optional import fr.acinq.eclair.Features.VariableLengthOnion -import fr.acinq.eclair.TestConstants.{TestDatabases, TestPgDatabases, TestSqliteDatabases} +import fr.acinq.eclair.TestDatabases._ import fr.acinq.eclair.db.sqlite.SqliteUtils._ import fr.acinq.eclair.router.Announcements import fr.acinq.eclair.router.Router.PublicChannel import fr.acinq.eclair.wire.protocol.{Color, NodeAddress, Tor2} -import fr.acinq.eclair.{CltvExpiryDelta, Feature, FeatureSupport, Features, MilliSatoshiLong, ShortChannelId, TestConstants, randomBytes32, randomKey} +import fr.acinq.eclair.{CltvExpiryDelta, Features, MilliSatoshiLong, ShortChannelId, TestDatabases, randomBytes32, randomKey} import org.scalatest.funsuite.AnyFunSuite import scala.collection.{SortedMap, mutable} -class SqliteNetworkDbSpec extends AnyFunSuite { - - import TestConstants.forAllDbs +class NetworkDbSpec extends AnyFunSuite { val shortChannelIds = (42 to (5000 + 42)).map(i => ShortChannelId(i)) - test("init sqlite 2 times in a row") { + test("init database 2 times in a row") { forAllDbs { dbs => - val db1 = dbs.network() - val db2 = dbs.network() + val db1 = dbs.network + val db2 = dbs.network } } @@ -85,7 +83,7 @@ class SqliteNetworkDbSpec extends AnyFunSuite { test("add/remove/list nodes") { forAllDbs { dbs => - val db = dbs.network() + val db = dbs.network val node_1 = Announcements.makeNodeAnnouncement(randomKey, "node-alice", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, Features.empty) val node_2 = Announcements.makeNodeAnnouncement(randomKey, "node-bob", Color(100.toByte, 200.toByte, 300.toByte), NodeAddress.fromParts("192.168.1.42", 42000).get :: Nil, Features(VariableLengthOnion -> Optional)) @@ -111,7 +109,7 @@ class SqliteNetworkDbSpec extends AnyFunSuite { test("correctly handle txids that start with 0") { forAllDbs { dbs => - val db = dbs.network() + val db = dbs.network val sig = ByteVector64.Zeroes val c = Announcements.makeChannelAnnouncement(Block.RegtestGenesisBlock.hash, ShortChannelId(42), randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, sig, sig, sig, sig) val txid = ByteVector32.fromValidHex("0001" * 16) @@ -121,7 +119,7 @@ class SqliteNetworkDbSpec extends AnyFunSuite { } def simpleTest(dbs: TestDatabases) = { - val db = dbs.network() + val db = dbs.network def sig = Crypto.sign(randomBytes32, randomKey) @@ -228,7 +226,7 @@ class SqliteNetworkDbSpec extends AnyFunSuite { test("remove many channels") { forAllDbs { dbs => - val db = dbs.network() + val db = dbs.network val sig = Crypto.sign(randomBytes32, randomKey) val priv = randomKey val pub = priv.publicKey @@ -250,7 +248,7 @@ class SqliteNetworkDbSpec extends AnyFunSuite { test("prune many channels") { forAllDbs { dbs => - val db = dbs.network() + val db = dbs.network db.addToPruned(shortChannelIds) shortChannelIds.foreach { id => assert(db.isPruned((id))) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala similarity index 98% rename from eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala rename to eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala index 91274b3574..9360cd8144 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePaymentsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala @@ -18,26 +18,26 @@ package fr.acinq.eclair.db import fr.acinq.bitcoin.Crypto.PrivateKey import fr.acinq.bitcoin.{Block, ByteVector32, Crypto} -import fr.acinq.eclair.TestConstants.{TestPgDatabases, TestSqliteDatabases, forAllDbs} +import fr.acinq.eclair.TestDatabases.{TestPgDatabases, TestSqliteDatabases, forAllDbs} import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.db.sqlite.SqlitePaymentsDb import fr.acinq.eclair.payment._ import fr.acinq.eclair.router.Router.{ChannelHop, NodeHop} import fr.acinq.eclair.wire.protocol.{ChannelUpdate, UnknownNextPeer} -import fr.acinq.eclair.{CltvExpiryDelta, MilliSatoshiLong, ShortChannelId, TestConstants, randomBytes32, randomBytes64, randomKey} +import fr.acinq.eclair.{CltvExpiryDelta, MilliSatoshiLong, ShortChannelId, TestDatabases, randomBytes32, randomBytes64, randomKey} import org.scalatest.funsuite.AnyFunSuite import java.util.UUID import scala.concurrent.duration._ -class SqlitePaymentsDbSpec extends AnyFunSuite { +class PaymentsDbSpec extends AnyFunSuite { - import SqlitePaymentsDbSpec._ + import PaymentsDbSpec._ - test("init sqlite 2 times in a row") { + test("init database 2 times in a row") { forAllDbs { dbs => - val db1 = dbs.payments() - val db2 = dbs.payments() + val db1 = dbs.payments + val db2 = dbs.payments } } @@ -319,7 +319,7 @@ class SqlitePaymentsDbSpec extends AnyFunSuite { test("add/retrieve/update incoming payments") { forAllDbs { dbs => - val db = dbs.payments() + val db = dbs.payments // can't receive a payment without an invoice associated with it assertThrows[IllegalArgumentException](db.receiveIncomingPayment(randomBytes32, 12345678 msat)) @@ -372,7 +372,7 @@ class SqlitePaymentsDbSpec extends AnyFunSuite { test("add/retrieve/update outgoing payments") { forAllDbs { dbs => - val db = dbs.payments() + val db = dbs.payments val parentId = UUID.randomUUID() val i1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(123 msat), paymentHash1, davePriv, "Some invoice", CltvExpiryDelta(18), expirySeconds = None, timestamp = 0) @@ -428,7 +428,7 @@ class SqlitePaymentsDbSpec extends AnyFunSuite { } test("high level payments overview") { - val db = new SqlitePaymentsDb(TestConstants.sqliteInMemory()) + val db = new SqlitePaymentsDb(TestDatabases.sqliteInMemory()) // -- feed db with incoming payments val expiredInvoice = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(123 msat), randomBytes32, alicePriv, "incoming #1", CltvExpiryDelta(18), timestamp = 1) @@ -503,7 +503,7 @@ class SqlitePaymentsDbSpec extends AnyFunSuite { } -object SqlitePaymentsDbSpec { +object PaymentsDbSpec { val (alicePriv, bobPriv, carolPriv, davePriv) = (randomKey, randomKey, randomKey, randomKey) val (alice, bob, carol, dave) = (alicePriv.publicKey, bobPriv.publicKey, carolPriv.publicKey, davePriv.publicKey) val hop_ab = ChannelHop(alice, bob, ChannelUpdate(randomBytes64, randomBytes32, ShortChannelId(42), 1, 0, 0, CltvExpiryDelta(12), 1 msat, 1 msat, 1, None)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePeersDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/PeersDbSpec.scala similarity index 89% rename from eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePeersDbSpec.scala rename to eclair-core/src/test/scala/fr/acinq/eclair/db/PeersDbSpec.scala index b057015b4f..5fe3a702a0 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePeersDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/PeersDbSpec.scala @@ -17,25 +17,25 @@ package fr.acinq.eclair.db import fr.acinq.bitcoin.Crypto.PublicKey +import fr.acinq.eclair.randomKey import fr.acinq.eclair.wire.protocol.{NodeAddress, Tor2, Tor3} -import fr.acinq.eclair.{TestConstants, randomKey} import org.scalatest.funsuite.AnyFunSuite -class SqlitePeersDbSpec extends AnyFunSuite { +class PeersDbSpec extends AnyFunSuite { - import TestConstants.forAllDbs + import fr.acinq.eclair.TestDatabases.forAllDbs - test("init sqlite 2 times in a row") { + test("init database 2 times in a row") { forAllDbs { dbs => - val db1 = dbs.peers() - val db2 = dbs.peers() + val db1 = dbs.peers + val db2 = dbs.peers } } test("add/remove/list peers") { forAllDbs { dbs => - val db = dbs.peers() + val db = dbs.peers case class TestCase(nodeId: PublicKey, nodeAddress: NodeAddress) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePendingRelayDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/PendingRelayDbSpec.scala similarity index 89% rename from eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePendingRelayDbSpec.scala rename to eclair-core/src/test/scala/fr/acinq/eclair/db/PendingRelayDbSpec.scala index 58707263b9..a24220a67d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqlitePendingRelayDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/PendingRelayDbSpec.scala @@ -17,25 +17,25 @@ package fr.acinq.eclair.db import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FAIL_MALFORMED_HTLC, CMD_FULFILL_HTLC} +import fr.acinq.eclair.randomBytes32 import fr.acinq.eclair.wire.protocol.FailureMessageCodecs -import fr.acinq.eclair.{TestConstants, randomBytes32} import org.scalatest.funsuite.AnyFunSuite -class SqlitePendingRelayDbSpec extends AnyFunSuite { +class PendingRelayDbSpec extends AnyFunSuite { - import TestConstants.forAllDbs + import fr.acinq.eclair.TestDatabases.forAllDbs - test("init sqlite 2 times in a row") { + test("init database 2 times in a row") { forAllDbs { dbs => - val db1 = dbs.pendingRelay() - val db2 = dbs.pendingRelay() + val db1 = dbs.pendingRelay + val db2 = dbs.pendingRelay } } test("add/remove/list messages") { forAllDbs { dbs => - val db = dbs.pendingRelay() + val db = dbs.pendingRelay val channelId1 = randomBytes32 val channelId2 = randomBytes32 diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/PgUtilsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/PgUtilsSpec.scala new file mode 100644 index 0000000000..388bf6b731 --- /dev/null +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/PgUtilsSpec.scala @@ -0,0 +1,114 @@ +package fr.acinq.eclair.db + +import com.opentable.db.postgres.embedded.EmbeddedPostgres +import com.typesafe.config.{Config, ConfigFactory} +import fr.acinq.eclair.db.pg.PgUtils.JdbcUrlChanged +import fr.acinq.eclair.db.pg.PgUtils.PgLock.{LockFailure, LockFailureHandler} +import fr.acinq.eclair.{TestKitBaseClass, TestUtils} +import grizzled.slf4j.Logging +import org.scalatest.concurrent.Eventually +import org.scalatest.funsuite.AnyFunSuiteLike + +import java.io.File +import java.util.UUID + +class PgUtilsSpec extends TestKitBaseClass with AnyFunSuiteLike with Eventually { + + test("database lock") { + val pg = EmbeddedPostgres.start() + val config = PgUtilsSpec.testConfig(pg.getPort) + val datadir = new File(TestUtils.BUILD_DIRECTORY, s"pg_test_${UUID.randomUUID()}") + datadir.mkdirs() + val instanceId1 = UUID.randomUUID() + // this will lock the database for this instance id + val db = Databases.postgres(config, instanceId1, datadir, LockFailureHandler.logAndThrow) + + assert( + intercept[LockFailureHandler.LockException] { + // this will fail because the database is already locked for a different instance id + Databases.postgres(config, UUID.randomUUID(), datadir, LockFailureHandler.logAndThrow) + }.lockFailure === LockFailure.AlreadyLocked(instanceId1)) + + // we can renew the lease at will + db.obtainExclusiveLock() + + // we wait significantly longer than the lease interval, and make sure that the lock is still there + Thread.sleep(10_000) + assert( + intercept[LockFailureHandler.LockException] { + // this will fail because the database is already locked for a different instance id + Databases.postgres(config, UUID.randomUUID(), datadir, LockFailureHandler.logAndThrow) + }.lockFailure === LockFailure.AlreadyLocked(instanceId1)) + + // we close the first connection + db.dataSource.close() + eventually(db.dataSource.isClosed) + // we wait just a bit longer than the lease interval + Thread.sleep(6_000) + + // now we can put a lock with a different instance id + val instanceId2 = UUID.randomUUID() + Databases.postgres(config, instanceId2, datadir, LockFailureHandler.logAndThrow) + + // we close the second connection + db.dataSource.close() + eventually(db.dataSource.isClosed) + + // but we don't wait for the previous lease to expire, so we can't take over right now + assert(intercept[LockFailureHandler.LockException] { + // this will fail because even if we have acquired the table lock, the previous lease still hasn't expired + Databases.postgres(config, UUID.randomUUID(), datadir, LockFailureHandler.logAndThrow) + }.lockFailure === LockFailure.AlreadyLocked(instanceId2)) + + pg.close() + } + + test("jdbc url check") { + val pg = EmbeddedPostgres.start() + val config = PgUtilsSpec.testConfig(pg.getPort) + val datadir = new File(TestUtils.BUILD_DIRECTORY, s"pg_test_${UUID.randomUUID()}") + datadir.mkdirs() + // this will lock the database for this instance id + val db = Databases.postgres(config, UUID.randomUUID(), datadir, LockFailureHandler.logAndThrow) + + // we close the first connection + db.dataSource.close() + eventually(db.dataSource.isClosed) + + // here we change the config to simulate an involuntary change in the server we connect to + val config1 = ConfigFactory.parseString("postgres.port=1234").withFallback(config) + intercept[JdbcUrlChanged] { + Databases.postgres(config1, UUID.randomUUID(), datadir, LockFailureHandler.logAndThrow) + } + + pg.close() + } + +} + +object PgUtilsSpec extends Logging { + + def testConfig(port: Int): Config = ConfigFactory.parseString( + s""" + |postgres { + | database = "" + | host = "localhost" + | port = $port + | username = "postgres" + | password = "" + | pool { + | max-size = 10 // recommended value = number_of_cpu_cores * 2 + | connection-timeout = 30 seconds + | idle-timeout = 10 minutes + | max-life-time = 30 minutes + | } + | lease { + | interval = 5 seconds // lease-interval must be greater than lease-renew-interval + | renew-interval = 2 seconds + | } + | lock-type = "lease" // lease or none + |} + |""".stripMargin + ) + +} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteUtilsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteUtilsSpec.scala index e3639615f7..03082bb10b 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteUtilsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteUtilsSpec.scala @@ -17,15 +17,14 @@ package fr.acinq.eclair.db import java.sql.SQLException - -import fr.acinq.eclair.TestConstants +import fr.acinq.eclair.{TestConstants, TestDatabases} import fr.acinq.eclair.db.sqlite.SqliteUtils.using import org.scalatest.funsuite.AnyFunSuite class SqliteUtilsSpec extends AnyFunSuite { test("using with auto-commit disabled") { - val conn = TestConstants.sqliteInMemory() + val conn = TestDatabases.sqliteInMemory() using(conn.createStatement()) { statement => statement.executeUpdate("CREATE TABLE utils_test (id INTEGER NOT NULL PRIMARY KEY, updated_at INTEGER)")