feat(aibundle): TraceableStore implemented

This commit is contained in:
Guillaume Loulier
2026-02-20 11:33:57 +01:00
parent d32182f73f
commit 2e53ac10dd
9 changed files with 245 additions and 10 deletions

View File

@@ -6,6 +6,7 @@ CHANGELOG
* Move debug service decorating to compiler pass to cover user-defined services
* Add `TraceableAgent`
* Add `TraceableStore`
0.5
---

View File

@@ -225,6 +225,7 @@ return static function (ContainerConfigurator $container): void {
tagged_iterator('ai.traceable_message_store'),
tagged_iterator('ai.traceable_chat'),
tagged_iterator('ai.traceable_agent'),
tagged_iterator('ai.traceable_store'),
])
->tag('data_collector')

View File

@@ -15,6 +15,7 @@ use Symfony\AI\AiBundle\Profiler\TraceableAgent;
use Symfony\AI\AiBundle\Profiler\TraceableChat;
use Symfony\AI\AiBundle\Profiler\TraceableMessageStore;
use Symfony\AI\AiBundle\Profiler\TraceablePlatform;
use Symfony\AI\AiBundle\Profiler\TraceableStore;
use Symfony\AI\AiBundle\Profiler\TraceableToolbox;
use Symfony\Component\Clock\ClockInterface;
use Symfony\Component\DependencyInjection\Compiler\CompilerPassInterface;
@@ -24,7 +25,7 @@ use Symfony\Component\DependencyInjection\Reference;
use function Symfony\Component\String\u;
class DebugCompilerPass implements CompilerPassInterface
final class DebugCompilerPass implements CompilerPassInterface
{
public function process(ContainerBuilder $container): void
{
@@ -87,5 +88,15 @@ class DebugCompilerPass implements CompilerPassInterface
$suffix = u($agent)->afterLast('.')->toString();
$container->setDefinition('ai.traceable_agent.'.$suffix, $traceableAgentDefinition);
}
foreach (array_keys($container->findTaggedServiceIds('ai.store')) as $store) {
$traceableStoreDefinition = (new Definition(TraceableStore::class))
->setDecoratedService($store, priority: -1024)
->setArguments([new Reference('.inner')])
->addTag('ai.traceable_store')
->addTag('kernel.reset', ['method' => 'reset']);
$suffix = u($store)->afterLast('.')->toString();
$container->setDefinition('ai.traceable_store.'.$suffix, $traceableStoreDefinition);
}
}
}

View File

@@ -26,6 +26,7 @@ use Symfony\Component\HttpKernel\DataCollector\LateDataCollectorInterface;
* @phpstan-import-type MessageStoreData from TraceableMessageStore
* @phpstan-import-type ChatData from TraceableChat
* @phpstan-import-type AgentData from TraceableAgent
* @phpstan-import-type StoreData from TraceableStore
*
* @phpstan-type CollectedPlatformCallData array{
* model: string,
@@ -62,12 +63,18 @@ final class DataCollector extends AbstractDataCollector implements LateDataColle
*/
private readonly array $agents;
/**
* @var TraceableStore[]
*/
private readonly array $stores;
/**
* @param iterable<TraceablePlatform> $platforms
* @param iterable<TraceableToolbox> $toolboxes
* @param iterable<TraceableMessageStore> $messageStores
* @param iterable<TraceableChat> $chats
* @param iterable<TraceableAgent> $agents
* @param iterable<TraceableStore> $stores
*/
public function __construct(
iterable $platforms,
@@ -75,12 +82,14 @@ final class DataCollector extends AbstractDataCollector implements LateDataColle
iterable $messageStores,
iterable $chats,
iterable $agents,
iterable $stores,
) {
$this->platforms = iterator_to_array($platforms);
$this->toolboxes = iterator_to_array($toolboxes);
$this->messageStores = iterator_to_array($messageStores);
$this->chats = iterator_to_array($chats);
$this->agents = iterator_to_array($agents);
$this->stores = iterator_to_array($stores);
}
public function collect(Request $request, Response $response, ?\Throwable $exception = null): void
@@ -97,6 +106,7 @@ final class DataCollector extends AbstractDataCollector implements LateDataColle
'messages' => array_merge(...array_map(static fn (TraceableMessageStore $messageStore): array => $messageStore->calls, $this->messageStores)),
'chats' => array_merge(...array_map(static fn (TraceableChat $chat): array => $chat->calls, $this->chats)),
'agents' => array_merge(...array_map(static fn (TraceableAgent $agent): array => $agent->calls, $this->agents)),
'stores' => array_merge(...array_map(static fn (TraceableStore $store): array => $store->calls, $this->stores)),
];
}
@@ -158,6 +168,14 @@ final class DataCollector extends AbstractDataCollector implements LateDataColle
return $this->data['agents'] ?? [];
}
/**
* @return StoreData[]
*/
public function getStores(): array
{
return $this->data['stores'] ?? [];
}
/**
* @return Tool[]
*/

