|
|
|
@ -38,7 +38,10 @@ use OCP\TaskProcessing\ISynchronousProvider; |
|
|
|
use OCP\TaskProcessing\ITaskType; |
|
|
|
use OCP\TaskProcessing\ShapeDescriptor; |
|
|
|
use OCP\TaskProcessing\Task; |
|
|
|
use OCP\TaskProcessing\TaskTypes\TextToImage; |
|
|
|
use OCP\TaskProcessing\TaskTypes\TextToText; |
|
|
|
use OCP\TaskProcessing\TaskTypes\TextToTextSummary; |
|
|
|
use OCP\TextProcessing\SummaryTaskType; |
|
|
|
use PHPUnit\Framework\Constraint\IsInstanceOf; |
|
|
|
use Psr\Log\LoggerInterface; |
|
|
|
use Test\BackgroundJob\DummyJobList; |
|
|
|
@ -204,6 +207,85 @@ class BrokenSyncProvider implements IProvider, ISynchronousProvider { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
class SuccessfulTextProcessingSummaryProvider implements \OCP\TextProcessing\IProvider { |
|
|
|
public bool $ran = false; |
|
|
|
|
|
|
|
public function getName(): string { |
|
|
|
return 'TEST Vanilla LLM Provider'; |
|
|
|
} |
|
|
|
|
|
|
|
public function process(string $prompt): string { |
|
|
|
$this->ran = true; |
|
|
|
return $prompt . ' Summarize'; |
|
|
|
} |
|
|
|
|
|
|
|
public function getTaskType(): string { |
|
|
|
return SummaryTaskType::class; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
class FailingTextProcessingSummaryProvider implements \OCP\TextProcessing\IProvider { |
|
|
|
public bool $ran = false; |
|
|
|
|
|
|
|
public function getName(): string { |
|
|
|
return 'TEST Vanilla LLM Provider'; |
|
|
|
} |
|
|
|
|
|
|
|
public function process(string $prompt): string { |
|
|
|
$this->ran = true; |
|
|
|
throw new \Exception('ERROR'); |
|
|
|
} |
|
|
|
|
|
|
|
public function getTaskType(): string { |
|
|
|
return SummaryTaskType::class; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
class SuccessfulTextToImageProvider implements \OCP\TextToImage\IProvider { |
|
|
|
public bool $ran = false; |
|
|
|
|
|
|
|
public function getId(): string { |
|
|
|
return 'test:successful'; |
|
|
|
} |
|
|
|
|
|
|
|
public function getName(): string { |
|
|
|
return 'TEST Provider'; |
|
|
|
} |
|
|
|
|
|
|
|
public function generate(string $prompt, array $resources): void { |
|
|
|
$this->ran = true; |
|
|
|
foreach($resources as $resource) { |
|
|
|
fwrite($resource, 'test'); |
|
|
|
fclose($resource); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
public function getExpectedRuntime(): int { |
|
|
|
return 1; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
class FailingTextToImageProvider implements \OCP\TextToImage\IProvider { |
|
|
|
public bool $ran = false; |
|
|
|
|
|
|
|
public function getId(): string { |
|
|
|
return 'test:failing'; |
|
|
|
} |
|
|
|
|
|
|
|
public function getName(): string { |
|
|
|
return 'TEST Provider'; |
|
|
|
} |
|
|
|
|
|
|
|
public function generate(string $prompt, array $resources): void { |
|
|
|
$this->ran = true; |
|
|
|
throw new \RuntimeException('ERROR'); |
|
|
|
} |
|
|
|
|
|
|
|
public function getExpectedRuntime(): int { |
|
|
|
return 1; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
/** |
|
|
|
* @group DB |
|
|
|
*/ |
|
|
|
@ -227,6 +309,10 @@ class TaskProcessingTest extends \Test\TestCase { |
|
|
|
BrokenSyncProvider::class => new BrokenSyncProvider(), |
|
|
|
AsyncProvider::class => new AsyncProvider(), |
|
|
|
AudioToImage::class => new AudioToImage(), |
|
|
|
SuccessfulTextProcessingSummaryProvider::class => new SuccessfulTextProcessingSummaryProvider(), |
|
|
|
FailingTextProcessingSummaryProvider::class => new FailingTextProcessingSummaryProvider(), |
|
|
|
SuccessfulTextToImageProvider::class => new SuccessfulTextToImageProvider(), |
|
|
|
FailingTextToImageProvider::class => new FailingTextToImageProvider(), |
|
|
|
]; |
|
|
|
|
|
|
|
$this->serverContainer = $this->createMock(IServerContainer::class); |
|
|
|
@ -257,6 +343,26 @@ class TaskProcessingTest extends \Test\TestCase { |
|
|
|
|
|
|
|
$this->eventDispatcher = $this->createMock(IEventDispatcher::class); |
|
|
|
|
|
|
|
$textProcessingManager = new \OC\TextProcessing\Manager( |
|
|
|
$this->serverContainer, |
|
|
|
$this->coordinator, |
|
|
|
\OC::$server->get(LoggerInterface::class), |
|
|
|
$this->jobList, |
|
|
|
\OC::$server->get(\OC\TextProcessing\Db\TaskMapper::class), |
|
|
|
\OC::$server->get(IConfig::class), |
|
|
|
); |
|
|
|
|
|
|
|
$text2imageManager = new \OC\TextToImage\Manager( |
|
|
|
$this->serverContainer, |
|
|
|
$this->coordinator, |
|
|
|
\OC::$server->get(LoggerInterface::class), |
|
|
|
$this->jobList, |
|
|
|
\OC::$server->get(\OC\TextToImage\Db\TaskMapper::class), |
|
|
|
\OC::$server->get(IConfig::class), |
|
|
|
\OC::$server->get(IAppDataFactory::class), |
|
|
|
); |
|
|
|
|
|
|
|
|
|
|
|
$this->manager = new Manager( |
|
|
|
$this->coordinator, |
|
|
|
$this->serverContainer, |
|
|
|
@ -266,8 +372,8 @@ class TaskProcessingTest extends \Test\TestCase { |
|
|
|
$this->eventDispatcher, |
|
|
|
\OC::$server->get(IAppDataFactory::class), |
|
|
|
\OC::$server->get(IRootFolder::class), |
|
|
|
\OC::$server->get(\OCP\TextProcessing\IManager::class), |
|
|
|
\OC::$server->get(\OCP\TextToImage\IManager::class), |
|
|
|
$textProcessingManager, |
|
|
|
$text2imageManager, |
|
|
|
\OC::$server->get(ISpeechToTextManager::class), |
|
|
|
); |
|
|
|
} |
|
|
|
@ -507,4 +613,127 @@ class TaskProcessingTest extends \Test\TestCase { |
|
|
|
$this->expectException(NotFoundException::class); |
|
|
|
$this->manager->getTask($task->getId()); |
|
|
|
} |
|
|
|
|
|
|
|
public function testShouldTransparentlyHandleTextProcessingProviders() { |
|
|
|
$this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ |
|
|
|
new ServiceRegistration('test', SuccessfulTextProcessingSummaryProvider::class) |
|
|
|
]); |
|
|
|
$this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
|
|
]); |
|
|
|
$taskTypes = $this->manager->getAvailableTaskTypes(); |
|
|
|
self::assertCount(1, $taskTypes); |
|
|
|
self::assertTrue(isset($taskTypes[TextToTextSummary::ID])); |
|
|
|
self::assertTrue($this->manager->hasProviders()); |
|
|
|
$task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null); |
|
|
|
$this->manager->scheduleTask($task); |
|
|
|
|
|
|
|
$this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class)); |
|
|
|
|
|
|
|
$backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob( |
|
|
|
\OCP\Server::get(ITimeFactory::class), |
|
|
|
$this->manager, |
|
|
|
$this->jobList, |
|
|
|
\OCP\Server::get(LoggerInterface::class), |
|
|
|
); |
|
|
|
$backgroundJob->start($this->jobList); |
|
|
|
|
|
|
|
$task = $this->manager->getTask($task->getId()); |
|
|
|
self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus()); |
|
|
|
self::assertIsArray($task->getOutput()); |
|
|
|
self::assertTrue(isset($task->getOutput()['output'])); |
|
|
|
self::assertEquals('Hello Summarize', $task->getOutput()['output']); |
|
|
|
self::assertTrue($this->providers[SuccessfulTextProcessingSummaryProvider::class]->ran); |
|
|
|
} |
|
|
|
|
|
|
|
public function testShouldTransparentlyHandleFailingTextProcessingProviders() { |
|
|
|
$this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ |
|
|
|
new ServiceRegistration('test', FailingTextProcessingSummaryProvider::class) |
|
|
|
]); |
|
|
|
$this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
|
|
]); |
|
|
|
$taskTypes = $this->manager->getAvailableTaskTypes(); |
|
|
|
self::assertCount(1, $taskTypes); |
|
|
|
self::assertTrue(isset($taskTypes[TextToTextSummary::ID])); |
|
|
|
self::assertTrue($this->manager->hasProviders()); |
|
|
|
$task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null); |
|
|
|
$this->manager->scheduleTask($task); |
|
|
|
|
|
|
|
$this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class)); |
|
|
|
|
|
|
|
$backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob( |
|
|
|
\OCP\Server::get(ITimeFactory::class), |
|
|
|
$this->manager, |
|
|
|
$this->jobList, |
|
|
|
\OCP\Server::get(LoggerInterface::class), |
|
|
|
); |
|
|
|
$backgroundJob->start($this->jobList); |
|
|
|
|
|
|
|
$task = $this->manager->getTask($task->getId()); |
|
|
|
self::assertEquals(Task::STATUS_FAILED, $task->getStatus()); |
|
|
|
self::assertTrue($task->getOutput() === null); |
|
|
|
self::assertEquals('ERROR', $task->getErrorMessage()); |
|
|
|
self::assertTrue($this->providers[FailingTextProcessingSummaryProvider::class]->ran); |
|
|
|
} |
|
|
|
|
|
|
|
public function testShouldTransparentlyHandleText2ImageProviders() { |
|
|
|
$this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([ |
|
|
|
new ServiceRegistration('test', SuccessfulTextToImageProvider::class) |
|
|
|
]); |
|
|
|
$this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
|
|
]); |
|
|
|
$taskTypes = $this->manager->getAvailableTaskTypes(); |
|
|
|
self::assertCount(1, $taskTypes); |
|
|
|
self::assertTrue(isset($taskTypes[TextToImage::ID])); |
|
|
|
self::assertTrue($this->manager->hasProviders()); |
|
|
|
$task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null); |
|
|
|
$this->manager->scheduleTask($task); |
|
|
|
|
|
|
|
$this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class)); |
|
|
|
|
|
|
|
$backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob( |
|
|
|
\OCP\Server::get(ITimeFactory::class), |
|
|
|
$this->manager, |
|
|
|
$this->jobList, |
|
|
|
\OCP\Server::get(LoggerInterface::class), |
|
|
|
); |
|
|
|
$backgroundJob->start($this->jobList); |
|
|
|
|
|
|
|
$task = $this->manager->getTask($task->getId()); |
|
|
|
self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus()); |
|
|
|
self::assertIsArray($task->getOutput()); |
|
|
|
self::assertTrue(isset($task->getOutput()['images'])); |
|
|
|
self::assertIsArray($task->getOutput()['images']); |
|
|
|
self::assertCount(3, $task->getOutput()['images']); |
|
|
|
self::assertTrue($this->providers[SuccessfulTextToImageProvider::class]->ran); |
|
|
|
} |
|
|
|
|
|
|
|
public function testShouldTransparentlyHandleFailingText2ImageProviders() { |
|
|
|
$this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([ |
|
|
|
new ServiceRegistration('test', FailingTextToImageProvider::class) |
|
|
|
]); |
|
|
|
$this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
|
|
]); |
|
|
|
$taskTypes = $this->manager->getAvailableTaskTypes(); |
|
|
|
self::assertCount(1, $taskTypes); |
|
|
|
self::assertTrue(isset($taskTypes[TextToImage::ID])); |
|
|
|
self::assertTrue($this->manager->hasProviders()); |
|
|
|
$task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null); |
|
|
|
$this->manager->scheduleTask($task); |
|
|
|
|
|
|
|
$this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class)); |
|
|
|
|
|
|
|
$backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob( |
|
|
|
\OCP\Server::get(ITimeFactory::class), |
|
|
|
$this->manager, |
|
|
|
$this->jobList, |
|
|
|
\OCP\Server::get(LoggerInterface::class), |
|
|
|
); |
|
|
|
$backgroundJob->start($this->jobList); |
|
|
|
|
|
|
|
$task = $this->manager->getTask($task->getId()); |
|
|
|
self::assertEquals(Task::STATUS_FAILED, $task->getStatus()); |
|
|
|
self::assertTrue($task->getOutput() === null); |
|
|
|
self::assertEquals('ERROR', $task->getErrorMessage()); |
|
|
|
self::assertTrue($this->providers[FailingTextToImageProvider::class]->ran); |
|
|
|
} |
|
|
|
} |