feat(store): allow to sort documents as batches

This commit is contained in:
Guillaume Loulier
2026-03-16 17:44:27 +01:00
parent b662357e31
commit 293df7d90b
4 changed files with 276 additions and 25 deletions

View File

@@ -52,6 +52,34 @@ Available strategies:
* ``ANGULAR_DISTANCE``
* ``CHEBYSHEV_DISTANCE``
Batch Processing
----------------
For large datasets, the distance calculator can process documents in batches
instead of scoring the entire dataset at once. After each batch, only the best
candidates are kept, reducing peak memory from O(N) to O(maxItems + batchSize)::
use Symfony\AI\Store\Distance\DistanceCalculator;
use Symfony\AI\Store\Distance\DistanceStrategy;
use Symfony\AI\Store\InMemory\Store;
$calculator = new DistanceCalculator(
strategy: DistanceStrategy::COSINE_DISTANCE,
batchSize: 500, # Default to 100
);
$store = new Store($calculator);
// Batch processing is activated when both batchSize and maxItems are set
$results = $store->query($vector, [
'maxItems' => 10,
]);
.. note::
Batch processing requires ``maxItems`` to be set in the query options.
Without it, the calculator falls back to the standard full-sort behavior
since all results are needed and no pruning can occur.
Metadata Filtering
------------------

View File

@@ -19,8 +19,12 @@ use Symfony\AI\Store\Document\VectorDocument;
*/
final class DistanceCalculator
{
/**
* @param positive-int $batchSize when set alongside $maxItems in {@see self::calculate()}, documents are scored in chunks of this size
*/
public function __construct(
private readonly DistanceStrategy $strategy = DistanceStrategy::COSINE_DISTANCE,
private readonly int $batchSize = 100,
) {
}
@@ -32,13 +36,21 @@ final class DistanceCalculator
*/
public function calculate(array $documents, Vector $vector, ?int $maxItems = null): array
{
$strategy = match ($this->strategy) {
DistanceStrategy::COSINE_DISTANCE => $this->cosineDistance(...),
DistanceStrategy::ANGULAR_DISTANCE => $this->angularDistance(...),
DistanceStrategy::EUCLIDEAN_DISTANCE => $this->euclideanDistance(...),
DistanceStrategy::MANHATTAN_DISTANCE => $this->manhattanDistance(...),
DistanceStrategy::CHEBYSHEV_DISTANCE => $this->chebyshevDistance(...),
};
if (null !== $maxItems && $this->batchSize <= \count($documents)) {
return $this->calculateBatched($documents, $vector, $maxItems);
}
return $this->calculateAgainstAll($documents, $vector, $maxItems);
}
/**
* @param VectorDocument[] $documents
*
* @return VectorDocument[]
*/
private function calculateAgainstAll(array $documents, Vector $vector, ?int $maxItems): array
{
$strategy = $this->resolveStrategy();
$currentEmbeddings = array_map(
static fn (VectorDocument $vectorDocument): array => [
@@ -63,6 +75,65 @@ final class DistanceCalculator
);
}
/**
* Processes documents in chunks of {@see self::$batchSize}, keeping only the top $maxItems candidates after each chunk.
*
* @param VectorDocument[] $documents
* @param positive-int $maxItems
*
* @return VectorDocument[]
*/
private function calculateBatched(array $documents, Vector $vector, int $maxItems): array
{
$strategy = $this->resolveStrategy();
/** @var array<int, array{distance: float, document: VectorDocument}> $candidates */
$candidates = [];
foreach (array_chunk($documents, $this->batchSize) as $batch) {
$batchResults = array_map(
static fn (VectorDocument $vectorDocument): array => [
'distance' => $strategy($vectorDocument, $vector),
'document' => $vectorDocument,
],
$batch,
);
$candidates = [
...$candidates,
...$batchResults,
];
usort(
$candidates,
static fn (array $a, array $b): int => $a['distance'] <=> $b['distance'],
);
if (\count($candidates) > $maxItems) {
$candidates = \array_slice($candidates, 0, $maxItems);
}
}
return array_map(
static fn (array $embedding): VectorDocument => $embedding['document']->withScore($embedding['distance']),
$candidates,
);
}
/**
* @return \Closure(VectorDocument, Vector): float
*/
private function resolveStrategy(): \Closure
{
return match ($this->strategy) {
DistanceStrategy::COSINE_DISTANCE => $this->cosineDistance(...),
DistanceStrategy::ANGULAR_DISTANCE => $this->angularDistance(...),
DistanceStrategy::EUCLIDEAN_DISTANCE => $this->euclideanDistance(...),
DistanceStrategy::MANHATTAN_DISTANCE => $this->manhattanDistance(...),
DistanceStrategy::CHEBYSHEV_DISTANCE => $this->chebyshevDistance(...),
};
}
private function cosineDistance(VectorDocument $embedding, Vector $against): float
{
return 1 - $this->cosineSimilarity($embedding, $against);
@@ -125,9 +196,9 @@ final class DistanceCalculator
);
return array_reduce(
array: $embeddingsAsPower,
callback: static fn (float $value, float $current): float => max($value, $current),
initial: 0.0,
$embeddingsAsPower,
static fn (float $value, float $current): float => max($value, $current),
0.0,
);
}
}