View File

@@ -0,0 +1,90 @@
<?php
/*
* This file is part of the Symfony package.
*
* (c) Fabien Potencier <fabien@symfony.com>
*
* For the full copyright and license information, please view the LICENSE
* file that was distributed with this source code.
*/
namespace Symfony\AI\AiBundle\Profiler;
use Symfony\AI\Store\Document\VectorDocument;
use Symfony\AI\Store\Query\QueryInterface;
use Symfony\AI\Store\StoreInterface;
use Symfony\Component\Clock\ClockInterface;
use Symfony\Component\Clock\MonotonicClock;
use Symfony\Contracts\Service\ResetInterface;
/**
* @author Guillaume Loulier <personal@guillaumeloulier.fr>
*
* @phpstan-type StoreData array{
* method: string,
* documents?: VectorDocument|VectorDocument[],
* query?: QueryInterface,
* ids?: string[]|string,
* options?: array<string, mixed>,
* called_at: \DateTimeImmutable,
* }
*/
final class TraceableStore implements StoreInterface, ResetInterface
{
/**
* @var StoreData[]
*/
public array $calls = [];
public function __construct(
private readonly StoreInterface $store,
private readonly ClockInterface $clock = new MonotonicClock(),
) {
}
public function add(VectorDocument|array $documents): void
{
$this->calls[] = [
'method' => 'add',
'documents' => $documents,
'called_at' => $this->clock->now(),
];
$this->store->add($documents);
}
public function query(QueryInterface $query, array $options = []): iterable
{
$this->calls[] = [
'method' => 'query',
'query' => $query,
'options' => $options,
'called_at' => $this->clock->now(),
];
return $this->store->query($query, $options);
}
public function remove(array|string $ids, array $options = []): void
{
$this->calls[] = [
'method' => 'remove',
'ids' => $ids,
'options' => $options,
'called_at' => $this->clock->now(),
];
$this->store->remove($ids, $options);
}
public function supports(string $queryClass): bool
{
return $this->store->supports($queryClass);
}
public function reset(): void
{
$this->calls = [];
}
}

View File

@@ -36,6 +36,10 @@
<b class="label">Chats sent</b>
<span class="sf-toolbar-status">{{ collector.chats|length }}</span>
</div>
<div class="sf-toolbar-info-piece">
<b class="label">Stores used</b>
<span class="sf-toolbar-status">{{ collector.stores|length }}</span>
</div>
</div>
{% endset %}

View File

