mirror of
https://github.com/symfony/ai-bundle.git
synced 2026-03-23 23:12:08 +01:00
feat(aibundle): TraceableStore implemented
This commit is contained in:
@@ -6,6 +6,7 @@ CHANGELOG
|
||||
|
||||
* Move debug service decorating to compiler pass to cover user-defined services
|
||||
* Add `TraceableAgent`
|
||||
* Add `TraceableStore`
|
||||
|
||||
0.5
|
||||
---
|
||||
|
||||
@@ -225,6 +225,7 @@ return static function (ContainerConfigurator $container): void {
|
||||
tagged_iterator('ai.traceable_message_store'),
|
||||
tagged_iterator('ai.traceable_chat'),
|
||||
tagged_iterator('ai.traceable_agent'),
|
||||
tagged_iterator('ai.traceable_store'),
|
||||
])
|
||||
->tag('data_collector')
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ use Symfony\AI\AiBundle\Profiler\TraceableAgent;
|
||||
use Symfony\AI\AiBundle\Profiler\TraceableChat;
|
||||
use Symfony\AI\AiBundle\Profiler\TraceableMessageStore;
|
||||
use Symfony\AI\AiBundle\Profiler\TraceablePlatform;
|
||||
use Symfony\AI\AiBundle\Profiler\TraceableStore;
|
||||
use Symfony\AI\AiBundle\Profiler\TraceableToolbox;
|
||||
use Symfony\Component\Clock\ClockInterface;
|
||||
use Symfony\Component\DependencyInjection\Compiler\CompilerPassInterface;
|
||||
@@ -24,7 +25,7 @@ use Symfony\Component\DependencyInjection\Reference;
|
||||
|
||||
use function Symfony\Component\String\u;
|
||||
|
||||
class DebugCompilerPass implements CompilerPassInterface
|
||||
final class DebugCompilerPass implements CompilerPassInterface
|
||||
{
|
||||
public function process(ContainerBuilder $container): void
|
||||
{
|
||||
@@ -87,5 +88,15 @@ class DebugCompilerPass implements CompilerPassInterface
|
||||
$suffix = u($agent)->afterLast('.')->toString();
|
||||
$container->setDefinition('ai.traceable_agent.'.$suffix, $traceableAgentDefinition);
|
||||
}
|
||||
|
||||
foreach (array_keys($container->findTaggedServiceIds('ai.store')) as $store) {
|
||||
$traceableStoreDefinition = (new Definition(TraceableStore::class))
|
||||
->setDecoratedService($store, priority: -1024)
|
||||
->setArguments([new Reference('.inner')])
|
||||
->addTag('ai.traceable_store')
|
||||
->addTag('kernel.reset', ['method' => 'reset']);
|
||||
$suffix = u($store)->afterLast('.')->toString();
|
||||
$container->setDefinition('ai.traceable_store.'.$suffix, $traceableStoreDefinition);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ use Symfony\Component\HttpKernel\DataCollector\LateDataCollectorInterface;
|
||||
* @phpstan-import-type MessageStoreData from TraceableMessageStore
|
||||
* @phpstan-import-type ChatData from TraceableChat
|
||||
* @phpstan-import-type AgentData from TraceableAgent
|
||||
* @phpstan-import-type StoreData from TraceableStore
|
||||
*
|
||||
* @phpstan-type CollectedPlatformCallData array{
|
||||
* model: string,
|
||||
@@ -62,12 +63,18 @@ final class DataCollector extends AbstractDataCollector implements LateDataColle
|
||||
*/
|
||||
private readonly array $agents;
|
||||
|
||||
/**
|
||||
* @var TraceableStore[]
|
||||
*/
|
||||
private readonly array $stores;
|
||||
|
||||
/**
|
||||
* @param iterable<TraceablePlatform> $platforms
|
||||
* @param iterable<TraceableToolbox> $toolboxes
|
||||
* @param iterable<TraceableMessageStore> $messageStores
|
||||
* @param iterable<TraceableChat> $chats
|
||||
* @param iterable<TraceableAgent> $agents
|
||||
* @param iterable<TraceableStore> $stores
|
||||
*/
|
||||
public function __construct(
|
||||
iterable $platforms,
|
||||
@@ -75,12 +82,14 @@ final class DataCollector extends AbstractDataCollector implements LateDataColle
|
||||
iterable $messageStores,
|
||||
iterable $chats,
|
||||
iterable $agents,
|
||||
iterable $stores,
|
||||
) {
|
||||
$this->platforms = iterator_to_array($platforms);
|
||||
$this->toolboxes = iterator_to_array($toolboxes);
|
||||
$this->messageStores = iterator_to_array($messageStores);
|
||||
$this->chats = iterator_to_array($chats);
|
||||
$this->agents = iterator_to_array($agents);
|
||||
$this->stores = iterator_to_array($stores);
|
||||
}
|
||||
|
||||
public function collect(Request $request, Response $response, ?\Throwable $exception = null): void
|
||||
@@ -97,6 +106,7 @@ final class DataCollector extends AbstractDataCollector implements LateDataColle
|
||||
'messages' => array_merge(...array_map(static fn (TraceableMessageStore $messageStore): array => $messageStore->calls, $this->messageStores)),
|
||||
'chats' => array_merge(...array_map(static fn (TraceableChat $chat): array => $chat->calls, $this->chats)),
|
||||
'agents' => array_merge(...array_map(static fn (TraceableAgent $agent): array => $agent->calls, $this->agents)),
|
||||
'stores' => array_merge(...array_map(static fn (TraceableStore $store): array => $store->calls, $this->stores)),
|
||||
];
|
||||
}
|
||||
|
||||
@@ -158,6 +168,14 @@ final class DataCollector extends AbstractDataCollector implements LateDataColle
|
||||
return $this->data['agents'] ?? [];
|
||||
}
|
||||
|
||||
/**
|
||||
* @return StoreData[]
|
||||
*/
|
||||
public function getStores(): array
|
||||
{
|
||||
return $this->data['stores'] ?? [];
|
||||
}
|
||||
|
||||
/**
|
||||
* @return Tool[]
|
||||
*/
|
||||
|
||||
90
src/Profiler/TraceableStore.php
Normal file
90
src/Profiler/TraceableStore.php
Normal file
@@ -0,0 +1,90 @@
|
||||
<?php
|
||||
|
||||
/*
|
||||
* This file is part of the Symfony package.
|
||||
*
|
||||
* (c) Fabien Potencier <fabien@symfony.com>
|
||||
*
|
||||
* For the full copyright and license information, please view the LICENSE
|
||||
* file that was distributed with this source code.
|
||||
*/
|
||||
|
||||
namespace Symfony\AI\AiBundle\Profiler;
|
||||
|
||||
use Symfony\AI\Store\Document\VectorDocument;
|
||||
use Symfony\AI\Store\Query\QueryInterface;
|
||||
use Symfony\AI\Store\StoreInterface;
|
||||
use Symfony\Component\Clock\ClockInterface;
|
||||
use Symfony\Component\Clock\MonotonicClock;
|
||||
use Symfony\Contracts\Service\ResetInterface;
|
||||
|
||||
/**
|
||||
* @author Guillaume Loulier <personal@guillaumeloulier.fr>
|
||||
*
|
||||
* @phpstan-type StoreData array{
|
||||
* method: string,
|
||||
* documents?: VectorDocument|VectorDocument[],
|
||||
* query?: QueryInterface,
|
||||
* ids?: string[]|string,
|
||||
* options?: array<string, mixed>,
|
||||
* called_at: \DateTimeImmutable,
|
||||
* }
|
||||
*/
|
||||
final class TraceableStore implements StoreInterface, ResetInterface
|
||||
{
|
||||
/**
|
||||
* @var StoreData[]
|
||||
*/
|
||||
public array $calls = [];
|
||||
|
||||
public function __construct(
|
||||
private readonly StoreInterface $store,
|
||||
private readonly ClockInterface $clock = new MonotonicClock(),
|
||||
) {
|
||||
}
|
||||
|
||||
public function add(VectorDocument|array $documents): void
|
||||
{
|
||||
$this->calls[] = [
|
||||
'method' => 'add',
|
||||
'documents' => $documents,
|
||||
'called_at' => $this->clock->now(),
|
||||
];
|
||||
|
||||
$this->store->add($documents);
|
||||
}
|
||||
|
||||
public function query(QueryInterface $query, array $options = []): iterable
|
||||
{
|
||||
$this->calls[] = [
|
||||
'method' => 'query',
|
||||
'query' => $query,
|
||||
'options' => $options,
|
||||
'called_at' => $this->clock->now(),
|
||||
];
|
||||
|
||||
return $this->store->query($query, $options);
|
||||
}
|
||||
|
||||
public function remove(array|string $ids, array $options = []): void
|
||||
{
|
||||
$this->calls[] = [
|
||||
'method' => 'remove',
|
||||
'ids' => $ids,
|
||||
'options' => $options,
|
||||
'called_at' => $this->clock->now(),
|
||||
];
|
||||
|
||||
$this->store->remove($ids, $options);
|
||||
}
|
||||
|
||||
public function supports(string $queryClass): bool
|
||||
{
|
||||
return $this->store->supports($queryClass);
|
||||
}
|
||||
|
||||
public function reset(): void
|
||||
{
|
||||
$this->calls = [];
|
||||
}
|
||||
}
|
||||
@@ -36,6 +36,10 @@
|
||||
<b class="label">Chats sent</b>
|
||||
<span class="sf-toolbar-status">{{ collector.chats|length }}</span>
|
||||
</div>
|
||||
<div class="sf-toolbar-info-piece">
|
||||
<b class="label">Stores used</b>
|
||||
<span class="sf-toolbar-status">{{ collector.stores|length }}</span>
|
||||
</div>
|
||||
</div>
|
||||
{% endset %}
|
||||
|
||||
|
||||
@@ -17,12 +17,13 @@ use Symfony\AI\AiBundle\Profiler\TraceableAgent;
|
||||
use Symfony\AI\AiBundle\Profiler\TraceableChat;
|
||||
use Symfony\AI\AiBundle\Profiler\TraceableMessageStore;
|
||||
use Symfony\AI\AiBundle\Profiler\TraceablePlatform;
|
||||
use Symfony\AI\AiBundle\Profiler\TraceableStore;
|
||||
use Symfony\AI\AiBundle\Profiler\TraceableToolbox;
|
||||
use Symfony\Component\Clock\ClockInterface;
|
||||
use Symfony\Component\DependencyInjection\ContainerBuilder;
|
||||
use Symfony\Component\DependencyInjection\Reference;
|
||||
|
||||
class DebugCompilerPassTest extends TestCase
|
||||
final class DebugCompilerPassTest extends TestCase
|
||||
{
|
||||
public function testProcessAddsTraceableDefinitionsInDebug()
|
||||
{
|
||||
@@ -34,6 +35,7 @@ class DebugCompilerPassTest extends TestCase
|
||||
$container->register('ai.chat.main', \stdClass::class)->addTag('ai.chat');
|
||||
$container->register('ai.toolbox.my_agent', \stdClass::class)->addTag('ai.toolbox');
|
||||
$container->register('ai.agent.my_agent', \stdClass::class)->addTag('ai.agent');
|
||||
$container->register('ai.store.store', \stdClass::class)->addTag('ai.store');
|
||||
|
||||
(new DebugCompilerPass())->process($container);
|
||||
|
||||
@@ -71,6 +73,13 @@ class DebugCompilerPassTest extends TestCase
|
||||
$this->assertEquals([new Reference('.inner')], $traceableAgent->getArguments());
|
||||
$this->assertTrue($traceableAgent->hasTag('ai.traceable_agent'));
|
||||
$this->assertSame([['method' => 'reset']], $traceableAgent->getTag('kernel.reset'));
|
||||
|
||||
$traceableStore = $container->getDefinition('ai.traceable_store.store');
|
||||
$this->assertSame(TraceableStore::class, $traceableStore->getClass());
|
||||
$this->assertSame(['ai.store.store', null, -1024], $traceableStore->getDecoratedService());
|
||||
$this->assertEquals([new Reference('.inner')], $traceableStore->getArguments());
|
||||
$this->assertTrue($traceableStore->hasTag('ai.traceable_store'));
|
||||
$this->assertSame([['method' => 'reset']], $traceableStore->getTag('kernel.reset'));
|
||||
}
|
||||
|
||||
public function testProcessSkipsWhenDebugDisabled()
|
||||
@@ -83,6 +92,7 @@ class DebugCompilerPassTest extends TestCase
|
||||
$container->register('ai.chat.main', \stdClass::class)->addTag('ai.chat');
|
||||
$container->register('ai.toolbox.my_agent', \stdClass::class)->addTag('ai.toolbox');
|
||||
$container->register('ai.agent.my_agent', \stdClass::class)->addTag('ai.agent');
|
||||
$container->register('ai.store.store', \stdClass::class)->addTag('ai.store');
|
||||
|
||||
(new DebugCompilerPass())->process($container);
|
||||
|
||||
@@ -91,5 +101,6 @@ class DebugCompilerPassTest extends TestCase
|
||||
$this->assertFalse($container->hasDefinition('ai.traceable_chat.main'));
|
||||
$this->assertFalse($container->hasDefinition('ai.traceable_toolbox.my_agent'));
|
||||
$this->assertFalse($container->hasDefinition('ai.traceable_agent.my_agent'));
|
||||
$this->assertFalse($container->hasDefinition('ai.traceable_store.store'));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ use Symfony\AI\AiBundle\Profiler\TraceableAgent;
|
||||
use Symfony\AI\AiBundle\Profiler\TraceableChat;
|
||||
use Symfony\AI\AiBundle\Profiler\TraceableMessageStore;
|
||||
use Symfony\AI\AiBundle\Profiler\TraceablePlatform;
|
||||
use Symfony\AI\AiBundle\Profiler\TraceableStore;
|
||||
use Symfony\AI\Chat\Chat;
|
||||
use Symfony\AI\Chat\InMemory\Store as InMemoryStore;
|
||||
use Symfony\AI\Platform\Message\Content\Text;
|
||||
@@ -31,8 +32,12 @@ use Symfony\AI\Platform\Result\DeferredResult;
|
||||
use Symfony\AI\Platform\Result\RawResultInterface;
|
||||
use Symfony\AI\Platform\Result\StreamResult;
|
||||
use Symfony\AI\Platform\Result\TextResult;
|
||||
use Symfony\AI\Platform\Vector\Vector;
|
||||
use Symfony\AI\Store\Document\VectorDocument;
|
||||
use Symfony\AI\Store\InMemory\Store;
|
||||
use Symfony\Component\Clock\MockClock;
|
||||
use Symfony\Component\Clock\MonotonicClock;
|
||||
use Symfony\Component\Uid\Uuid;
|
||||
|
||||
class DataCollectorTest extends TestCase
|
||||
{
|
||||
@@ -48,7 +53,7 @@ class DataCollectorTest extends TestCase
|
||||
$result = $traceablePlatform->invoke('gpt-4o', $messageBag, ['stream' => false]);
|
||||
$this->assertSame('Assistant response', $result->asText());
|
||||
|
||||
$dataCollector = new DataCollector([$traceablePlatform], [], [], [], []);
|
||||
$dataCollector = new DataCollector([$traceablePlatform], [], [], [], [], []);
|
||||
$dataCollector->lateCollect();
|
||||
|
||||
$this->assertCount(1, $dataCollector->getPlatformCalls());
|
||||
@@ -72,7 +77,7 @@ class DataCollectorTest extends TestCase
|
||||
$result = $traceablePlatform->invoke('gpt-4o', $messageBag, ['stream' => true]);
|
||||
$this->assertSame('Assistant response', implode('', iterator_to_array($result->asStream())));
|
||||
|
||||
$dataCollector = new DataCollector([$traceablePlatform], [], [], [], []);
|
||||
$dataCollector = new DataCollector([$traceablePlatform], [], [], [], [], []);
|
||||
$dataCollector->lateCollect();
|
||||
|
||||
$this->assertCount(1, $dataCollector->getPlatformCalls());
|
||||
@@ -96,7 +101,7 @@ class DataCollectorTest extends TestCase
|
||||
// Invoke but do NOT consume the stream
|
||||
$traceablePlatform->invoke('gpt-4o', $messageBag, ['stream' => true]);
|
||||
|
||||
$dataCollector = new DataCollector([$traceablePlatform], [], [], [], []);
|
||||
$dataCollector = new DataCollector([$traceablePlatform], [], [], [], [], []);
|
||||
$dataCollector->lateCollect();
|
||||
|
||||
$this->assertCount(1, $dataCollector->getPlatformCalls());
|
||||
@@ -136,7 +141,7 @@ class DataCollectorTest extends TestCase
|
||||
Message::ofUser('Hello World'),
|
||||
));
|
||||
|
||||
$dataCollector = new DataCollector([], [], [$traceableMessageStore], [], []);
|
||||
$dataCollector = new DataCollector([], [], [$traceableMessageStore], [], [], []);
|
||||
$dataCollector->lateCollect();
|
||||
|
||||
$calls = $dataCollector->getMessages();
|
||||
@@ -159,7 +164,7 @@ class DataCollectorTest extends TestCase
|
||||
|
||||
$traceableChat->submit(Message::ofUser('Hello World'));
|
||||
|
||||
$dataCollector = new DataCollector([], [], [], [$traceableChat], []);
|
||||
$dataCollector = new DataCollector([], [], [], [$traceableChat], [], []);
|
||||
$dataCollector->lateCollect();
|
||||
|
||||
$calls = $dataCollector->getChats();
|
||||
@@ -172,7 +177,7 @@ class DataCollectorTest extends TestCase
|
||||
|
||||
public function testGetNameReturnsShortName()
|
||||
{
|
||||
$dataCollector = new DataCollector([], [], [], [], []);
|
||||
$dataCollector = new DataCollector([], [], [], [], [], []);
|
||||
|
||||
$name = $dataCollector->getName();
|
||||
|
||||
@@ -188,7 +193,7 @@ class DataCollectorTest extends TestCase
|
||||
yield from [];
|
||||
})();
|
||||
|
||||
$dataCollector = new DataCollector([], $generator, [], [], []);
|
||||
$dataCollector = new DataCollector([], $generator, [], [], [], []);
|
||||
$dataCollector->lateCollect();
|
||||
|
||||
$this->assertSame([], $dataCollector->getTools());
|
||||
@@ -209,7 +214,7 @@ class DataCollectorTest extends TestCase
|
||||
|
||||
$traceableAgent->call($messageBag);
|
||||
|
||||
$dataCollector = new DataCollector([], [], [], [], [$traceableAgent]);
|
||||
$dataCollector = new DataCollector([], [], [], [], [$traceableAgent], []);
|
||||
$dataCollector->lateCollect();
|
||||
|
||||
$this->assertCount(1, $dataCollector->getAgents());
|
||||
@@ -220,4 +225,15 @@ class DataCollectorTest extends TestCase
|
||||
'called_at' => $clock->now(),
|
||||
], $dataCollector->getAgents()[0]);
|
||||
}
|
||||
|
||||
public function testCollectsDataForStores()
|
||||
{
|
||||
$traceableStore = new TraceableStore(new Store());
|
||||
$traceableStore->add(new VectorDocument(Uuid::v7()->toRfc4122(), new Vector([0.1, 0.2, 0.3])));
|
||||
|
||||
$dataCollector = new DataCollector([], [], [], [], [], [$traceableStore]);
|
||||
$dataCollector->lateCollect();
|
||||
|
||||
$this->assertCount(1, $dataCollector->getStores());
|
||||
}
|
||||
}
|
||||
|
||||
83
tests/Profiler/TraceableStoreTest.php
Normal file
83
tests/Profiler/TraceableStoreTest.php
Normal file
@@ -0,0 +1,83 @@
|
||||
<?php
|
||||
|
||||
/*
|
||||
* This file is part of the Symfony package.
|
||||
*
|
||||
* (c) Fabien Potencier <fabien@symfony.com>
|
||||
*
|
||||
* For the full copyright and license information, please view the LICENSE
|
||||
* file that was distributed with this source code.
|
||||
*/
|
||||
|
||||
namespace Symfony\AI\AiBundle\Tests\Profiler;
|
||||
|
||||
use PHPUnit\Framework\TestCase;
|
||||
use Symfony\AI\AiBundle\Profiler\TraceableStore;
|
||||
use Symfony\AI\Platform\Vector\Vector;
|
||||
use Symfony\AI\Store\Document\VectorDocument;
|
||||
use Symfony\AI\Store\InMemory\Store;
|
||||
use Symfony\AI\Store\Query\VectorQuery;
|
||||
use Symfony\Component\Clock\MockClock;
|
||||
use Symfony\Component\Uid\Uuid;
|
||||
|
||||
final class TraceableStoreTest extends TestCase
|
||||
{
|
||||
public function testStoreCanRetrieveDataOnNewDocuments()
|
||||
{
|
||||
$clock = new MockClock('2020-01-01 10:00:00');
|
||||
|
||||
$traceableStore = new TraceableStore(new Store(), $clock);
|
||||
|
||||
$document = new VectorDocument(Uuid::v7()->toRfc4122(), new Vector([0.1, 0.2, 0.3]));
|
||||
|
||||
$traceableStore->add($document);
|
||||
|
||||
$this->assertEquals([
|
||||
[
|
||||
'method' => 'add',
|
||||
'documents' => $document,
|
||||
'called_at' => $clock->now(),
|
||||
],
|
||||
], $traceableStore->calls);
|
||||
}
|
||||
|
||||
public function testStoreCanRetrieveDataOnQuery()
|
||||
{
|
||||
$clock = new MockClock('2020-01-01 10:00:00');
|
||||
|
||||
$traceableStore = new TraceableStore(new Store(), $clock);
|
||||
|
||||
$query = new VectorQuery(new Vector([0.1, 0.2, 0.3]));
|
||||
|
||||
$traceableStore->query($query);
|
||||
|
||||
$this->assertEquals([
|
||||
[
|
||||
'method' => 'query',
|
||||
'query' => $query,
|
||||
'options' => [],
|
||||
'called_at' => $clock->now(),
|
||||
],
|
||||
], $traceableStore->calls);
|
||||
}
|
||||
|
||||
public function testStoreCanRetrieveDataOnRemove()
|
||||
{
|
||||
$clock = new MockClock('2020-01-01 10:00:00');
|
||||
|
||||
$traceableStore = new TraceableStore(new Store(), $clock);
|
||||
|
||||
$uuid = Uuid::v7()->toRfc4122();
|
||||
|
||||
$traceableStore->remove([$uuid]);
|
||||
|
||||
$this->assertEquals([
|
||||
[
|
||||
'method' => 'remove',
|
||||
'ids' => [$uuid],
|
||||
'options' => [],
|
||||
'called_at' => $clock->now(),
|
||||
],
|
||||
], $traceableStore->calls);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user