View File

@@ -24,9 +24,9 @@ use Symfony\Component\Uid\Uuid;
final class DistanceCalculatorTest extends TestCase
{
/**
* @param array<list<float>> $documentVectors
* @param list<float> $queryVector
* @param list<int> $expectedOrder
* @param array<float[]> $documentVectors
* @param float[] $queryVector
* @param int[] $expectedOrder
*/
#[TestDox('Calculates distances correctly using $strategy strategy')]
#[DataProvider('provideDistanceStrategyTestCases')]
@@ -59,7 +59,7 @@ final class DistanceCalculatorTest extends TestCase
}
/**
* @return \Generator<string, array{DistanceStrategy, array<list<float>>, list<float>, list<int>}>
* @return \Generator<string, array{DistanceStrategy, array<float[]>, float[], int[]}>
*/
public static function provideDistanceStrategyTestCases(): \Generator
{
@@ -218,15 +218,13 @@ final class DistanceCalculatorTest extends TestCase
{
$calculator = new DistanceCalculator(DistanceStrategy::EUCLIDEAN_DISTANCE);
// Create high-dimensional vectors (100 dimensions)
$dimensions = 100;
$vector1 = array_fill(0, $dimensions, 0.1);
$vector2 = array_fill(0, $dimensions, 0.2);
$vector1 = array_fill(0, 100, 0.1);
$vector2 = array_fill(0, 100, 0.2);
$doc1 = new VectorDocument(Uuid::v4(), new Vector($vector1));
$doc2 = new VectorDocument(Uuid::v4(), new Vector($vector2));
$queryVector = new Vector(array_fill(0, $dimensions, 0.15));
$queryVector = new Vector(array_fill(0, 100, 0.15));
$result = $calculator->calculate([$doc1, $doc2], $queryVector);
@@ -290,10 +288,8 @@ final class DistanceCalculatorTest extends TestCase
#[TestDox('Uses cosine distance as default strategy')]
public function testDefaultStrategyIsCosineDistance()
{
// Test that default constructor uses cosine distance
$calculator = new DistanceCalculator();
// Create vectors where cosine distance ordering differs from Euclidean
$doc1 = new VectorDocument(Uuid::v4(), new Vector([1.0, 0.0, 0.0]));
$doc2 = new VectorDocument(Uuid::v4(), new Vector([100.0, 0.0, 0.0])); // Same direction but different magnitude
@@ -305,4 +301,155 @@ final class DistanceCalculatorTest extends TestCase
// The order might vary but both are equally similar in terms of direction
$this->assertCount(2, $result);
}
#[TestDox('Batched calculation returns same top results as full calculation')]
public function testBatchedCalculationReturnsSameResultsAsFull()
{
$documents = [
new VectorDocument(Uuid::v4(), new Vector([0.0, 0.0]), new Metadata(['id' => 'a'])),
new VectorDocument(Uuid::v4(), new Vector([1.0, 0.0]), new Metadata(['id' => 'b'])),
new VectorDocument(Uuid::v4(), new Vector([0.0, 1.0]), new Metadata(['id' => 'c'])),
new VectorDocument(Uuid::v4(), new Vector([1.0, 1.0]), new Metadata(['id' => 'd'])),
new VectorDocument(Uuid::v4(), new Vector([0.5, 0.5]), new Metadata(['id' => 'e'])),
];
$queryVector = new Vector([0.0, 0.0]);
$fullCalculator = new DistanceCalculator(DistanceStrategy::EUCLIDEAN_DISTANCE);
$batchedCalculator = new DistanceCalculator(DistanceStrategy::EUCLIDEAN_DISTANCE, batchSize: 2);
$fullResult = $fullCalculator->calculate($documents, $queryVector, 3);
$batchedResult = $batchedCalculator->calculate($documents, $queryVector, 3);
$fullIds = array_map(static fn (VectorDocument $doc): string => $doc->getMetadata()['id'], $fullResult);
$batchedIds = array_map(static fn (VectorDocument $doc): string => $doc->getMetadata()['id'], $batchedResult);
$this->assertSame($fullIds, $batchedIds);
$this->assertCount(3, $batchedResult);
}
#[TestDox('Batched calculation prunes candidates after each batch')]
public function testBatchedCalculationPrunesCandidates()
{
// 10 documents, batch size 3, maxItems 2
// After each batch of 3, only the top 2 candidates are kept
$documents = [];
for ($i = 0; $i < 10; ++$i) {
$documents[] = new VectorDocument(
Uuid::v4(),
new Vector([(float) $i, 0.0]),
new Metadata(['id' => (string) $i]),
);
}
$calculator = new DistanceCalculator(DistanceStrategy::EUCLIDEAN_DISTANCE, batchSize: 3);
$result = $calculator->calculate($documents, new Vector([0.0, 0.0]), 2);
$this->assertCount(2, $result);
$ids = array_map(static fn (VectorDocument $doc): string => $doc->getMetadata()['id'], $result);
$this->assertSame(['0', '1'], $ids);
}
#[TestDox('Batched calculation falls back to full when maxItems is null')]
public function testBatchedCalculationFallsBackWithoutMaxItems()
{
$doc1 = new VectorDocument(Uuid::v4(), new Vector([1.0, 0.0]), new Metadata(['id' => 'a']));
$doc2 = new VectorDocument(Uuid::v4(), new Vector([0.0, 1.0]), new Metadata(['id' => 'b']));
$calculator = new DistanceCalculator(DistanceStrategy::EUCLIDEAN_DISTANCE, batchSize: 1);
$result = $calculator->calculate([$doc1, $doc2], new Vector([1.0, 0.0]));
// Without maxItems, all documents are returned (full calculation path)
$this->assertCount(2, $result);
$this->assertSame('a', $result[0]->getMetadata()['id']);
$this->assertSame('b', $result[1]->getMetadata()['id']);
}
#[TestDox('Batched calculation works with batch size larger than document count')]
public function testBatchedCalculationWithLargeBatchSize()
{
$documents = [
new VectorDocument(Uuid::v4(), new Vector([0.0, 0.0]), new Metadata(['id' => 'a'])),
new VectorDocument(Uuid::v4(), new Vector([1.0, 0.0]), new Metadata(['id' => 'b'])),
];
$calculator = new DistanceCalculator(DistanceStrategy::EUCLIDEAN_DISTANCE, batchSize: 1000);
$result = $calculator->calculate($documents, new Vector([0.0, 0.0]), 1);
$this->assertCount(1, $result);
$this->assertSame('a', $result[0]->getMetadata()['id']);
}
#[TestDox('Batched calculation returns empty array for empty documents')]
public function testBatchedCalculationWithEmptyDocuments()
{
$calculator = new DistanceCalculator(batchSize: 100);
$result = $calculator->calculate([], new Vector([1.0, 2.0]), 5);
$this->assertSame([], $result);
}
#[TestDox('Batched calculation preserves scores')]
public function testBatchedCalculationPreservesScores()
{
$doc = new VectorDocument(Uuid::v4(), new Vector([3.0, 4.0]));
$fullCalculator = new DistanceCalculator(DistanceStrategy::EUCLIDEAN_DISTANCE);
$batchedCalculator = new DistanceCalculator(DistanceStrategy::EUCLIDEAN_DISTANCE, batchSize: 1);
$fullResult = $fullCalculator->calculate([$doc], new Vector([0.0, 0.0]), 1);
$batchedResult = $batchedCalculator->calculate([$doc], new Vector([0.0, 0.0]), 1);
$this->assertSame($fullResult[0]->getScore(), $batchedResult[0]->getScore());
$this->assertEqualsWithDelta(5.0, $batchedResult[0]->getScore(), 0.0001);
}
/**
* @param int[] $expectedOrder
*/
#[TestDox('Batched calculation works with $strategy strategy')]
#[DataProvider('provideBatchedStrategyTestCases')]
public function testBatchedCalculationWithDifferentStrategies(DistanceStrategy $strategy, array $expectedOrder)
{
$documents = [
new VectorDocument(Uuid::v4(), new Vector([1.0, 0.0, 0.0]), new Metadata(['index' => 0])),
new VectorDocument(Uuid::v4(), new Vector([0.0, 1.0, 0.0]), new Metadata(['index' => 1])),
new VectorDocument(Uuid::v4(), new Vector([0.0, 0.0, 1.0]), new Metadata(['index' => 2])),
new VectorDocument(Uuid::v4(), new Vector([0.5, 0.5, 0.707]), new Metadata(['index' => 3])),
];
$queryVector = new Vector([1.0, 0.0, 0.0]);
$calculator = new DistanceCalculator($strategy, batchSize: 2);
$result = $calculator->calculate($documents, $queryVector, 2);
$this->assertCount(2, $result);
foreach ($expectedOrder as $position => $expectedIndex) {
$this->assertSame($expectedIndex, $result[$position]->getMetadata()['index']);
}
}
/**
* @return \Generator<string, array{DistanceStrategy, int[]}>
*/
public static function provideBatchedStrategyTestCases(): \Generator
{
yield 'cosine distance batched' => [
DistanceStrategy::COSINE_DISTANCE,
[0, 3],
];
yield 'euclidean distance batched' => [
DistanceStrategy::EUCLIDEAN_DISTANCE,
[0, 3],
];
yield 'manhattan distance batched' => [
DistanceStrategy::MANHATTAN_DISTANCE,
[0, 3],
];
}
}

View File

@@ -17,8 +17,11 @@ use Symfony\AI\Store\Distance\DistanceCalculator;
use Symfony\AI\Store\Distance\DistanceStrategy;
use Symfony\AI\Store\Document\Metadata;
use Symfony\AI\Store\Document\VectorDocument;
use Symfony\AI\Store\Exception\InvalidArgumentException;
use Symfony\AI\Store\Exception\UnsupportedQueryTypeException;
use Symfony\AI\Store\InMemory\Store;
use Symfony\AI\Store\Query\HybridQuery;
use Symfony\AI\Store\Query\QueryInterface;
use Symfony\AI\Store\Query\TextQuery;
use Symfony\AI\Store\Query\VectorQuery;
use Symfony\Component\Uid\Uuid;
@@ -364,7 +367,7 @@ final class StoreTest extends TestCase
new VectorDocument($id, new Vector([0.1, 0.1, 0.5])),
]);
$this->expectException(\Symfony\AI\Store\Exception\InvalidArgumentException::class);
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('No supported options.');
$store->remove((string) $id, ['unsupported' => 'option']);
@@ -453,8 +456,9 @@ final class StoreTest extends TestCase
public function testHybridQueryThrowsExceptionForInvalidSemanticRatio()
{
$this->expectException(\Symfony\AI\Store\Exception\InvalidArgumentException::class);
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Semantic ratio must be between 0.0 and 1.0');
$this->expectExceptionCode(0);
new HybridQuery(new Vector([0.1, 0.2, 0.3]), 'test', 1.5);
}
@@ -464,11 +468,12 @@ final class StoreTest extends TestCase
$store = new Store();
// Create a mock query type that InMemory store doesn't support
$unsupportedQuery = new class implements \Symfony\AI\Store\Query\QueryInterface {
$unsupportedQuery = new class implements QueryInterface {
};
$this->expectException(\Symfony\AI\Store\Exception\UnsupportedQueryTypeException::class);
$this->expectException(UnsupportedQueryTypeException::class);
$this->expectExceptionMessageMatches('/not supported/');
$this->expectExceptionCode(0);
$store->query($unsupportedQuery);
}