Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added main_mypy.txt
Binary file not shown.
Binary file added pr_mypy.txt
Binary file not shown.
16 changes: 13 additions & 3 deletions src/google/adk/agents/remote_a2a_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def _is_remote_response(self, event: Event) -> bool:

def _construct_message_parts_from_session(
self, ctx: InvocationContext
) -> tuple[list[A2APart], Optional[str]]:
) -> tuple[list[A2APart], Optional[str], Optional[str]]:
"""Construct A2A message parts from session events.
Args:
Expand All @@ -391,6 +391,7 @@ def _construct_message_parts_from_session(
"""
message_parts: list[A2APart] = []
context_id = None
task_id = None

events_to_process = []
for event in reversed(ctx.session.events):
Expand All @@ -400,6 +401,14 @@ def _construct_message_parts_from_session(
if event.custom_metadata:
metadata = event.custom_metadata
context_id = metadata.get(A2A_METADATA_PREFIX + "context_id")
response_meta = metadata.get(A2A_METADATA_PREFIX + "response", {})
task_state = None
if isinstance(response_meta, dict):
status = response_meta.get("status", {})
if isinstance(status, dict):
task_state = status.get("state")
if task_state in ("input-required", "auth-required"):
task_id = metadata.get(A2A_METADATA_PREFIX + "task_id")
Comment on lines +404 to +411
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue with this implementation is that it is not possible to define if the new user message is a follow up to the input-required event, or a request for a new task -> in this case the task_id should not be set

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The remote agent owns the task,it knows if the task_id is still valid or not. If the user starts a new request, the remote agent can handle it. Forwarding task_id on input-required is the right default behavior. If more control is needed, that can be a separate improvement.

# Historical note: this behavior originally always applied, regardless
# of whether the agent was stateful or stateless. However, only stateful
# agents can be expected to have previous events in the remote session.
Expand Down Expand Up @@ -427,7 +436,7 @@ def _construct_message_parts_from_session(
else:
logger.warning("Failed to convert part to A2A format: %s", part)

return message_parts, context_id
return message_parts, context_id, task_id

async def _handle_a2a_response(
self, a2a_response: A2AClientEvent | A2AMessage, ctx: InvocationContext
Expand Down Expand Up @@ -624,7 +633,7 @@ async def _run_async_impl(
# Create A2A request for function response or regular message
a2a_request = self._create_a2a_request_for_user_function_response(ctx)
if not a2a_request:
message_parts, context_id = self._construct_message_parts_from_session(
message_parts, context_id, task_id = self._construct_message_parts_from_session(
ctx
)

Expand All @@ -645,6 +654,7 @@ async def _run_async_impl(
parts=message_parts,
role="user",
context_id=context_id,
task_id=task_id,
)

logger.debug(build_a2a_request_log(a2a_request))
Expand Down
38 changes: 19 additions & 19 deletions tests/unittests/agents/test_remote_a2a_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def test_construct_message_parts_from_session_success(self):
mock_a2a_part = Mock()
self.mock_genai_part_converter.return_value = mock_a2a_part

parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, task_id = self.agent._construct_message_parts_from_session(
self.mock_context
)

Expand Down Expand Up @@ -649,7 +649,7 @@ def test_construct_message_parts_from_session_success_multiple_parts(self):
mock_a2a_part2,
]

parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, task_id = self.agent._construct_message_parts_from_session(
self.mock_context
)

Expand All @@ -660,7 +660,7 @@ def test_construct_message_parts_from_session_empty_events(self):
"""Test message parts construction with empty events."""
self.mock_session.events = []

parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, task_id = self.agent._construct_message_parts_from_session(
self.mock_context
)

Expand Down Expand Up @@ -718,7 +718,7 @@ def mock_converter(part):
"google.adk.agents.remote_a2a_agent._present_other_agent_message"
) as mock_present:
mock_present.side_effect = lambda event: event
parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, task_id = self.agent._construct_message_parts_from_session(
self.mock_context
)
assert len(parts) == 1
Expand Down Expand Up @@ -768,7 +768,7 @@ def mock_converter(part):
"google.adk.agents.remote_a2a_agent._present_other_agent_message"
) as mock_present:
mock_present.side_effect = lambda event: event
parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, task_id = self.agent._construct_message_parts_from_session(
self.mock_context
)
assert len(parts) == 3
Expand Down Expand Up @@ -823,7 +823,7 @@ def mock_converter(part):
"google.adk.agents.remote_a2a_agent._present_other_agent_message"
) as mock_present:
mock_present.side_effect = lambda event: event
parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, task_id = self.agent._construct_message_parts_from_session(
self.mock_context
)
assert len(parts) == 1
Expand Down Expand Up @@ -954,7 +954,7 @@ def mock_converter(part):

self.mock_genai_part_converter.side_effect = mock_converter

parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, task_id = self.agent._construct_message_parts_from_session(
self.mock_context
)

Expand Down Expand Up @@ -1373,12 +1373,12 @@ def test_construct_message_parts_from_session_success(self):
mock_convert.return_value = mock_event

with patch.object(
self.agent, "_genai_part_converter"
self.agent, "_genai_part_converter"
) as mock_convert_part:
mock_a2a_part = Mock()
mock_convert_part.return_value = mock_a2a_part

parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, task_id = self.agent._construct_message_parts_from_session(
self.mock_context
)

Expand All @@ -1390,7 +1390,7 @@ def test_construct_message_parts_from_session_empty_events(self):
"""Test message parts construction with empty events."""
self.mock_session.events = []

parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, task_id = self.agent._construct_message_parts_from_session(
self.mock_context
)

Expand Down Expand Up @@ -1966,10 +1966,7 @@ async def test_run_async_impl_no_message_parts(self):
with patch.object(
self.agent, "_construct_message_parts_from_session"
) as mock_construct:
mock_construct.return_value = (
[],
None,
) # Tuple with empty parts and no context_id
mock_construct.return_value =([], None, None) # Tuple with empty parts and no context_id

events = []
async for event in self.agent._run_async_impl(self.mock_context):
Expand Down Expand Up @@ -1999,7 +1996,8 @@ async def test_run_async_impl_successful_request(self):
mock_construct.return_value = (
[mock_a2a_part],
"context-123",
) # Tuple with parts and context_id
None,
) # Tuple with parts and context_id , no task_id

