Skip to content

Commit

Permalink
Allow userland middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
l-vo committed Feb 8, 2022
1 parent bbb3b82 commit b4a7efa
Show file tree
Hide file tree
Showing 10 changed files with 293 additions and 78 deletions.
59 changes: 59 additions & 0 deletions DependencyInjection/Compiler/MiddlewaresPass.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
<?php

namespace Doctrine\Bundle\DoctrineBundle\DependencyInjection\Compiler;

use Doctrine\Bundle\DoctrineBundle\Middleware\ConnectionNameAwareInterface;
use Symfony\Component\DependencyInjection\ChildDefinition;
use Symfony\Component\DependencyInjection\Compiler\CompilerPassInterface;
use Symfony\Component\DependencyInjection\ContainerBuilder;

use function array_keys;
use function is_subclass_of;
use function sprintf;

class MiddlewaresPass implements CompilerPassInterface
{
/** @var string */
private $connexionDefsParam;

/** @var string */
private $middlewareTag;

public function __construct(
string $connexionDefsParam = 'doctrine.connections',
string $middlewareTag = 'doctrine.middleware'
) {
$this->connexionDefsParam = $connexionDefsParam;
$this->middlewareTag = $middlewareTag;
}

public function process(ContainerBuilder $container)
{
$middlewareAbstractDefs = [];
foreach (array_keys($container->findTaggedServiceIds($this->middlewareTag)) as $id) {
$def = $container->getDefinition($id);
$middlewareAbstractDefs[$id] = $def;
}

$connections = $container->getParameter($this->connexionDefsParam);
foreach ($connections as $name => $id) {
$middlewareDefs = [];
foreach ($middlewareAbstractDefs as $id => $abstractDef) {
$middlewareDefs[] = $childDef = $container->setDefinition(
sprintf('%s.%s', $id, $name),
new ChildDefinition($id)
);

if (! is_subclass_of($abstractDef->getClass(), ConnectionNameAwareInterface::class)) {
continue;
}

$childDef->addMethodCall('setConnectionName', [$name]);
}

$container
->getDefinition(sprintf('doctrine.dbal.%s_connection.configuration', $name))
->addMethodCall('setMiddlewares', [$middlewareDefs]);
}
}
}
32 changes: 15 additions & 17 deletions DependencyInjection/DoctrineExtension.php
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
use Symfony\Component\DependencyInjection\Alias;
use Symfony\Component\DependencyInjection\ChildDefinition;
use Symfony\Component\DependencyInjection\ContainerBuilder;
use Symfony\Component\DependencyInjection\ContainerInterface;
use Symfony\Component\DependencyInjection\Definition;
use Symfony\Component\DependencyInjection\Exception\InvalidArgumentException;
use Symfony\Component\DependencyInjection\Loader\XmlFileLoader;
Expand Down Expand Up @@ -85,6 +84,8 @@ public function load(array $configs, ContainerBuilder $container)
$this->dbalLoad($config['dbal'], $container);

$this->loadMessengerServices($container);

$this->loadMiddlewares($container);
}

if (empty($config['orm'])) {
Expand Down Expand Up @@ -143,9 +144,18 @@ protected function dbalLoad(array $config, ContainerBuilder $container)
$container->setParameter('doctrine.connections', $connections);
$container->setParameter('doctrine.default_connection', $this->defaultConnection);

/** @psalm-suppress UndefinedClass */
if (class_exists(Middleware::class)) {
$container
->getDefinition('doctrine.dbal.logger')
->replaceArgument(0, null);
}

foreach ($config['connections'] as $name => $connection) {
$this->loadDbalConnection($name, $connection, $container);
}

$container->registerForAutoconfiguration(Middleware::class)->addTag('doctrine.middleware');
}

