1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
| import requests import json import urllib3 from typing import List, Dict, Any import logging
logging.basicConfig(level=logging.INFO)
requests.packages.urllib3.disable_warnings() from aishu_anyshare_api import ApiClient token = ApiClient.get_global_access_token() host = ApiClient.get_global_host()
retrieval_url = "/api/intelli-search/v1/mf/retrieval"
DEFAULT_TOP_K = 10 DEFAULT_SCORE_THRESHOLD = 0.1 DEFAULT_TIMEOUT_TIMEOUT = 30000
def retrieval(query, source_ranges, history, lib_ranges, top_k=DEFAULT_TOP_K, score_threshold=DEFAULT_SCORE_THRESHOLD, timeout=DEFAULT_TIMEOUT_TIMEOUT): if not query: return [] params = build_recall_params(query, source_ranges, lib_ranges, top_k, score_threshold, timeout) logging.info(f"召回参数: {params}") if not params: logging.error("召回参数为空") return [] records = _retrieval(params) logging.info(f"召回结果: {records}") all_records = _deduplicate_and_sort(records) logging.info(f"去重排序后结果: {all_records}") return all_records
def build_recall_params(query, source_ranges, lib_ranges, top_k, score_threshold, timeout=DEFAULT_TIMEOUT_TIMEOUT): doc_params = {"top_k": top_k, "score_threshold": score_threshold, "ids": [], "ranges": [], "search_method": "vector_search", "weight_filters": []} wiki_params = {"top_k": top_k, "score_threshold": score_threshold, "ranges": [], "ids": [], "search_method": "hybrid_search", "weight_filters": []} faq_params = {"top_k": top_k, "score_threshold": score_threshold, "ranges": [], "id": [], "search_method": "hybrid_search", "weight_filters": [], "item_output_detail": {"field": "content"}}
for source in source_ranges: id = source.get("id") if source.get("type") == "doc": if id and id not in doc_params["ranges"]: doc_params["ranges"].append(id) elif source.get("type") == "wiki": if id and id not in wiki_params["ids"]: wiki_params["ids"].append(id) elif source.get("type") == "faq": if id and id not in faq_params["id"]: faq_params["id"].append(id)
params = {"text": query, "doc": doc_params, "wiki": wiki_params, "faq": faq_params, "timeout": timeout} return params
def _retrieval(params): url = f"{host}{retrieval_url}" try: headers = {"Content-Type": "application/json", "Authorization": f"Bearer {token}"} logging.info(f"请求URL: {url}") response = requests.post(url=url, json=params, headers=headers, timeout=DEFAULT_TIMEOUT_TIMEOUT, verify=False) logging.info(f"响应状态码: {response.status_code}, body={response.text}") if response.ok: try: result = response.json() return result.get("records", []) except json.JSONDecodeError: logging.error(f"JSON解析失败: {response.text}") return [] else: logging.error(f"请求失败: status={response.status_code}, body={response.text}") return [] except requests.exceptions.RequestException: import traceback; traceback.print_exc() return []
def _deduplicate_and_sort(records): seen = {} for record in records: record_id = record.get("id") if record_id and record_id not in seen: seen[record_id] = record elif record_id: existing = seen.get(record_id) if existing and record.get("score", 0) > existing.get("score", 0): seen[record_id] = record unique_records = list(seen.values()) unique_records.sort(key=lambda x: x.get("score", 0), reverse=True) return unique_records
def main(query, source_ranges, history, lib_ranges): result = retrieval(query=query, source_ranges=source_ranges, history=history, lib_ranges=lib_ranges) logging.info(f"召回结果: {result}") rst = "" for record in result: rst += record.get("content", " ") logging.info(f"合并后结果: {rst}") return rst
|