You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1409 lines
51 KiB

  1. <?php
  2. /**
  3. * SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors
  4. * SPDX-License-Identifier: AGPL-3.0-or-later
  5. */
  6. namespace Test\TaskProcessing;
  7. use OC\AppFramework\Bootstrap\Coordinator;
  8. use OC\AppFramework\Bootstrap\RegistrationContext;
  9. use OC\AppFramework\Bootstrap\ServiceRegistration;
  10. use OC\EventDispatcher\EventDispatcher;
  11. use OC\TaskProcessing\Db\TaskMapper;
  12. use OC\TaskProcessing\Manager;
  13. use OC\TaskProcessing\RemoveOldTasksBackgroundJob;
  14. use OC\TaskProcessing\SynchronousBackgroundJob;
  15. use OCP\App\IAppManager;
  16. use OCP\AppFramework\Utility\ITimeFactory;
  17. use OCP\BackgroundJob\IJobList;
  18. use OCP\EventDispatcher\IEventDispatcher;
  19. use OCP\Files\AppData\IAppDataFactory;
  20. use OCP\Files\Config\ICachedMountInfo;
  21. use OCP\Files\Config\IUserMountCache;
  22. use OCP\Files\File;
  23. use OCP\Files\IRootFolder;
  24. use OCP\Http\Client\IClientService;
  25. use OCP\IAppConfig;
  26. use OCP\ICacheFactory;
  27. use OCP\IConfig;
  28. use OCP\IDBConnection;
  29. use OCP\IServerContainer;
  30. use OCP\IUser;
  31. use OCP\IUserManager;
  32. use OCP\IUserSession;
  33. use OCP\L10N\IFactory;
  34. use OCP\Server;
  35. use OCP\TaskProcessing\EShapeType;
  36. use OCP\TaskProcessing\Events\GetTaskProcessingProvidersEvent;
  37. use OCP\TaskProcessing\Events\TaskFailedEvent;
  38. use OCP\TaskProcessing\Events\TaskSuccessfulEvent;
  39. use OCP\TaskProcessing\Exception\NotFoundException;
  40. use OCP\TaskProcessing\Exception\PreConditionNotMetException;
  41. use OCP\TaskProcessing\Exception\ProcessingException;
  42. use OCP\TaskProcessing\Exception\UnauthorizedException;
  43. use OCP\TaskProcessing\Exception\ValidationException;
  44. use OCP\TaskProcessing\IManager;
  45. use OCP\TaskProcessing\IProvider;
  46. use OCP\TaskProcessing\ISynchronousProvider;
  47. use OCP\TaskProcessing\ITaskType;
  48. use OCP\TaskProcessing\ITriggerableProvider;
  49. use OCP\TaskProcessing\ShapeDescriptor;
  50. use OCP\TaskProcessing\Task;
  51. use OCP\TaskProcessing\TaskTypes\TextToImage;
  52. use OCP\TaskProcessing\TaskTypes\TextToText;
  53. use OCP\TaskProcessing\TaskTypes\TextToTextSummary;
  54. use OCP\TextProcessing\SummaryTaskType;
  55. use PHPUnit\Framework\Constraint\IsInstanceOf;
  56. use Psr\Log\LoggerInterface;
  57. use Test\BackgroundJob\DummyJobList;
  58. class AudioToImage implements ITaskType {
  59. public const ID = 'test:audiotoimage';
  60. public function getId(): string {
  61. return self::ID;
  62. }
  63. public function getName(): string {
  64. return self::class;
  65. }
  66. public function getDescription(): string {
  67. return self::class;
  68. }
  69. public function getInputShape(): array {
  70. return [
  71. 'audio' => new ShapeDescriptor('Audio', 'The audio', EShapeType::Audio),
  72. ];
  73. }
  74. public function getOutputShape(): array {
  75. return [
  76. 'spectrogram' => new ShapeDescriptor('Spectrogram', 'The audio spectrogram', EShapeType::Image),
  77. ];
  78. }
  79. }
  80. class AsyncProvider implements IProvider {
  81. public function getId(): string {
  82. return 'test:sync:success';
  83. }
  84. public function getName(): string {
  85. return self::class;
  86. }
  87. public function getTaskTypeId(): string {
  88. return AudioToImage::ID;
  89. }
  90. public function getExpectedRuntime(): int {
  91. return 10;
  92. }
  93. public function getOptionalInputShape(): array {
  94. return [
  95. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  96. ];
  97. }
  98. public function getOptionalOutputShape(): array {
  99. return [
  100. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  101. ];
  102. }
  103. public function getInputShapeEnumValues(): array {
  104. return [];
  105. }
  106. public function getInputShapeDefaults(): array {
  107. return [];
  108. }
  109. public function getOptionalInputShapeEnumValues(): array {
  110. return [];
  111. }
  112. public function getOptionalInputShapeDefaults(): array {
  113. return [];
  114. }
  115. public function getOutputShapeEnumValues(): array {
  116. return [];
  117. }
  118. public function getOptionalOutputShapeEnumValues(): array {
  119. return [];
  120. }
  121. }
  122. class SuccessfulSyncProvider implements IProvider, ISynchronousProvider {
  123. public const ID = 'test:sync:success';
  124. public function getId(): string {
  125. return self::ID;
  126. }
  127. public function getName(): string {
  128. return self::class;
  129. }
  130. public function getTaskTypeId(): string {
  131. return TextToText::ID;
  132. }
  133. public function getExpectedRuntime(): int {
  134. return 10;
  135. }
  136. public function getOptionalInputShape(): array {
  137. return [
  138. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  139. ];
  140. }
  141. public function getOptionalOutputShape(): array {
  142. return [
  143. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  144. ];
  145. }
  146. public function process(?string $userId, array $input, callable $reportProgress): array {
  147. return ['output' => $input['input']];
  148. }
  149. public function getInputShapeEnumValues(): array {
  150. return [];
  151. }
  152. public function getInputShapeDefaults(): array {
  153. return [];
  154. }
  155. public function getOptionalInputShapeEnumValues(): array {
  156. return [];
  157. }
  158. public function getOptionalInputShapeDefaults(): array {
  159. return [];
  160. }
  161. public function getOutputShapeEnumValues(): array {
  162. return [];
  163. }
  164. public function getOptionalOutputShapeEnumValues(): array {
  165. return [];
  166. }
  167. }
  168. class FailingSyncProvider implements IProvider, ISynchronousProvider {
  169. public const ERROR_MESSAGE = 'Failure';
  170. public function getId(): string {
  171. return 'test:sync:fail';
  172. }
  173. public function getName(): string {
  174. return self::class;
  175. }
  176. public function getTaskTypeId(): string {
  177. return TextToText::ID;
  178. }
  179. public function getExpectedRuntime(): int {
  180. return 10;
  181. }
  182. public function getOptionalInputShape(): array {
  183. return [
  184. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  185. ];
  186. }
  187. public function getOptionalOutputShape(): array {
  188. return [
  189. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  190. ];
  191. }
  192. public function process(?string $userId, array $input, callable $reportProgress): array {
  193. throw new ProcessingException(self::ERROR_MESSAGE);
  194. }
  195. public function getInputShapeEnumValues(): array {
  196. return [];
  197. }
  198. public function getInputShapeDefaults(): array {
  199. return [];
  200. }
  201. public function getOptionalInputShapeEnumValues(): array {
  202. return [];
  203. }
  204. public function getOptionalInputShapeDefaults(): array {
  205. return [];
  206. }
  207. public function getOutputShapeEnumValues(): array {
  208. return [];
  209. }
  210. public function getOptionalOutputShapeEnumValues(): array {
  211. return [];
  212. }
  213. }
  214. class BrokenSyncProvider implements IProvider, ISynchronousProvider {
  215. public function getId(): string {
  216. return 'test:sync:broken-output';
  217. }
  218. public function getName(): string {
  219. return self::class;
  220. }
  221. public function getTaskTypeId(): string {
  222. return TextToText::ID;
  223. }
  224. public function getExpectedRuntime(): int {
  225. return 10;
  226. }
  227. public function getOptionalInputShape(): array {
  228. return [
  229. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  230. ];
  231. }
  232. public function getOptionalOutputShape(): array {
  233. return [
  234. 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text),
  235. ];
  236. }
  237. public function process(?string $userId, array $input, callable $reportProgress): array {
  238. return [];
  239. }
  240. public function getInputShapeEnumValues(): array {
  241. return [];
  242. }
  243. public function getInputShapeDefaults(): array {
  244. return [];
  245. }
  246. public function getOptionalInputShapeEnumValues(): array {
  247. return [];
  248. }
  249. public function getOptionalInputShapeDefaults(): array {
  250. return [];
  251. }
  252. public function getOutputShapeEnumValues(): array {
  253. return [];
  254. }
  255. public function getOptionalOutputShapeEnumValues(): array {
  256. return [];
  257. }
  258. }
  259. class SuccessfulTextProcessingSummaryProvider implements \OCP\TextProcessing\IProvider {
  260. public bool $ran = false;
  261. public function getName(): string {
  262. return 'TEST Vanilla LLM Provider';
  263. }
  264. public function process(string $prompt): string {
  265. $this->ran = true;
  266. return $prompt . ' Summarize';
  267. }
  268. public function getTaskType(): string {
  269. return SummaryTaskType::class;
  270. }
  271. }
  272. class FailingTextProcessingSummaryProvider implements \OCP\TextProcessing\IProvider {
  273. public bool $ran = false;
  274. public function getName(): string {
  275. return 'TEST Vanilla LLM Provider';
  276. }
  277. public function process(string $prompt): string {
  278. $this->ran = true;
  279. throw new \Exception('ERROR');
  280. }
  281. public function getTaskType(): string {
  282. return SummaryTaskType::class;
  283. }
  284. }
  285. class SuccessfulTextToImageProvider implements \OCP\TextToImage\IProvider {
  286. public bool $ran = false;
  287. public function getId(): string {
  288. return 'test:successful';
  289. }
  290. public function getName(): string {
  291. return 'TEST Provider';
  292. }
  293. public function generate(string $prompt, array $resources): void {
  294. $this->ran = true;
  295. foreach ($resources as $resource) {
  296. fwrite($resource, 'test');
  297. }
  298. }
  299. public function getExpectedRuntime(): int {
  300. return 1;
  301. }
  302. }
  303. class FailingTextToImageProvider implements \OCP\TextToImage\IProvider {
  304. public bool $ran = false;
  305. public function getId(): string {
  306. return 'test:failing';
  307. }
  308. public function getName(): string {
  309. return 'TEST Provider';
  310. }
  311. public function generate(string $prompt, array $resources): void {
  312. $this->ran = true;
  313. throw new \RuntimeException('ERROR');
  314. }
  315. public function getExpectedRuntime(): int {
  316. return 1;
  317. }
  318. }
  319. class ExternalProvider implements IProvider {
  320. public const ID = 'event:external:provider';
  321. public const TASK_TYPE_ID = 'event:external:tasktype';
  322. public function getId(): string {
  323. return self::ID;
  324. }
  325. public function getName(): string {
  326. return 'External Provider via Event';
  327. }
  328. public function getTaskTypeId(): string {
  329. return self::TASK_TYPE_ID;
  330. }
  331. public function getExpectedRuntime(): int {
  332. return 5;
  333. }
  334. public function getOptionalInputShape(): array {
  335. return [];
  336. }
  337. public function getOptionalOutputShape(): array {
  338. return [];
  339. }
  340. public function getInputShapeEnumValues(): array {
  341. return [];
  342. }
  343. public function getInputShapeDefaults(): array {
  344. return [];
  345. }
  346. public function getOptionalInputShapeEnumValues(): array {
  347. return [];
  348. }
  349. public function getOptionalInputShapeDefaults(): array {
  350. return [];
  351. }
  352. public function getOutputShapeEnumValues(): array {
  353. return [];
  354. }
  355. public function getOptionalOutputShapeEnumValues(): array {
  356. return [];
  357. }
  358. }
  359. class ExternalTriggerableProvider implements ITriggerableProvider {
  360. public const ID = 'event:external:provider:triggerable';
  361. public const TASK_TYPE_ID = TextToText::ID;
  362. public function getId(): string {
  363. return self::ID;
  364. }
  365. public function getName(): string {
  366. return 'External Triggerable Provider via Event';
  367. }
  368. public function getTaskTypeId(): string {
  369. return self::TASK_TYPE_ID;
  370. }
  371. public function trigger(): void {
  372. }
  373. public function getExpectedRuntime(): int {
  374. return 5;
  375. }
  376. public function getOptionalInputShape(): array {
  377. return [];
  378. }
  379. public function getOptionalOutputShape(): array {
  380. return [];
  381. }
  382. public function getInputShapeEnumValues(): array {
  383. return [];
  384. }
  385. public function getInputShapeDefaults(): array {
  386. return [];
  387. }
  388. public function getOptionalInputShapeEnumValues(): array {
  389. return [];
  390. }
  391. public function getOptionalInputShapeDefaults(): array {
  392. return [];
  393. }
  394. public function getOutputShapeEnumValues(): array {
  395. return [];
  396. }
  397. public function getOptionalOutputShapeEnumValues(): array {
  398. return [];
  399. }
  400. }
  401. class ConflictingExternalProvider implements IProvider {
  402. // Same ID as SuccessfulSyncProvider
  403. public const ID = 'test:sync:success';
  404. public const TASK_TYPE_ID = 'event:external:tasktype'; // Can be different task type
  405. public function getId(): string {
  406. return self::ID;
  407. }
  408. public function getName(): string {
  409. return 'Conflicting External Provider';
  410. }
  411. public function getTaskTypeId(): string {
  412. return self::TASK_TYPE_ID;
  413. }
  414. public function getExpectedRuntime(): int {
  415. return 50;
  416. }
  417. public function getOptionalInputShape(): array {
  418. return [];
  419. }
  420. public function getOptionalOutputShape(): array {
  421. return [];
  422. }
  423. public function getInputShapeEnumValues(): array {
  424. return [];
  425. }
  426. public function getInputShapeDefaults(): array {
  427. return [];
  428. }
  429. public function getOptionalInputShapeEnumValues(): array {
  430. return [];
  431. }
  432. public function getOptionalInputShapeDefaults(): array {
  433. return [];
  434. }
  435. public function getOutputShapeEnumValues(): array {
  436. return [];
  437. }
  438. public function getOptionalOutputShapeEnumValues(): array {
  439. return [];
  440. }
  441. }
  442. class ExternalTaskType implements ITaskType {
  443. public const ID = 'event:external:tasktype';
  444. public function getId(): string {
  445. return self::ID;
  446. }
  447. public function getName(): string {
  448. return 'External Task Type via Event';
  449. }
  450. public function getDescription(): string {
  451. return 'A task type added via event';
  452. }
  453. public function getInputShape(): array {
  454. return ['external_input' => new ShapeDescriptor('Ext In', '', EShapeType::Text)];
  455. }
  456. public function getOutputShape(): array {
  457. return ['external_output' => new ShapeDescriptor('Ext Out', '', EShapeType::Text)];
  458. }
  459. }
  460. class ConflictingExternalTaskType implements ITaskType {
  461. // Same ID as built-in TextToText
  462. public const ID = TextToText::ID;
  463. public function getId(): string {
  464. return self::ID;
  465. }
  466. public function getName(): string {
  467. return 'Conflicting External Task Type';
  468. }
  469. public function getDescription(): string {
  470. return 'Overrides built-in TextToText';
  471. }
  472. public function getInputShape(): array {
  473. return ['override_input' => new ShapeDescriptor('Override In', '', EShapeType::Number)];
  474. }
  475. public function getOutputShape(): array {
  476. return ['override_output' => new ShapeDescriptor('Override Out', '', EShapeType::Number)];
  477. }
  478. }
  479. /**
  480. * @group DB
  481. */
  482. class TaskProcessingTest extends \Test\TestCase {
  483. private IManager $manager;
  484. private Coordinator $coordinator;
  485. private array $providers;
  486. private IServerContainer $serverContainer;
  487. private IEventDispatcher $eventDispatcher;
  488. private RegistrationContext $registrationContext;
  489. private TaskMapper $taskMapper;
  490. private IJobList $jobList;
  491. private IUserMountCache $userMountCache;
  492. private IRootFolder $rootFolder;
  493. private IConfig $config;
  494. private IAppConfig $appConfig;
  495. public const TEST_USER = 'testuser';
  496. protected function setUp(): void {
  497. parent::setUp();
  498. $this->providers = [
  499. SuccessfulSyncProvider::class => new SuccessfulSyncProvider(),
  500. FailingSyncProvider::class => new FailingSyncProvider(),
  501. BrokenSyncProvider::class => new BrokenSyncProvider(),
  502. AsyncProvider::class => new AsyncProvider(),
  503. AudioToImage::class => new AudioToImage(),
  504. SuccessfulTextProcessingSummaryProvider::class => new SuccessfulTextProcessingSummaryProvider(),
  505. FailingTextProcessingSummaryProvider::class => new FailingTextProcessingSummaryProvider(),
  506. SuccessfulTextToImageProvider::class => new SuccessfulTextToImageProvider(),
  507. FailingTextToImageProvider::class => new FailingTextToImageProvider(),
  508. ExternalProvider::class => new ExternalProvider(),
  509. ExternalTriggerableProvider::class => new ExternalTriggerableProvider(),
  510. ConflictingExternalProvider::class => new ConflictingExternalProvider(),
  511. ExternalTaskType::class => new ExternalTaskType(),
  512. ConflictingExternalTaskType::class => new ConflictingExternalTaskType(),
  513. ];
  514. $userManager = Server::get(IUserManager::class);
  515. if (!$userManager->userExists(self::TEST_USER)) {
  516. $userManager->createUser(self::TEST_USER, 'test');
  517. }
  518. $this->serverContainer = $this->createMock(IServerContainer::class);
  519. $this->serverContainer->expects($this->any())->method('get')->willReturnCallback(function ($class) {
  520. return $this->providers[$class];
  521. });
  522. $this->eventDispatcher = new EventDispatcher(
  523. new \Symfony\Component\EventDispatcher\EventDispatcher(),
  524. $this->serverContainer,
  525. Server::get(LoggerInterface::class),
  526. );
  527. $this->registrationContext = $this->createMock(RegistrationContext::class);
  528. $this->coordinator = $this->createMock(Coordinator::class);
  529. $this->coordinator->expects($this->any())->method('getRegistrationContext')->willReturn($this->registrationContext);
  530. $this->rootFolder = Server::get(IRootFolder::class);
  531. $this->taskMapper = Server::get(TaskMapper::class);
  532. $this->jobList = $this->createPartialMock(DummyJobList::class, ['add']);
  533. $this->jobList->expects($this->any())->method('add')->willReturnCallback(function (): void {
  534. });
  535. $this->eventDispatcher = $this->createMock(IEventDispatcher::class);
  536. $this->configureEventDispatcherMock();
  537. $text2imageManager = new \OC\TextToImage\Manager(
  538. $this->serverContainer,
  539. $this->coordinator,
  540. Server::get(LoggerInterface::class),
  541. $this->jobList,
  542. Server::get(\OC\TextToImage\Db\TaskMapper::class),
  543. Server::get(IConfig::class),
  544. Server::get(IAppDataFactory::class),
  545. );
  546. $this->userMountCache = $this->createMock(IUserMountCache::class);
  547. $this->config = Server::get(IConfig::class);
  548. $this->appConfig = Server::get(IAppConfig::class);
  549. $this->manager = new Manager(
  550. $this->appConfig,
  551. $this->coordinator,
  552. $this->serverContainer,
  553. Server::get(LoggerInterface::class),
  554. $this->taskMapper,
  555. $this->jobList,
  556. $this->eventDispatcher,
  557. Server::get(IAppDataFactory::class),
  558. Server::get(IRootFolder::class),
  559. $text2imageManager,
  560. $this->userMountCache,
  561. Server::get(IClientService::class),
  562. Server::get(IAppManager::class),
  563. $userManager,
  564. Server::get(IUserSession::class),
  565. Server::get(ICacheFactory::class),
  566. Server::get(IFactory::class),
  567. );
  568. }
  569. private function getFile(string $name, string $content): File {
  570. $folder = $this->rootFolder->getUserFolder(self::TEST_USER);
  571. $file = $folder->newFile($name, $content);
  572. return $file;
  573. }
  574. public function testShouldNotHaveAnyProviders(): void {
  575. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([]);
  576. self::assertCount(0, $this->manager->getAvailableTaskTypes());
  577. self::assertCount(0, $this->manager->getAvailableTaskTypeIds());
  578. self::assertFalse($this->manager->hasProviders());
  579. self::expectException(PreConditionNotMetException::class);
  580. $this->manager->scheduleTask(new Task(TextToText::ID, ['input' => 'Hello'], 'test', null));
  581. }
  582. public function testProviderShouldBeRegisteredAndTaskTypeDisabled(): void {
  583. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  584. new ServiceRegistration('test', SuccessfulSyncProvider::class)
  585. ]);
  586. $taskProcessingTypeSettings = [
  587. TextToText::ID => false,
  588. ];
  589. $this->appConfig->setValueString('core', 'ai.taskprocessing_type_preferences', json_encode($taskProcessingTypeSettings), lazy: true);
  590. self::assertCount(0, $this->manager->getAvailableTaskTypes());
  591. self::assertCount(1, $this->manager->getAvailableTaskTypes(true));
  592. self::assertCount(0, $this->manager->getAvailableTaskTypeIds());
  593. self::assertCount(1, $this->manager->getAvailableTaskTypeIds(true));
  594. self::assertTrue($this->manager->hasProviders());
  595. self::expectException(PreConditionNotMetException::class);
  596. $this->manager->scheduleTask(new Task(TextToText::ID, ['input' => 'Hello'], 'test', null));
  597. }
  598. public function testProviderShouldBeRegisteredAndTaskFailValidation(): void {
  599. $this->appConfig->setValueString('core', 'ai.taskprocessing_type_preferences', '', lazy: true);
  600. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  601. new ServiceRegistration('test', BrokenSyncProvider::class)
  602. ]);
  603. self::assertCount(1, $this->manager->getAvailableTaskTypes());
  604. self::assertCount(1, $this->manager->getAvailableTaskTypeIds());
  605. self::assertTrue($this->manager->hasProviders());
  606. $task = new Task(TextToText::ID, ['wrongInputKey' => 'Hello'], 'test', null);
  607. self::assertNull($task->getId());
  608. self::expectException(ValidationException::class);
  609. $this->manager->scheduleTask($task);
  610. }
  611. public function testProviderShouldBeRegisteredAndTaskWithFilesFailValidation(): void {
  612. $this->registrationContext->expects($this->any())->method('getTaskProcessingTaskTypes')->willReturn([
  613. new ServiceRegistration('test', AudioToImage::class)
  614. ]);
  615. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  616. new ServiceRegistration('test', AsyncProvider::class)
  617. ]);
  618. $user = $this->createMock(IUser::class);
  619. $user->expects($this->any())->method('getUID')->willReturn(null);
  620. $mount = $this->createMock(ICachedMountInfo::class);
  621. $mount->expects($this->any())->method('getUser')->willReturn($user);
  622. $this->userMountCache->expects($this->any())->method('getMountsForFileId')->willReturn([$mount]);
  623. self::assertCount(1, $this->manager->getAvailableTaskTypes());
  624. self::assertCount(1, $this->manager->getAvailableTaskTypeIds());
  625. self::assertTrue($this->manager->hasProviders());
  626. $audioId = $this->getFile('audioInput', 'Hello')->getId();
  627. $task = new Task(AudioToImage::ID, ['audio' => $audioId], 'test', null);
  628. self::assertNull($task->getId());
  629. self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus());
  630. self::expectException(UnauthorizedException::class);
  631. $this->manager->scheduleTask($task);
  632. }
  633. public function testProviderShouldBeRegisteredAndFail(): void {
  634. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  635. new ServiceRegistration('test', FailingSyncProvider::class)
  636. ]);
  637. self::assertCount(1, $this->manager->getAvailableTaskTypes());
  638. self::assertCount(1, $this->manager->getAvailableTaskTypeIds());
  639. self::assertTrue($this->manager->hasProviders());
  640. $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null);
  641. self::assertNull($task->getId());
  642. self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus());
  643. $this->manager->scheduleTask($task);
  644. self::assertNotNull($task->getId());
  645. self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus());
  646. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class));
  647. $backgroundJob = new SynchronousBackgroundJob(
  648. Server::get(ITimeFactory::class),
  649. $this->manager,
  650. $this->jobList,
  651. Server::get(LoggerInterface::class),
  652. );
  653. $backgroundJob->start($this->jobList);
  654. $task = $this->manager->getTask($task->getId());
  655. self::assertEquals(Task::STATUS_FAILED, $task->getStatus());
  656. self::assertEquals(FailingSyncProvider::ERROR_MESSAGE, $task->getErrorMessage());
  657. }
  658. public function testProviderShouldBeRegisteredAndFailOutputValidation(): void {
  659. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  660. new ServiceRegistration('test', BrokenSyncProvider::class)
  661. ]);
  662. self::assertCount(1, $this->manager->getAvailableTaskTypes());
  663. self::assertCount(1, $this->manager->getAvailableTaskTypeIds());
  664. self::assertTrue($this->manager->hasProviders());
  665. $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null);
  666. self::assertNull($task->getId());
  667. self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus());
  668. $this->manager->scheduleTask($task);
  669. self::assertNotNull($task->getId());
  670. self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus());
  671. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class));
  672. $backgroundJob = new SynchronousBackgroundJob(
  673. Server::get(ITimeFactory::class),
  674. $this->manager,
  675. $this->jobList,
  676. Server::get(LoggerInterface::class),
  677. );
  678. $backgroundJob->start($this->jobList);
  679. $task = $this->manager->getTask($task->getId());
  680. self::assertEquals(Task::STATUS_FAILED, $task->getStatus());
  681. self::assertEquals('The task was processed successfully but the provider\'s output doesn\'t pass validation against the task type\'s outputShape spec and/or the provider\'s own optionalOutputShape spec', $task->getErrorMessage());
  682. }
  683. public function testProviderShouldBeRegisteredAndRun(): void {
  684. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  685. new ServiceRegistration('test', SuccessfulSyncProvider::class)
  686. ]);
  687. self::assertCount(1, $this->manager->getAvailableTaskTypes());
  688. self::assertCount(1, $this->manager->getAvailableTaskTypeIds());
  689. $taskTypeStruct = $this->manager->getAvailableTaskTypes()[array_keys($this->manager->getAvailableTaskTypes())[0]];
  690. self::assertTrue(isset($taskTypeStruct['inputShape']['input']));
  691. self::assertEquals(EShapeType::Text, $taskTypeStruct['inputShape']['input']->getShapeType());
  692. self::assertTrue(isset($taskTypeStruct['optionalInputShape']['optionalKey']));
  693. self::assertEquals(EShapeType::Text, $taskTypeStruct['optionalInputShape']['optionalKey']->getShapeType());
  694. self::assertTrue(isset($taskTypeStruct['outputShape']['output']));
  695. self::assertEquals(EShapeType::Text, $taskTypeStruct['outputShape']['output']->getShapeType());
  696. self::assertTrue(isset($taskTypeStruct['optionalOutputShape']['optionalKey']));
  697. self::assertEquals(EShapeType::Text, $taskTypeStruct['optionalOutputShape']['optionalKey']->getShapeType());
  698. self::assertTrue($this->manager->hasProviders());
  699. $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null);
  700. self::assertNull($task->getId());
  701. self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus());
  702. $this->manager->scheduleTask($task);
  703. self::assertNotNull($task->getId());
  704. self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus());
  705. // Task object retrieved from db is up-to-date
  706. $task2 = $this->manager->getTask($task->getId());
  707. self::assertEquals($task->getId(), $task2->getId());
  708. self::assertEquals(['input' => 'Hello'], $task2->getInput());
  709. self::assertNull($task2->getOutput());
  710. self::assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus());
  711. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
  712. $backgroundJob = new SynchronousBackgroundJob(
  713. Server::get(ITimeFactory::class),
  714. $this->manager,
  715. $this->jobList,
  716. Server::get(LoggerInterface::class),
  717. );
  718. $backgroundJob->start($this->jobList);
  719. $task = $this->manager->getTask($task->getId());
  720. self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus(), 'Status is ' . $task->getStatus() . ' with error message: ' . $task->getErrorMessage());
  721. self::assertEquals(['output' => 'Hello'], $task->getOutput());
  722. self::assertEquals(1, $task->getProgress());
  723. }
  724. public function testTaskTypeExplicitlyEnabled(): void {
  725. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  726. new ServiceRegistration('test', SuccessfulSyncProvider::class)
  727. ]);
  728. $taskProcessingTypeSettings = [
  729. TextToText::ID => true,
  730. ];
  731. $this->appConfig->setValueString('core', 'ai.taskprocessing_type_preferences', json_encode($taskProcessingTypeSettings), lazy: true);
  732. self::assertCount(1, $this->manager->getAvailableTaskTypes());
  733. self::assertCount(1, $this->manager->getAvailableTaskTypeIds());
  734. self::assertTrue($this->manager->hasProviders());
  735. $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null);
  736. self::assertNull($task->getId());
  737. self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus());
  738. $this->manager->scheduleTask($task);
  739. self::assertNotNull($task->getId());
  740. self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus());
  741. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
  742. $backgroundJob = new SynchronousBackgroundJob(
  743. Server::get(ITimeFactory::class),
  744. $this->manager,
  745. $this->jobList,
  746. Server::get(LoggerInterface::class),
  747. );
  748. $backgroundJob->start($this->jobList);
  749. $task = $this->manager->getTask($task->getId());
  750. self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus(), 'Status is ' . $task->getStatus() . ' with error message: ' . $task->getErrorMessage());
  751. self::assertEquals(['output' => 'Hello'], $task->getOutput());
  752. self::assertEquals(1, $task->getProgress());
  753. }
  754. public function testAsyncProviderWithFilesShouldBeRegisteredAndRunReturningRawFileData(): void {
  755. $this->registrationContext->expects($this->any())->method('getTaskProcessingTaskTypes')->willReturn([
  756. new ServiceRegistration('test', AudioToImage::class)
  757. ]);
  758. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  759. new ServiceRegistration('test', AsyncProvider::class)
  760. ]);
  761. $user = $this->createMock(IUser::class);
  762. $user->expects($this->any())->method('getUID')->willReturn('testuser');
  763. $mount = $this->createMock(ICachedMountInfo::class);
  764. $mount->expects($this->any())->method('getUser')->willReturn($user);
  765. $this->userMountCache->expects($this->any())->method('getMountsForFileId')->willReturn([$mount]);
  766. self::assertCount(1, $this->manager->getAvailableTaskTypes());
  767. self::assertCount(1, $this->manager->getAvailableTaskTypeIds());
  768. self::assertTrue($this->manager->hasProviders());
  769. $audioId = $this->getFile('audioInput', 'Hello')->getId();
  770. $task = new Task(AudioToImage::ID, ['audio' => $audioId], 'test', 'testuser');
  771. self::assertNull($task->getId());
  772. self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus());
  773. $this->manager->scheduleTask($task);
  774. self::assertNotNull($task->getId());
  775. self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus());
  776. // Task object retrieved from db is up-to-date
  777. $task2 = $this->manager->getTask($task->getId());
  778. self::assertEquals($task->getId(), $task2->getId());
  779. self::assertEquals(['audio' => $audioId], $task2->getInput());
  780. self::assertNull($task2->getOutput());
  781. self::assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus());
  782. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
  783. $this->manager->setTaskProgress($task2->getId(), 0.1);
  784. $input = $this->manager->prepareInputData($task2);
  785. self::assertTrue(isset($input['audio']));
  786. self::assertInstanceOf(File::class, $input['audio']);
  787. self::assertEquals($audioId, $input['audio']->getId());
  788. $this->manager->setTaskResult($task2->getId(), null, ['spectrogram' => 'World']);
  789. $task = $this->manager->getTask($task->getId());
  790. self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus());
  791. self::assertEquals(1, $task->getProgress());
  792. self::assertTrue(isset($task->getOutput()['spectrogram']));
  793. $node = $this->rootFolder->getFirstNodeByIdInPath($task->getOutput()['spectrogram'], '/' . $this->rootFolder->getAppDataDirectoryName() . '/');
  794. self::assertNotNull($node);
  795. self::assertInstanceOf(File::class, $node);
  796. self::assertEquals('World', $node->getContent());
  797. }
  798. public function testAsyncProviderWithFilesShouldBeRegisteredAndRunReturningFileIds(): void {
  799. $this->registrationContext->expects($this->any())->method('getTaskProcessingTaskTypes')->willReturn([
  800. new ServiceRegistration('test', AudioToImage::class)
  801. ]);
  802. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  803. new ServiceRegistration('test', AsyncProvider::class)
  804. ]);
  805. $user = $this->createMock(IUser::class);
  806. $user->expects($this->any())->method('getUID')->willReturn('testuser');
  807. $mount = $this->createMock(ICachedMountInfo::class);
  808. $mount->expects($this->any())->method('getUser')->willReturn($user);
  809. $this->userMountCache->expects($this->any())->method('getMountsForFileId')->willReturn([$mount]);
  810. self::assertCount(1, $this->manager->getAvailableTaskTypes());
  811. self::assertCount(1, $this->manager->getAvailableTaskTypeIds());
  812. self::assertTrue($this->manager->hasProviders());
  813. $audioId = $this->getFile('audioInput', 'Hello')->getId();
  814. $task = new Task(AudioToImage::ID, ['audio' => $audioId], 'test', 'testuser');
  815. self::assertNull($task->getId());
  816. self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus());
  817. $this->manager->scheduleTask($task);
  818. self::assertNotNull($task->getId());
  819. self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus());
  820. // Task object retrieved from db is up-to-date
  821. $task2 = $this->manager->getTask($task->getId());
  822. self::assertEquals($task->getId(), $task2->getId());
  823. self::assertEquals(['audio' => $audioId], $task2->getInput());
  824. self::assertNull($task2->getOutput());
  825. self::assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus());
  826. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
  827. $this->manager->setTaskProgress($task2->getId(), 0.1);
  828. $input = $this->manager->prepareInputData($task2);
  829. self::assertTrue(isset($input['audio']));
  830. self::assertInstanceOf(File::class, $input['audio']);
  831. self::assertEquals($audioId, $input['audio']->getId());
  832. $outputFileId = $this->getFile('audioOutput', 'World')->getId();
  833. $this->manager->setTaskResult($task2->getId(), null, ['spectrogram' => $outputFileId], true);
  834. $task = $this->manager->getTask($task->getId());
  835. self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus());
  836. self::assertEquals(1, $task->getProgress());
  837. self::assertTrue(isset($task->getOutput()['spectrogram']));
  838. $node = $this->rootFolder->getFirstNodeById($task->getOutput()['spectrogram']);
  839. self::assertNotNull($node, 'fileId:' . $task->getOutput()['spectrogram']);
  840. self::assertInstanceOf(File::class, $node);
  841. self::assertEquals('World', $node->getContent());
  842. }
  843. public function testNonexistentTask(): void {
  844. $this->expectException(NotFoundException::class);
  845. $this->manager->getTask(2147483646);
  846. }
  847. public function testOldTasksShouldBeCleanedUp(): void {
  848. $currentTime = new \DateTime('now');
  849. $timeFactory = $this->createMock(ITimeFactory::class);
  850. $timeFactory->expects($this->any())->method('getDateTime')->willReturnCallback(fn () => $currentTime);
  851. $timeFactory->expects($this->any())->method('getTime')->willReturnCallback(fn () => $currentTime->getTimestamp());
  852. $this->taskMapper = new TaskMapper(
  853. Server::get(IDBConnection::class),
  854. $timeFactory,
  855. );
  856. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  857. new ServiceRegistration('test', SuccessfulSyncProvider::class)
  858. ]);
  859. self::assertCount(1, $this->manager->getAvailableTaskTypes());
  860. self::assertCount(1, $this->manager->getAvailableTaskTypeIds());
  861. self::assertTrue($this->manager->hasProviders());
  862. $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null);
  863. $this->manager->scheduleTask($task);
  864. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
  865. $backgroundJob = new SynchronousBackgroundJob(
  866. Server::get(ITimeFactory::class),
  867. $this->manager,
  868. $this->jobList,
  869. Server::get(LoggerInterface::class),
  870. );
  871. $backgroundJob->start($this->jobList);
  872. $task = $this->manager->getTask($task->getId());
  873. $currentTime = $currentTime->add(new \DateInterval('P1Y'));
  874. // run background job
  875. $bgJob = new RemoveOldTasksBackgroundJob(
  876. $timeFactory,
  877. $this->manager,
  878. $this->taskMapper,
  879. Server::get(LoggerInterface::class),
  880. Server::get(IAppDataFactory::class),
  881. );
  882. $bgJob->setArgument([]);
  883. $bgJob->start($this->jobList);
  884. $this->expectException(NotFoundException::class);
  885. $this->manager->getTask($task->getId());
  886. }
  887. public function testShouldTransparentlyHandleTextProcessingProviders(): void {
  888. $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
  889. new ServiceRegistration('test', SuccessfulTextProcessingSummaryProvider::class)
  890. ]);
  891. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  892. ]);
  893. $taskTypes = $this->manager->getAvailableTaskTypes();
  894. self::assertCount(1, $taskTypes);
  895. self::assertCount(1, $this->manager->getAvailableTaskTypeIds());
  896. self::assertTrue(isset($taskTypes[TextToTextSummary::ID]));
  897. self::assertTrue($this->manager->hasProviders());
  898. $task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null);
  899. $this->manager->scheduleTask($task);
  900. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
  901. $backgroundJob = new SynchronousBackgroundJob(
  902. Server::get(ITimeFactory::class),
  903. $this->manager,
  904. $this->jobList,
  905. Server::get(LoggerInterface::class),
  906. );
  907. $backgroundJob->start($this->jobList);
  908. $task = $this->manager->getTask($task->getId());
  909. self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus());
  910. self::assertIsArray($task->getOutput());
  911. self::assertTrue(isset($task->getOutput()['output']));
  912. self::assertEquals('Hello Summarize', $task->getOutput()['output']);
  913. self::assertTrue($this->providers[SuccessfulTextProcessingSummaryProvider::class]->ran);
  914. }
  915. public function testShouldTransparentlyHandleFailingTextProcessingProviders(): void {
  916. $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
  917. new ServiceRegistration('test', FailingTextProcessingSummaryProvider::class)
  918. ]);
  919. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  920. ]);
  921. $taskTypes = $this->manager->getAvailableTaskTypes();
  922. self::assertCount(1, $taskTypes);
  923. self::assertCount(1, $this->manager->getAvailableTaskTypeIds());
  924. self::assertTrue(isset($taskTypes[TextToTextSummary::ID]));
  925. self::assertTrue($this->manager->hasProviders());
  926. $task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null);
  927. $this->manager->scheduleTask($task);
  928. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class));
  929. $backgroundJob = new SynchronousBackgroundJob(
  930. Server::get(ITimeFactory::class),
  931. $this->manager,
  932. $this->jobList,
  933. Server::get(LoggerInterface::class),
  934. );
  935. $backgroundJob->start($this->jobList);
  936. $task = $this->manager->getTask($task->getId());
  937. self::assertEquals(Task::STATUS_FAILED, $task->getStatus());
  938. self::assertTrue($task->getOutput() === null);
  939. self::assertEquals('ERROR', $task->getErrorMessage());
  940. self::assertTrue($this->providers[FailingTextProcessingSummaryProvider::class]->ran);
  941. }
  942. public function testShouldTransparentlyHandleText2ImageProviders(): void {
  943. $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([
  944. new ServiceRegistration('test', SuccessfulTextToImageProvider::class)
  945. ]);
  946. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  947. ]);
  948. $taskTypes = $this->manager->getAvailableTaskTypes();
  949. self::assertCount(1, $taskTypes);
  950. self::assertCount(1, $this->manager->getAvailableTaskTypeIds());
  951. self::assertTrue(isset($taskTypes[TextToImage::ID]));
  952. self::assertTrue($this->manager->hasProviders());
  953. $task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null);
  954. $this->manager->scheduleTask($task);
  955. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
  956. $backgroundJob = new SynchronousBackgroundJob(
  957. Server::get(ITimeFactory::class),
  958. $this->manager,
  959. $this->jobList,
  960. Server::get(LoggerInterface::class),
  961. );
  962. $backgroundJob->start($this->jobList);
  963. $task = $this->manager->getTask($task->getId());
  964. self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus());
  965. self::assertIsArray($task->getOutput());
  966. self::assertTrue(isset($task->getOutput()['images']));
  967. self::assertIsArray($task->getOutput()['images']);
  968. self::assertCount(3, $task->getOutput()['images']);
  969. self::assertTrue($this->providers[SuccessfulTextToImageProvider::class]->ran);
  970. $node = $this->rootFolder->getFirstNodeByIdInPath($task->getOutput()['images'][0], '/' . $this->rootFolder->getAppDataDirectoryName() . '/');
  971. self::assertNotNull($node);
  972. self::assertInstanceOf(File::class, $node);
  973. self::assertEquals('test', $node->getContent());
  974. }
  975. public function testShouldTransparentlyHandleFailingText2ImageProviders(): void {
  976. $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([
  977. new ServiceRegistration('test', FailingTextToImageProvider::class)
  978. ]);
  979. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  980. ]);
  981. $taskTypes = $this->manager->getAvailableTaskTypes();
  982. self::assertCount(1, $taskTypes);
  983. self::assertCount(1, $this->manager->getAvailableTaskTypeIds());
  984. self::assertTrue(isset($taskTypes[TextToImage::ID]));
  985. self::assertTrue($this->manager->hasProviders());
  986. $task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null);
  987. $this->manager->scheduleTask($task);
  988. $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class));
  989. $backgroundJob = new SynchronousBackgroundJob(
  990. Server::get(ITimeFactory::class),
  991. $this->manager,
  992. $this->jobList,
  993. Server::get(LoggerInterface::class),
  994. );
  995. $backgroundJob->start($this->jobList);
  996. $task = $this->manager->getTask($task->getId());
  997. self::assertEquals(Task::STATUS_FAILED, $task->getStatus());
  998. self::assertTrue($task->getOutput() === null);
  999. self::assertEquals('ERROR', $task->getErrorMessage());
  1000. self::assertTrue($this->providers[FailingTextToImageProvider::class]->ran);
  1001. }
  1002. public function testMergeProvidersLocalAndEvent() {
  1003. // Arrange: Local provider registered, DIFFERENT external provider via event
  1004. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  1005. new ServiceRegistration('test', SuccessfulSyncProvider::class)
  1006. ]);
  1007. $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]);
  1008. $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([]);
  1009. $this->registrationContext->expects($this->any())->method('getSpeechToTextProviders')->willReturn([]);
  1010. $externalProvider = new ExternalProvider(); // ID = 'event:external:provider'
  1011. $this->configureEventDispatcherMock(providersToAdd: [$externalProvider]);
  1012. $this->manager = $this->createManagerInstance();
  1013. // Act
  1014. $providers = $this->manager->getProviders();
  1015. // Assert: Both providers should be present
  1016. self::assertArrayHasKey(SuccessfulSyncProvider::ID, $providers);
  1017. self::assertInstanceOf(SuccessfulSyncProvider::class, $providers[SuccessfulSyncProvider::ID]);
  1018. self::assertArrayHasKey(ExternalProvider::ID, $providers);
  1019. self::assertInstanceOf(ExternalProvider::class, $providers[ExternalProvider::ID]);
  1020. self::assertCount(2, $providers);
  1021. }
  1022. public function testGetProvidersIncludesExternalViaEvent() {
  1023. // Arrange: No local providers, one external provider via event
  1024. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([]);
  1025. $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]);
  1026. $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([]);
  1027. $this->registrationContext->expects($this->any())->method('getSpeechToTextProviders')->willReturn([]);
  1028. $externalProvider = new ExternalProvider();
  1029. $this->configureEventDispatcherMock(providersToAdd: [$externalProvider]);
  1030. $this->manager = $this->createManagerInstance(); // Create manager with configured mocks
  1031. // Act
  1032. $providers = $this->manager->getProviders(); // Returns ID-indexed array
  1033. // Assert
  1034. self::assertArrayHasKey(ExternalProvider::ID, $providers);
  1035. self::assertInstanceOf(ExternalProvider::class, $providers[ExternalProvider::ID]);
  1036. self::assertCount(1, $providers);
  1037. self::assertTrue($this->manager->hasProviders());
  1038. }
  1039. public function testGetAvailableTaskTypesIncludesExternalViaEvent() {
  1040. // Arrange: No local types/providers, one external type and provider via event
  1041. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([]);
  1042. $this->registrationContext->expects($this->any())->method('getTaskProcessingTaskTypes')->willReturn([]);
  1043. $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]);
  1044. $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([]);
  1045. $this->registrationContext->expects($this->any())->method('getSpeechToTextProviders')->willReturn([]);
  1046. $externalProvider = new ExternalProvider(); // Provides ExternalTaskType
  1047. $externalTaskType = new ExternalTaskType();
  1048. $this->configureEventDispatcherMock(
  1049. providersToAdd: [$externalProvider],
  1050. taskTypesToAdd: [$externalTaskType]
  1051. );
  1052. $this->manager = $this->createManagerInstance();
  1053. // Act
  1054. $availableTypes = $this->manager->getAvailableTaskTypes();
  1055. // Assert
  1056. self::assertArrayHasKey(ExternalTaskType::ID, $availableTypes);
  1057. self::assertContains(ExternalTaskType::ID, $this->manager->getAvailableTaskTypeIds());
  1058. self::assertEquals(ExternalTaskType::ID, $externalProvider->getTaskTypeId(), 'Test Sanity: Provider must handle the Task Type');
  1059. self::assertEquals('External Task Type via Event', $availableTypes[ExternalTaskType::ID]['name']);
  1060. // Check if shapes match the external type/provider
  1061. self::assertArrayHasKey('external_input', $availableTypes[ExternalTaskType::ID]['inputShape']);
  1062. self::assertArrayHasKey('external_output', $availableTypes[ExternalTaskType::ID]['outputShape']);
  1063. self::assertEmpty($availableTypes[ExternalTaskType::ID]['optionalInputShape']); // From ExternalProvider
  1064. }
  1065. public function testLocalProviderWinsConflictWithEvent() {
  1066. // Arrange: Local provider registered, conflicting external provider via event
  1067. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  1068. new ServiceRegistration('test', SuccessfulSyncProvider::class)
  1069. ]);
  1070. $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]);
  1071. $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([]);
  1072. $this->registrationContext->expects($this->any())->method('getSpeechToTextProviders')->willReturn([]);
  1073. $conflictingExternalProvider = new ConflictingExternalProvider(); // ID = 'test:sync:success'
  1074. $this->configureEventDispatcherMock(providersToAdd: [$conflictingExternalProvider]);
  1075. $this->manager = $this->createManagerInstance();
  1076. // Act
  1077. $providers = $this->manager->getProviders();
  1078. // Assert: Only the local provider should be present for the conflicting ID
  1079. self::assertArrayHasKey(SuccessfulSyncProvider::ID, $providers);
  1080. self::assertInstanceOf(SuccessfulSyncProvider::class, $providers[SuccessfulSyncProvider::ID]);
  1081. self::assertCount(1, $providers); // Ensure no extra provider was added
  1082. }
  1083. public function testTriggerableProviderWithNoOtherRunningTasks() {
  1084. // Arrange: Local provider registered, conflicting external provider via event
  1085. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([]);
  1086. $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]);
  1087. $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([]);
  1088. $this->registrationContext->expects($this->any())->method('getSpeechToTextProviders')->willReturn([]);
  1089. $externalProvider = $this->createPartialMock(ExternalTriggerableProvider::class, ['trigger']);
  1090. $externalProvider->expects($this->once())->method('trigger');
  1091. $this->configureEventDispatcherMock(providersToAdd: [$externalProvider]);
  1092. $this->manager = $this->createManagerInstance();
  1093. // Act
  1094. $task = new Task($externalProvider->getTaskTypeId(), ['input' => ''], 'tests', 'foobar');
  1095. $this->manager->scheduleTask($task);
  1096. }
  1097. public function testTriggerableProviderWithOtherRunningTasks() {
  1098. // Arrange: Local provider registered, conflicting external provider via event
  1099. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([]);
  1100. $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]);
  1101. $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([]);
  1102. $this->registrationContext->expects($this->any())->method('getSpeechToTextProviders')->willReturn([]);
  1103. $externalProvider = $this->createPartialMock(ExternalTriggerableProvider::class, ['trigger']);
  1104. $externalProvider->expects($this->once())->method('trigger');
  1105. $this->configureEventDispatcherMock(providersToAdd: [$externalProvider]);
  1106. $this->manager = $this->createManagerInstance();
  1107. $task = new Task($externalProvider->getTaskTypeId(), ['input' => ''], 'tests', 'foobar');
  1108. $this->manager->scheduleTask($task);
  1109. $this->manager->lockTask($task);
  1110. // Act
  1111. $task = new Task($externalProvider->getTaskTypeId(), ['input' => ''], 'tests', 'foobar');
  1112. $this->manager->scheduleTask($task);
  1113. }
  1114. public function testMergeTaskTypesLocalAndEvent() {
  1115. // Arrange: Local type registered, DIFFERENT external type via event
  1116. $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
  1117. new ServiceRegistration('test', AsyncProvider::class)
  1118. ]);
  1119. $this->registrationContext->expects($this->any())->method('getTaskProcessingTaskTypes')->willReturn([
  1120. new ServiceRegistration('test', AudioToImage::class)
  1121. ]);
  1122. $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]);
  1123. $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([]);
  1124. $this->registrationContext->expects($this->any())->method('getSpeechToTextProviders')->willReturn([]);
  1125. $externalTaskType = new ExternalTaskType(); // ID = 'event:external:tasktype'
  1126. $externalProvider = new ExternalProvider(); // Handles 'event:external:tasktype'
  1127. $this->configureEventDispatcherMock(
  1128. providersToAdd: [$externalProvider],
  1129. taskTypesToAdd: [$externalTaskType]
  1130. );
  1131. $this->manager = $this->createManagerInstance();
  1132. // Act
  1133. $availableTypes = $this->manager->getAvailableTaskTypes();
  1134. $availableTypeIds = $this->manager->getAvailableTaskTypeIds();
  1135. // Assert: Both task types should be available
  1136. self::assertContains(AudioToImage::ID, $availableTypeIds);
  1137. self::assertArrayHasKey(AudioToImage::ID, $availableTypes);
  1138. self::assertEquals(AudioToImage::class, $availableTypes[AudioToImage::ID]['name']);
  1139. self::assertContains(ExternalTaskType::ID, $availableTypeIds);
  1140. self::assertArrayHasKey(ExternalTaskType::ID, $availableTypes);
  1141. self::assertEquals('External Task Type via Event', $availableTypes[ExternalTaskType::ID]['name']);
  1142. self::assertCount(2, $availableTypes);
  1143. }
  1144. private function createManagerInstance(): Manager {
  1145. // Clear potentially cached config values if needed
  1146. $this->appConfig->deleteKey('core', 'ai.taskprocessing_type_preferences');
  1147. // Re-create Text2ImageManager if its state matters or mocks change
  1148. $text2imageManager = new \OC\TextToImage\Manager(
  1149. $this->serverContainer,
  1150. $this->coordinator,
  1151. Server::get(LoggerInterface::class),
  1152. $this->jobList,
  1153. Server::get(\OC\TextToImage\Db\TaskMapper::class),
  1154. $this->config, // Use the shared config mock
  1155. Server::get(IAppDataFactory::class),
  1156. );
  1157. return new Manager(
  1158. $this->appConfig,
  1159. $this->coordinator,
  1160. $this->serverContainer,
  1161. Server::get(LoggerInterface::class),
  1162. $this->taskMapper,
  1163. $this->jobList,
  1164. $this->eventDispatcher, // Use the potentially reconfigured mock
  1165. Server::get(IAppDataFactory::class),
  1166. $this->rootFolder,
  1167. $text2imageManager,
  1168. $this->userMountCache,
  1169. Server::get(IClientService::class),
  1170. Server::get(IAppManager::class),
  1171. Server::get(IUserManager::class),
  1172. Server::get(IUserSession::class),
  1173. Server::get(ICacheFactory::class),
  1174. Server::get(IFactory::class),
  1175. );
  1176. }
  1177. private function configureEventDispatcherMock(
  1178. array $providersToAdd = [],
  1179. array $taskTypesToAdd = [],
  1180. ?int $expectedCalls = null,
  1181. ): void {
  1182. $dispatchExpectation = $expectedCalls === null ? $this->any() : $this->exactly($expectedCalls);
  1183. $this->eventDispatcher->expects($dispatchExpectation)
  1184. ->method('dispatchTyped')
  1185. ->willReturnCallback(function (object $event) use ($providersToAdd, $taskTypesToAdd): void {
  1186. if ($event instanceof GetTaskProcessingProvidersEvent) {
  1187. foreach ($providersToAdd as $providerInstance) {
  1188. $event->addProvider($providerInstance);
  1189. }
  1190. foreach ($taskTypesToAdd as $taskTypeInstance) {
  1191. $event->addTaskType($taskTypeInstance);
  1192. }
  1193. }
  1194. });
  1195. }
  1196. }