diff --git a/main_mypy.txt b/main_mypy.txt new file mode 100644 index 0000000000..2ceca275cd Binary files /dev/null and b/main_mypy.txt differ diff --git a/pr_mypy.txt b/pr_mypy.txt new file mode 100644 index 0000000000..969b039c44 Binary files /dev/null and b/pr_mypy.txt differ diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index 6072a5ddcb..f16d11ed3d 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -379,7 +379,7 @@ def _is_remote_response(self, event: Event) -> bool: def _construct_message_parts_from_session( self, ctx: InvocationContext - ) -> tuple[list[A2APart], Optional[str]]: + ) -> tuple[list[A2APart], Optional[str], Optional[str]]: """Construct A2A message parts from session events. Args: @@ -391,6 +391,7 @@ def _construct_message_parts_from_session( """ message_parts: list[A2APart] = [] context_id = None + task_id = None events_to_process = [] for event in reversed(ctx.session.events): @@ -400,6 +401,14 @@ def _construct_message_parts_from_session( if event.custom_metadata: metadata = event.custom_metadata context_id = metadata.get(A2A_METADATA_PREFIX + "context_id") + response_meta = metadata.get(A2A_METADATA_PREFIX + "response", {}) + task_state = None + if isinstance(response_meta, dict): + status = response_meta.get("status", {}) + if isinstance(status, dict): + task_state = status.get("state") + if task_state in ("input-required", "auth-required"): + task_id = metadata.get(A2A_METADATA_PREFIX + "task_id") # Historical note: this behavior originally always applied, regardless # of whether the agent was stateful or stateless. However, only stateful # agents can be expected to have previous events in the remote session. @@ -427,7 +436,7 @@ def _construct_message_parts_from_session( else: logger.warning("Failed to convert part to A2A format: %s", part) - return message_parts, context_id + return message_parts, context_id, task_id async def _handle_a2a_response( self, a2a_response: A2AClientEvent | A2AMessage, ctx: InvocationContext @@ -624,7 +633,7 @@ async def _run_async_impl( # Create A2A request for function response or regular message a2a_request = self._create_a2a_request_for_user_function_response(ctx) if not a2a_request: - message_parts, context_id = self._construct_message_parts_from_session( + message_parts, context_id, task_id = self._construct_message_parts_from_session( ctx ) @@ -645,6 +654,7 @@ async def _run_async_impl( parts=message_parts, role="user", context_id=context_id, + task_id=task_id, ) logger.debug(build_a2a_request_log(a2a_request)) diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index 0f1ce896a3..c0ba0b89f6 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -615,7 +615,7 @@ def test_construct_message_parts_from_session_success(self): mock_a2a_part = Mock() self.mock_genai_part_converter.return_value = mock_a2a_part - parts, context_id = self.agent._construct_message_parts_from_session( + parts, context_id, task_id = self.agent._construct_message_parts_from_session( self.mock_context ) @@ -649,7 +649,7 @@ def test_construct_message_parts_from_session_success_multiple_parts(self): mock_a2a_part2, ] - parts, context_id = self.agent._construct_message_parts_from_session( + parts, context_id, task_id = self.agent._construct_message_parts_from_session( self.mock_context ) @@ -660,7 +660,7 @@ def test_construct_message_parts_from_session_empty_events(self): """Test message parts construction with empty events.""" self.mock_session.events = [] - parts, context_id = self.agent._construct_message_parts_from_session( + parts, context_id, task_id = self.agent._construct_message_parts_from_session( self.mock_context ) @@ -718,7 +718,7 @@ def mock_converter(part): "google.adk.agents.remote_a2a_agent._present_other_agent_message" ) as mock_present: mock_present.side_effect = lambda event: event - parts, context_id = self.agent._construct_message_parts_from_session( + parts, context_id, task_id = self.agent._construct_message_parts_from_session( self.mock_context ) assert len(parts) == 1 @@ -768,7 +768,7 @@ def mock_converter(part): "google.adk.agents.remote_a2a_agent._present_other_agent_message" ) as mock_present: mock_present.side_effect = lambda event: event - parts, context_id = self.agent._construct_message_parts_from_session( + parts, context_id, task_id = self.agent._construct_message_parts_from_session( self.mock_context ) assert len(parts) == 3 @@ -823,7 +823,7 @@ def mock_converter(part): "google.adk.agents.remote_a2a_agent._present_other_agent_message" ) as mock_present: mock_present.side_effect = lambda event: event - parts, context_id = self.agent._construct_message_parts_from_session( + parts, context_id, task_id = self.agent._construct_message_parts_from_session( self.mock_context ) assert len(parts) == 1 @@ -954,7 +954,7 @@ def mock_converter(part): self.mock_genai_part_converter.side_effect = mock_converter - parts, context_id = self.agent._construct_message_parts_from_session( + parts, context_id, task_id = self.agent._construct_message_parts_from_session( self.mock_context ) @@ -1373,12 +1373,12 @@ def test_construct_message_parts_from_session_success(self): mock_convert.return_value = mock_event with patch.object( - self.agent, "_genai_part_converter" + self.agent, "_genai_part_converter" ) as mock_convert_part: mock_a2a_part = Mock() mock_convert_part.return_value = mock_a2a_part - parts, context_id = self.agent._construct_message_parts_from_session( + parts, context_id, task_id = self.agent._construct_message_parts_from_session( self.mock_context ) @@ -1390,7 +1390,7 @@ def test_construct_message_parts_from_session_empty_events(self): """Test message parts construction with empty events.""" self.mock_session.events = [] - parts, context_id = self.agent._construct_message_parts_from_session( + parts, context_id, task_id = self.agent._construct_message_parts_from_session( self.mock_context ) @@ -1966,10 +1966,7 @@ async def test_run_async_impl_no_message_parts(self): with patch.object( self.agent, "_construct_message_parts_from_session" ) as mock_construct: - mock_construct.return_value = ( - [], - None, - ) # Tuple with empty parts and no context_id + mock_construct.return_value =([], None, None) # Tuple with empty parts and no context_id events = [] async for event in self.agent._run_async_impl(self.mock_context): @@ -1999,7 +1996,8 @@ async def test_run_async_impl_successful_request(self): mock_construct.return_value = ( [mock_a2a_part], "context-123", - ) # Tuple with parts and context_id + None, + ) # Tuple with parts and context_id , no task_id # Mock A2A client mock_a2a_client = create_autospec(spec=A2AClient, instance=True) @@ -2071,6 +2069,7 @@ async def test_run_async_impl_a2a_client_error(self): mock_construct.return_value = ( [mock_a2a_part], "context-123", + None, ) # Tuple with parts and context_id # Mock A2A client that throws an exception @@ -2138,6 +2137,7 @@ async def test_run_async_impl_with_meta_provider(self): mock_construct.return_value = ( [mock_a2a_part], "context-123", + None, ) # Tuple with parts and context_id # Mock A2A client @@ -2242,10 +2242,8 @@ async def test_run_async_impl_no_message_parts(self): with patch.object( self.agent, "_construct_message_parts_from_session" ) as mock_construct: - mock_construct.return_value = ( - [], - None, - ) # Tuple with empty parts and no context_id + mock_construct.return_value = ([], None, None) + # Tuple with empty parts and no context_id events = [] async for event in self.agent._run_async_impl(self.mock_context): @@ -2275,6 +2273,7 @@ async def test_run_async_impl_successful_request(self): mock_construct.return_value = ( [mock_a2a_part], "context-123", + None, ) # Tuple with parts and context_id # Mock A2A client @@ -2349,6 +2348,7 @@ async def test_run_async_impl_a2a_client_error(self): mock_construct.return_value = ( [mock_a2a_part], "context-123", + None, ) # Tuple with parts and context_id # Mock A2A client that throws an exception