-
-
Notifications
You must be signed in to change notification settings - Fork 4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Marcel Klehr <mklehr@gmx.net>
- Loading branch information
1 parent
27e1c86
commit ebc7631
Showing
1 changed file
with
269 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,269 @@ | ||
<?php | ||
/** | ||
* Copyright (c) 2023 Marcel Klehr <mklehr@gmx.net> | ||
* This file is licensed under the Affero General Public License version 3 or | ||
* later. | ||
* See the COPYING-README file. | ||
*/ | ||
|
||
namespace Test\LanguageModel; | ||
|
||
use OC\AppFramework\Bootstrap\Coordinator; | ||
use OC\LanguageModel\Db\TaskMapper; | ||
use OC\LanguageModel\LanguageModelManager; | ||
use OC\LanguageModel\TaskBackgroundJob; | ||
use OCP\BackgroundJob\IJobList; | ||
use OCP\Common\Exception\NotFoundException; | ||
use OCP\EventDispatcher\IEventDispatcher; | ||
use OCP\IServerContainer; | ||
use OCP\LanguageModel\Events\TaskFailedEvent; | ||
use OCP\LanguageModel\Events\TaskSuccessfulEvent; | ||
use OCP\LanguageModel\FreePromptTask; | ||
use OCP\LanguageModel\HeadlineTask; | ||
use OCP\LanguageModel\IHeadlineProvider; | ||
use OCP\LanguageModel\ILanguageModelManager; | ||
use OCP\LanguageModel\ILanguageModelProvider; | ||
use OCP\LanguageModel\ILanguageModelTask; | ||
use OCP\LanguageModel\ISummaryProvider; | ||
use OCP\LanguageModel\SummaryTask; | ||
use OCP\LanguageModel\TopicsTask; | ||
use OCP\PreConditionNotMetException; | ||
use Psr\Log\LoggerInterface; | ||
use Test\BackgroundJob\DummyJobList; | ||
|
||
class TestVanillaLanguageModelProvider implements ILanguageModelProvider { | ||
public bool $ran = false; | ||
|
||
public function getName(): string { | ||
return 'TEST Vanilla LLM Provider'; | ||
} | ||
|
||
public function prompt(string $prompt): string { | ||
$this->ran = true; | ||
return $prompt . ' Free Prompt'; | ||
} | ||
} | ||
|
||
class TestFailingLanguageModelProvider implements ILanguageModelProvider { | ||
public bool $ran = false; | ||
|
||
public function getName(): string { | ||
return 'TEST Vanilla LLM Provider'; | ||
} | ||
|
||
public function prompt(string $prompt): string { | ||
$this->ran = true; | ||
throw new \Exception('ERROR'); | ||
} | ||
} | ||
|
||
class TestFullLanguageModelProvider implements ILanguageModelProvider, ISummaryProvider, IHeadlineProvider { | ||
public function getName(): string { | ||
return 'TEST Full LLM Provider'; | ||
} | ||
|
||
public function prompt(string $prompt): string { | ||
return $prompt . ' Free Prompt'; | ||
} | ||
|
||
public function findHeadline(string $text): string { | ||
return $text . ' Headline'; | ||
} | ||
|
||
public function summarize(string $text): string { | ||
return $text. ' Summarize'; | ||
} | ||
} | ||
|
||
class LanguageModelManagerTest extends \Test\TestCase { | ||
private ILanguageModelManager $languageModelManager; | ||
private Coordinator $coordinator; | ||
|
||
protected function setUp(): void { | ||
parent::setUp(); | ||
|
||
$this->languageModelManager = new LanguageModelManager( | ||
\OC::$server->get(IServerContainer::class), | ||
$this->coordinator = \OC::$server->get(Coordinator::class), | ||
\OC::$server->get(LoggerInterface::class), | ||
\OC::$server->get(IJobList::class), | ||
\OC::$server->get(TaskMapper::class), | ||
); | ||
} | ||
|
||
public function testShouldNotHaveAnyProviders() { | ||
$this->assertCount(0, $this->languageModelManager->getAvailableTasks()); | ||
$this->assertCount(0, $this->languageModelManager->getAvailableTaskTypes()); | ||
$this->assertFalse($this->languageModelManager->hasProviders()); | ||
$this->expectException(PreConditionNotMetException::class); | ||
$this->languageModelManager->runTask(new FreePromptTask('Hello', 'test', null)); | ||
} | ||
|
||
public function testProviderShouldBeRegisteredAndRun() { | ||
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestVanillaLanguageModelProvider::class); | ||
$this->assertCount(1, $this->languageModelManager->getAvailableTasks()); | ||
$this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes()); | ||
$this->assertTrue($this->languageModelManager->hasProviders()); | ||
$this->assertEquals('Hello Free Prompt', $this->languageModelManager->runTask(new FreePromptTask('Hello', 'test', null))); | ||
|
||
// Summaries are not implemented by the vanilla provider, only free prompt | ||
$this->expectException(PreConditionNotMetException::class); | ||
$this->languageModelManager->runTask(new SummaryTask('Hello', 'test', null)); | ||
} | ||
|
||
public function testProviderShouldBeRegisteredAndScheduled() { | ||
// register provider | ||
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestVanillaLanguageModelProvider::class); | ||
$this->assertCount(1, $this->languageModelManager->getAvailableTasks()); | ||
$this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes()); | ||
$this->assertTrue($this->languageModelManager->hasProviders()); | ||
|
||
// create task object | ||
$task = new FreePromptTask('Hello', 'test', null); | ||
$this->assertNull($task->getId()); | ||
$this->assertNull($task->getOutput()); | ||
|
||
// schedule works | ||
$this->assertEquals(ILanguageModelTask::STATUS_UNKNOWN, $task->getStatus()); | ||
$this->languageModelManager->scheduleTask($task); | ||
|
||
// Task object is up-to-date | ||
$this->assertNotNull($task->getId()); | ||
$this->assertNull($task->getOutput()); | ||
$this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task->getStatus()); | ||
|
||
// Task object retrieved from db is up-to-date | ||
$task2 = $this->languageModelManager->getTask($task->getId()); | ||
$this->assertEquals($task->getId(), $task2->getId()); | ||
$this->assertEquals('Hello', $task2->getInput()); | ||
$this->assertNull($task2->getOutput()); | ||
$this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task2->getStatus()); | ||
|
||
/** @var IEventDispatcher $eventDispatcher */ | ||
$eventDispatcher = \OC::$server->get(IEventDispatcher::class); | ||
$successfulEventFired = false; | ||
$eventDispatcher->addListener(TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $event) use (&$successfulEventFired, $task) { | ||
$successfulEventFired = true; | ||
$t = $event->getTask(); | ||
$this->assertEquals($task->getId(), $t->getId()); | ||
$this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $t->getStatus()); | ||
$this->assertEquals('Hello Free Prompt', $t->getOutput()); | ||
}); | ||
$failedEventFired = false; | ||
$eventDispatcher->addListener(TaskFailedEvent::class, function (TaskFailedEvent $event) use (&$failedEventFired, $task) { | ||
$failedEventFired = true; | ||
$t = $event->getTask(); | ||
$this->assertEquals($task->getId(), $t->getId()); | ||
$this->assertEquals(ILanguageModelTask::STATUS_FAILED, $t->getStatus()); | ||
$this->assertEquals('ERROR', $event->getErrorMessage()); | ||
}); | ||
|
||
// run background job | ||
/** @var TaskBackgroundJob $bgJob */ | ||
$bgJob = \OC::$server->get(TaskBackgroundJob::class); | ||
$bgJob->setArgument(['taskId' => $task->getId()]); | ||
$bgJob->start(new DummyJobList()); | ||
$provider = \OC::$server->get(TestVanillaLanguageModelProvider::class); | ||
$this->assertTrue($provider->ran); | ||
$this->assertTrue($successfulEventFired); | ||
$this->assertFalse($failedEventFired); | ||
|
||
// Task object retrieved from db is up-to-date | ||
$task3 = $this->languageModelManager->getTask($task->getId()); | ||
$this->assertEquals($task->getId(), $task3->getId()); | ||
$this->assertEquals('Hello', $task3->getInput()); | ||
$this->assertEquals('Hello Free Prompt', $task3->getOutput()); | ||
$this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $task2->getStatus()); | ||
} | ||
|
||
public function testMultipleProvidersShouldBeRegisteredAndRunCorrectly() { | ||
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestVanillaLanguageModelProvider::class); | ||
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestFullLanguageModelProvider::class); | ||
$this->assertCount(3, $this->languageModelManager->getAvailableTasks()); | ||
$this->assertCount(3, $this->languageModelManager->getAvailableTaskTypes()); | ||
$this->assertTrue($this->languageModelManager->hasProviders()); | ||
|
||
// Try free prompt again | ||
$this->assertEquals('Hello Free Prompt', $this->languageModelManager->runTask(new FreePromptTask('Hello', 'test', null))); | ||
|
||
// Try headline task | ||
$this->assertEquals('Hello Headline', $this->languageModelManager->runTask(new HeadlineTask('Hello', 'test', null))); | ||
|
||
// Try summary task | ||
$this->assertEquals('Hello Summarize', $this->languageModelManager->runTask(new SummaryTask('Hello', 'test', null))); | ||
|
||
// Topics are not implemented by both the vanilla provider and the full provider | ||
$this->expectException(PreConditionNotMetException::class); | ||
$this->languageModelManager->runTask(new TopicsTask('Hello', 'test', null)); | ||
} | ||
|
||
public function testNonexistentTask() { | ||
$this->expectException(NotFoundException::class); | ||
$this->languageModelManager->getTask(98765432456); | ||
} | ||
|
||
public function testTaskFailure() { | ||
// register provider | ||
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestFailingLanguageModelProvider::class); | ||
$this->assertCount(1, $this->languageModelManager->getAvailableTasks()); | ||
$this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes()); | ||
$this->assertTrue($this->languageModelManager->hasProviders()); | ||
|
||
// create task object | ||
$task = new FreePromptTask('Hello', 'test', null); | ||
$this->assertNull($task->getId()); | ||
$this->assertNull($task->getOutput()); | ||
|
||
// schedule works | ||
$this->assertEquals(ILanguageModelTask::STATUS_UNKNOWN, $task->getStatus()); | ||
$this->languageModelManager->scheduleTask($task); | ||
|
||
// Task object is up-to-date | ||
$this->assertNotNull($task->getId()); | ||
$this->assertNull($task->getOutput()); | ||
$this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task->getStatus()); | ||
|
||
// Task object retrieved from db is up-to-date | ||
$task2 = $this->languageModelManager->getTask($task->getId()); | ||
$this->assertEquals($task->getId(), $task2->getId()); | ||
$this->assertEquals('Hello', $task2->getInput()); | ||
$this->assertNull($task2->getOutput()); | ||
$this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task2->getStatus()); | ||
|
||
/** @var IEventDispatcher $eventDispatcher */ | ||
$eventDispatcher = \OC::$server->get(IEventDispatcher::class); | ||
$successfulEventFired = false; | ||
$eventDispatcher->addListener(TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $event) use (&$successfulEventFired, $task) { | ||
$successfulEventFired = true; | ||
$t = $event->getTask(); | ||
$this->assertEquals($task->getId(), $t->getId()); | ||
$this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $t->getStatus()); | ||
$this->assertEquals('Hello Free Prompt', $t->getOutput()); | ||
}); | ||
$failedEventFired = false; | ||
$eventDispatcher->addListener(TaskFailedEvent::class, function (TaskFailedEvent $event) use (&$failedEventFired, $task) { | ||
$failedEventFired = true; | ||
$t = $event->getTask(); | ||
$this->assertEquals($task->getId(), $t->getId()); | ||
$this->assertEquals(ILanguageModelTask::STATUS_FAILED, $t->getStatus()); | ||
$this->assertEquals('ERROR', $event->getErrorMessage()); | ||
}); | ||
|
||
// run background job | ||
/** @var TaskBackgroundJob $bgJob */ | ||
$bgJob = \OC::$server->get(TaskBackgroundJob::class); | ||
$bgJob->setArgument(['taskId' => $task->getId()]); | ||
$bgJob->start(new DummyJobList()); | ||
$provider = \OC::$server->get(TestFailingLanguageModelProvider::class); | ||
$this->assertTrue($provider->ran); | ||
$this->assertTrue($failedEventFired); | ||
$this->assertFalse($successfulEventFired); | ||
|
||
// Task object retrieved from db is up-to-date | ||
$task3 = $this->languageModelManager->getTask($task->getId()); | ||
$this->assertEquals($task->getId(), $task3->getId()); | ||
$this->assertEquals('Hello', $task3->getInput()); | ||
$this->assertNull($task3->getOutput()); | ||
$this->assertEquals(ILanguageModelTask::STATUS_FAILED, $task2->getStatus()); | ||
} | ||
} |