1 changed files with 269 additions and 0 deletions
@ -0,0 +1,269 @@ |
|||
<?php |
|||
/** |
|||
* Copyright (c) 2023 Marcel Klehr <mklehr@gmx.net> |
|||
* This file is licensed under the Affero General Public License version 3 or |
|||
* later. |
|||
* See the COPYING-README file. |
|||
*/ |
|||
|
|||
namespace Test\LanguageModel; |
|||
|
|||
use OC\AppFramework\Bootstrap\Coordinator; |
|||
use OC\LanguageModel\Db\TaskMapper; |
|||
use OC\LanguageModel\LanguageModelManager; |
|||
use OC\LanguageModel\TaskBackgroundJob; |
|||
use OCP\BackgroundJob\IJobList; |
|||
use OCP\Common\Exception\NotFoundException; |
|||
use OCP\EventDispatcher\IEventDispatcher; |
|||
use OCP\IServerContainer; |
|||
use OCP\LanguageModel\Events\TaskFailedEvent; |
|||
use OCP\LanguageModel\Events\TaskSuccessfulEvent; |
|||
use OCP\LanguageModel\FreePromptTask; |
|||
use OCP\LanguageModel\HeadlineTask; |
|||
use OCP\LanguageModel\IHeadlineProvider; |
|||
use OCP\LanguageModel\ILanguageModelManager; |
|||
use OCP\LanguageModel\ILanguageModelProvider; |
|||
use OCP\LanguageModel\ILanguageModelTask; |
|||
use OCP\LanguageModel\ISummaryProvider; |
|||
use OCP\LanguageModel\SummaryTask; |
|||
use OCP\LanguageModel\TopicsTask; |
|||
use OCP\PreConditionNotMetException; |
|||
use Psr\Log\LoggerInterface; |
|||
use Test\BackgroundJob\DummyJobList; |
|||
|
|||
class TestVanillaLanguageModelProvider implements ILanguageModelProvider { |
|||
public bool $ran = false; |
|||
|
|||
public function getName(): string { |
|||
return 'TEST Vanilla LLM Provider'; |
|||
} |
|||
|
|||
public function prompt(string $prompt): string { |
|||
$this->ran = true; |
|||
return $prompt . ' Free Prompt'; |
|||
} |
|||
} |
|||
|
|||
class TestFailingLanguageModelProvider implements ILanguageModelProvider { |
|||
public bool $ran = false; |
|||
|
|||
public function getName(): string { |
|||
return 'TEST Vanilla LLM Provider'; |
|||
} |
|||
|
|||
public function prompt(string $prompt): string { |
|||
$this->ran = true; |
|||
throw new \Exception('ERROR'); |
|||
} |
|||
} |
|||
|
|||
class TestFullLanguageModelProvider implements ILanguageModelProvider, ISummaryProvider, IHeadlineProvider { |
|||
public function getName(): string { |
|||
return 'TEST Full LLM Provider'; |
|||
} |
|||
|
|||
public function prompt(string $prompt): string { |
|||
return $prompt . ' Free Prompt'; |
|||
} |
|||
|
|||
public function findHeadline(string $text): string { |
|||
return $text . ' Headline'; |
|||
} |
|||
|
|||
public function summarize(string $text): string { |
|||
return $text. ' Summarize'; |
|||
} |
|||
} |
|||
|
|||
class LanguageModelManagerTest extends \Test\TestCase { |
|||
private ILanguageModelManager $languageModelManager; |
|||
private Coordinator $coordinator; |
|||
|
|||
protected function setUp(): void { |
|||
parent::setUp(); |
|||
|
|||
$this->languageModelManager = new LanguageModelManager( |
|||
\OC::$server->get(IServerContainer::class), |
|||
$this->coordinator = \OC::$server->get(Coordinator::class), |
|||
\OC::$server->get(LoggerInterface::class), |
|||
\OC::$server->get(IJobList::class), |
|||
\OC::$server->get(TaskMapper::class), |
|||
); |
|||
} |
|||
|
|||
public function testShouldNotHaveAnyProviders() { |
|||
$this->assertCount(0, $this->languageModelManager->getAvailableTasks()); |
|||
$this->assertCount(0, $this->languageModelManager->getAvailableTaskTypes()); |
|||
$this->assertFalse($this->languageModelManager->hasProviders()); |
|||
$this->expectException(PreConditionNotMetException::class); |
|||
$this->languageModelManager->runTask(new FreePromptTask('Hello', 'test', null)); |
|||
} |
|||
|
|||
public function testProviderShouldBeRegisteredAndRun() { |
|||
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestVanillaLanguageModelProvider::class); |
|||
$this->assertCount(1, $this->languageModelManager->getAvailableTasks()); |
|||
$this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes()); |
|||
$this->assertTrue($this->languageModelManager->hasProviders()); |
|||
$this->assertEquals('Hello Free Prompt', $this->languageModelManager->runTask(new FreePromptTask('Hello', 'test', null))); |
|||
|
|||
// Summaries are not implemented by the vanilla provider, only free prompt
|
|||
$this->expectException(PreConditionNotMetException::class); |
|||
$this->languageModelManager->runTask(new SummaryTask('Hello', 'test', null)); |
|||
} |
|||
|
|||
public function testProviderShouldBeRegisteredAndScheduled() { |
|||
// register provider
|
|||
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestVanillaLanguageModelProvider::class); |
|||
$this->assertCount(1, $this->languageModelManager->getAvailableTasks()); |
|||
$this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes()); |
|||
$this->assertTrue($this->languageModelManager->hasProviders()); |
|||
|
|||
// create task object
|
|||
$task = new FreePromptTask('Hello', 'test', null); |
|||
$this->assertNull($task->getId()); |
|||
$this->assertNull($task->getOutput()); |
|||
|
|||
// schedule works
|
|||
$this->assertEquals(ILanguageModelTask::STATUS_UNKNOWN, $task->getStatus()); |
|||
$this->languageModelManager->scheduleTask($task); |
|||
|
|||
// Task object is up-to-date
|
|||
$this->assertNotNull($task->getId()); |
|||
$this->assertNull($task->getOutput()); |
|||
$this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task->getStatus()); |
|||
|
|||
// Task object retrieved from db is up-to-date
|
|||
$task2 = $this->languageModelManager->getTask($task->getId()); |
|||
$this->assertEquals($task->getId(), $task2->getId()); |
|||
$this->assertEquals('Hello', $task2->getInput()); |
|||
$this->assertNull($task2->getOutput()); |
|||
$this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task2->getStatus()); |
|||
|
|||
/** @var IEventDispatcher $eventDispatcher */ |
|||
$eventDispatcher = \OC::$server->get(IEventDispatcher::class); |
|||
$successfulEventFired = false; |
|||
$eventDispatcher->addListener(TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $event) use (&$successfulEventFired, $task) { |
|||
$successfulEventFired = true; |
|||
$t = $event->getTask(); |
|||
$this->assertEquals($task->getId(), $t->getId()); |
|||
$this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $t->getStatus()); |
|||
$this->assertEquals('Hello Free Prompt', $t->getOutput()); |
|||
}); |
|||
$failedEventFired = false; |
|||
$eventDispatcher->addListener(TaskFailedEvent::class, function (TaskFailedEvent $event) use (&$failedEventFired, $task) { |
|||
$failedEventFired = true; |
|||
$t = $event->getTask(); |
|||
$this->assertEquals($task->getId(), $t->getId()); |
|||
$this->assertEquals(ILanguageModelTask::STATUS_FAILED, $t->getStatus()); |
|||
$this->assertEquals('ERROR', $event->getErrorMessage()); |
|||
}); |
|||
|
|||
// run background job
|
|||
/** @var TaskBackgroundJob $bgJob */ |
|||
$bgJob = \OC::$server->get(TaskBackgroundJob::class); |
|||
$bgJob->setArgument(['taskId' => $task->getId()]); |
|||
$bgJob->start(new DummyJobList()); |
|||
$provider = \OC::$server->get(TestVanillaLanguageModelProvider::class); |
|||
$this->assertTrue($provider->ran); |
|||
$this->assertTrue($successfulEventFired); |
|||
$this->assertFalse($failedEventFired); |
|||
|
|||
// Task object retrieved from db is up-to-date
|
|||
$task3 = $this->languageModelManager->getTask($task->getId()); |
|||
$this->assertEquals($task->getId(), $task3->getId()); |
|||
$this->assertEquals('Hello', $task3->getInput()); |
|||
$this->assertEquals('Hello Free Prompt', $task3->getOutput()); |
|||
$this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $task2->getStatus()); |
|||
} |
|||
|
|||
public function testMultipleProvidersShouldBeRegisteredAndRunCorrectly() { |
|||
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestVanillaLanguageModelProvider::class); |
|||
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestFullLanguageModelProvider::class); |
|||
$this->assertCount(3, $this->languageModelManager->getAvailableTasks()); |
|||
$this->assertCount(3, $this->languageModelManager->getAvailableTaskTypes()); |
|||
$this->assertTrue($this->languageModelManager->hasProviders()); |
|||
|
|||
// Try free prompt again
|
|||
$this->assertEquals('Hello Free Prompt', $this->languageModelManager->runTask(new FreePromptTask('Hello', 'test', null))); |
|||
|
|||
// Try headline task
|
|||
$this->assertEquals('Hello Headline', $this->languageModelManager->runTask(new HeadlineTask('Hello', 'test', null))); |
|||
|
|||
// Try summary task
|
|||
$this->assertEquals('Hello Summarize', $this->languageModelManager->runTask(new SummaryTask('Hello', 'test', null))); |
|||
|
|||
// Topics are not implemented by both the vanilla provider and the full provider
|
|||
$this->expectException(PreConditionNotMetException::class); |
|||
$this->languageModelManager->runTask(new TopicsTask('Hello', 'test', null)); |
|||
} |
|||
|
|||
public function testNonexistentTask() { |
|||
$this->expectException(NotFoundException::class); |
|||
$this->languageModelManager->getTask(98765432456); |
|||
} |
|||
|
|||
public function testTaskFailure() { |
|||
// register provider
|
|||
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestFailingLanguageModelProvider::class); |
|||
$this->assertCount(1, $this->languageModelManager->getAvailableTasks()); |
|||
$this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes()); |
|||
$this->assertTrue($this->languageModelManager->hasProviders()); |
|||
|
|||
// create task object
|
|||
$task = new FreePromptTask('Hello', 'test', null); |
|||
$this->assertNull($task->getId()); |
|||
$this->assertNull($task->getOutput()); |
|||
|
|||
// schedule works
|
|||
$this->assertEquals(ILanguageModelTask::STATUS_UNKNOWN, $task->getStatus()); |
|||
$this->languageModelManager->scheduleTask($task); |
|||
|
|||
// Task object is up-to-date
|
|||
$this->assertNotNull($task->getId()); |
|||
$this->assertNull($task->getOutput()); |
|||
$this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task->getStatus()); |
|||
|
|||
// Task object retrieved from db is up-to-date
|
|||
$task2 = $this->languageModelManager->getTask($task->getId()); |
|||
$this->assertEquals($task->getId(), $task2->getId()); |
|||
$this->assertEquals('Hello', $task2->getInput()); |
|||
$this->assertNull($task2->getOutput()); |
|||
$this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task2->getStatus()); |
|||
|
|||
/** @var IEventDispatcher $eventDispatcher */ |
|||
$eventDispatcher = \OC::$server->get(IEventDispatcher::class); |
|||
$successfulEventFired = false; |
|||
$eventDispatcher->addListener(TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $event) use (&$successfulEventFired, $task) { |
|||
$successfulEventFired = true; |
|||
$t = $event->getTask(); |
|||
$this->assertEquals($task->getId(), $t->getId()); |
|||
$this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $t->getStatus()); |
|||
$this->assertEquals('Hello Free Prompt', $t->getOutput()); |
|||
}); |
|||
$failedEventFired = false; |
|||
$eventDispatcher->addListener(TaskFailedEvent::class, function (TaskFailedEvent $event) use (&$failedEventFired, $task) { |
|||
$failedEventFired = true; |
|||
$t = $event->getTask(); |
|||
$this->assertEquals($task->getId(), $t->getId()); |
|||
$this->assertEquals(ILanguageModelTask::STATUS_FAILED, $t->getStatus()); |
|||
$this->assertEquals('ERROR', $event->getErrorMessage()); |
|||
}); |
|||
|
|||
// run background job
|
|||
/** @var TaskBackgroundJob $bgJob */ |
|||
$bgJob = \OC::$server->get(TaskBackgroundJob::class); |
|||
$bgJob->setArgument(['taskId' => $task->getId()]); |
|||
$bgJob->start(new DummyJobList()); |
|||
$provider = \OC::$server->get(TestFailingLanguageModelProvider::class); |
|||
$this->assertTrue($provider->ran); |
|||
$this->assertTrue($failedEventFired); |
|||
$this->assertFalse($successfulEventFired); |
|||
|
|||
// Task object retrieved from db is up-to-date
|
|||
$task3 = $this->languageModelManager->getTask($task->getId()); |
|||
$this->assertEquals($task->getId(), $task3->getId()); |
|||
$this->assertEquals('Hello', $task3->getInput()); |
|||
$this->assertNull($task3->getOutput()); |
|||
$this->assertEquals(ILanguageModelTask::STATUS_FAILED, $task2->getStatus()); |
|||
} |
|||
} |
|||
Write
Preview
Loading…
Cancel
Save
Reference in new issue