Files
search-hub/hub/router.py

113 lines
3.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""搜索路由器 — 多源路由 + 结果去重合并"""
import time
from providers.base import SearchResult
class SearchRouter:
"""搜索路由器"""
def __init__(self, providers: dict):
"""
providers: {name: provider_instance}
"""
self.providers = providers
def search(self, query: str, source='auto', max_results=10):
"""
统一搜索入口
source 取值:
- 'auto' : 自动选择最优可用源
- 'tavily' : 指定单个源
- 'tavily,baidu' : 多源合并
"""
start = time.time()
if not source or source == 'auto':
results = self._auto_search(query, max_results)
used_source = results[0].source if results else None
elif ',' in source:
sources = [s.strip() for s in source.split(',') if s.strip()]
results = self._multi_search(query, sources, max_results)
used_source = source
else:
provider = self.providers.get(source)
if not provider or not provider.is_available():
return {
'query': query,
'results': [],
'total': 0,
'source': source,
'elapsed': round(time.time() - start, 2),
'error': f'搜索源 "{source}" 不可用',
}
raw = provider.search(query, max_results)
results = raw
used_source = source
elapsed = round(time.time() - start, 2)
return {
'query': query,
'results': [r.to_dict() for r in results],
'total': len(results),
'source': used_source,
'elapsed': elapsed,
}
def _auto_search(self, query, max_results):
"""自动选择:按优先级 fallback成功即返回"""
sorted_providers = sorted(
[p for p in self.providers.values() if p.is_available() and p.enabled],
key=lambda p: p.priority,
)
for provider in sorted_providers:
try:
results = provider.search(query, max_results)
if results:
return self._dedup(results, max_results)
except Exception:
continue
return []
def _multi_search(self, query, sources, max_results):
"""多源并发搜索"""
all_results = []
for name in sources:
provider = self.providers.get(name)
if not provider or not provider.is_available():
continue
try:
# 每个源搜 max_results 条
per_source = max(max_results // len(sources), 3)
results = provider.search(query, per_source)
all_results.extend(results)
except Exception:
continue
return self._dedup(all_results, max_results)
def get_sources(self):
"""返回所有搜索源状态"""
return [
p.get_status()
for p in sorted(self.providers.values(), key=lambda p: p.priority)
]
@staticmethod
def _dedup(results, max_results):
"""URL 去重 + 按分数排序"""
seen = set()
unique = []
for r in results:
key = r.url or r.title
if key and key not in seen:
seen.add(key)
unique.append(r)
# 按分数降序排列
unique.sort(key=lambda r: r.score, reverse=True)
return unique[:max_results]