Skip to content

Commit 844487a

Browse files
committed
feat: Update agent interface to support optional auth handler name and improve observability token caching
1 parent 7b4f130 commit 844487a

5 files changed

Lines changed: 138 additions & 15 deletions

File tree

python/perplexity/sample-agent/agent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ async def invoke_agent(
107107
self,
108108
message: str,
109109
auth: Authorization,
110-
auth_handler_name: str,
110+
auth_handler_name: str | None,
111111
context: TurnContext,
112112
) -> str:
113113
# Log the user identity
@@ -193,7 +193,7 @@ async def _invoke_agent_with_inference_scope(
193193
self,
194194
message: str,
195195
auth: Authorization,
196-
auth_handler_name: str,
196+
auth_handler_name: str | None,
197197
context: TurnContext,
198198
) -> str:
199199
"""invoke_agent wrapped in an InferenceScope for observability."""
@@ -234,7 +234,7 @@ async def invoke_agent_with_scope(
234234
self,
235235
message: str,
236236
auth: Authorization,
237-
auth_handler_name: str,
237+
auth_handler_name: str | None,
238238
context: TurnContext,
239239
) -> str:
240240
# Extract identity from the activity recipient (populated by the platform).

python/perplexity/sample-agent/agent_interface.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77

88
from abc import ABC, abstractmethod
9+
from typing import Optional
910
from microsoft_agents.hosting.core import Authorization, TurnContext
1011

1112

@@ -18,14 +19,14 @@ class AgentInterface(ABC):
1819
"""
1920
@abstractmethod
2021
async def invoke_agent(
21-
self, message: str, auth: Authorization, auth_handler_name: str, context: TurnContext
22+
self, message: str, auth: Authorization, auth_handler_name: Optional[str], context: TurnContext
2223
) -> str:
2324
"""Process a user message and return a response."""
2425
pass
2526

2627
@abstractmethod
2728
async def invoke_agent_with_scope(
28-
self, message: str, auth: Authorization, auth_handler_name: str, context: TurnContext
29+
self, message: str, auth: Authorization, auth_handler_name: Optional[str], context: TurnContext
2930
) -> str:
3031
"""Process a user message within an observability scope and return a response."""
3132
pass

python/perplexity/sample-agent/hosting.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,21 @@ async def _typing_loop():
144144
if self.auth_handler_name:
145145
try:
146146
recipient = context.activity.recipient
147-
tenant_id = getattr(recipient, "tenant_id", None) or ""
148-
agent_id = getattr(recipient, "agentic_app_id", None) or ""
147+
tenant_id = (getattr(recipient, "tenant_id", None) or "").strip()
148+
agent_id = (getattr(recipient, "agentic_app_id", None) or "").strip()
149149
obs_token = await self.auth.exchange_token(
150150
context,
151151
scopes=get_observability_authentication_scope(),
152152
auth_handler_id=self.auth_handler_name,
153153
)
154154
if obs_token and obs_token.token:
155-
cache_agentic_token(tenant_id, agent_id, obs_token.token)
156-
logger.info("Agentic token cached for observability exporter")
155+
if tenant_id and agent_id:
156+
cache_agentic_token(tenant_id, agent_id, obs_token.token)
157+
logger.info("Agentic token cached for observability exporter")
158+
else:
159+
logger.info(
160+
"Skipping observability token cache because tenant_id or agent_id is missing"
161+
)
157162
except Exception as token_err:
158163
logger.warning("Failed to exchange/cache observability token: %s", token_err)
159164

python/perplexity/sample-agent/mcp_tool_registration_service.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,10 @@ async def _connect_server(server_config):
290290
server_url,
291291
exc,
292292
)
293+
try:
294+
await session.close()
295+
except Exception:
296+
pass
293297
return None
294298

295299
results = await _asyncio.gather(
@@ -395,13 +399,17 @@ async def execute_tool(name: str, arguments: dict) -> str:
395399
"Tool '%s' failed after %d attempts — clearing MCP cache",
396400
name, _MCP_MAX_RETRIES + 1,
397401
)
398-
svc._initialized = False
402+
await svc._invalidate_cache()
399403
return f"Error executing tool '{name}': {last_error}"
400404

401405
return execute_tool
402406

403-
async def close(self) -> None:
404-
"""Close all cached MCP sessions (call on server shutdown)."""
407+
async def _invalidate_cache(self) -> None:
408+
"""Close existing MCP sessions and clear all cached state.
409+
410+
Called when retries are exhausted so the next turn reconnects
411+
from scratch instead of appending duplicates.
412+
"""
405413
for s in self._sessions:
406414
try:
407415
await s.close()
@@ -411,3 +419,7 @@ async def close(self) -> None:
411419
self._tool_map.clear()
412420
self._openai_tools.clear()
413421
self._initialized = False
422+
423+
async def close(self) -> None:
424+
"""Close all cached MCP sessions (call on server shutdown)."""
425+
await self._invalidate_cache()

python/perplexity/sample-agent/perplexity_client.py

Lines changed: 108 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,99 @@
3131
# Timeout (seconds) for a single Perplexity API call.
3232
_PER_ROUND_TIMEOUT = 30
3333

34+
# Tool-selection threshold: when more tools than this are available,
35+
# make a fast preliminary call to pick only the relevant ones.
36+
_TOOL_SELECTION_THRESHOLD = 20
37+
38+
# Maximum tools the selector may return.
39+
_TOOL_SELECTION_MAX = 15
40+
41+
# Timeout (seconds) for the tool-selection call.
42+
_TOOL_SELECTION_TIMEOUT = 15
43+
44+
45+
async def select_relevant_tools(
46+
client: AsyncOpenAI,
47+
model: str,
48+
user_message: str,
49+
all_tools: list[dict],
50+
) -> list[dict]:
51+
"""Use a fast LLM call to pick only the tools relevant to *user_message*.
52+
53+
Returns a filtered subset (≤ ``_TOOL_SELECTION_MAX``) of *all_tools*.
54+
On any failure the full list is returned so the main flow is never blocked.
55+
"""
56+
# Build a compact one-line-per-tool catalog for the selector prompt.
57+
catalog_lines: list[str] = []
58+
for idx, t in enumerate(all_tools):
59+
name = t.get("name", "unknown")
60+
desc = (t.get("description") or "")[:120]
61+
catalog_lines.append(f"{idx}: {name}{desc}")
62+
catalog = "\n".join(catalog_lines)
63+
64+
selector_prompt = (
65+
"Given the user's request, select ONLY the tools needed to fulfill it.\n"
66+
"Return a JSON array of tool index numbers (integers). Include tools that "
67+
"might be needed for follow-up steps (e.g., if creating a document and sharing "
68+
"a link, include both create and share tools).\n"
69+
f"Select at most {_TOOL_SELECTION_MAX} tools. Return ONLY a JSON array like "
70+
"[0, 3, 7], no explanation.\n\n"
71+
f'User request: "{user_message}"\n\n'
72+
f"Available tools:\n{catalog}"
73+
)
74+
75+
try:
76+
resp = await asyncio.wait_for(
77+
client.responses.create(
78+
model=model,
79+
instructions="You are a tool selector. Return ONLY a JSON array of integers.",
80+
input=selector_prompt,
81+
store=False,
82+
),
83+
timeout=_TOOL_SELECTION_TIMEOUT,
84+
)
85+
86+
raw_text = ""
87+
for item in resp.output:
88+
if item.type == "message":
89+
for c in getattr(item, "content", []):
90+
if hasattr(c, "text") and c.text:
91+
raw_text += c.text
92+
if not raw_text:
93+
raw_text = str(resp.output_text or "")
94+
95+
# Strip markdown fences and extract the JSON array.
96+
raw_text = raw_text.strip().strip("`").strip()
97+
if raw_text.startswith("json"):
98+
raw_text = raw_text[4:].strip()
99+
100+
match = re.search(r"\[[\d,\s]+\]", raw_text)
101+
if not match:
102+
logger.warning("Tool selector returned unparseable response — using all tools")
103+
return all_tools
104+
105+
indices: list[int] = json.loads(match.group())
106+
selected = [all_tools[i] for i in indices if 0 <= i < len(all_tools)]
107+
108+
if not selected:
109+
logger.warning("Tool selector returned empty set — using all tools")
110+
return all_tools
111+
112+
logger.info(
113+
"Tool selector narrowed %d → %d tools: %s",
114+
len(all_tools),
115+
len(selected),
116+
[t.get("name") for t in selected],
117+
)
118+
return selected
119+
120+
except asyncio.TimeoutError:
121+
logger.warning("Tool selector timed out (%ds) — using all tools", _TOOL_SELECTION_TIMEOUT)
122+
return all_tools
123+
except Exception as exc:
124+
logger.warning("Tool selector failed (%s) — using all tools", exc)
125+
return all_tools
126+
34127

35128
class PerplexityClient:
36129
"""Async client for Perplexity AI using the Agent API (Responses API)."""
@@ -66,6 +159,11 @@ async def invoke(
66159
"""
67160
logger.info("Invoking Perplexity model=%s (tools=%d)", self.model, len(tools or []))
68161

