mirror of
https://github.com/symfony/ai.git
synced 2026-03-23 23:42:18 +01:00
feat(store): allow to sort documents as batches
This commit is contained in:
@@ -52,6 +52,34 @@ Available strategies:
|
||||
* ``ANGULAR_DISTANCE``
|
||||
* ``CHEBYSHEV_DISTANCE``
|
||||
|
||||
Batch Processing
|
||||
----------------
|
||||
|
||||
For large datasets, the distance calculator can process documents in batches
|
||||
instead of scoring the entire dataset at once. After each batch, only the best
|
||||
candidates are kept, reducing peak memory from O(N) to O(maxItems + batchSize)::
|
||||
|
||||
use Symfony\AI\Store\Distance\DistanceCalculator;
|
||||
use Symfony\AI\Store\Distance\DistanceStrategy;
|
||||
use Symfony\AI\Store\InMemory\Store;
|
||||
|
||||
$calculator = new DistanceCalculator(
|
||||
strategy: DistanceStrategy::COSINE_DISTANCE,
|
||||
batchSize: 500, # Default to 100
|
||||
);
|
||||
$store = new Store($calculator);
|
||||
|
||||
// Batch processing is activated when both batchSize and maxItems are set
|
||||
$results = $store->query($vector, [
|
||||
'maxItems' => 10,
|
||||
]);
|
||||
|
||||
.. note::
|
||||
|
||||
Batch processing requires ``maxItems`` to be set in the query options.
|
||||
Without it, the calculator falls back to the standard full-sort behavior
|
||||
since all results are needed and no pruning can occur.
|
||||
|
||||
Metadata Filtering
|
||||
------------------
|
||||
|
||||
|
||||
@@ -19,8 +19,12 @@ use Symfony\AI\Store\Document\VectorDocument;
|
||||
*/
|
||||
final class DistanceCalculator
|
||||
{
|
||||
/**
|
||||
* @param positive-int $batchSize when set alongside $maxItems in {@see self::calculate()}, documents are scored in chunks of this size
|
||||
*/
|
||||
public function __construct(
|
||||
private readonly DistanceStrategy $strategy = DistanceStrategy::COSINE_DISTANCE,
|
||||
private readonly int $batchSize = 100,
|
||||
) {
|
||||
}
|
||||
|
||||
@@ -32,13 +36,21 @@ final class DistanceCalculator
|
||||
*/
|
||||
public function calculate(array $documents, Vector $vector, ?int $maxItems = null): array
|
||||
{
|
||||
$strategy = match ($this->strategy) {
|
||||
DistanceStrategy::COSINE_DISTANCE => $this->cosineDistance(...),
|
||||
DistanceStrategy::ANGULAR_DISTANCE => $this->angularDistance(...),
|
||||
DistanceStrategy::EUCLIDEAN_DISTANCE => $this->euclideanDistance(...),
|
||||
DistanceStrategy::MANHATTAN_DISTANCE => $this->manhattanDistance(...),
|
||||
DistanceStrategy::CHEBYSHEV_DISTANCE => $this->chebyshevDistance(...),
|
||||
};
|
||||
if (null !== $maxItems && $this->batchSize <= \count($documents)) {
|
||||
return $this->calculateBatched($documents, $vector, $maxItems);
|
||||
}
|
||||
|
||||
return $this->calculateAgainstAll($documents, $vector, $maxItems);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param VectorDocument[] $documents
|
||||
*
|
||||
* @return VectorDocument[]
|
||||
*/
|
||||
private function calculateAgainstAll(array $documents, Vector $vector, ?int $maxItems): array
|
||||
{
|
||||
$strategy = $this->resolveStrategy();
|
||||
|
||||
$currentEmbeddings = array_map(
|
||||
static fn (VectorDocument $vectorDocument): array => [
|
||||
@@ -63,6 +75,65 @@ final class DistanceCalculator
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes documents in chunks of {@see self::$batchSize}, keeping only the top $maxItems candidates after each chunk.
|
||||
*
|
||||
* @param VectorDocument[] $documents
|
||||
* @param positive-int $maxItems
|
||||
*
|
||||
* @return VectorDocument[]
|
||||
*/
|
||||
private function calculateBatched(array $documents, Vector $vector, int $maxItems): array
|
||||
{
|
||||
$strategy = $this->resolveStrategy();
|
||||
|
||||
/** @var array<int, array{distance: float, document: VectorDocument}> $candidates */
|
||||
$candidates = [];
|
||||
|
||||
foreach (array_chunk($documents, $this->batchSize) as $batch) {
|
||||
$batchResults = array_map(
|
||||
static fn (VectorDocument $vectorDocument): array => [
|
||||
'distance' => $strategy($vectorDocument, $vector),
|
||||
'document' => $vectorDocument,
|
||||
],
|
||||
$batch,
|
||||
);
|
||||
|
||||
$candidates = [
|
||||
...$candidates,
|
||||
...$batchResults,
|
||||
];
|
||||
|
||||
usort(
|
||||
$candidates,
|
||||
static fn (array $a, array $b): int => $a['distance'] <=> $b['distance'],
|
||||
);
|
||||
|
||||
if (\count($candidates) > $maxItems) {
|
||||
$candidates = \array_slice($candidates, 0, $maxItems);
|
||||
}
|
||||
}
|
||||
|
||||
return array_map(
|
||||
static fn (array $embedding): VectorDocument => $embedding['document']->withScore($embedding['distance']),
|
||||
$candidates,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* @return \Closure(VectorDocument, Vector): float
|
||||
*/
|
||||
private function resolveStrategy(): \Closure
|
||||
{
|
||||
return match ($this->strategy) {
|
||||
DistanceStrategy::COSINE_DISTANCE => $this->cosineDistance(...),
|
||||
DistanceStrategy::ANGULAR_DISTANCE => $this->angularDistance(...),
|
||||
DistanceStrategy::EUCLIDEAN_DISTANCE => $this->euclideanDistance(...),
|
||||
DistanceStrategy::MANHATTAN_DISTANCE => $this->manhattanDistance(...),
|
||||
DistanceStrategy::CHEBYSHEV_DISTANCE => $this->chebyshevDistance(...),
|
||||
};
|
||||
}
|
||||
|
||||
private function cosineDistance(VectorDocument $embedding, Vector $against): float
|
||||
{
|
||||
return 1 - $this->cosineSimilarity($embedding, $against);
|
||||
@@ -125,9 +196,9 @@ final class DistanceCalculator
|
||||
);
|
||||
|
||||
return array_reduce(
|
||||
array: $embeddingsAsPower,
|
||||
callback: static fn (float $value, float $current): float => max($value, $current),
|
||||
initial: 0.0,
|
||||
$embeddingsAsPower,
|
||||
static fn (float $value, float $current): float => max($value, $current),
|
||||
0.0,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,9 +24,9 @@ use Symfony\Component\Uid\Uuid;
|
||||
final class DistanceCalculatorTest extends TestCase
|
||||
{
|
||||
/**
|
||||
* @param array<list<float>> $documentVectors
|
||||
* @param list<float> $queryVector
|
||||
* @param list<int> $expectedOrder
|
||||
* @param array<float[]> $documentVectors
|
||||
* @param float[] $queryVector
|
||||
* @param int[] $expectedOrder
|
||||
*/
|
||||
#[TestDox('Calculates distances correctly using $strategy strategy')]
|
||||
#[DataProvider('provideDistanceStrategyTestCases')]
|
||||
@@ -59,7 +59,7 @@ final class DistanceCalculatorTest extends TestCase
|
||||
}
|
||||
|
||||
/**
|
||||
* @return \Generator<string, array{DistanceStrategy, array<list<float>>, list<float>, list<int>}>
|
||||
* @return \Generator<string, array{DistanceStrategy, array<float[]>, float[], int[]}>
|
||||
*/
|
||||
public static function provideDistanceStrategyTestCases(): \Generator
|
||||
{
|
||||
@@ -218,15 +218,13 @@ final class DistanceCalculatorTest extends TestCase
|
||||
{
|
||||
$calculator = new DistanceCalculator(DistanceStrategy::EUCLIDEAN_DISTANCE);
|
||||
|
||||
// Create high-dimensional vectors (100 dimensions)
|
||||
$dimensions = 100;
|
||||
$vector1 = array_fill(0, $dimensions, 0.1);
|
||||
$vector2 = array_fill(0, $dimensions, 0.2);
|
||||
$vector1 = array_fill(0, 100, 0.1);
|
||||
$vector2 = array_fill(0, 100, 0.2);
|
||||
|
||||
$doc1 = new VectorDocument(Uuid::v4(), new Vector($vector1));
|
||||
$doc2 = new VectorDocument(Uuid::v4(), new Vector($vector2));
|
||||
|
||||
$queryVector = new Vector(array_fill(0, $dimensions, 0.15));
|
||||
$queryVector = new Vector(array_fill(0, 100, 0.15));
|
||||
|
||||
$result = $calculator->calculate([$doc1, $doc2], $queryVector);
|
||||
|
||||
@@ -290,10 +288,8 @@ final class DistanceCalculatorTest extends TestCase
|
||||
#[TestDox('Uses cosine distance as default strategy')]
|
||||
public function testDefaultStrategyIsCosineDistance()
|
||||
{
|
||||
// Test that default constructor uses cosine distance
|
||||
$calculator = new DistanceCalculator();
|
||||
|
||||
// Create vectors where cosine distance ordering differs from Euclidean
|
||||
$doc1 = new VectorDocument(Uuid::v4(), new Vector([1.0, 0.0, 0.0]));
|
||||
$doc2 = new VectorDocument(Uuid::v4(), new Vector([100.0, 0.0, 0.0])); // Same direction but different magnitude
|
||||
|
||||
@@ -305,4 +301,155 @@ final class DistanceCalculatorTest extends TestCase
|
||||
// The order might vary but both are equally similar in terms of direction
|
||||
$this->assertCount(2, $result);
|
||||
}
|
||||
|
||||
#[TestDox('Batched calculation returns same top results as full calculation')]
|
||||
public function testBatchedCalculationReturnsSameResultsAsFull()
|
||||
{
|
||||
$documents = [
|
||||
new VectorDocument(Uuid::v4(), new Vector([0.0, 0.0]), new Metadata(['id' => 'a'])),
|
||||
new VectorDocument(Uuid::v4(), new Vector([1.0, 0.0]), new Metadata(['id' => 'b'])),
|
||||
new VectorDocument(Uuid::v4(), new Vector([0.0, 1.0]), new Metadata(['id' => 'c'])),
|
||||
new VectorDocument(Uuid::v4(), new Vector([1.0, 1.0]), new Metadata(['id' => 'd'])),
|
||||
new VectorDocument(Uuid::v4(), new Vector([0.5, 0.5]), new Metadata(['id' => 'e'])),
|
||||
];
|
||||
|
||||
$queryVector = new Vector([0.0, 0.0]);
|
||||
|
||||
$fullCalculator = new DistanceCalculator(DistanceStrategy::EUCLIDEAN_DISTANCE);
|
||||
$batchedCalculator = new DistanceCalculator(DistanceStrategy::EUCLIDEAN_DISTANCE, batchSize: 2);
|
||||
|
||||
$fullResult = $fullCalculator->calculate($documents, $queryVector, 3);
|
||||
$batchedResult = $batchedCalculator->calculate($documents, $queryVector, 3);
|
||||
|
||||
$fullIds = array_map(static fn (VectorDocument $doc): string => $doc->getMetadata()['id'], $fullResult);
|
||||
$batchedIds = array_map(static fn (VectorDocument $doc): string => $doc->getMetadata()['id'], $batchedResult);
|
||||
|
||||
$this->assertSame($fullIds, $batchedIds);
|
||||
$this->assertCount(3, $batchedResult);
|
||||
}
|
||||
|
||||
#[TestDox('Batched calculation prunes candidates after each batch')]
|
||||
public function testBatchedCalculationPrunesCandidates()
|
||||
{
|
||||
// 10 documents, batch size 3, maxItems 2
|
||||
// After each batch of 3, only the top 2 candidates are kept
|
||||
$documents = [];
|
||||
for ($i = 0; $i < 10; ++$i) {
|
||||
$documents[] = new VectorDocument(
|
||||
Uuid::v4(),
|
||||
new Vector([(float) $i, 0.0]),
|
||||
new Metadata(['id' => (string) $i]),
|
||||
);
|
||||
}
|
||||
|
||||
$calculator = new DistanceCalculator(DistanceStrategy::EUCLIDEAN_DISTANCE, batchSize: 3);
|
||||
$result = $calculator->calculate($documents, new Vector([0.0, 0.0]), 2);
|
||||
|
||||
$this->assertCount(2, $result);
|
||||
|
||||
$ids = array_map(static fn (VectorDocument $doc): string => $doc->getMetadata()['id'], $result);
|
||||
$this->assertSame(['0', '1'], $ids);
|
||||
}
|
||||
|
||||
#[TestDox('Batched calculation falls back to full when maxItems is null')]
|
||||
public function testBatchedCalculationFallsBackWithoutMaxItems()
|
||||
{
|
||||
$doc1 = new VectorDocument(Uuid::v4(), new Vector([1.0, 0.0]), new Metadata(['id' => 'a']));
|
||||
$doc2 = new VectorDocument(Uuid::v4(), new Vector([0.0, 1.0]), new Metadata(['id' => 'b']));
|
||||
|
||||
$calculator = new DistanceCalculator(DistanceStrategy::EUCLIDEAN_DISTANCE, batchSize: 1);
|
||||
$result = $calculator->calculate([$doc1, $doc2], new Vector([1.0, 0.0]));
|
||||
|
||||
// Without maxItems, all documents are returned (full calculation path)
|
||||
$this->assertCount(2, $result);
|
||||
$this->assertSame('a', $result[0]->getMetadata()['id']);
|
||||
$this->assertSame('b', $result[1]->getMetadata()['id']);
|
||||
}
|
||||
|
||||
#[TestDox('Batched calculation works with batch size larger than document count')]
|
||||
public function testBatchedCalculationWithLargeBatchSize()
|
||||
{
|
||||
$documents = [
|
||||
new VectorDocument(Uuid::v4(), new Vector([0.0, 0.0]), new Metadata(['id' => 'a'])),
|
||||
new VectorDocument(Uuid::v4(), new Vector([1.0, 0.0]), new Metadata(['id' => 'b'])),
|
||||
];
|
||||
|
||||
$calculator = new DistanceCalculator(DistanceStrategy::EUCLIDEAN_DISTANCE, batchSize: 1000);
|
||||
$result = $calculator->calculate($documents, new Vector([0.0, 0.0]), 1);
|
||||
|
||||
$this->assertCount(1, $result);
|
||||
$this->assertSame('a', $result[0]->getMetadata()['id']);
|
||||
}
|
||||
|
||||
#[TestDox('Batched calculation returns empty array for empty documents')]
|
||||
public function testBatchedCalculationWithEmptyDocuments()
|
||||
{
|
||||
$calculator = new DistanceCalculator(batchSize: 100);
|
||||
|
||||
$result = $calculator->calculate([], new Vector([1.0, 2.0]), 5);
|
||||
|
||||
$this->assertSame([], $result);
|
||||
}
|
||||
|
||||
#[TestDox('Batched calculation preserves scores')]
|
||||
public function testBatchedCalculationPreservesScores()
|
||||
{
|
||||
$doc = new VectorDocument(Uuid::v4(), new Vector([3.0, 4.0]));
|
||||
|
||||
$fullCalculator = new DistanceCalculator(DistanceStrategy::EUCLIDEAN_DISTANCE);
|
||||
$batchedCalculator = new DistanceCalculator(DistanceStrategy::EUCLIDEAN_DISTANCE, batchSize: 1);
|
||||
|
||||
$fullResult = $fullCalculator->calculate([$doc], new Vector([0.0, 0.0]), 1);
|
||||
$batchedResult = $batchedCalculator->calculate([$doc], new Vector([0.0, 0.0]), 1);
|
||||
|
||||
$this->assertSame($fullResult[0]->getScore(), $batchedResult[0]->getScore());
|
||||
$this->assertEqualsWithDelta(5.0, $batchedResult[0]->getScore(), 0.0001);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param int[] $expectedOrder
|
||||
*/
|
||||
#[TestDox('Batched calculation works with $strategy strategy')]
|
||||
#[DataProvider('provideBatchedStrategyTestCases')]
|
||||
public function testBatchedCalculationWithDifferentStrategies(DistanceStrategy $strategy, array $expectedOrder)
|
||||
{
|
||||
$documents = [
|
||||
new VectorDocument(Uuid::v4(), new Vector([1.0, 0.0, 0.0]), new Metadata(['index' => 0])),
|
||||
new VectorDocument(Uuid::v4(), new Vector([0.0, 1.0, 0.0]), new Metadata(['index' => 1])),
|
||||
new VectorDocument(Uuid::v4(), new Vector([0.0, 0.0, 1.0]), new Metadata(['index' => 2])),
|
||||
new VectorDocument(Uuid::v4(), new Vector([0.5, 0.5, 0.707]), new Metadata(['index' => 3])),
|
||||
];
|
||||
|
||||
$queryVector = new Vector([1.0, 0.0, 0.0]);
|
||||
|
||||
$calculator = new DistanceCalculator($strategy, batchSize: 2);
|
||||
$result = $calculator->calculate($documents, $queryVector, 2);
|
||||
|
||||
$this->assertCount(2, $result);
|
||||
|
||||
foreach ($expectedOrder as $position => $expectedIndex) {
|
||||
$this->assertSame($expectedIndex, $result[$position]->getMetadata()['index']);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @return \Generator<string, array{DistanceStrategy, int[]}>
|
||||
*/
|
||||
public static function provideBatchedStrategyTestCases(): \Generator
|
||||
{
|
||||
yield 'cosine distance batched' => [
|
||||
DistanceStrategy::COSINE_DISTANCE,
|
||||
[0, 3],
|
||||
];
|
||||
|
||||
yield 'euclidean distance batched' => [
|
||||
DistanceStrategy::EUCLIDEAN_DISTANCE,
|
||||
[0, 3],
|
||||
];
|
||||
|
||||
yield 'manhattan distance batched' => [
|
||||
DistanceStrategy::MANHATTAN_DISTANCE,
|
||||
[0, 3],
|
||||
];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,8 +17,11 @@ use Symfony\AI\Store\Distance\DistanceCalculator;
|
||||
use Symfony\AI\Store\Distance\DistanceStrategy;
|
||||
use Symfony\AI\Store\Document\Metadata;
|
||||
use Symfony\AI\Store\Document\VectorDocument;
|
||||
use Symfony\AI\Store\Exception\InvalidArgumentException;
|
||||
use Symfony\AI\Store\Exception\UnsupportedQueryTypeException;
|
||||
use Symfony\AI\Store\InMemory\Store;
|
||||
use Symfony\AI\Store\Query\HybridQuery;
|
||||
use Symfony\AI\Store\Query\QueryInterface;
|
||||
use Symfony\AI\Store\Query\TextQuery;
|
||||
use Symfony\AI\Store\Query\VectorQuery;
|
||||
use Symfony\Component\Uid\Uuid;
|
||||
@@ -364,7 +367,7 @@ final class StoreTest extends TestCase
|
||||
new VectorDocument($id, new Vector([0.1, 0.1, 0.5])),
|
||||
]);
|
||||
|
||||
$this->expectException(\Symfony\AI\Store\Exception\InvalidArgumentException::class);
|
||||
$this->expectException(InvalidArgumentException::class);
|
||||
$this->expectExceptionMessage('No supported options.');
|
||||
|
||||
$store->remove((string) $id, ['unsupported' => 'option']);
|
||||
@@ -453,8 +456,9 @@ final class StoreTest extends TestCase
|
||||
|
||||
public function testHybridQueryThrowsExceptionForInvalidSemanticRatio()
|
||||
{
|
||||
$this->expectException(\Symfony\AI\Store\Exception\InvalidArgumentException::class);
|
||||
$this->expectException(InvalidArgumentException::class);
|
||||
$this->expectExceptionMessage('Semantic ratio must be between 0.0 and 1.0');
|
||||
$this->expectExceptionCode(0);
|
||||
|
||||
new HybridQuery(new Vector([0.1, 0.2, 0.3]), 'test', 1.5);
|
||||
}
|
||||
@@ -464,11 +468,12 @@ final class StoreTest extends TestCase
|
||||
$store = new Store();
|
||||
|
||||
// Create a mock query type that InMemory store doesn't support
|
||||
$unsupportedQuery = new class implements \Symfony\AI\Store\Query\QueryInterface {
|
||||
$unsupportedQuery = new class implements QueryInterface {
|
||||
};
|
||||
|
||||
$this->expectException(\Symfony\AI\Store\Exception\UnsupportedQueryTypeException::class);
|
||||
$this->expectException(UnsupportedQueryTypeException::class);
|
||||
$this->expectExceptionMessageMatches('/not supported/');
|
||||
$this->expectExceptionCode(0);
|
||||
|
||||
$store->query($unsupportedQuery);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user