# Mock A2A client
mock_a2a_client = create_autospec(spec=A2AClient, instance=True)
Expand Down Expand Up @@ -2071,6 +2069,7 @@ async def test_run_async_impl_a2a_client_error(self):
mock_construct.return_value = (
[mock_a2a_part],
"context-123",
None,
) # Tuple with parts and context_id

# Mock A2A client that throws an exception
Expand Down Expand Up @@ -2138,6 +2137,7 @@ async def test_run_async_impl_with_meta_provider(self):
mock_construct.return_value = (
[mock_a2a_part],
"context-123",
None,
) # Tuple with parts and context_id

# Mock A2A client
Expand Down Expand Up @@ -2242,10 +2242,8 @@ async def test_run_async_impl_no_message_parts(self):
with patch.object(
self.agent, "_construct_message_parts_from_session"
) as mock_construct:
mock_construct.return_value = (
[],
None,
) # Tuple with empty parts and no context_id
mock_construct.return_value = ([], None, None)
# Tuple with empty parts and no context_id

events = []
async for event in self.agent._run_async_impl(self.mock_context):
Expand Down Expand Up @@ -2275,6 +2273,7 @@ async def test_run_async_impl_successful_request(self):
mock_construct.return_value = (
[mock_a2a_part],
"context-123",
None,
) # Tuple with parts and context_id

# Mock A2A client
Expand Down Expand Up @@ -2349,6 +2348,7 @@ async def test_run_async_impl_a2a_client_error(self):
mock_construct.return_value = (
[mock_a2a_part],
"context-123",
None,
) # Tuple with parts and context_id

# Mock A2A client that throws an exception
Expand Down