Browse Source

test: Add more tests for legacy pass-through

Signed-off-by: Marcel Klehr <mklehr@gmx.net>
pull/45094/head
Marcel Klehr 2 years ago
parent
commit
bd5dfd0b5f
  1. 5
      lib/private/TaskProcessing/Manager.php
  2. 233
      tests/lib/TaskProcessing/TaskProcessingTest.php

5
lib/private/TaskProcessing/Manager.php

@ -38,6 +38,7 @@ use OCP\Files\GenericFileException;
use OCP\Files\IAppData;
use OCP\Files\IRootFolder;
use OCP\Files\NotPermittedException;
use OCP\Files\SimpleFS\ISimpleFile;
use OCP\IL10N;
use OCP\IServerContainer;
use OCP\L10N\IFactory;
@ -265,7 +266,7 @@ class Manager implements IManager {
$resources = [];
$files = [];
for ($i = 0; $i < $input['numberOfImages']; $i++) {
$file = $folder->newFile( time() . '-' . rand(1, 100000) . '-' . $i);
$file = $folder->newFile(time() . '-' . rand(1, 100000) . '-' . $i);
$files[] = $file;
$resource = $file->write();
if ($resource !== false && $resource !== true && is_resource($resource)) {
@ -282,7 +283,7 @@ class Manager implements IManager {
} catch (\RuntimeException $e) {
throw new ProcessingException($e->getMessage(), 0, $e);
}
return ['images' => array_map(fn (File $file) => $file->getContent(), $files)];
return ['images' => array_map(fn (ISimpleFile $file) => $file->getContent(), $files)];
}
};
$newProviders[$newProvider->getId()] = $newProvider;

233
tests/lib/TaskProcessing/TaskProcessingTest.php

@ -38,7 +38,10 @@ use OCP\TaskProcessing\ISynchronousProvider;
use OCP\TaskProcessing\ITaskType;
use OCP\TaskProcessing\ShapeDescriptor;
use OCP\TaskProcessing\Task;
use OCP\TaskProcessing\TaskTypes\TextToImage;
use OCP\TaskProcessing\TaskTypes\TextToText;
use OCP\TaskProcessing\TaskTypes\TextToTextSummary;
use OCP\TextProcessing\SummaryTaskType;
use PHPUnit\Framework\Constraint\IsInstanceOf;
use Psr\Log\LoggerInterface;
use Test\BackgroundJob\DummyJobList;
@ -204,6 +207,85 @@ class BrokenSyncProvider implements IProvider, ISynchronousProvider {
}
}
class SuccessfulTextProcessingSummaryProvider implements \OCP\TextProcessing\IProvider {
public bool $ran = false;
public function getName(): string {
return 'TEST Vanilla LLM Provider';
}
public function process(string $prompt): string {
$this->ran = true;
return $prompt . ' Summarize';
}
public function getTaskType(): string {
return SummaryTaskType::class;
}
}
class FailingTextProcessingSummaryProvider implements \OCP\TextProcessing\IProvider {
public bool $ran = false;
public function getName(): string {
return 'TEST Vanilla LLM Provider';
}
public function process(string $prompt): string {
$this->ran = true;
throw new \Exception('ERROR');
}
public function getTaskType(): string {
return SummaryTaskType::class;
}
}
class SuccessfulTextToImageProvider implements \OCP\TextToImage\IProvider {
public bool $ran = false;
public function getId(): string {
return 'test:successful';
}
public function getName(): string {
return 'TEST Provider';
}
public function generate(string $prompt, array $resources): void {
$this->ran = true;
foreach($resources as $resource) {
fwrite($resource, 'test');
fclose($resource);
}
}
public function getExpectedRuntime(): int {
return 1;
}
}
class FailingTextToImageProvider implements \OCP\TextToImage\IProvider {
public bool $ran = false;
public function getId(): string {
return 'test:failing';
}
public function getName(): string {
return 'TEST Provider';
}
public function generate(string $prompt, array $resources): void {
$this->ran = true;
throw new \RuntimeException('ERROR');
}
public function getExpectedRuntime(): int {
return 1;
}
}
/**
* @group DB
*/
@ -227,6 +309,10 @@ class TaskProcessingTest extends \Test\TestCase {
BrokenSyncProvider::class => new BrokenSyncProvider(),
AsyncProvider::class => new AsyncProvider(),
AudioToImage::class => new AudioToImage(),
SuccessfulTextProcessingSummaryProvider::class => new SuccessfulTextProcessingSummaryProvider(),
FailingTextProcessingSummaryProvider::class => new FailingTextProcessingSummaryProvider(),
SuccessfulTextToImageProvider::class => new SuccessfulTextToImageProvider(),
FailingTextToImageProvider::class => new FailingTextToImageProvider(),
];
$this->serverContainer = $this->createMock(IServerContainer::class);
@ -257,6 +343,26 @@ class TaskProcessingTest extends \Test\TestCase {
$this->eventDispatcher = $this->createMock(IEventDispatcher::class);
$textProcessingManager = new \OC\TextProcessing\Manager(
$this->serverContainer,
$this->coordinator,
\OC::$server->get(LoggerInterface::class),
$this->jobList,
\OC::$server->get(\OC\TextProcessing\Db\TaskMapper::class),
\OC::$server->get(IConfig::class),
);
$text2imageManager = new \OC\TextToImage\Manager(
$this->serverContainer,
$this->coordinator,
\OC::$server->get(LoggerInterface::class),
$this->jobList,
\OC::$server->get(\OC\TextToImage\Db\TaskMapper::class),
\OC::$server->get(IConfig::class),
\OC::$server->get(IAppDataFactory::class),
);
$this->manager = new Manager(
$this->coordinator,
$this->serverContainer,
@ -266,8 +372,8 @@ class TaskProcessingTest extends \Test\TestCase {
$this->eventDispatcher,
\OC::$server->get(IAppDataFactory::class),
\OC::$server->get(IRootFolder::class),
\OC::$server->get(\OCP\TextProcessing\IManager::class),
\OC::$server->get(\OCP\TextToImage\IManager::class),
$textProcessingManager,
$text2imageManager,
\OC::$server->get(ISpeechToTextManager::class),
);
}
@ -507,4 +613,127 @@ class TaskProcessingTest extends \Test\TestCase {
$this->expectException(NotFoundException::class);
$this->manager->getTask($task->getId());
}
public function testShouldTransparentlyHandleTextProcessingProviders() {
$this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
new ServiceRegistration('test', SuccessfulTextProcessingSummaryProvider::class)
]);
$this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
]);
$taskTypes = $this->manager->getAvailableTaskTypes();
self::assertCount(1, $taskTypes);
self::assertTrue(isset($taskTypes[TextToTextSummary::ID]));
self::assertTrue($this->manager->hasProviders());
$task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null);
$this->manager->scheduleTask($task);
$this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
$backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
\OCP\Server::get(ITimeFactory::class),
$this->manager,
$this->jobList,
\OCP\Server::get(LoggerInterface::class),
);
$backgroundJob->start($this->jobList);
$task = $this->manager->getTask($task->getId());
self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus());
self::assertIsArray($task->getOutput());
self::assertTrue(isset($task->getOutput()['output']));
self::assertEquals('Hello Summarize', $task->getOutput()['output']);
self::assertTrue($this->providers[SuccessfulTextProcessingSummaryProvider::class]->ran);
}
public function testShouldTransparentlyHandleFailingTextProcessingProviders() {
$this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
new ServiceRegistration('test', FailingTextProcessingSummaryProvider::class)
]);
$this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
]);
$taskTypes = $this->manager->getAvailableTaskTypes();
self::assertCount(1, $taskTypes);
self::assertTrue(isset($taskTypes[TextToTextSummary::ID]));
self::assertTrue($this->manager->hasProviders());
$task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null);
$this->manager->scheduleTask($task);
$this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class));
$backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
\OCP\Server::get(ITimeFactory::class),
$this->manager,
$this->jobList,
\OCP\Server::get(LoggerInterface::class),
);
$backgroundJob->start($this->jobList);
$task = $this->manager->getTask($task->getId());
self::assertEquals(Task::STATUS_FAILED, $task->getStatus());
self::assertTrue($task->getOutput() === null);
self::assertEquals('ERROR', $task->getErrorMessage());
self::assertTrue($this->providers[FailingTextProcessingSummaryProvider::class]->ran);
}
public function testShouldTransparentlyHandleText2ImageProviders() {
$this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([
new ServiceRegistration('test', SuccessfulTextToImageProvider::class)
]);
$this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
]);
$taskTypes = $this->manager->getAvailableTaskTypes();
self::assertCount(1, $taskTypes);
self::assertTrue(isset($taskTypes[TextToImage::ID]));
self::assertTrue($this->manager->hasProviders());
$task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null);
$this->manager->scheduleTask($task);
$this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
$backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
\OCP\Server::get(ITimeFactory::class),
$this->manager,
$this->jobList,
\OCP\Server::get(LoggerInterface::class),
);
$backgroundJob->start($this->jobList);
$task = $this->manager->getTask($task->getId());
self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus());
self::assertIsArray($task->getOutput());
self::assertTrue(isset($task->getOutput()['images']));
self::assertIsArray($task->getOutput()['images']);
self::assertCount(3, $task->getOutput()['images']);
self::assertTrue($this->providers[SuccessfulTextToImageProvider::class]->ran);
}
public function testShouldTransparentlyHandleFailingText2ImageProviders() {
$this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([
new ServiceRegistration('test', FailingTextToImageProvider::class)
]);
$this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
]);
$taskTypes = $this->manager->getAvailableTaskTypes();
self::assertCount(1, $taskTypes);
self::assertTrue(isset($taskTypes[TextToImage::ID]));
self::assertTrue($this->manager->hasProviders());
$task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null);
$this->manager->scheduleTask($task);
$this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class));
$backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
\OCP\Server::get(ITimeFactory::class),
$this->manager,
$this->jobList,
\OCP\Server::get(LoggerInterface::class),
);
$backgroundJob->start($this->jobList);
$task = $this->manager->getTask($task->getId());
self::assertEquals(Task::STATUS_FAILED, $task->getStatus());
self::assertTrue($task->getOutput() === null);
self::assertEquals('ERROR', $task->getErrorMessage());
self::assertTrue($this->providers[FailingTextToImageProvider::class]->ran);
}
}
Loading…
Cancel
Save