Skip to content

Commit

Permalink
Keep SQLite AI incremented after rollback (#1167)
Browse files Browse the repository at this point in the history
  • Loading branch information
mvorisek authored Feb 23, 2024
1 parent df7dbb9 commit 548bc1c
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 8 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/test-unit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ jobs:
php -d opcache.enable_cli=1 vendor/bin/phpunit --exclude-group none $(if [ -n "$LOG_COVERAGE" ]; then echo --coverage-text; else echo --no-coverage; fi) --fail-on-warning --fail-on-risky $(if vendor/bin/phpunit --version | grep -q '^PHPUnit 9\.'; then echo -v; else echo --fail-on-notice --fail-on-deprecation --display-notices --display-deprecations --display-warnings --display-errors --display-incomplete --display-skipped; fi)
if [ -n "$LOG_COVERAGE" ]; then mv coverage/phpunit.cov coverage/phpunit-sqlite.cov; fi
- name: "Run tests: SQLite 3.25.3"
run: |
apk add sqlite-dev=3.25.3-r0 --repository=https://dl-cdn.alpinelinux.org/alpine/v3.6/main
php -d opcache.enable_cli=1 vendor/bin/phpunit --exclude-group none $(if [ -n "$LOG_COVERAGE" ]; then echo --coverage-text; else echo --no-coverage; fi) --fail-on-warning --fail-on-risky $(if vendor/bin/phpunit --version | grep -q '^PHPUnit 9\.'; then echo -v; else echo --fail-on-notice --fail-on-deprecation --display-notices --display-deprecations --display-warnings --display-errors --display-incomplete --display-skipped; fi)
if [ -n "$LOG_COVERAGE" ]; then mv coverage/phpunit.cov coverage/phpunit-sqlite325.cov; fi
- name: "Run tests: MySQL - PDO"
if: success() || failure()
env:
Expand Down
2 changes: 1 addition & 1 deletion phpstan.neon.dist
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ parameters:
-
message: '~^Class Doctrine\\DBAL\\Platforms\\SqlitePlatform referenced with incorrect case: Doctrine\\DBAL\\Platforms\\SQLitePlatform\.$~'
path: '*'
count: 25
count: 24

# TODO these rules are generated, this ignores should be fixed in the code
# for src/Schema/TestCase.php
Expand Down
1 change: 1 addition & 0 deletions src/Persistence/Sql/Sqlite/Connection.php
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ protected static function createDbalConfiguration(): Configuration
$configuration->setMiddlewares([
...$configuration->getMiddlewares(),
new EnableForeignKeys(),
new PreserveAutoincrementOnRollbackMiddleware(),
]);

return $configuration;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
<?php

declare(strict_types=1);

namespace Atk4\Data\Persistence\Sql\Sqlite;

use Atk4\Data\Persistence\Sql\Exception;
use Doctrine\DBAL\Driver\Middleware\AbstractConnectionMiddleware;

class PreserveAutoincrementOnRollbackConnectionMiddleware extends AbstractConnectionMiddleware
{
private static string $libraryVersion;

private function createExpressionFromStringLiteral(string $value): Expression
{
return new Expression('\'' . str_replace('\'', '\'\'', $value) . '\'');
}

/**
* @return array<string, array<string, int>>
*/
protected function listSequences(): array
{
if ((self::$libraryVersion ?? null) === null) {
$getLibraryVersionSql = (new Query())
->field('sqlite_version()')
->render()[0];
self::$libraryVersion = $this->query($getLibraryVersionSql)->fetchOne();
}

if (version_compare(self::$libraryVersion, '3.37') < 0) {
$listAllSchemasSql = (new Query())
->table('pragma_database_list')
->field('name')
->render()[0];
$allSchemas = $this->query($listAllSchemasSql)->fetchFirstColumn();

$schemas = [];
foreach ($allSchemas as $schema) {
$dummySelectFromSqliteSequenceTableSql = (new Query())
->table($schema . '.sqlite_sequence')
->field('name')
->render()[0];
try {
$this->query($dummySelectFromSqliteSequenceTableSql)->fetchFirstColumn();
$schemas[] = $schema;
} catch (\Exception $e) {
while ($e->getPrevious() !== null) {
$e = $e->getPrevious();
}

if (!str_contains($e->getMessage(), 'HY000')
|| !str_contains($e->getMessage(), 'no such table: ' . $schema . '.sqlite_sequence')
) {
throw $e;
}
}
}
} else {
$listSchemasSql = (new Query())
->table('pragma_table_list')
->field('schema')
->where('name', $this->createExpressionFromStringLiteral('sqlite_sequence'))
->render()[0];
$schemas = $this->query($listSchemasSql)->fetchFirstColumn();
}

$res = [];
if ($schemas !== []) {
$listSequencesSql = implode("\nUNION ALL\n", array_map(function (string $schema) {
return (new Query())
->table($schema . '.sqlite_sequence')
->field($this->createExpressionFromStringLiteral($schema), 'schema')
->field('name')
->field('seq', 'value')
->render()[0];
}, $schemas));

$res = [];
foreach ($this->query($listSequencesSql)->fetchAllAssociative() as $row) {
$value = (int) $row['value'];
if (!is_int($row['value']) && (string) $value !== $row['value']) {
throw (new Exception('Unexpected SQLite sequence value'))
->addMoreInfo('value', $row['value']);
}

$res[$row['schema']][$row['name']] = $value;
}
}

return $res;
}

/**
* @param array<string, array<string, int>> $beforeRollbackSequences
*/
protected function restoreSequencesIfDecremented(array $beforeRollbackSequences): void
{
$afterRollbackSequences = $this->listSequences();

foreach ($beforeRollbackSequences as $schema => $beforeRollbackSequences2) {
foreach ($beforeRollbackSequences2 as $table => $beforeRollbackValue) {
$afterRollbackValue = $afterRollbackSequences[$schema][$table] ?? null;
if ($afterRollbackValue >= $beforeRollbackValue) {
continue;
}

if ($afterRollbackValue === null) { // https://sqlite.org/forum/info/3e7cc380f0a159c6
$query = (new Query())
->mode('insert')
->set('name', $this->createExpressionFromStringLiteral($table));
} else {
$query = (new Query())
->mode('update')
->where('name', $this->createExpressionFromStringLiteral($table));
}
$query->table($schema . '.sqlite_sequence');
$query->set('seq', $this->createExpressionFromStringLiteral((string) $beforeRollbackValue));

$this->exec($query->render()[0]);
}
}
}

#[\Override]
public function exec(string $sql): int
{
$isRollback = str_starts_with(strtoupper(ltrim($sql)), 'ROLLBACK ');

if ($isRollback) {
$beforeRollbackSequences = $this->listSequences();
}

$res = parent::exec($sql);

if ($isRollback) {
$this->restoreSequencesIfDecremented($beforeRollbackSequences);
}

return $res;
}

#[\Override]
public function rollBack()
{
$beforeRollbackSequences = $this->listSequences();

$res = parent::rollBack();

$this->restoreSequencesIfDecremented($beforeRollbackSequences);

return $res;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
<?php

declare(strict_types=1);

namespace Atk4\Data\Persistence\Sql\Sqlite;

use Doctrine\DBAL\Driver;
use Doctrine\DBAL\Driver\Connection;
use Doctrine\DBAL\Driver\Middleware;
use Doctrine\DBAL\Driver\Middleware\AbstractDriverMiddleware;

class PreserveAutoincrementOnRollbackMiddleware implements Middleware
{
#[\Override]
public function wrap(Driver $driver): Driver
{
return new class($driver) extends AbstractDriverMiddleware {
#[\Override]
public function connect(
#[\SensitiveParameter]
array $params
): Connection {
return new PreserveAutoincrementOnRollbackConnectionMiddleware(parent::connect($params));
}
};
}
}
7 changes: 0 additions & 7 deletions tests/Persistence/Sql/WithDb/SelectTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -620,13 +620,6 @@ public function testImportAndAutoincrement(): void
self::assertSame(103, $m->insert(['f1' => 'N']));
});

// TODO workaround SQLite to be consistent with other databases
// https://stackoverflow.com/questions/27947712/sqlite-repeats-primary-key-autoincrement-value-after-rollback
// https://github.com/atk4/data/issues/1162
if ($this->getDatabasePlatform() instanceof SQLitePlatform) {
return;
}

$invokeInAtomicAndThrowFx(static function () use ($invokeInAtomicAndThrowFx, $m) {
self::assertSame(104, $m->insert(['f1' => 'O1']));
$invokeInAtomicAndThrowFx(static function () use ($invokeInAtomicAndThrowFx, $m) {
Expand Down

0 comments on commit 548bc1c

Please sign in to comment.