|
|
|
@ -9,10 +9,15 @@ |
|
|
|
namespace Test\LanguageModel; |
|
|
|
|
|
|
|
use OC\AppFramework\Bootstrap\Coordinator; |
|
|
|
use OC\AppFramework\Bootstrap\RegistrationContext; |
|
|
|
use OC\AppFramework\Bootstrap\ServiceRegistration; |
|
|
|
use OC\EventDispatcher\EventDispatcher; |
|
|
|
use OC\LanguageModel\Db\Task; |
|
|
|
use OC\LanguageModel\Db\TaskMapper; |
|
|
|
use OC\LanguageModel\LanguageModelManager; |
|
|
|
use OC\LanguageModel\TaskBackgroundJob; |
|
|
|
use OCP\BackgroundJob\IJobList; |
|
|
|
use OCP\AppFramework\Db\DoesNotExistException; |
|
|
|
use OCP\AppFramework\Utility\ITimeFactory; |
|
|
|
use OCP\Common\Exception\NotFoundException; |
|
|
|
use OCP\EventDispatcher\IEventDispatcher; |
|
|
|
use OCP\IServerContainer; |
|
|
|
@ -82,16 +87,69 @@ class LanguageModelManagerTest extends \Test\TestCase { |
|
|
|
protected function setUp(): void { |
|
|
|
parent::setUp(); |
|
|
|
|
|
|
|
$this->providers = [ |
|
|
|
TestVanillaLanguageModelProvider::class => new TestVanillaLanguageModelProvider(), |
|
|
|
TestFullLanguageModelProvider::class => new TestFullLanguageModelProvider(), |
|
|
|
TestFailingLanguageModelProvider::class => new TestFailingLanguageModelProvider(), |
|
|
|
]; |
|
|
|
|
|
|
|
$this->serverContainer = $this->createMock(IServerContainer::class); |
|
|
|
$this->serverContainer->expects($this->any())->method('get')->willReturnCallback(function ($class) { |
|
|
|
return $this->providers[$class]; |
|
|
|
}); |
|
|
|
|
|
|
|
$this->eventDispatcher = new EventDispatcher( |
|
|
|
new \Symfony\Component\EventDispatcher\EventDispatcher(), |
|
|
|
$this->serverContainer, |
|
|
|
\OC::$server->get(LoggerInterface::class), |
|
|
|
); |
|
|
|
|
|
|
|
$this->registrationContext = $this->createMock(RegistrationContext::class); |
|
|
|
$this->coordinator = $this->createMock(Coordinator::class); |
|
|
|
$this->coordinator->expects($this->any())->method('getRegistrationContext')->willReturn($this->registrationContext); |
|
|
|
|
|
|
|
$this->taskMapper = $this->createMock(TaskMapper::class); |
|
|
|
$this->tasksDb = []; |
|
|
|
$this->taskMapper |
|
|
|
->expects($this->any()) |
|
|
|
->method('insert') |
|
|
|
->willReturnCallback(function (Task $task) { |
|
|
|
$task->setId(count($this->tasksDb) ? max(array_keys($this->tasksDb)) : 1); |
|
|
|
$this->tasksDb[$task->getId()] = $task->toRow(); |
|
|
|
return $task; |
|
|
|
}); |
|
|
|
$this->taskMapper |
|
|
|
->expects($this->any()) |
|
|
|
->method('update') |
|
|
|
->willReturnCallback(function (Task $task) { |
|
|
|
$this->tasksDb[$task->getId()] = $task->toRow(); |
|
|
|
return $task; |
|
|
|
}); |
|
|
|
$this->taskMapper |
|
|
|
->expects($this->any()) |
|
|
|
->method('find') |
|
|
|
->willReturnCallback(function (int $id) { |
|
|
|
if (!isset($this->tasksDb[$id])) { |
|
|
|
throw new DoesNotExistException('Could not find it'); |
|
|
|
} |
|
|
|
return Task::fromRow($this->tasksDb[$id]); |
|
|
|
}); |
|
|
|
|
|
|
|
$this->jobList = $this->createPartialMock(DummyJobList::class, ['add']); |
|
|
|
$this->jobList->expects($this->any())->method('add')->willReturnCallback(function () { |
|
|
|
}); |
|
|
|
|
|
|
|
$this->languageModelManager = new LanguageModelManager( |
|
|
|
\OC::$server->get(IServerContainer::class), |
|
|
|
$this->coordinator = \OC::$server->get(Coordinator::class), |
|
|
|
$this->serverContainer, |
|
|
|
$this->coordinator, |
|
|
|
\OC::$server->get(LoggerInterface::class), |
|
|
|
\OC::$server->get(IJobList::class), |
|
|
|
\OC::$server->get(TaskMapper::class), |
|
|
|
$this->jobList, |
|
|
|
$this->taskMapper, |
|
|
|
); |
|
|
|
} |
|
|
|
|
|
|
|
public function testShouldNotHaveAnyProviders() { |
|
|
|
$this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([]); |
|
|
|
$this->assertCount(0, $this->languageModelManager->getAvailableTasks()); |
|
|
|
$this->assertCount(0, $this->languageModelManager->getAvailableTaskTypes()); |
|
|
|
$this->assertFalse($this->languageModelManager->hasProviders()); |
|
|
|
@ -100,7 +158,9 @@ class LanguageModelManagerTest extends \Test\TestCase { |
|
|
|
} |
|
|
|
|
|
|
|
public function testProviderShouldBeRegisteredAndRun() { |
|
|
|
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestVanillaLanguageModelProvider::class); |
|
|
|
$this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([ |
|
|
|
new ServiceRegistration('test', TestVanillaLanguageModelProvider::class) |
|
|
|
]); |
|
|
|
$this->assertCount(1, $this->languageModelManager->getAvailableTasks()); |
|
|
|
$this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes()); |
|
|
|
$this->assertTrue($this->languageModelManager->hasProviders()); |
|
|
|
@ -113,7 +173,9 @@ class LanguageModelManagerTest extends \Test\TestCase { |
|
|
|
|
|
|
|
public function testProviderShouldBeRegisteredAndScheduled() { |
|
|
|
// register provider
|
|
|
|
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestVanillaLanguageModelProvider::class); |
|
|
|
$this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([ |
|
|
|
new ServiceRegistration('test', TestVanillaLanguageModelProvider::class) |
|
|
|
]); |
|
|
|
$this->assertCount(1, $this->languageModelManager->getAvailableTasks()); |
|
|
|
$this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes()); |
|
|
|
$this->assertTrue($this->languageModelManager->hasProviders()); |
|
|
|
@ -139,10 +201,10 @@ class LanguageModelManagerTest extends \Test\TestCase { |
|
|
|
$this->assertNull($task2->getOutput()); |
|
|
|
$this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task2->getStatus()); |
|
|
|
|
|
|
|
/** @var IEventDispatcher $eventDispatcher */ |
|
|
|
$eventDispatcher = \OC::$server->get(IEventDispatcher::class); |
|
|
|
/** @var IEventDispatcher $this->eventDispatcher */ |
|
|
|
$this->eventDispatcher = \OC::$server->get(IEventDispatcher::class); |
|
|
|
$successfulEventFired = false; |
|
|
|
$eventDispatcher->addListener(TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $event) use (&$successfulEventFired, $task) { |
|
|
|
$this->eventDispatcher->addListener(TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $event) use (&$successfulEventFired, $task) { |
|
|
|
$successfulEventFired = true; |
|
|
|
$t = $event->getTask(); |
|
|
|
$this->assertEquals($task->getId(), $t->getId()); |
|
|
|
@ -150,7 +212,7 @@ class LanguageModelManagerTest extends \Test\TestCase { |
|
|
|
$this->assertEquals('Hello Free Prompt', $t->getOutput()); |
|
|
|
}); |
|
|
|
$failedEventFired = false; |
|
|
|
$eventDispatcher->addListener(TaskFailedEvent::class, function (TaskFailedEvent $event) use (&$failedEventFired, $task) { |
|
|
|
$this->eventDispatcher->addListener(TaskFailedEvent::class, function (TaskFailedEvent $event) use (&$failedEventFired, $task) { |
|
|
|
$failedEventFired = true; |
|
|
|
$t = $event->getTask(); |
|
|
|
$this->assertEquals($task->getId(), $t->getId()); |
|
|
|
@ -159,11 +221,14 @@ class LanguageModelManagerTest extends \Test\TestCase { |
|
|
|
}); |
|
|
|
|
|
|
|
// run background job
|
|
|
|
/** @var TaskBackgroundJob $bgJob */ |
|
|
|
$bgJob = \OC::$server->get(TaskBackgroundJob::class); |
|
|
|
$bgJob = new TaskBackgroundJob( |
|
|
|
\OC::$server->get(ITimeFactory::class), |
|
|
|
$this->languageModelManager, |
|
|
|
$this->eventDispatcher, |
|
|
|
); |
|
|
|
$bgJob->setArgument(['taskId' => $task->getId()]); |
|
|
|
$bgJob->start(new DummyJobList()); |
|
|
|
$provider = \OC::$server->get(TestVanillaLanguageModelProvider::class); |
|
|
|
$bgJob->start($this->jobList); |
|
|
|
$provider = $this->providers[TestVanillaLanguageModelProvider::class]; |
|
|
|
$this->assertTrue($provider->ran); |
|
|
|
$this->assertTrue($successfulEventFired); |
|
|
|
$this->assertFalse($failedEventFired); |
|
|
|
@ -173,12 +238,14 @@ class LanguageModelManagerTest extends \Test\TestCase { |
|
|
|
$this->assertEquals($task->getId(), $task3->getId()); |
|
|
|
$this->assertEquals('Hello', $task3->getInput()); |
|
|
|
$this->assertEquals('Hello Free Prompt', $task3->getOutput()); |
|
|
|
$this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $task2->getStatus()); |
|
|
|
$this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $task3->getStatus()); |
|
|
|
} |
|
|
|
|
|
|
|
public function testMultipleProvidersShouldBeRegisteredAndRunCorrectly() { |
|
|
|
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestVanillaLanguageModelProvider::class); |
|
|
|
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestFullLanguageModelProvider::class); |
|
|
|
$this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([ |
|
|
|
new ServiceRegistration('test', TestVanillaLanguageModelProvider::class), |
|
|
|
new ServiceRegistration('test', TestFullLanguageModelProvider::class), |
|
|
|
]); |
|
|
|
$this->assertCount(3, $this->languageModelManager->getAvailableTasks()); |
|
|
|
$this->assertCount(3, $this->languageModelManager->getAvailableTaskTypes()); |
|
|
|
$this->assertTrue($this->languageModelManager->hasProviders()); |
|
|
|
@ -204,7 +271,9 @@ class LanguageModelManagerTest extends \Test\TestCase { |
|
|
|
|
|
|
|
public function testTaskFailure() { |
|
|
|
// register provider
|
|
|
|
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestFailingLanguageModelProvider::class); |
|
|
|
$this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([ |
|
|
|
new ServiceRegistration('test', TestFailingLanguageModelProvider::class), |
|
|
|
]); |
|
|
|
$this->assertCount(1, $this->languageModelManager->getAvailableTasks()); |
|
|
|
$this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes()); |
|
|
|
$this->assertTrue($this->languageModelManager->hasProviders()); |
|
|
|
@ -230,10 +299,8 @@ class LanguageModelManagerTest extends \Test\TestCase { |
|
|
|
$this->assertNull($task2->getOutput()); |
|
|
|
$this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task2->getStatus()); |
|
|
|
|
|
|
|
/** @var IEventDispatcher $eventDispatcher */ |
|
|
|
$eventDispatcher = \OC::$server->get(IEventDispatcher::class); |
|
|
|
$successfulEventFired = false; |
|
|
|
$eventDispatcher->addListener(TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $event) use (&$successfulEventFired, $task) { |
|
|
|
$this->eventDispatcher->addListener(TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $event) use (&$successfulEventFired, $task) { |
|
|
|
$successfulEventFired = true; |
|
|
|
$t = $event->getTask(); |
|
|
|
$this->assertEquals($task->getId(), $t->getId()); |
|
|
|
@ -241,7 +308,7 @@ class LanguageModelManagerTest extends \Test\TestCase { |
|
|
|
$this->assertEquals('Hello Free Prompt', $t->getOutput()); |
|
|
|
}); |
|
|
|
$failedEventFired = false; |
|
|
|
$eventDispatcher->addListener(TaskFailedEvent::class, function (TaskFailedEvent $event) use (&$failedEventFired, $task) { |
|
|
|
$this->eventDispatcher->addListener(TaskFailedEvent::class, function (TaskFailedEvent $event) use (&$failedEventFired, $task) { |
|
|
|
$failedEventFired = true; |
|
|
|
$t = $event->getTask(); |
|
|
|
$this->assertEquals($task->getId(), $t->getId()); |
|
|
|
@ -250,11 +317,14 @@ class LanguageModelManagerTest extends \Test\TestCase { |
|
|
|
}); |
|
|
|
|
|
|
|
// run background job
|
|
|
|
/** @var TaskBackgroundJob $bgJob */ |
|
|
|
$bgJob = \OC::$server->get(TaskBackgroundJob::class); |
|
|
|
$bgJob = new TaskBackgroundJob( |
|
|
|
\OC::$server->get(ITimeFactory::class), |
|
|
|
$this->languageModelManager, |
|
|
|
$this->eventDispatcher, |
|
|
|
); |
|
|
|
$bgJob->setArgument(['taskId' => $task->getId()]); |
|
|
|
$bgJob->start(new DummyJobList()); |
|
|
|
$provider = \OC::$server->get(TestFailingLanguageModelProvider::class); |
|
|
|
$bgJob->start($this->jobList); |
|
|
|
$provider = $this->providers[TestFailingLanguageModelProvider::class]; |
|
|
|
$this->assertTrue($provider->ran); |
|
|
|
$this->assertTrue($failedEventFired); |
|
|
|
$this->assertFalse($successfulEventFired); |
|
|
|
@ -264,6 +334,6 @@ class LanguageModelManagerTest extends \Test\TestCase { |
|
|
|
$this->assertEquals($task->getId(), $task3->getId()); |
|
|
|
$this->assertEquals('Hello', $task3->getInput()); |
|
|
|
$this->assertNull($task3->getOutput()); |
|
|
|
$this->assertEquals(ILanguageModelTask::STATUS_FAILED, $task2->getStatus()); |
|
|
|
$this->assertEquals(ILanguageModelTask::STATUS_FAILED, $task3->getStatus()); |
|
|
|
} |
|
|
|
} |