/**
Expand All @@ -160,7 +170,6 @@ protected function loadDbalConnection($name, array $connection, ContainerBuilder
$configuration = $container->setDefinition(sprintf('doctrine.dbal.%s_connection.configuration', $name), new ChildDefinition('doctrine.dbal.connection.configuration'));
$logger = null;
if ($connection['logging']) {
$this->useMiddlewaresIfAvailable($connection, $container, $name, $configuration);
$logger = new Reference('doctrine.dbal.logger');
}

Expand Down Expand Up @@ -1073,25 +1082,14 @@ private function createArrayAdapterCachePool(ContainerBuilder $container, string
return $id;
}

/** @param array<string, mixed> $connection */
protected function useMiddlewaresIfAvailable(array $connection, ContainerBuilder $container, string $name, Definition $configuration): void
private function loadMiddlewares(ContainerBuilder $container): void
{
/** @psalm-suppress UndefinedClass */
if (! interface_exists(Middleware::class)) {
if (! class_exists(Middleware::class)) {
return;
}

$container
->getDefinition('doctrine.dbal.logger')
->replaceArgument(0, null);

$loggingMiddlewareDef = $container->setDefinition(
sprintf('doctrine.dbal.%s_connection.logging_middleware', $name),
new ChildDefinition('doctrine.dbal.logging_middleware')
);
$loggingMiddlewareDef->addArgument(new Reference('logger', ContainerInterface::NULL_ON_INVALID_REFERENCE));
$loggingMiddlewareDef->addTag('monolog.logger', ['channel' => 'doctrine']);

$configuration->addMethodCall('setMiddlewares', [[$loggingMiddlewareDef]]);
$loader = new XmlFileLoader($container, new FileLocator(__DIR__ . '/../Resources/config'));
$loader->load('middlewares.xml');
}
}
8 changes: 8 additions & 0 deletions DoctrineBundle.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
use Doctrine\Bundle\DoctrineBundle\DependencyInjection\Compiler\DbalSchemaFilterPass;
use Doctrine\Bundle\DoctrineBundle\DependencyInjection\Compiler\EntityListenerPass;
use Doctrine\Bundle\DoctrineBundle\DependencyInjection\Compiler\IdGeneratorPass;
use Doctrine\Bundle\DoctrineBundle\DependencyInjection\Compiler\MiddlewaresPass;
use Doctrine\Bundle\DoctrineBundle\DependencyInjection\Compiler\RemoveProfilerControllerPass;
use Doctrine\Bundle\DoctrineBundle\DependencyInjection\Compiler\ServiceRepositoryCompilerPass;
use Doctrine\Bundle\DoctrineBundle\DependencyInjection\Compiler\WellKnownSchemaFilterPass;
use Doctrine\Common\Util\ClassUtils;
use Doctrine\DBAL\Driver\Middleware;
use Doctrine\ORM\EntityManagerInterface;
use Doctrine\ORM\Proxy\Autoloader;
use Symfony\Bridge\Doctrine\DependencyInjection\CompilerPass\DoctrineValidationPass;
Expand All @@ -26,6 +28,7 @@
use function assert;
use function class_exists;
use function clearstatcache;
use function interface_exists;
use function spl_autoload_unregister;

class DoctrineBundle extends Bundle
Expand Down Expand Up @@ -60,6 +63,11 @@ public function build(ContainerBuilder $container)
$container->addCompilerPass(new CacheSchemaSubscriberPass(), PassConfig::TYPE_BEFORE_REMOVING, -10);
$container->addCompilerPass(new RemoveProfilerControllerPass());

/** @psalm-suppress UndefinedClass */
if (interface_exists(Middleware::class)) {
$container->addCompilerPass(new MiddlewaresPass());
}