@@ -17,12 +17,13 @@ use Symfony\AI\AiBundle\Profiler\TraceableAgent;
use Symfony\AI\AiBundle\Profiler\TraceableChat;
use Symfony\AI\AiBundle\Profiler\TraceableMessageStore;
use Symfony\AI\AiBundle\Profiler\TraceablePlatform;
use Symfony\AI\AiBundle\Profiler\TraceableStore;
use Symfony\AI\AiBundle\Profiler\TraceableToolbox;
use Symfony\Component\Clock\ClockInterface;
use Symfony\Component\DependencyInjection\ContainerBuilder;
use Symfony\Component\DependencyInjection\Reference;
class DebugCompilerPassTest extends TestCase
final class DebugCompilerPassTest extends TestCase
{
public function testProcessAddsTraceableDefinitionsInDebug()
{
@@ -34,6 +35,7 @@ class DebugCompilerPassTest extends TestCase
$container->register('ai.chat.main', \stdClass::class)->addTag('ai.chat');
$container->register('ai.toolbox.my_agent', \stdClass::class)->addTag('ai.toolbox');
$container->register('ai.agent.my_agent', \stdClass::class)->addTag('ai.agent');
$container->register('ai.store.store', \stdClass::class)->addTag('ai.store');
(new DebugCompilerPass())->process($container);
@@ -71,6 +73,13 @@ class DebugCompilerPassTest extends TestCase
$this->assertEquals([new Reference('.inner')], $traceableAgent->getArguments());
$this->assertTrue($traceableAgent->hasTag('ai.traceable_agent'));
$this->assertSame([['method' => 'reset']], $traceableAgent->getTag('kernel.reset'));
$traceableStore = $container->getDefinition('ai.traceable_store.store');
$this->assertSame(TraceableStore::class, $traceableStore->getClass());
$this->assertSame(['ai.store.store', null, -1024], $traceableStore->getDecoratedService());
$this->assertEquals([new Reference('.inner')], $traceableStore->getArguments());
$this->assertTrue($traceableStore->hasTag('ai.traceable_store'));
$this->assertSame([['method' => 'reset']], $traceableStore->getTag('kernel.reset'));
}
public function testProcessSkipsWhenDebugDisabled()
@@ -83,6 +92,7 @@ class DebugCompilerPassTest extends TestCase
$container->register('ai.chat.main', \stdClass::class)->addTag('ai.chat');
$container->register('ai.toolbox.my_agent', \stdClass::class)->addTag('ai.toolbox');
$container->register('ai.agent.my_agent', \stdClass::class)->addTag('ai.agent');
$container->register('ai.store.store', \stdClass::class)->addTag('ai.store');
(new DebugCompilerPass())->process($container);
@@ -91,5 +101,6 @@ class DebugCompilerPassTest extends TestCase
$this->assertFalse($container->hasDefinition('ai.traceable_chat.main'));
$this->assertFalse($container->hasDefinition('ai.traceable_toolbox.my_agent'));
$this->assertFalse($container->hasDefinition('ai.traceable_agent.my_agent'));
$this->assertFalse($container->hasDefinition('ai.traceable_store.store'));
}
}

View File

@@ -19,6 +19,7 @@ use Symfony\AI\AiBundle\Profiler\TraceableAgent;
use Symfony\AI\AiBundle\Profiler\TraceableChat;
use Symfony\AI\AiBundle\Profiler\TraceableMessageStore;
use Symfony\AI\AiBundle\Profiler\TraceablePlatform;
use Symfony\AI\AiBundle\Profiler\TraceableStore;
use Symfony\AI\Chat\Chat;
use Symfony\AI\Chat\InMemory\Store as InMemoryStore;
use Symfony\AI\Platform\Message\Content\Text;
@@ -31,8 +32,12 @@ use Symfony\AI\Platform\Result\DeferredResult;
use Symfony\AI\Platform\Result\RawResultInterface;
use Symfony\AI\Platform\Result\StreamResult;
use Symfony\AI\Platform\Result\TextResult;
use Symfony\AI\Platform\Vector\Vector;
use Symfony\AI\Store\Document\VectorDocument;
use Symfony\AI\Store\InMemory\Store;
use Symfony\Component\Clock\MockClock;
use Symfony\Component\Clock\MonotonicClock;
use Symfony\Component\Uid\Uuid;
class DataCollectorTest extends TestCase
{
@@ -48,7 +53,7 @@ class DataCollectorTest extends TestCase
$result = $traceablePlatform->invoke('gpt-4o', $messageBag, ['stream' => false]);
$this->assertSame('Assistant response', $result->asText());
$dataCollector = new DataCollector([$traceablePlatform], [], [], [], []);
$dataCollector = new DataCollector([$traceablePlatform], [], [], [], [], []);
$dataCollector->lateCollect();
$this->assertCount(1, $dataCollector->getPlatformCalls());
@@ -72,7 +77,7 @@ class DataCollectorTest extends TestCase
$result = $traceablePlatform->invoke('gpt-4o', $messageBag, ['stream' => true]);
$this->assertSame('Assistant response', implode('', iterator_to_array($result->asStream())));
$dataCollector = new DataCollector([$traceablePlatform], [], [], [], []);
$dataCollector = new DataCollector([$traceablePlatform], [], [], [], [], []);
$dataCollector->lateCollect();
$this->assertCount(1, $dataCollector->getPlatformCalls());
@@ -96,7 +101,7 @@ class DataCollectorTest extends TestCase
// Invoke but do NOT consume the stream
$traceablePlatform->invoke('gpt-4o', $messageBag, ['stream' => true]);
$dataCollector = new DataCollector([$traceablePlatform], [], [], [], []);
$dataCollector = new DataCollector([$traceablePlatform], [], [], [], [], []);
$dataCollector->lateCollect();
$this->assertCount(1, $dataCollector->getPlatformCalls());
@@ -136,7 +141,7 @@ class DataCollectorTest extends TestCase
Message::ofUser('Hello World'),
));
$dataCollector = new DataCollector([], [], [$traceableMessageStore], [], []);
$dataCollector = new DataCollector([], [], [$traceableMessageStore], [], [], []);
$dataCollector->lateCollect();
$calls = $dataCollector->getMessages();
@@ -159,7 +164,7 @@ class DataCollectorTest extends TestCase
$traceableChat->submit(Message::ofUser('Hello World'));
$dataCollector = new DataCollector([], [], [], [$traceableChat], []);
$dataCollector = new DataCollector([], [], [], [$traceableChat], [], []);
$dataCollector->lateCollect();
$calls = $dataCollector->getChats();
@@ -172,7 +177,7 @@ class DataCollectorTest extends TestCase
public function testGetNameReturnsShortName()
{
$dataCollector = new DataCollector([], [], [], [], []);
$dataCollector = new DataCollector([], [], [], [], [], []);
$name = $dataCollector->getName();
@@ -188,7 +193,7 @@ class DataCollectorTest extends TestCase
yield from [];
})();
$dataCollector = new DataCollector([], $generator, [], [], []);
$dataCollector = new DataCollector([], $generator, [], [], [], []);
$dataCollector->lateCollect();
$this->assertSame([], $dataCollector->getTools());
@@ -209,7 +214,7 @@ class DataCollectorTest extends TestCase
$traceableAgent->call($messageBag);
$dataCollector = new DataCollector([], [], [], [], [$traceableAgent]);
$dataCollector = new DataCollector([], [], [], [], [$traceableAgent], []);
$dataCollector->lateCollect();
$this->assertCount(1, $dataCollector->getAgents());
@@ -220,4 +225,15 @@ class DataCollectorTest extends TestCase
'called_at' => $clock->now(),
], $dataCollector->getAgents()[0]);
}
public function testCollectsDataForStores()
{
$traceableStore = new TraceableStore(new Store());
$traceableStore->add(new VectorDocument(Uuid::v7()->toRfc4122(), new Vector([0.1, 0.2, 0.3])));
$dataCollector = new DataCollector([], [], [], [], [], [$traceableStore]);
$dataCollector->lateCollect();
$this->assertCount(1, $dataCollector->getStores());
}
}

