You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

739 lines
26 KiB

  1. <?php
  2. /**
  3. * Copyright (c) 2024 Marcel Klehr <mklehr@gmx.net>
  4. * This file is licensed under the Affero General Public License version 3 or
  5. * later.
  6. * See the COPYING-README file.
  7. */
  8. namespace Test\TextProcessing;
  9. use OC\AppFramework\Bootstrap\Coordinator;
  10. use OC\AppFramework\Bootstrap\RegistrationContext;
  11. use OC\AppFramework\Bootstrap\ServiceRegistration;
  12. use OC\EventDispatcher\EventDispatcher;
  13. use OC\TaskProcessing\Db\TaskMapper;
  14. use OC\TaskProcessing\Manager;
  15. use OC\TaskProcessing\RemoveOldTasksBackgroundJob;
  16. use OCP\AppFramework\Utility\ITimeFactory;
  17. use OCP\BackgroundJob\IJobList;
  18. use OCP\EventDispatcher\IEventDispatcher;
  19. use OCP\Files\AppData\IAppDataFactory;
  20. use OCP\Files\IAppData;
  21. use OCP\Files\IRootFolder;
  22. use OCP\IConfig;
  23. use OCP\IDBConnection;
  24. use OCP\IServerContainer;
  25. use OCP\PreConditionNotMetException;
  26. use OCP\SpeechToText\ISpeechToTextManager;
  27. use OCP\TaskProcessing\EShapeType;
  28. use OCP\TaskProcessing\Events\TaskFailedEvent;
  29. use OCP\TaskProcessing\Events\TaskSuccessfulEvent;
  30. use OCP\TaskProcessing\Exception\NotFoundException;
  31. use OCP\TaskProcessing\Exception\ProcessingException;
  32. use OCP\TaskProcessing\Exception\ValidationException;
  33. use OCP\TaskProcessing\IManager;
  34. use OCP\TaskProcessing\IProvider;
  35. use OCP\TaskProcessing\ISynchronousProvider;
  36. use OCP\TaskProcessing\ITaskType;
  37. use OCP\TaskProcessing\ShapeDescriptor;
  38. use OCP\TaskProcessing\Task;
  39. use OCP\TaskProcessing\TaskTypes\TextToImage;
  40. use OCP\TaskProcessing\TaskTypes\TextToText;
  41. use OCP\TaskProcessing\TaskTypes\TextToTextSummary;
  42. use OCP\TextProcessing\SummaryTaskType;
  43. use PHPUnit\Framework\Constraint\IsInstanceOf;
  44. use Psr\Log\LoggerInterface;
  45. use Test\BackgroundJob\DummyJobList;
  46. class AudioToImage implements ITaskType {
  47. public const ID = 'test:audiotoimage';
  48. public function getId(): string {
  49. return self::ID;
  50. }
  51. public function getName(): string {
  52. return self::class;
  53. }
  54. public function getDescription(): string {
  55. return self::class;
  56. }
  57. public function getInputShape(): array {
  58. return [
  59. 'audio' => new ShapeDescriptor('Audio', 'The audio', EShapeType::Audio),
  60. ];
  61. }
  62. public function getOutputShape(): array {
  63. return [
  64. 'spectrogram' => new ShapeDescriptor('Spectrogram', 'The audio spectrogram', EShapeType::Image),
  65. ];
  66. }
  67. }
  68. class AsyncProvider implements IProvider {
  69. public function getId(): string {
  70. return 'test:sync:success';
  71. }
  72. public function getName(): string {
  73. return self::class;
  74. }
  75. public function getTaskTypeId(): string {
  76. return AudioToImage::ID;
  77. }
  78. public function getExpectedRuntime(): int {
  79. return 10;
  80. }
  81. public function getOptionalInputShape(): array {
  82. return [
  83. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  84. ];
  85. }
  86. public function getOptionalOutputShape(): array {
  87. return [
  88. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  89. ];
  90. }
  91. }
  92. class SuccessfulSyncProvider implements IProvider, ISynchronousProvider {
  93. public function getId(): string {
  94. return 'test:sync:success';
  95. }
  96. public function getName(): string {
  97. return self::class;
  98. }
  99. public function getTaskTypeId(): string {
  100. return TextToText::ID;
  101. }
  102. public function getExpectedRuntime(): int {
  103. return 10;
  104. }
  105. public function getOptionalInputShape(): array {
  106. return [
  107. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  108. ];
  109. }
  110. public function getOptionalOutputShape(): array {
  111. return [
  112. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  113. ];
  114. }
  115. public function process(?string $userId, array $input): array {
  116. return ['output' => $input['input']];
  117. }
  118. }
  119. class FailingSyncProvider implements IProvider, ISynchronousProvider {
  120. public const ERROR_MESSAGE = 'Failure';
  121. public function getId(): string {
  122. return 'test:sync:fail';
  123. }
  124. public function getName(): string {
  125. return self::class;
  126. }
  127. public function getTaskTypeId(): string {
  128. return TextToText::ID;
  129. }
  130. public function getExpectedRuntime(): int {
  131. return 10;
  132. }
  133. public function getOptionalInputShape(): array {
  134. return [
  135. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  136. ];
  137. }
  138. public function getOptionalOutputShape(): array {
  139. return [
  140. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  141. ];
  142. }
  143. public function process(?string $userId, array $input): array {
  144. throw new ProcessingException(self::ERROR_MESSAGE);
  145. }
  146. }
  147. class BrokenSyncProvider implements IProvider, ISynchronousProvider {
  148. public function getId(): string {
  149. return 'test:sync:broken-output';
  150. }
  151. public function getName(): string {
  152. return self::class;
  153. }
  154. public function getTaskTypeId(): string {
  155. return TextToText::ID;
  156. }
  157. public function getExpectedRuntime(): int {
  158. return 10;
  159. }
  160. public function getOptionalInputShape(): array {
  161. return [
  162. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  163. ];
  164. }
  165. public function getOptionalOutputShape(): array {
  166. return [
  167. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  168. ];
  169. }
  170. public function process(?string $userId, array $input): array {
  171. return [];
  172. }
  173. }
  174. class SuccessfulTextProcessingSummaryProvider implements \OCP\TextProcessing\IProvider {
  175. public bool $ran = false;
  176. public function getName(): string {
  177. return 'TEST Vanilla LLM Provider';
  178. }
  179. public function process(string $prompt): string {
  180. $this->ran = true;
  181. return $prompt . ' Summarize';
  182. }
  183. public function getTaskType(): string {
  184. return SummaryTaskType::class;
  185. }
  186. }
  187. class FailingTextProcessingSummaryProvider implements \OCP\TextProcessing\IProvider {
  188. public bool $ran = false;
  189. public function getName(): string {
  190. return 'TEST Vanilla LLM Provider';
  191. }
  192. public function process(string $prompt): string {
  193. $this->ran = true;
  194. throw new \Exception('ERROR');
  195. }
  196. public function getTaskType(): string {
  197. return SummaryTaskType::class;
  198. }
  199. }
  200. class SuccessfulTextToImageProvider implements \OCP\TextToImage\IProvider {
  201. public bool $ran = false;
  202. public function getId(): string {
  203. return 'test:successful';
  204. }
  205. public function getName(): string {
  206. return 'TEST Provider';
  207. }
  208. public function generate(string $prompt, array $resources): void {
  209. $this->ran = true;
  210. foreach($resources as $resource) {
  211. fwrite($resource, 'test');
  212. fclose($resource);
  213. }
  214. }
  215. public function getExpectedRuntime(): int {
  216. return 1;
  217. }
  218. }
  219. class FailingTextToImageProvider implements \OCP\TextToImage\IProvider {
  220. public bool $ran = false;
  221. public function getId(): string {
  222. return 'test:failing';
  223. }
  224. public function getName(): string {
  225. return 'TEST Provider';
  226. }
  227. public function generate(string $prompt, array $resources): void {
  228. $this->ran = true;
  229. throw new \RuntimeException('ERROR');
  230. }
  231. public function getExpectedRuntime(): int {
  232. return 1;
  233. }
  234. }
  235. /**
  236. * @group DB
  237. */
  238. class TaskProcessingTest extends \Test\TestCase {
  239. private IManager $manager;
  240. private Coordinator $coordinator;
  241. private array $providers;
  242. private IServerContainer $serverContainer;
  243. private IEventDispatcher $eventDispatcher;
  244. private RegistrationContext $registrationContext;
  245. private TaskMapper $taskMapper;
  246. private IJobList $jobList;
  247. private IAppData $appData;
  248. protected function setUp(): void {
  249. parent::setUp();
  250. $this->providers = [
  251. SuccessfulSyncProvider::class => new SuccessfulSyncProvider(),
  252. FailingSyncProvider::class => new FailingSyncProvider(),
  253. BrokenSyncProvider::class => new BrokenSyncProvider(),
  254. AsyncProvider::class => new AsyncProvider(),
  255. AudioToImage::class => new AudioToImage(),
  256. SuccessfulTextProcessingSummaryProvider::class => new SuccessfulTextProcessingSummaryProvider(),
  257. FailingTextProcessingSummaryProvider::class => new FailingTextProcessingSummaryProvider(),
  258. SuccessfulTextToImageProvider::class => new SuccessfulTextToImageProvider(),
  259. FailingTextToImageProvider::class => new FailingTextToImageProvider(),
  260. ];
  261. $this->serverContainer = $this->createMock(IServerContainer::class);
  262. $this->serverContainer->expects($this->any())->method('get')->willReturnCallback(function ($class) {
  263. return $this->providers[$class];
  264. });
  265. $this->eventDispatcher = new EventDispatcher(
  266. new \Symfony\Component\EventDispatcher\EventDispatcher(),
  267. $this->serverContainer,
  268. \OC::$server->get(LoggerInterface::class),
  269. );
  270. $this->registrationContext = $this->createMock(RegistrationContext::class);
  271. $this->coordinator = $this->createMock(Coordinator::class);
  272. $this->coordinator->expects($this->any())->method('getRegistrationContext')->willReturn($this->registrationContext);
  273. $this->taskMapper = \OCP\Server::get(TaskMapper::class);
  274. $this->jobList = $this->createPartialMock(DummyJobList::class, ['add']);
  275. $this->jobList->expects($this->any())->method('add')->willReturnCallback(function () {
  276. });
  277. $config = $this->createMock(IConfig::class);
  278. $config->method('getAppValue')
  279. ->with('core', 'ai.textprocessing_provider_preferences', '')
  280. ->willReturn('');
  281. $this->eventDispatcher = $this->createMock(IEventDispatcher::class);
  282. $textProcessingManager = new \OC\TextProcessing\Manager(
  283. $this->serverContainer,
  284. $this->coordinator,
  285. \OC::$server->get(LoggerInterface::class),
  286. $this->jobList,
  287. \OC::$server->get(\OC\TextProcessing\Db\TaskMapper::class),
  288. \OC::$server->get(IConfig::class),
  289. );
  290. $text2imageManager = new \OC\TextToImage\Manager(
  291. $this->serverContainer,
  292. $this->coordinator,
  293. \OC::$server->get(LoggerInterface::class),
  294. $this->jobList,
  295. \OC::$server->get(\OC\TextToImage\Db\TaskMapper::class),
  296. \OC::$server->get(IConfig::class),
  297. \OC::$server->get(IAppDataFactory::class),
  298. );
  299. $this->manager = new Manager(
  300. $this->coordinator,
  301. $this->serverContainer,
  302. \OC::$server->get(LoggerInterface::class),
  303. $this->taskMapper,
  304. $this->jobList,
  305. $this->eventDispatcher,
  306. \OC::$server->get(IAppDataFactory::class),
  307. \OC::$server->get(IRootFolder::class),
  308. $textProcessingManager,
  309. $text2imageManager,
  310. \OC::$server->get(ISpeechToTextManager::class),
  311. );
  312. }
  313. private function getFile(string $name, string $content): \OCP\Files\File {
  314. /** @var IRootFolder $rootFolder */
  315. $rootFolder = \OC::$server->get(IRootFolder::class);
  316. $this->appData = \OC::$server->get(IAppDataFactory::class)->get('core');
  317. try {
  318. $folder = $this->appData->getFolder('test');
  319. } catch (\OCP\Files\NotFoundException $e) {
  320. $folder = $this->appData->newFolder('test');
  321. }
  322. $file = $folder->newFile($name, $content);
  323. $inputFile = current($rootFolder->getByIdInPath($file->getId(), '/' . $rootFolder->getAppDataDirectoryName() . '/'));
  324. if (!$inputFile instanceof \OCP\Files\File) {
  325. throw new \Exception('PEBCAK');
  326. }
  327. return $inputFile;
  328. }
  329. public function testShouldNotHaveAnyProviders() {
  330. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([]);
  331. self::assertCount(0, $this->manager->getAvailableTaskTypes());
  332. self::assertFalse($this->manager->hasProviders());
  333. self::expectException(PreConditionNotMetException::class);
  334. $this->manager->scheduleTask(new Task(TextToText::ID, ['input' => 'Hello'], 'test', null));
  335. }
  336. public function testProviderShouldBeRegisteredAndTaskFailValidation() {
  337. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  338. new ServiceRegistration('test', BrokenSyncProvider::class)
  339. ]);
  340. self::assertCount(1, $this->manager->getAvailableTaskTypes());
  341. self::assertTrue($this->manager->hasProviders());
  342. $task = new Task(TextToText::ID, ['wrongInputKey' => 'Hello'], 'test', null);
  343. self::assertNull($task->getId());
  344. self::expectException(ValidationException::class);
  345. $this->manager->scheduleTask($task);
  346. }
  347. public function testProviderShouldBeRegisteredAndFail() {
  348. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  349. new ServiceRegistration('test', FailingSyncProvider::class)
  350. ]);
  351. self::assertCount(1, $this->manager->getAvailableTaskTypes());
  352. self::assertTrue($this->manager->hasProviders());
  353. $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null);
  354. self::assertNull($task->getId());
  355. self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus());
  356. $this->manager->scheduleTask($task);
  357. self::assertNotNull($task->getId());
  358. self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus());
  359. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class));
  360. $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
  361. \OCP\Server::get(ITimeFactory::class),
  362. $this->manager,
  363. $this->jobList,
  364. \OCP\Server::get(LoggerInterface::class),
  365. );
  366. $backgroundJob->start($this->jobList);
  367. $task = $this->manager->getTask($task->getId());
  368. self::assertEquals(Task::STATUS_FAILED, $task->getStatus());
  369. self::assertEquals(FailingSyncProvider::ERROR_MESSAGE, $task->getErrorMessage());
  370. }
  371. public function testProviderShouldBeRegisteredAndFailOutputValidation() {
  372. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  373. new ServiceRegistration('test', BrokenSyncProvider::class)
  374. ]);
  375. self::assertCount(1, $this->manager->getAvailableTaskTypes());
  376. self::assertTrue($this->manager->hasProviders());
  377. $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null);
  378. self::assertNull($task->getId());
  379. self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus());
  380. $this->manager->scheduleTask($task);
  381. self::assertNotNull($task->getId());
  382. self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus());
  383. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class));
  384. $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
  385. \OCP\Server::get(ITimeFactory::class),
  386. $this->manager,
  387. $this->jobList,
  388. \OCP\Server::get(LoggerInterface::class),
  389. );
  390. $backgroundJob->start($this->jobList);
  391. $task = $this->manager->getTask($task->getId());
  392. self::assertEquals(Task::STATUS_FAILED, $task->getStatus());
  393. self::assertEquals('The task was processed successfully but the provider\'s output doesn\'t pass validation against the task type\'s outputShape spec and/or the provider\'s own optionalOutputShape spec', $task->getErrorMessage());
  394. }
  395. public function testProviderShouldBeRegisteredAndRun() {
  396. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  397. new ServiceRegistration('test', SuccessfulSyncProvider::class)
  398. ]);
  399. self::assertCount(1, $this->manager->getAvailableTaskTypes());
  400. $taskTypeStruct = $this->manager->getAvailableTaskTypes()[array_keys($this->manager->getAvailableTaskTypes())[0]];
  401. self::assertTrue(isset($taskTypeStruct['inputShape']['input']));
  402. self::assertEquals(EShapeType::Text, $taskTypeStruct['inputShape']['input']->getShapeType());
  403. self::assertTrue(isset($taskTypeStruct['optionalInputShape']['optionalKey']));
  404. self::assertEquals(EShapeType::Text, $taskTypeStruct['optionalInputShape']['optionalKey']->getShapeType());
  405. self::assertTrue(isset($taskTypeStruct['outputShape']['output']));
  406. self::assertEquals(EShapeType::Text, $taskTypeStruct['outputShape']['output']->getShapeType());
  407. self::assertTrue(isset($taskTypeStruct['optionalOutputShape']['optionalKey']));
  408. self::assertEquals(EShapeType::Text, $taskTypeStruct['optionalOutputShape']['optionalKey']->getShapeType());
  409. self::assertTrue($this->manager->hasProviders());
  410. $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null);
  411. self::assertNull($task->getId());
  412. self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus());
  413. $this->manager->scheduleTask($task);
  414. self::assertNotNull($task->getId());
  415. self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus());
  416. // Task object retrieved from db is up-to-date
  417. $task2 = $this->manager->getTask($task->getId());
  418. self::assertEquals($task->getId(), $task2->getId());
  419. self::assertEquals(['input' => 'Hello'], $task2->getInput());
  420. self::assertNull($task2->getOutput());
  421. self::assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus());
  422. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
  423. $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
  424. \OCP\Server::get(ITimeFactory::class),
  425. $this->manager,
  426. $this->jobList,
  427. \OCP\Server::get(LoggerInterface::class),
  428. );
  429. $backgroundJob->start($this->jobList);
  430. $task = $this->manager->getTask($task->getId());
  431. self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus(), 'Status is '. $task->getStatus() . ' with error message: ' . $task->getErrorMessage());
  432. self::assertEquals(['output' => 'Hello'], $task->getOutput());
  433. self::assertEquals(1, $task->getProgress());
  434. }
  435. public function testAsyncProviderWithFilesShouldBeRegisteredAndRun() {
  436. $this->registrationContext->expects($this->any())->method('getTaskProcessingTaskTypes')->willReturn([
  437. new ServiceRegistration('test', AudioToImage::class)
  438. ]);
  439. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  440. new ServiceRegistration('test', AsyncProvider::class)
  441. ]);
  442. self::assertCount(1, $this->manager->getAvailableTaskTypes());
  443. self::assertTrue($this->manager->hasProviders());
  444. $audioId = $this->getFile('audioInput', 'Hello')->getId();
  445. $task = new Task(AudioToImage::ID, ['audio' => $audioId], 'test', null);
  446. self::assertNull($task->getId());
  447. self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus());
  448. $this->manager->scheduleTask($task);
  449. self::assertNotNull($task->getId());
  450. self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus());
  451. // Task object retrieved from db is up-to-date
  452. $task2 = $this->manager->getTask($task->getId());
  453. self::assertEquals($task->getId(), $task2->getId());
  454. self::assertEquals(['audio' => $audioId], $task2->getInput());
  455. self::assertNull($task2->getOutput());
  456. self::assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus());
  457. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
  458. $this->manager->setTaskProgress($task2->getId(), 0.1);
  459. $input = $this->manager->prepareInputData($task2);
  460. self::assertTrue(isset($input['audio']));
  461. self::assertInstanceOf(\OCP\Files\File::class, $input['audio']);
  462. self::assertEquals($audioId, $input['audio']->getId());
  463. $this->manager->setTaskResult($task2->getId(), null, ['spectrogram' => 'World']);
  464. $task = $this->manager->getTask($task->getId());
  465. self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus());
  466. self::assertEquals(1, $task->getProgress());
  467. self::assertTrue(isset($task->getOutput()['spectrogram']));
  468. $root = \OCP\Server::get(IRootFolder::class);
  469. $node = $root->getFirstNodeByIdInPath($task->getOutput()['spectrogram'], '/' . $root->getAppDataDirectoryName() . '/');
  470. self::assertNotNull($node);
  471. self::assertInstanceOf(\OCP\Files\File::class, $node);
  472. self::assertEquals('World', $node->getContent());
  473. }
  474. public function testNonexistentTask() {
  475. $this->expectException(\OCP\TaskProcessing\Exception\NotFoundException::class);
  476. $this->manager->getTask(2147483646);
  477. }
  478. public function testOldTasksShouldBeCleanedUp() {
  479. $currentTime = new \DateTime('now');
  480. $timeFactory = $this->createMock(ITimeFactory::class);
  481. $timeFactory->expects($this->any())->method('getDateTime')->willReturnCallback(fn () => $currentTime);
  482. $timeFactory->expects($this->any())->method('getTime')->willReturnCallback(fn () => $currentTime->getTimestamp());
  483. $this->taskMapper = new TaskMapper(
  484. \OCP\Server::get(IDBConnection::class),
  485. $timeFactory,
  486. );
  487. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  488. new ServiceRegistration('test', SuccessfulSyncProvider::class)
  489. ]);
  490. self::assertCount(1, $this->manager->getAvailableTaskTypes());
  491. self::assertTrue($this->manager->hasProviders());
  492. $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null);
  493. $this->manager->scheduleTask($task);
  494. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
  495. $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
  496. \OCP\Server::get(ITimeFactory::class),
  497. $this->manager,
  498. $this->jobList,
  499. \OCP\Server::get(LoggerInterface::class),
  500. );
  501. $backgroundJob->start($this->jobList);
  502. $task = $this->manager->getTask($task->getId());
  503. $currentTime = $currentTime->add(new \DateInterval('P1Y'));
  504. // run background job
  505. $bgJob = new RemoveOldTasksBackgroundJob(
  506. $timeFactory,
  507. $this->taskMapper,
  508. \OC::$server->get(LoggerInterface::class),
  509. );
  510. $bgJob->setArgument([]);
  511. $bgJob->start($this->jobList);
  512. $this->expectException(NotFoundException::class);
  513. $this->manager->getTask($task->getId());
  514. }
  515. public function testShouldTransparentlyHandleTextProcessingProviders() {
  516. $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
  517. new ServiceRegistration('test', SuccessfulTextProcessingSummaryProvider::class)
  518. ]);
  519. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  520. ]);
  521. $taskTypes = $this->manager->getAvailableTaskTypes();
  522. self::assertCount(1, $taskTypes);
  523. self::assertTrue(isset($taskTypes[TextToTextSummary::ID]));
  524. self::assertTrue($this->manager->hasProviders());
  525. $task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null);
  526. $this->manager->scheduleTask($task);
  527. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
  528. $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
  529. \OCP\Server::get(ITimeFactory::class),
  530. $this->manager,
  531. $this->jobList,
  532. \OCP\Server::get(LoggerInterface::class),
  533. );
  534. $backgroundJob->start($this->jobList);
  535. $task = $this->manager->getTask($task->getId());
  536. self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus());
  537. self::assertIsArray($task->getOutput());
  538. self::assertTrue(isset($task->getOutput()['output']));
  539. self::assertEquals('Hello Summarize', $task->getOutput()['output']);
  540. self::assertTrue($this->providers[SuccessfulTextProcessingSummaryProvider::class]->ran);
  541. }
  542. public function testShouldTransparentlyHandleFailingTextProcessingProviders() {
  543. $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
  544. new ServiceRegistration('test', FailingTextProcessingSummaryProvider::class)
  545. ]);
  546. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  547. ]);
  548. $taskTypes = $this->manager->getAvailableTaskTypes();
  549. self::assertCount(1, $taskTypes);
  550. self::assertTrue(isset($taskTypes[TextToTextSummary::ID]));
  551. self::assertTrue($this->manager->hasProviders());
  552. $task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null);
  553. $this->manager->scheduleTask($task);
  554. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class));
  555. $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
  556. \OCP\Server::get(ITimeFactory::class),
  557. $this->manager,
  558. $this->jobList,
  559. \OCP\Server::get(LoggerInterface::class),
  560. );
  561. $backgroundJob->start($this->jobList);
  562. $task = $this->manager->getTask($task->getId());
  563. self::assertEquals(Task::STATUS_FAILED, $task->getStatus());
  564. self::assertTrue($task->getOutput() === null);
  565. self::assertEquals('ERROR', $task->getErrorMessage());
  566. self::assertTrue($this->providers[FailingTextProcessingSummaryProvider::class]->ran);
  567. }
  568. public function testShouldTransparentlyHandleText2ImageProviders() {
  569. $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([
  570. new ServiceRegistration('test', SuccessfulTextToImageProvider::class)
  571. ]);
  572. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  573. ]);
  574. $taskTypes = $this->manager->getAvailableTaskTypes();
  575. self::assertCount(1, $taskTypes);
  576. self::assertTrue(isset($taskTypes[TextToImage::ID]));
  577. self::assertTrue($this->manager->hasProviders());
  578. $task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null);
  579. $this->manager->scheduleTask($task);
  580. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
  581. $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
  582. \OCP\Server::get(ITimeFactory::class),
  583. $this->manager,
  584. $this->jobList,
  585. \OCP\Server::get(LoggerInterface::class),
  586. );
  587. $backgroundJob->start($this->jobList);
  588. $task = $this->manager->getTask($task->getId());
  589. self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus());
  590. self::assertIsArray($task->getOutput());
  591. self::assertTrue(isset($task->getOutput()['images']));
  592. self::assertIsArray($task->getOutput()['images']);
  593. self::assertCount(3, $task->getOutput()['images']);
  594. self::assertTrue($this->providers[SuccessfulTextToImageProvider::class]->ran);
  595. }
  596. public function testShouldTransparentlyHandleFailingText2ImageProviders() {
  597. $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([
  598. new ServiceRegistration('test', FailingTextToImageProvider::class)
  599. ]);
  600. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  601. ]);
  602. $taskTypes = $this->manager->getAvailableTaskTypes();
  603. self::assertCount(1, $taskTypes);
  604. self::assertTrue(isset($taskTypes[TextToImage::ID]));
  605. self::assertTrue($this->manager->hasProviders());
  606. $task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null);
  607. $this->manager->scheduleTask($task);
  608. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class));
  609. $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
  610. \OCP\Server::get(ITimeFactory::class),
  611. $this->manager,
  612. $this->jobList,
  613. \OCP\Server::get(LoggerInterface::class),
  614. );
  615. $backgroundJob->start($this->jobList);
  616. $task = $this->manager->getTask($task->getId());
  617. self::assertEquals(Task::STATUS_FAILED, $task->getStatus());
  618. self::assertTrue($task->getOutput() === null);
  619. self::assertEquals('ERROR', $task->getErrorMessage());
  620. self::assertTrue($this->providers[FailingTextToImageProvider::class]->ran);
  621. }
  622. }