169 lines
5.8 KiB
Python
169 lines
5.8 KiB
Python
"""百度搜索源 — 通过百度千帆官方 API"""
|
||
|
||
import json
|
||
import time
|
||
import requests
|
||
from providers.base import BaseProvider, SearchResult
|
||
|
||
|
||
class BaiduProvider(BaseProvider):
|
||
name = 'baidu'
|
||
display_name = '百度搜索'
|
||
needs_api_key = True
|
||
enabled = True
|
||
priority = 10 # auto 模式首选
|
||
|
||
def __init__(self, config: dict, mode='web'):
|
||
"""
|
||
mode: 'web' → 网页搜索(快速)
|
||
'intelligent' → 智能检索生成(AI 分析)
|
||
"""
|
||
super().__init__(config)
|
||
self._mode = mode
|
||
if mode == 'intelligent':
|
||
self.name = 'baidu-intelligent'
|
||
self.display_name = '百度智能检索'
|
||
self.priority = 21
|
||
self.enabled = False # 仅手动选择,不参与 auto
|
||
|
||
bc = config.get('baidu', {})
|
||
self.api_key = bc.get('api_key')
|
||
self.intelligent_url = bc.get(
|
||
'intelligent_url',
|
||
'https://qianfan.baidubce.com/v2/ai_search/chat/completions',
|
||
)
|
||
self.web_search_url = bc.get(
|
||
'web_search_url',
|
||
'https://qianfan.baidubce.com/v2/ai_search/web_search',
|
||
)
|
||
|
||
def is_available(self) -> bool:
|
||
return bool(self.api_key)
|
||
|
||
def search(self, query: str, max_results: int = 10) -> list:
|
||
if not self.api_key:
|
||
return []
|
||
|
||
if self._mode == 'intelligent':
|
||
# 智能检索按引用条数扣费,限制最多3条省额度
|
||
return self._intelligent_search(query, min(max_results, 3))
|
||
return self._web_search(query, max_results)
|
||
|
||
def _intelligent_search(self, query: str, max_results: int) -> list:
|
||
"""智能检索生成 — 返回 AI 回答 + 引用来源"""
|
||
headers = {
|
||
'Authorization': f'Bearer {self.api_key}',
|
||
'Content-Type': 'application/json',
|
||
}
|
||
payload = {
|
||
'messages': [{'content': query, 'role': 'user'}],
|
||
'stream': False,
|
||
'model': 'ernie-4.5-turbo-128k',
|
||
'enable_corner_markers': True,
|
||
'enable_deep_search': True,
|
||
}
|
||
|
||
try:
|
||
resp = requests.post(
|
||
self.intelligent_url,
|
||
json=payload,
|
||
headers=headers,
|
||
timeout=60,
|
||
)
|
||
if resp.status_code != 200:
|
||
return []
|
||
|
||
data = resp.json()
|
||
results = []
|
||
|
||
# 从引用来源中提取搜索结果
|
||
references = data.get('references', []) or data.get('result', {}).get('references', [])
|
||
for ref in references[:max_results]:
|
||
title = ref.get('title', '') or ref.get('name', '')
|
||
url = ref.get('url', '') or ref.get('link', '')
|
||
content = ref.get('summary', '') or ref.get('content', '') or ref.get('snippet', '')
|
||
if title and url:
|
||
results.append(SearchResult(
|
||
title=title,
|
||
url=url,
|
||
content=content,
|
||
score=0.8,
|
||
source=self.name,
|
||
))
|
||
|
||
# 如果没有引用链接,尝试从 AI 回答的 content 中提取
|
||
if not results:
|
||
ai_content = ''
|
||
try:
|
||
ai_content = data['choices'][0]['message']['content']
|
||
except (KeyError, IndexError):
|
||
ai_content = data.get('result', {}).get('answer', '')
|
||
|
||
if ai_content:
|
||
# 作为 AI 搜索结果展示
|
||
results.append(SearchResult(
|
||
title=f'百度AI: {query}',
|
||
url=f'https://www.baidu.com/s?wd={query}',
|
||
content=ai_content[:500],
|
||
score=0.7,
|
||
source=self.name,
|
||
))
|
||
|
||
return results
|
||
|
||
except requests.exceptions.RequestException:
|
||
return []
|
||
|
||
def _web_search(self, query: str, max_results: int) -> list:
|
||
"""百度网页搜索 API"""
|
||
if max_results <= 0:
|
||
return []
|
||
|
||
headers = {
|
||
'Authorization': f'Bearer {self.api_key}',
|
||
'Content-Type': 'application/json',
|
||
}
|
||
payload = {
|
||
'messages': [{'content': query, 'role': 'user'}],
|
||
'search_source': 'baidu_search_v2',
|
||
'resource_type_filter': [{'type': 'web', 'top_k': max_results}],
|
||
}
|
||
|
||
try:
|
||
resp = requests.post(
|
||
self.web_search_url,
|
||
json=payload,
|
||
headers=headers,
|
||
timeout=25,
|
||
)
|
||
if resp.status_code != 200:
|
||
return []
|
||
|
||
data = resp.json()
|
||
results = []
|
||
|
||
# 响应格式: {"request_id":"...", "references":[...]}
|
||
refs = data.get('references', []) or data.get('result', {}).get('items', [])
|
||
|
||
for ref in refs[:max_results]:
|
||
title = ref.get('title', '') or ref.get('name', '')
|
||
url = ref.get('url', '') or ref.get('link', '')
|
||
# snippet 是简短摘要,content 是完整内容
|
||
snippet = ref.get('snippet', '') or ref.get('content', '') or ''
|
||
published = ref.get('date', '') or ref.get('published_date', '')
|
||
|
||
if title and url:
|
||
results.append(SearchResult(
|
||
title=title,
|
||
url=url,
|
||
content=snippet[:500] if len(snippet) > 500 else snippet,
|
||
score=0.6,
|
||
source=self.name,
|
||
published_date=published,
|
||
))
|
||
|
||
return results
|
||
|
||
except requests.exceptions.RequestException:
|
||
return []
|