Skip to content

Commit

Permalink
LLM OCP API: Change Tests to use EventDispatcher mock
Browse files Browse the repository at this point in the history
Signed-off-by: Marcel Klehr <mklehr@gmx.net>
  • Loading branch information
marcelklehr committed Jul 7, 2023
1 parent 9f405a1 commit bf2dcd6
Showing 1 changed file with 8 additions and 41 deletions.
49 changes: 8 additions & 41 deletions tests/lib/LanguageModel/LanguageModelManagerTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
use OCP\LanguageModel\SummaryTask;
use OCP\LanguageModel\TopicsTask;
use OCP\PreConditionNotMetException;
use PHPUnit\Framework\Constraint\IsInstanceOf;
use Psr\Log\LoggerInterface;
use Test\BackgroundJob\DummyJobList;

Expand Down Expand Up @@ -63,7 +64,7 @@ public function prompt(string $prompt): string {
}
}

class TestFullLanguageModelProvider implements ILanguageModelProvider, ISummaryProvider, IHeadlineProvider {
class TestAdvancedLanguageModelProvider implements ILanguageModelProvider, ISummaryProvider, IHeadlineProvider {
public function getName(): string {
return 'TEST Full LLM Provider';
}
Expand All @@ -90,7 +91,7 @@ protected function setUp(): void {

$this->providers = [
TestVanillaLanguageModelProvider::class => new TestVanillaLanguageModelProvider(),
TestFullLanguageModelProvider::class => new TestFullLanguageModelProvider(),
TestAdvancedLanguageModelProvider::class => new TestAdvancedLanguageModelProvider(),
TestFailingLanguageModelProvider::class => new TestFailingLanguageModelProvider(),
];

Expand Down Expand Up @@ -214,24 +215,8 @@ public function testProviderShouldBeRegisteredAndScheduled() {
$this->assertNull($task2->getOutput());
$this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task2->getStatus());

/** @var IEventDispatcher $this->eventDispatcher */
$this->eventDispatcher = \OC::$server->get(IEventDispatcher::class);
$successfulEventFired = false;
$this->eventDispatcher->addListener(TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $event) use (&$successfulEventFired, $task) {
$successfulEventFired = true;
$t = $event->getTask();
$this->assertEquals($task->getId(), $t->getId());
$this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $t->getStatus());
$this->assertEquals('Hello Free Prompt', $t->getOutput());
});
$failedEventFired = false;
$this->eventDispatcher->addListener(TaskFailedEvent::class, function (TaskFailedEvent $event) use (&$failedEventFired, $task) {
$failedEventFired = true;
$t = $event->getTask();
$this->assertEquals($task->getId(), $t->getId());
$this->assertEquals(ILanguageModelTask::STATUS_FAILED, $t->getStatus());
$this->assertEquals('ERROR', $event->getErrorMessage());
});
$this->eventDispatcher = $this->createMock(IEventDispatcher::class);
$this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));

// run background job
$bgJob = new TaskBackgroundJob(
Expand All @@ -243,8 +228,6 @@ public function testProviderShouldBeRegisteredAndScheduled() {
$bgJob->start($this->jobList);
$provider = $this->providers[TestVanillaLanguageModelProvider::class];
$this->assertTrue($provider->ran);
$this->assertTrue($successfulEventFired);
$this->assertFalse($failedEventFired);

// Task object retrieved from db is up-to-date
$task3 = $this->languageModelManager->getTask($task->getId());
Expand All @@ -257,7 +240,7 @@ public function testProviderShouldBeRegisteredAndScheduled() {
public function testMultipleProvidersShouldBeRegisteredAndRunCorrectly() {
$this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([
new ServiceRegistration('test', TestVanillaLanguageModelProvider::class),
new ServiceRegistration('test', TestFullLanguageModelProvider::class),
new ServiceRegistration('test', TestAdvancedLanguageModelProvider::class),
]);
$this->assertCount(3, $this->languageModelManager->getAvailableTaskClasses());
$this->assertCount(3, $this->languageModelManager->getAvailableTaskTypes());
Expand Down Expand Up @@ -312,22 +295,8 @@ public function testTaskFailure() {
$this->assertNull($task2->getOutput());
$this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task2->getStatus());

$successfulEventFired = false;
$this->eventDispatcher->addListener(TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $event) use (&$successfulEventFired, $task) {
$successfulEventFired = true;
$t = $event->getTask();
$this->assertEquals($task->getId(), $t->getId());
$this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $t->getStatus());
$this->assertEquals('Hello Free Prompt', $t->getOutput());
});
$failedEventFired = false;
$this->eventDispatcher->addListener(TaskFailedEvent::class, function (TaskFailedEvent $event) use (&$failedEventFired, $task) {
$failedEventFired = true;
$t = $event->getTask();
$this->assertEquals($task->getId(), $t->getId());
$this->assertEquals(ILanguageModelTask::STATUS_FAILED, $t->getStatus());
$this->assertEquals('ERROR', $event->getErrorMessage());
});
$this->eventDispatcher = $this->createMock(IEventDispatcher::class);
$this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class));

// run background job
$bgJob = new TaskBackgroundJob(
Expand All @@ -339,8 +308,6 @@ public function testTaskFailure() {
$bgJob->start($this->jobList);
$provider = $this->providers[TestFailingLanguageModelProvider::class];
$this->assertTrue($provider->ran);
$this->assertTrue($failedEventFired);
$this->assertFalse($successfulEventFired);

// Task object retrieved from db is up-to-date
$task3 = $this->languageModelManager->getTask($task->getId());
Expand Down

0 comments on commit bf2dcd6

Please sign in to comment.