162+
# When too many tools are registered, use a fast selector call to
163+
# narrow down to just the relevant ones before the main API request.
164+
if tools and len(tools) > _TOOL_SELECTION_THRESHOLD:
165+
tools = await select_relevant_tools(self._client, self.model, user_message, tools)
166+
69167
create_kwargs: dict[str, Any] = {
70168
"model": self.model,
71169
"input": user_message,
@@ -107,7 +205,14 @@ async def invoke(
107205
if ctx:
108206
create_kwargs["input"] = f"{user_message}\n\n{ctx}"
109207
tools = None
110-
response = await self._client.responses.create(**create_kwargs)
208+
try:
209+
response = await asyncio.wait_for(
210+
self._client.responses.create(**create_kwargs),
211+
timeout=_PER_ROUND_TIMEOUT,
212+
)
213+
except asyncio.TimeoutError:
214+
logger.warning("Perplexity API fallback round %d timed out (%ds) — returning partial answer", _round + 1, _PER_ROUND_TIMEOUT)
215+
break
111216
else:
112217
raise
113218

@@ -186,9 +291,9 @@ async def invoke(
186291
arguments = self._enrich_arguments(fc.name, arguments, user_message, tools or [])
187292

188293
logger.info("Executing MCP tool: %s (round %d)", fc.name, _round + 1)
189-
logger.info("Tool arguments: %s", json.dumps(arguments, indent=2, default=str))
294+
logger.debug("Tool arguments: %s", json.dumps(arguments, indent=2, default=str))
190295
result = await tool_executor(fc.name, arguments)
191-
logger.info("Tool result (first 500 chars): %.500s", json.dumps(result, default=str) if not isinstance(result, str) else result)
296+
logger.debug("Tool result (first 500 chars): %.500s", json.dumps(result, default=str) if not isinstance(result, str) else result)
192297

193298
# Track resource creation/finalization generically
194299
tool_lower = fc.name.lower()

0 commit comments

Comments
 (0)