mirror of
https://github.com/symfony/ai.git
synced 2026-03-23 23:42:18 +01:00
feat: add option to keep tool messages (#323)
Adds an option to the Toolbox/ChainProcessor to keep tool messages by avoiding to clone the original messageBag Fixes #321 Co-authored-by: Philip Heimböck <philip.heimboeck@abacus.ch>
This commit is contained in:
28
README.md
28
README.md
@@ -347,6 +347,34 @@ $eventDispatcher->addListener(ToolCallsExecuted::class, function (ToolCallsExecu
|
||||
});
|
||||
```
|
||||
|
||||
#### Keeping Tool Messages
|
||||
|
||||
Sometimes you might wish to keep the tool messages (`AssistantMessage` containing the `toolCalls` and `ToolCallMessage` containing the response) in the context.
|
||||
Enable the `keepToolMessages` flag of the toolbox' `ChainProcessor` to ensure those messages will be added to your `MessageBag`.
|
||||
|
||||
```php
|
||||
use PhpLlm\LlmChain\Chain\Toolbox\ChainProcessor;
|
||||
use PhpLlm\LlmChain\Chain\Toolbox\Toolbox;
|
||||
|
||||
// Platform & LLM instantiation
|
||||
$messages = new MessageBag(
|
||||
Message::forSystem(<<<PROMPT
|
||||
Please answer all user questions only using the similary_search tool. Do not add information and if you cannot
|
||||
find an answer, say so.
|
||||
PROMPT),
|
||||
Message::ofUser('...') // The user's question.
|
||||
);
|
||||
|
||||
$yourTool = new YourTool();
|
||||
|
||||
$toolbox = Toolbox::create($yourTool);
|
||||
$toolProcessor = new ChainProcessor($toolbox, keepToolMessages: true);
|
||||
|
||||
$chain = new Chain($platform, $llm, inputProcessor: [$toolProcessor], outputProcessor: [$toolProcessor]);
|
||||
$response = $chain->call($messages);
|
||||
// $messages will now include the tool messages
|
||||
```
|
||||
|
||||
#### Code Examples (with built-in tools)
|
||||
|
||||
1. [Brave Tool](examples/toolbox/brave.php)
|
||||
|
||||
@@ -33,6 +33,7 @@ final class ChainProcessor implements InputProcessorInterface, OutputProcessorIn
|
||||
private readonly ToolboxInterface $toolbox,
|
||||
private readonly ToolResultConverter $resultConverter = new ToolResultConverter(),
|
||||
private readonly ?EventDispatcherInterface $eventDispatcher = null,
|
||||
private readonly bool $keepToolMessages = false,
|
||||
) {
|
||||
}
|
||||
|
||||
@@ -86,7 +87,7 @@ final class ChainProcessor implements InputProcessorInterface, OutputProcessorIn
|
||||
private function handleToolCallsCallback(Output $output): \Closure
|
||||
{
|
||||
return function (ToolCallResponse $response, ?AssistantMessage $streamedAssistantResponse = null) use ($output): ResponseInterface {
|
||||
$messages = clone $output->messages;
|
||||
$messages = $this->keepToolMessages ? $output->messages : clone $output->messages;
|
||||
|
||||
if (null !== $streamedAssistantResponse && '' !== $streamedAssistantResponse->content) {
|
||||
$messages->add($streamedAssistantResponse);
|
||||
|
||||
@@ -4,13 +4,19 @@ declare(strict_types=1);
|
||||
|
||||
namespace PhpLlm\LlmChain\Tests\Chain\Toolbox;
|
||||
|
||||
use PhpLlm\LlmChain\Chain\ChainInterface;
|
||||
use PhpLlm\LlmChain\Chain\Exception\MissingModelSupportException;
|
||||
use PhpLlm\LlmChain\Chain\Input;
|
||||
use PhpLlm\LlmChain\Chain\Output;
|
||||
use PhpLlm\LlmChain\Chain\Toolbox\ChainProcessor;
|
||||
use PhpLlm\LlmChain\Chain\Toolbox\ToolboxInterface;
|
||||
use PhpLlm\LlmChain\Platform\Capability;
|
||||
use PhpLlm\LlmChain\Platform\Message\AssistantMessage;
|
||||
use PhpLlm\LlmChain\Platform\Message\MessageBag;
|
||||
use PhpLlm\LlmChain\Platform\Message\ToolCallMessage;
|
||||
use PhpLlm\LlmChain\Platform\Model;
|
||||
use PhpLlm\LlmChain\Platform\Response\ToolCall;
|
||||
use PhpLlm\LlmChain\Platform\Response\ToolCallResponse;
|
||||
use PhpLlm\LlmChain\Platform\Tool\ExecutionReference;
|
||||
use PhpLlm\LlmChain\Platform\Tool\Tool;
|
||||
use PHPUnit\Framework\Attributes\CoversClass;
|
||||
@@ -20,7 +26,10 @@ use PHPUnit\Framework\TestCase;
|
||||
|
||||
#[CoversClass(ChainProcessor::class)]
|
||||
#[UsesClass(Input::class)]
|
||||
#[UsesClass(Output::class)]
|
||||
#[UsesClass(Tool::class)]
|
||||
#[UsesClass(ToolCall::class)]
|
||||
#[UsesClass(ToolCallResponse::class)]
|
||||
#[UsesClass(ExecutionReference::class)]
|
||||
#[UsesClass(MessageBag::class)]
|
||||
#[UsesClass(MissingModelSupportException::class)]
|
||||
@@ -87,4 +96,54 @@ class ChainProcessorTest extends TestCase
|
||||
|
||||
$chainProcessor->processInput($input);
|
||||
}
|
||||
|
||||
#[Test]
|
||||
public function processOutputWithToolCallResponseKeepingMessages(): void
|
||||
{
|
||||
$toolbox = $this->createMock(ToolboxInterface::class);
|
||||
$toolbox->expects($this->once())->method('execute')->willReturn('Test response');
|
||||
|
||||
$model = new Model('gpt-4', [Capability::TOOL_CALLING]);
|
||||
|
||||
$messageBag = new MessageBag();
|
||||
|
||||
$response = new ToolCallResponse(new ToolCall('id1', 'tool1', ['arg1' => 'value1']));
|
||||
|
||||
$chain = $this->createStub(ChainInterface::class);
|
||||
|
||||
$chainProcessor = new ChainProcessor($toolbox, keepToolMessages: true);
|
||||
$chainProcessor->setChain($chain);
|
||||
|
||||
$output = new Output($model, $response, $messageBag, []);
|
||||
|
||||
$chainProcessor->processOutput($output);
|
||||
|
||||
self::assertCount(2, $messageBag);
|
||||
self::assertInstanceOf(AssistantMessage::class, $messageBag->getMessages()[0]);
|
||||
self::assertInstanceOf(ToolCallMessage::class, $messageBag->getMessages()[1]);
|
||||
}
|
||||
|
||||
#[Test]
|
||||
public function processOutputWithToolCallResponseForgettingMessages(): void
|
||||
{
|
||||
$toolbox = $this->createMock(ToolboxInterface::class);
|
||||
$toolbox->expects($this->once())->method('execute')->willReturn('Test response');
|
||||
|
||||
$model = new Model('gpt-4', [Capability::TOOL_CALLING]);
|
||||
|
||||
$messageBag = new MessageBag();
|
||||
|
||||
$response = new ToolCallResponse(new ToolCall('id1', 'tool1', ['arg1' => 'value1']));
|
||||
|
||||
$chain = $this->createStub(ChainInterface::class);
|
||||
|
||||
$chainProcessor = new ChainProcessor($toolbox, keepToolMessages: false);
|
||||
$chainProcessor->setChain($chain);
|
||||
|
||||
$output = new Output($model, $response, $messageBag, []);
|
||||
|
||||
$chainProcessor->processOutput($output);
|
||||
|
||||
self::assertCount(0, $messageBag);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user