113 lines
3.5 KiB
Python
113 lines
3.5 KiB
Python
"""搜索路由器 — 多源路由 + 结果去重合并"""
|
||
|
||
import time
|
||
from providers.base import SearchResult
|
||
|
||
|
||
class SearchRouter:
|
||
"""搜索路由器"""
|
||
|
||
def __init__(self, providers: dict):
|
||
"""
|
||
providers: {name: provider_instance}
|
||
"""
|
||
self.providers = providers
|
||
|
||
def search(self, query: str, source='auto', max_results=10):
|
||
"""
|
||
统一搜索入口
|
||
|
||
source 取值:
|
||
- 'auto' : 自动选择最优可用源
|
||
- 'tavily' : 指定单个源
|
||
- 'tavily,baidu' : 多源合并
|
||
"""
|
||
start = time.time()
|
||
|
||
if not source or source == 'auto':
|
||
results = self._auto_search(query, max_results)
|
||
used_source = results[0].source if results else None
|
||
elif ',' in source:
|
||
sources = [s.strip() for s in source.split(',') if s.strip()]
|
||
results = self._multi_search(query, sources, max_results)
|
||
used_source = source
|
||
else:
|
||
provider = self.providers.get(source)
|
||
if not provider or not provider.is_available():
|
||
return {
|
||
'query': query,
|
||
'results': [],
|
||
'total': 0,
|
||
'source': source,
|
||
'elapsed': round(time.time() - start, 2),
|
||
'error': f'搜索源 "{source}" 不可用',
|
||
}
|
||
raw = provider.search(query, max_results)
|
||
results = raw
|
||
used_source = source
|
||
|
||
elapsed = round(time.time() - start, 2)
|
||
return {
|
||
'query': query,
|
||
'results': [r.to_dict() for r in results],
|
||
'total': len(results),
|
||
'source': used_source,
|
||
'elapsed': elapsed,
|
||
}
|
||
|
||
def _auto_search(self, query, max_results):
|
||
"""自动选择:按优先级 fallback,成功即返回"""
|
||
sorted_providers = sorted(
|
||
[p for p in self.providers.values() if p.is_available() and p.enabled],
|
||
key=lambda p: p.priority,
|
||
)
|
||
|
||
for provider in sorted_providers:
|
||
try:
|
||
results = provider.search(query, max_results)
|
||
if results:
|
||
return self._dedup(results, max_results)
|
||
except Exception:
|
||
continue
|
||
|
||
return []
|
||
|
||
def _multi_search(self, query, sources, max_results):
|
||
"""多源并发搜索"""
|
||
all_results = []
|
||
for name in sources:
|
||
provider = self.providers.get(name)
|
||
if not provider or not provider.is_available():
|
||
continue
|
||
try:
|
||
# 每个源搜 max_results 条
|
||
per_source = max(max_results // len(sources), 3)
|
||
results = provider.search(query, per_source)
|
||
all_results.extend(results)
|
||
except Exception:
|
||
continue
|
||
|
||
return self._dedup(all_results, max_results)
|
||
|
||
def get_sources(self):
|
||
"""返回所有搜索源状态"""
|
||
return [
|
||
p.get_status()
|
||
for p in sorted(self.providers.values(), key=lambda p: p.priority)
|
||
]
|
||
|
||
@staticmethod
|
||
def _dedup(results, max_results):
|
||
"""URL 去重 + 按分数排序"""
|
||
seen = set()
|
||
unique = []
|
||
for r in results:
|
||
key = r.url or r.title
|
||
if key and key not in seen:
|
||
seen.add(key)
|
||
unique.append(r)
|
||
|
||
# 按分数降序排列
|
||
unique.sort(key=lambda r: r.score, reverse=True)
|
||
return unique[:max_results]
|