View File

@@ -0,0 +1,83 @@
<?php
/*
* This file is part of the Symfony package.
*
* (c) Fabien Potencier <fabien@symfony.com>
*
* For the full copyright and license information, please view the LICENSE
* file that was distributed with this source code.
*/
namespace Symfony\AI\AiBundle\Tests\Profiler;
use PHPUnit\Framework\TestCase;
use Symfony\AI\AiBundle\Profiler\TraceableStore;
use Symfony\AI\Platform\Vector\Vector;
use Symfony\AI\Store\Document\VectorDocument;
use Symfony\AI\Store\InMemory\Store;
use Symfony\AI\Store\Query\VectorQuery;
use Symfony\Component\Clock\MockClock;
use Symfony\Component\Uid\Uuid;
final class TraceableStoreTest extends TestCase
{
public function testStoreCanRetrieveDataOnNewDocuments()
{
$clock = new MockClock('2020-01-01 10:00:00');
$traceableStore = new TraceableStore(new Store(), $clock);
$document = new VectorDocument(Uuid::v7()->toRfc4122(), new Vector([0.1, 0.2, 0.3]));
$traceableStore->add($document);
$this->assertEquals([
[
'method' => 'add',
'documents' => $document,
'called_at' => $clock->now(),
],
], $traceableStore->calls);
}
public function testStoreCanRetrieveDataOnQuery()
{
$clock = new MockClock('2020-01-01 10:00:00');
$traceableStore = new TraceableStore(new Store(), $clock);
$query = new VectorQuery(new Vector([0.1, 0.2, 0.3]));
$traceableStore->query($query);
$this->assertEquals([
[
'method' => 'query',
'query' => $query,
'options' => [],
'called_at' => $clock->now(),
],
], $traceableStore->calls);
}
public function testStoreCanRetrieveDataOnRemove()
{
$clock = new MockClock('2020-01-01 10:00:00');
$traceableStore = new TraceableStore(new Store(), $clock);
$uuid = Uuid::v7()->toRfc4122();
$traceableStore->remove([$uuid]);
$this->assertEquals([
[
'method' => 'remove',
'ids' => [$uuid],
'options' => [],
'called_at' => $clock->now(),
],
], $traceableStore->calls);
}
}