Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: Marcel Klehr <mklehr@gmx.net>
  • Loading branch information
marcelklehr committed Jul 7, 2023
1 parent ebc7631 commit 20cb993
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 29 deletions.
2 changes: 1 addition & 1 deletion tests/lib/BackgroundJob/DummyJobList.php
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ public function setLastRun(IJob $job): void {
}

public function hasReservedJob(?string $className = null): bool {
return $this->reserved[$className ?? ''];
return isset($this->reserved[$className ?? '']) && $this->reserved[$className ?? ''];
}

public function setHasReservedJob(?string $className, bool $hasReserved): void {
Expand Down
126 changes: 98 additions & 28 deletions tests/lib/LanguageModel/LanguageModelManagerTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@
namespace Test\LanguageModel;

use OC\AppFramework\Bootstrap\Coordinator;
use OC\AppFramework\Bootstrap\RegistrationContext;
use OC\AppFramework\Bootstrap\ServiceRegistration;
use OC\EventDispatcher\EventDispatcher;
use OC\LanguageModel\Db\Task;
use OC\LanguageModel\Db\TaskMapper;
use OC\LanguageModel\LanguageModelManager;
use OC\LanguageModel\TaskBackgroundJob;
use OCP\BackgroundJob\IJobList;
use OCP\AppFramework\Db\DoesNotExistException;
use OCP\AppFramework\Utility\ITimeFactory;
use OCP\Common\Exception\NotFoundException;
use OCP\EventDispatcher\IEventDispatcher;
use OCP\IServerContainer;
Expand Down Expand Up @@ -82,16 +87,69 @@ class LanguageModelManagerTest extends \Test\TestCase {
protected function setUp(): void {
parent::setUp();

$this->providers = [
TestVanillaLanguageModelProvider::class => new TestVanillaLanguageModelProvider(),
TestFullLanguageModelProvider::class => new TestFullLanguageModelProvider(),
TestFailingLanguageModelProvider::class => new TestFailingLanguageModelProvider(),
];

$this->serverContainer = $this->createMock(IServerContainer::class);
$this->serverContainer->expects($this->any())->method('get')->willReturnCallback(function ($class) {
return $this->providers[$class];
});

$this->eventDispatcher = new EventDispatcher(
new \Symfony\Component\EventDispatcher\EventDispatcher(),
$this->serverContainer,
\OC::$server->get(LoggerInterface::class),
);

$this->registrationContext = $this->createMock(RegistrationContext::class);
$this->coordinator = $this->createMock(Coordinator::class);
$this->coordinator->expects($this->any())->method('getRegistrationContext')->willReturn($this->registrationContext);

$this->taskMapper = $this->createMock(TaskMapper::class);
$this->tasksDb = [];
$this->taskMapper
->expects($this->any())
->method('insert')
->willReturnCallback(function (Task $task) {
$task->setId(count($this->tasksDb) ? max(array_keys($this->tasksDb)) : 1);
$this->tasksDb[$task->getId()] = $task->toRow();
return $task;
});
$this->taskMapper
->expects($this->any())
->method('update')
->willReturnCallback(function (Task $task) {
$this->tasksDb[$task->getId()] = $task->toRow();
return $task;
});
$this->taskMapper
->expects($this->any())
->method('find')
->willReturnCallback(function (int $id) {
if (!isset($this->tasksDb[$id])) {
throw new DoesNotExistException('Could not find it');
}
return Task::fromRow($this->tasksDb[$id]);
});

$this->jobList = $this->createPartialMock(DummyJobList::class, ['add']);
$this->jobList->expects($this->any())->method('add')->willReturnCallback(function () {
});

$this->languageModelManager = new LanguageModelManager(
\OC::$server->get(IServerContainer::class),
$this->coordinator = \OC::$server->get(Coordinator::class),
$this->serverContainer,
$this->coordinator,
\OC::$server->get(LoggerInterface::class),
\OC::$server->get(IJobList::class),
\OC::$server->get(TaskMapper::class),
$this->jobList,
$this->taskMapper,
);
}

public function testShouldNotHaveAnyProviders() {
$this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([]);
$this->assertCount(0, $this->languageModelManager->getAvailableTasks());
$this->assertCount(0, $this->languageModelManager->getAvailableTaskTypes());
$this->assertFalse($this->languageModelManager->hasProviders());
Expand All @@ -100,7 +158,9 @@ public function testShouldNotHaveAnyProviders() {
}

public function testProviderShouldBeRegisteredAndRun() {
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestVanillaLanguageModelProvider::class);
$this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([
new ServiceRegistration('test', TestVanillaLanguageModelProvider::class)
]);
$this->assertCount(1, $this->languageModelManager->getAvailableTasks());
$this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes());
$this->assertTrue($this->languageModelManager->hasProviders());
Expand All @@ -113,7 +173,9 @@ public function testProviderShouldBeRegisteredAndRun() {

public function testProviderShouldBeRegisteredAndScheduled() {
// register provider
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestVanillaLanguageModelProvider::class);
$this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([
new ServiceRegistration('test', TestVanillaLanguageModelProvider::class)
]);
$this->assertCount(1, $this->languageModelManager->getAvailableTasks());
$this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes());
$this->assertTrue($this->languageModelManager->hasProviders());
Expand All @@ -139,18 +201,18 @@ public function testProviderShouldBeRegisteredAndScheduled() {
$this->assertNull($task2->getOutput());
$this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task2->getStatus());

/** @var IEventDispatcher $eventDispatcher */
$eventDispatcher = \OC::$server->get(IEventDispatcher::class);
/** @var IEventDispatcher $this->eventDispatcher */
$this->eventDispatcher = \OC::$server->get(IEventDispatcher::class);
$successfulEventFired = false;
$eventDispatcher->addListener(TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $event) use (&$successfulEventFired, $task) {
$this->eventDispatcher->addListener(TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $event) use (&$successfulEventFired, $task) {
$successfulEventFired = true;
$t = $event->getTask();
$this->assertEquals($task->getId(), $t->getId());
$this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $t->getStatus());
$this->assertEquals('Hello Free Prompt', $t->getOutput());
});
$failedEventFired = false;
$eventDispatcher->addListener(TaskFailedEvent::class, function (TaskFailedEvent $event) use (&$failedEventFired, $task) {
$this->eventDispatcher->addListener(TaskFailedEvent::class, function (TaskFailedEvent $event) use (&$failedEventFired, $task) {
$failedEventFired = true;
$t = $event->getTask();
$this->assertEquals($task->getId(), $t->getId());
Expand All @@ -159,11 +221,14 @@ public function testProviderShouldBeRegisteredAndScheduled() {
});

// run background job
/** @var TaskBackgroundJob $bgJob */
$bgJob = \OC::$server->get(TaskBackgroundJob::class);
$bgJob = new TaskBackgroundJob(
\OC::$server->get(ITimeFactory::class),
$this->languageModelManager,
$this->eventDispatcher,
);
$bgJob->setArgument(['taskId' => $task->getId()]);
$bgJob->start(new DummyJobList());
$provider = \OC::$server->get(TestVanillaLanguageModelProvider::class);
$bgJob->start($this->jobList);
$provider = $this->providers[TestVanillaLanguageModelProvider::class];
$this->assertTrue($provider->ran);
$this->assertTrue($successfulEventFired);
$this->assertFalse($failedEventFired);
Expand All @@ -173,12 +238,14 @@ public function testProviderShouldBeRegisteredAndScheduled() {
$this->assertEquals($task->getId(), $task3->getId());
$this->assertEquals('Hello', $task3->getInput());
$this->assertEquals('Hello Free Prompt', $task3->getOutput());
$this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $task2->getStatus());
$this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $task3->getStatus());
}

public function testMultipleProvidersShouldBeRegisteredAndRunCorrectly() {
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestVanillaLanguageModelProvider::class);
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestFullLanguageModelProvider::class);
$this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([
new ServiceRegistration('test', TestVanillaLanguageModelProvider::class),
new ServiceRegistration('test', TestFullLanguageModelProvider::class),
]);
$this->assertCount(3, $this->languageModelManager->getAvailableTasks());
$this->assertCount(3, $this->languageModelManager->getAvailableTaskTypes());
$this->assertTrue($this->languageModelManager->hasProviders());
Expand All @@ -204,7 +271,9 @@ public function testNonexistentTask() {

public function testTaskFailure() {
// register provider
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestFailingLanguageModelProvider::class);
$this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([
new ServiceRegistration('test', TestFailingLanguageModelProvider::class),
]);
$this->assertCount(1, $this->languageModelManager->getAvailableTasks());
$this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes());
$this->assertTrue($this->languageModelManager->hasProviders());
Expand All @@ -230,18 +299,16 @@ public function testTaskFailure() {
$this->assertNull($task2->getOutput());
$this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task2->getStatus());

/** @var IEventDispatcher $eventDispatcher */
$eventDispatcher = \OC::$server->get(IEventDispatcher::class);
$successfulEventFired = false;
$eventDispatcher->addListener(TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $event) use (&$successfulEventFired, $task) {
$this->eventDispatcher->addListener(TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $event) use (&$successfulEventFired, $task) {
$successfulEventFired = true;
$t = $event->getTask();
$this->assertEquals($task->getId(), $t->getId());
$this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $t->getStatus());
$this->assertEquals('Hello Free Prompt', $t->getOutput());
});
$failedEventFired = false;
$eventDispatcher->addListener(TaskFailedEvent::class, function (TaskFailedEvent $event) use (&$failedEventFired, $task) {
$this->eventDispatcher->addListener(TaskFailedEvent::class, function (TaskFailedEvent $event) use (&$failedEventFired, $task) {
$failedEventFired = true;
$t = $event->getTask();
$this->assertEquals($task->getId(), $t->getId());
Expand All @@ -250,11 +317,14 @@ public function testTaskFailure() {
});

// run background job
/** @var TaskBackgroundJob $bgJob */
$bgJob = \OC::$server->get(TaskBackgroundJob::class);
$bgJob = new TaskBackgroundJob(
\OC::$server->get(ITimeFactory::class),
$this->languageModelManager,
$this->eventDispatcher,
);
$bgJob->setArgument(['taskId' => $task->getId()]);
$bgJob->start(new DummyJobList());
$provider = \OC::$server->get(TestFailingLanguageModelProvider::class);
$bgJob->start($this->jobList);
$provider = $this->providers[TestFailingLanguageModelProvider::class];
$this->assertTrue($provider->ran);
$this->assertTrue($failedEventFired);
$this->assertFalse($successfulEventFired);
Expand All @@ -264,6 +334,6 @@ public function testTaskFailure() {
$this->assertEquals($task->getId(), $task3->getId());
$this->assertEquals('Hello', $task3->getInput());
$this->assertNull($task3->getOutput());
$this->assertEquals(ILanguageModelTask::STATUS_FAILED, $task2->getStatus());
$this->assertEquals(ILanguageModelTask::STATUS_FAILED, $task3->getStatus());
}
}

0 comments on commit 20cb993

Please sign in to comment.