if (! class_exists(RegisterUidTypePass::class)) {
return;
}
Expand Down
1 change: 1 addition & 0 deletions Mapping/ClassMetadataFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ protected function doLoadMetadata($class, $parent, $rootEntityFound, array $nonS
{
parent::doLoadMetadata($class, $parent, $rootEntityFound, $nonSuperclassParents);

/** @psalm-suppress TypeDoesNotContainType **/
if (! $class instanceof ClassMetadataInfo) {
return;
}
Expand Down
8 changes: 8 additions & 0 deletions Middleware/ConnectionNameAwareInterface.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
<?php

namespace Doctrine\Bundle\DoctrineBundle\Middleware;

interface ConnectionNameAwareInterface
{
public function setConnectionName(string $name): void;
}
3 changes: 0 additions & 3 deletions Resources/config/dbal.xml
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@
<argument type="service" id="debug.stopwatch" on-invalid="null" />
</service>

<service id="doctrine.dbal.logging_middleware" class="Doctrine\DBAL\Logging\Middleware" abstract="true">
</service>

<service id="data_collector.doctrine" class="%doctrine.data_collector.class%" public="false">
<tag name="data_collector" template="@Doctrine/Collector/db.html.twig" id="db" priority="250" />
<argument type="service" id="doctrine" />
Expand Down
14 changes: 14 additions & 0 deletions Resources/config/middlewares.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
<?xml version="1.0" ?>

<container xmlns="http://symfony.com/schema/dic/services"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://symfony.com/schema/dic/services http://symfony.com/schema/dic/services/services-1.0.xsd">

<services>
<service id="doctrine.dbal.logging_middleware" class="Doctrine\DBAL\Logging\Middleware" abstract="true">
<argument type="service" id="logger" on-invalid="null" />
<tag name="monolog.logger" channel="doctrine" />
<tag name="doctrine.middleware" />
</service>
</services>
</container>
170 changes: 170 additions & 0 deletions Tests/DependencyInjection/Compiler/MiddlewarePassTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
<?php

namespace Doctrine\Bundle\DoctrineBundle\Tests\DependencyInjection\Compiler;

use Doctrine\Bundle\DoctrineBundle\DependencyInjection\Compiler\MiddlewaresPass;
use Doctrine\Bundle\DoctrineBundle\DependencyInjection\DoctrineExtension;
use Doctrine\Bundle\DoctrineBundle\Middleware\ConnectionNameAwareInterface;
use Doctrine\DBAL\Driver;
use Doctrine\DBAL\Driver\Middleware;
use PHPUnit\Framework\TestCase;
use Symfony\Component\DependencyInjection\ContainerBuilder;
use Symfony\Component\DependencyInjection\ParameterBag\ParameterBag;

use function sprintf;

class MiddlewarePassTest extends TestCase
{
/** @return array<string, mixed[]> */
public function provideAddMiddleware(): array
{
return [
'not connection name aware' => [Middleware1::class, false],
'connection name aware' => [Middleware2::class, true],
];
}

/** @dataProvider provideAddMiddleware */
public function testAddMiddleware(string $middlewareClass, bool $connectionNameAware): void
{
/** @psalm-suppress UndefinedClass */
if (interface_exists(Middleware::class)) {
$this->markTestSkipped(sprintf('%s needed to run this test', Middleware::class));
}

$container = $this->createContainer(static function (ContainerBuilder $container) use ($middlewareClass) {
$container
->register('middleware', $middlewareClass)
->setAbstract(true)
->addTag('doctrine.middleware');

$container
->setAlias('conf_conn1', 'doctrine.dbal.conn1_connection.configuration')
->setPublic(true); // Avoid removal and inlining

$container
->setAlias('conf_conn2', 'doctrine.dbal.conn2_connection.configuration')
->setPublic(true); // Avoid removal and inlining
});

$this->assertMiddlewareInjected('conn1', $middlewareClass, $connectionNameAware, $container);
$this->assertMiddlewareInjected('conn2', $middlewareClass, $connectionNameAware, $container);
}

public function testAddMiddlewareWithAutoconfigure(): void
{
/** @psalm-suppress UndefinedClass */
if (interface_exists(Middleware::class)) {
$this->markTestSkipped(sprintf('%s needed to run this test', Middleware::class));
}

$container = $this->createContainer(static function (ContainerBuilder $container) {
/** @psalm-suppress UndefinedClass */
$container
->register('middleware', Middleware3::class)
->setAutoconfigured(true);

$container
->setAlias('conf_conn1', 'doctrine.dbal.conn1_connection.configuration')
->setPublic(true); // Avoid removal and inlining

$container
->setAlias('conf_conn2', 'doctrine.dbal.conn2_connection.configuration')
->setPublic(true); // Avoid removal and inlining
});

/** @psalm-suppress UndefinedClass */
$this->assertMiddlewareInjected('conn1', Middleware3::class, false, $container);
/** @psalm-suppress UndefinedClass */
$this->assertMiddlewareInjected('conn2', Middleware3::class, false, $container);
}

private function createContainer(callable $func): ContainerBuilder
{
$container = new ContainerBuilder(new ParameterBag(['kernel.debug' => false]));

$container->registerExtension(new DoctrineExtension());
$container->loadFromExtension('doctrine', [
'dbal' => [
'connections' => [
'conn1' => ['url' => 'mysql://user:pass@server1.tld:3306/db1'],
'conn2' => ['url' => 'mysql://user:pass@server2.tld:3306/db2'],
],
],
]);

$container->addCompilerPass(new MiddlewaresPass());

$func($container);

$container->compile();

return $container;
}

private function assertMiddlewareInjected(
string $connName,
string $middlewareClass,
bool $connectionNameAware,
ContainerBuilder $container
): void {
$calls = $container->getDefinition('conf_' . $connName)->getMethodCalls();
$middlewareFound = [];
foreach ($calls as $call) {
if ($call[0] !== 'setMiddlewares' || ! isset($call[1][0])) {
continue;
}

foreach ($call[1][0] as $middlewareDefs) {
if ($middlewareDefs->getClass() !== $middlewareClass) {
continue;
}

$middlewareFound[] = $middlewareDefs;
}
}

$this->assertCount(1, $middlewareFound, sprintf(
'Middleware not injected in doctrine.dbal.%s_connection.configuration',
$connName
));

$callsFound = [];
foreach ($middlewareFound[0]->getMethodCalls() as $call) {
if ($call[0] !== 'setConnectionName') {
continue;
}

$callsFound[] = $call;
}

$this->assertCount($connectionNameAware ? 1 : 0, $callsFound);
if (! $connectionNameAware) {
return;
}

$this->assertSame($call[1][0] ?? null, $connName);
}
}

class Middleware1
{
}

class Middleware2 implements ConnectionNameAwareInterface
{
public function setConnectionName(string $name): void
{
}
}

/** @psalm-suppress UndefinedClass */
if (interface_exists(Middleware::class)) {
class Middleware3 implements Middleware
{
public function wrap(Driver $driver): Driver
{
return $driver;
}
}
}
Loading

0 comments on commit b4a7efa

Please sign in to comment.