diff --git a/tests/lib/LanguageModel/LanguageModelManagerTest.php b/tests/lib/LanguageModel/LanguageModelManagerTest.php new file mode 100644 index 0000000000000..80cf4348a0118 --- /dev/null +++ b/tests/lib/LanguageModel/LanguageModelManagerTest.php @@ -0,0 +1,269 @@ + + * This file is licensed under the Affero General Public License version 3 or + * later. + * See the COPYING-README file. + */ + +namespace Test\LanguageModel; + +use OC\AppFramework\Bootstrap\Coordinator; +use OC\LanguageModel\Db\TaskMapper; +use OC\LanguageModel\LanguageModelManager; +use OC\LanguageModel\TaskBackgroundJob; +use OCP\BackgroundJob\IJobList; +use OCP\Common\Exception\NotFoundException; +use OCP\EventDispatcher\IEventDispatcher; +use OCP\IServerContainer; +use OCP\LanguageModel\Events\TaskFailedEvent; +use OCP\LanguageModel\Events\TaskSuccessfulEvent; +use OCP\LanguageModel\FreePromptTask; +use OCP\LanguageModel\HeadlineTask; +use OCP\LanguageModel\IHeadlineProvider; +use OCP\LanguageModel\ILanguageModelManager; +use OCP\LanguageModel\ILanguageModelProvider; +use OCP\LanguageModel\ILanguageModelTask; +use OCP\LanguageModel\ISummaryProvider; +use OCP\LanguageModel\SummaryTask; +use OCP\LanguageModel\TopicsTask; +use OCP\PreConditionNotMetException; +use Psr\Log\LoggerInterface; +use Test\BackgroundJob\DummyJobList; + +class TestVanillaLanguageModelProvider implements ILanguageModelProvider { + public bool $ran = false; + + public function getName(): string { + return 'TEST Vanilla LLM Provider'; + } + + public function prompt(string $prompt): string { + $this->ran = true; + return $prompt . ' Free Prompt'; + } +} + +class TestFailingLanguageModelProvider implements ILanguageModelProvider { + public bool $ran = false; + + public function getName(): string { + return 'TEST Vanilla LLM Provider'; + } + + public function prompt(string $prompt): string { + $this->ran = true; + throw new \Exception('ERROR'); + } +} + +class TestFullLanguageModelProvider implements ILanguageModelProvider, ISummaryProvider, IHeadlineProvider { + public function getName(): string { + return 'TEST Full LLM Provider'; + } + + public function prompt(string $prompt): string { + return $prompt . ' Free Prompt'; + } + + public function findHeadline(string $text): string { + return $text . ' Headline'; + } + + public function summarize(string $text): string { + return $text. ' Summarize'; + } +} + +class LanguageModelManagerTest extends \Test\TestCase { + private ILanguageModelManager $languageModelManager; + private Coordinator $coordinator; + + protected function setUp(): void { + parent::setUp(); + + $this->languageModelManager = new LanguageModelManager( + \OC::$server->get(IServerContainer::class), + $this->coordinator = \OC::$server->get(Coordinator::class), + \OC::$server->get(LoggerInterface::class), + \OC::$server->get(IJobList::class), + \OC::$server->get(TaskMapper::class), + ); + } + + public function testShouldNotHaveAnyProviders() { + $this->assertCount(0, $this->languageModelManager->getAvailableTasks()); + $this->assertCount(0, $this->languageModelManager->getAvailableTaskTypes()); + $this->assertFalse($this->languageModelManager->hasProviders()); + $this->expectException(PreConditionNotMetException::class); + $this->languageModelManager->runTask(new FreePromptTask('Hello', 'test', null)); + } + + public function testProviderShouldBeRegisteredAndRun() { + $this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestVanillaLanguageModelProvider::class); + $this->assertCount(1, $this->languageModelManager->getAvailableTasks()); + $this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes()); + $this->assertTrue($this->languageModelManager->hasProviders()); + $this->assertEquals('Hello Free Prompt', $this->languageModelManager->runTask(new FreePromptTask('Hello', 'test', null))); + + // Summaries are not implemented by the vanilla provider, only free prompt + $this->expectException(PreConditionNotMetException::class); + $this->languageModelManager->runTask(new SummaryTask('Hello', 'test', null)); + } + + public function testProviderShouldBeRegisteredAndScheduled() { + // register provider + $this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestVanillaLanguageModelProvider::class); + $this->assertCount(1, $this->languageModelManager->getAvailableTasks()); + $this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes()); + $this->assertTrue($this->languageModelManager->hasProviders()); + + // create task object + $task = new FreePromptTask('Hello', 'test', null); + $this->assertNull($task->getId()); + $this->assertNull($task->getOutput()); + + // schedule works + $this->assertEquals(ILanguageModelTask::STATUS_UNKNOWN, $task->getStatus()); + $this->languageModelManager->scheduleTask($task); + + // Task object is up-to-date + $this->assertNotNull($task->getId()); + $this->assertNull($task->getOutput()); + $this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task->getStatus()); + + // Task object retrieved from db is up-to-date + $task2 = $this->languageModelManager->getTask($task->getId()); + $this->assertEquals($task->getId(), $task2->getId()); + $this->assertEquals('Hello', $task2->getInput()); + $this->assertNull($task2->getOutput()); + $this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task2->getStatus()); + + /** @var IEventDispatcher $eventDispatcher */ + $eventDispatcher = \OC::$server->get(IEventDispatcher::class); + $successfulEventFired = false; + $eventDispatcher->addListener(TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $event) use (&$successfulEventFired, $task) { + $successfulEventFired = true; + $t = $event->getTask(); + $this->assertEquals($task->getId(), $t->getId()); + $this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $t->getStatus()); + $this->assertEquals('Hello Free Prompt', $t->getOutput()); + }); + $failedEventFired = false; + $eventDispatcher->addListener(TaskFailedEvent::class, function (TaskFailedEvent $event) use (&$failedEventFired, $task) { + $failedEventFired = true; + $t = $event->getTask(); + $this->assertEquals($task->getId(), $t->getId()); + $this->assertEquals(ILanguageModelTask::STATUS_FAILED, $t->getStatus()); + $this->assertEquals('ERROR', $event->getErrorMessage()); + }); + + // run background job + /** @var TaskBackgroundJob $bgJob */ + $bgJob = \OC::$server->get(TaskBackgroundJob::class); + $bgJob->setArgument(['taskId' => $task->getId()]); + $bgJob->start(new DummyJobList()); + $provider = \OC::$server->get(TestVanillaLanguageModelProvider::class); + $this->assertTrue($provider->ran); + $this->assertTrue($successfulEventFired); + $this->assertFalse($failedEventFired); + + // Task object retrieved from db is up-to-date + $task3 = $this->languageModelManager->getTask($task->getId()); + $this->assertEquals($task->getId(), $task3->getId()); + $this->assertEquals('Hello', $task3->getInput()); + $this->assertEquals('Hello Free Prompt', $task3->getOutput()); + $this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $task2->getStatus()); + } + + public function testMultipleProvidersShouldBeRegisteredAndRunCorrectly() { + $this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestVanillaLanguageModelProvider::class); + $this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestFullLanguageModelProvider::class); + $this->assertCount(3, $this->languageModelManager->getAvailableTasks()); + $this->assertCount(3, $this->languageModelManager->getAvailableTaskTypes()); + $this->assertTrue($this->languageModelManager->hasProviders()); + + // Try free prompt again + $this->assertEquals('Hello Free Prompt', $this->languageModelManager->runTask(new FreePromptTask('Hello', 'test', null))); + + // Try headline task + $this->assertEquals('Hello Headline', $this->languageModelManager->runTask(new HeadlineTask('Hello', 'test', null))); + + // Try summary task + $this->assertEquals('Hello Summarize', $this->languageModelManager->runTask(new SummaryTask('Hello', 'test', null))); + + // Topics are not implemented by both the vanilla provider and the full provider + $this->expectException(PreConditionNotMetException::class); + $this->languageModelManager->runTask(new TopicsTask('Hello', 'test', null)); + } + + public function testNonexistentTask() { + $this->expectException(NotFoundException::class); + $this->languageModelManager->getTask(98765432456); + } + + public function testTaskFailure() { + // register provider + $this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestFailingLanguageModelProvider::class); + $this->assertCount(1, $this->languageModelManager->getAvailableTasks()); + $this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes()); + $this->assertTrue($this->languageModelManager->hasProviders()); + + // create task object + $task = new FreePromptTask('Hello', 'test', null); + $this->assertNull($task->getId()); + $this->assertNull($task->getOutput()); + + // schedule works + $this->assertEquals(ILanguageModelTask::STATUS_UNKNOWN, $task->getStatus()); + $this->languageModelManager->scheduleTask($task); + + // Task object is up-to-date + $this->assertNotNull($task->getId()); + $this->assertNull($task->getOutput()); + $this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task->getStatus()); + + // Task object retrieved from db is up-to-date + $task2 = $this->languageModelManager->getTask($task->getId()); + $this->assertEquals($task->getId(), $task2->getId()); + $this->assertEquals('Hello', $task2->getInput()); + $this->assertNull($task2->getOutput()); + $this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task2->getStatus()); + + /** @var IEventDispatcher $eventDispatcher */ + $eventDispatcher = \OC::$server->get(IEventDispatcher::class); + $successfulEventFired = false; + $eventDispatcher->addListener(TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $event) use (&$successfulEventFired, $task) { + $successfulEventFired = true; + $t = $event->getTask(); + $this->assertEquals($task->getId(), $t->getId()); + $this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $t->getStatus()); + $this->assertEquals('Hello Free Prompt', $t->getOutput()); + }); + $failedEventFired = false; + $eventDispatcher->addListener(TaskFailedEvent::class, function (TaskFailedEvent $event) use (&$failedEventFired, $task) { + $failedEventFired = true; + $t = $event->getTask(); + $this->assertEquals($task->getId(), $t->getId()); + $this->assertEquals(ILanguageModelTask::STATUS_FAILED, $t->getStatus()); + $this->assertEquals('ERROR', $event->getErrorMessage()); + }); + + // run background job + /** @var TaskBackgroundJob $bgJob */ + $bgJob = \OC::$server->get(TaskBackgroundJob::class); + $bgJob->setArgument(['taskId' => $task->getId()]); + $bgJob->start(new DummyJobList()); + $provider = \OC::$server->get(TestFailingLanguageModelProvider::class); + $this->assertTrue($provider->ran); + $this->assertTrue($failedEventFired); + $this->assertFalse($successfulEventFired); + + // Task object retrieved from db is up-to-date + $task3 = $this->languageModelManager->getTask($task->getId()); + $this->assertEquals($task->getId(), $task3->getId()); + $this->assertEquals('Hello', $task3->getInput()); + $this->assertNull($task3->getOutput()); + $this->assertEquals(ILanguageModelTask::STATUS_FAILED, $task2->getStatus()); + } +}