356 lines
14 KiB
Python
356 lines
14 KiB
Python
from __future__ import annotations
|
||
|
||
from datetime import datetime
|
||
from typing import Any, Dict, Iterable, Optional, Type
|
||
|
||
from app.models.company import (
|
||
BaseCompanyModel,
|
||
BossCompany,
|
||
CompanyCleaningQueue,
|
||
QcwyCompany,
|
||
ZhilianCompany,
|
||
)
|
||
|
||
|
||
COMPANY_SOURCES = {"boss", "qcwy", "zhilian"}
|
||
QUEUE_TERMINAL_STATUSES = {"done", "failed"}
|
||
|
||
|
||
def normalize_company_id(source: str, company_id: str) -> str:
|
||
value = str(company_id or "").strip()
|
||
if source == "qcwy" and value.lower().startswith("co") and value[2:].isdigit():
|
||
return value[2:]
|
||
return value
|
||
|
||
|
||
def _pick_first(data: dict[str, Any], *keys: str) -> Optional[Any]:
|
||
for key in keys:
|
||
value = data.get(key)
|
||
if value not in (None, ""):
|
||
return value
|
||
return None
|
||
|
||
|
||
def _nested_get(data: dict[str, Any], *path: str) -> Any:
|
||
current: Any = data
|
||
for key in path:
|
||
if not isinstance(current, dict):
|
||
return None
|
||
current = current.get(key)
|
||
return current
|
||
|
||
|
||
def _clean_text(value: Any) -> Optional[str]:
|
||
if value is None:
|
||
return None
|
||
text = str(value).strip()
|
||
return text or None
|
||
|
||
|
||
def _model_for_source(source: str) -> Type[BaseCompanyModel]:
|
||
mapping: dict[str, Type[BaseCompanyModel]] = {
|
||
"boss": BossCompany,
|
||
"qcwy": QcwyCompany,
|
||
"zhilian": ZhilianCompany,
|
||
}
|
||
if source not in mapping:
|
||
raise ValueError(f"unsupported source: {source}")
|
||
return mapping[source]
|
||
|
||
|
||
def _extract_boss_fields(raw: dict[str, Any], company_id: str) -> dict[str, Any]:
|
||
payload = raw.get("zpData") if isinstance(raw.get("zpData"), dict) else raw
|
||
brand = payload.get("brandComInfoVO") or {}
|
||
company_full = payload.get("companyFullInfoVO") or {}
|
||
|
||
return {
|
||
"source_company_id": normalize_company_id("boss", company_id or _pick_first(brand, "encryptBrandId", "brandId")),
|
||
"company_name": _clean_text(
|
||
_pick_first(payload, "name")
|
||
or _pick_first(company_full, "name", "brandName")
|
||
or _pick_first(brand, "brandName")
|
||
) or "",
|
||
"company_type": _clean_text(_pick_first(company_full, "typeName") or _pick_first(brand, "brandIndustry")),
|
||
"industry": _clean_text(_pick_first(brand, "industryName") or _pick_first(company_full, "industry")),
|
||
"company_size": _clean_text(_pick_first(brand, "scaleName") or _pick_first(company_full, "scaleName")),
|
||
"financing_stage": _clean_text(_pick_first(brand, "stageName") or _pick_first(company_full, "stageName")),
|
||
"city": _clean_text(_pick_first(company_full, "cityName", "city")),
|
||
"address": _clean_text(_pick_first(company_full, "address", "addressInfo")),
|
||
"website": _clean_text(_pick_first(company_full, "website")),
|
||
"logo_url": _clean_text(_pick_first(company_full, "logo", "brandLogo") or _pick_first(brand, "logo", "brandLogo")),
|
||
"description": _clean_text(
|
||
_pick_first(company_full, "introduce", "introduction", "companyDesc")
|
||
or _pick_first(brand, "introduce")
|
||
),
|
||
}
|
||
|
||
|
||
def _extract_qcwy_fields(raw: dict[str, Any], company_id: str) -> dict[str, Any]:
|
||
financing = raw.get("financingStage") or {}
|
||
coinfo = raw.get("coinfo") if isinstance(raw.get("coinfo"), dict) else {}
|
||
|
||
return {
|
||
"source_company_id": normalize_company_id(
|
||
"qcwy",
|
||
company_id or _pick_first(raw, "companyId", "coId") or _nested_get(raw, "coinfo", "coid"),
|
||
),
|
||
"company_name": _clean_text(
|
||
_pick_first(raw, "companyName", "fullCompanyName", "companyNameEn")
|
||
or _pick_first(coinfo, "coname", "brandName")
|
||
) or "",
|
||
"company_type": _clean_text(_pick_first(raw, "companyTypeString", "orgTypeName") or _pick_first(coinfo, "cotype")),
|
||
"industry": _clean_text(
|
||
_pick_first(raw, "industryName", "companyIndustryType1Str")
|
||
or _pick_first(coinfo, "indtype1", "indtype2", "coIndustryText")
|
||
),
|
||
"company_size": _clean_text(
|
||
_pick_first(raw, "companySizeString", "companySize", "orgSizeName")
|
||
or _pick_first(coinfo, "cosize")
|
||
),
|
||
"financing_stage": _clean_text(_pick_first(financing, "name") or _pick_first(raw, "financingStageName")),
|
||
"city": _clean_text(_pick_first(raw, "cityName", "jobAreaString", "workCity") or _pick_first(coinfo, "areaString")),
|
||
"address": _clean_text(
|
||
_pick_first(raw, "address", "location")
|
||
or _nested_get(raw, "workLocation", "workAddress")
|
||
or _pick_first(coinfo, "caddr")
|
||
),
|
||
"website": _clean_text(_pick_first(raw, "companyUrl", "companyHref") or _pick_first(coinfo, "webUrl")),
|
||
"logo_url": _clean_text(_pick_first(raw, "companyLogo") or _pick_first(coinfo, "logourl")),
|
||
"description": _clean_text(
|
||
_pick_first(raw, "companyDesc", "company_desc", "description")
|
||
or _nested_get(raw, "campusRootOrgInfo", "description")
|
||
or _pick_first(coinfo, "coinfo")
|
||
),
|
||
}
|
||
|
||
|
||
def _extract_zhilian_fields(raw: dict[str, Any], company_id: str) -> dict[str, Any]:
|
||
data = raw.get("data") if isinstance(raw.get("data"), dict) else raw
|
||
company_base = data.get("companyBase") or {}
|
||
detailed_company = data.get("detailedCompany") or {}
|
||
|
||
return {
|
||
"source_company_id": normalize_company_id(
|
||
"zhilian",
|
||
company_id
|
||
or _pick_first(company_base, "companyNumber", "number")
|
||
or _pick_first(detailed_company, "companyNumber", "number"),
|
||
),
|
||
"company_name": _clean_text(_pick_first(company_base, "companyName") or _pick_first(data, "companyName")) or "",
|
||
"company_type": _clean_text(
|
||
_pick_first(company_base, "companyTypeName", "companyType")
|
||
or _pick_first(detailed_company, "companyTypeName")
|
||
),
|
||
"industry": _clean_text(_pick_first(company_base, "industryName") or _pick_first(detailed_company, "industryName")),
|
||
"company_size": _clean_text(
|
||
_pick_first(company_base, "companySize", "companySizeString")
|
||
or _pick_first(detailed_company, "companySize")
|
||
),
|
||
"financing_stage": _clean_text(
|
||
_pick_first(company_base, "financingStageName")
|
||
or _nested_get(company_base, "financingStage", "name")
|
||
or _nested_get(detailed_company, "financingStage", "name")
|
||
),
|
||
"city": _clean_text(_pick_first(company_base, "cityName") or _pick_first(detailed_company, "cityName")),
|
||
"address": _clean_text(_pick_first(company_base, "address") or _pick_first(detailed_company, "address")),
|
||
"website": _clean_text(_pick_first(company_base, "companyUrl", "website")),
|
||
"logo_url": _clean_text(_pick_first(company_base, "logoUrl", "companyLogo")),
|
||
"description": _clean_text(
|
||
_pick_first(company_base, "companyDescWithHtml", "companyDesc")
|
||
or _pick_first(detailed_company, "companyDescription", "companyDesc")
|
||
),
|
||
}
|
||
|
||
|
||
def extract_company_fields(source: str, raw: dict[str, Any], company_id: str) -> dict[str, Any]:
|
||
if source == "boss":
|
||
return _extract_boss_fields(raw, company_id)
|
||
if source == "qcwy":
|
||
return _extract_qcwy_fields(raw, company_id)
|
||
if source == "zhilian":
|
||
return _extract_zhilian_fields(raw, company_id)
|
||
raise ValueError(f"unsupported source: {source}")
|
||
|
||
|
||
class CompanyStorageService:
|
||
@staticmethod
|
||
def company_model(source: str) -> Type[BaseCompanyModel]:
|
||
return _model_for_source(source)
|
||
|
||
async def get_existing_company_ids(self, source: str, company_ids: Iterable[str]) -> set[str]:
|
||
normalized_ids = [normalize_company_id(source, item) for item in company_ids if item]
|
||
if not normalized_ids:
|
||
return set()
|
||
model = self.company_model(source)
|
||
rows = await model.filter(source_company_id__in=normalized_ids).values_list("source_company_id", flat=True)
|
||
return set(rows)
|
||
|
||
async def get_all_company_ids(self, source: str) -> set[str]:
|
||
"""获取该平台所有已入库的公司 ID(用于 ClickHouse 查询排除)"""
|
||
model = self.company_model(source)
|
||
rows = await model.all().values_list("source_company_id", flat=True)
|
||
return set(rows)
|
||
|
||
async def get_existing_queue_ids(self, source: str, company_ids: Iterable[str]) -> set[str]:
|
||
normalized_ids = [normalize_company_id(source, item) for item in company_ids if item]
|
||
if not normalized_ids:
|
||
return set()
|
||
rows = await CompanyCleaningQueue.filter(source=source, company_id__in=normalized_ids).values_list("company_id", flat=True)
|
||
return set(rows)
|
||
|
||
async def enqueue_company(self, source: str, company_id: str, company_name: str = "") -> tuple[CompanyCleaningQueue, bool]:
|
||
normalized_id = normalize_company_id(source, company_id)
|
||
defaults = {
|
||
"company_name": company_name or "",
|
||
"status": "pending",
|
||
"error_msg": "",
|
||
"retry_count": 0,
|
||
"started_at": None,
|
||
"finished_at": None,
|
||
"jobs_fetched": 0,
|
||
"jobs_stored": 0,
|
||
"jobs_duplicate": 0,
|
||
"jobs_failed": 0,
|
||
"jobs_error_msg": "",
|
||
}
|
||
queue, created = await CompanyCleaningQueue.get_or_create(
|
||
source=source,
|
||
company_id=normalized_id,
|
||
defaults=defaults,
|
||
)
|
||
if not created and company_name and queue.company_name != company_name:
|
||
queue.company_name = company_name
|
||
await queue.save(update_fields=["company_name", "updated_at"])
|
||
return queue, created
|
||
|
||
async def enqueue_companies(self, source: str, companies: Iterable[dict[str, str]]) -> int:
|
||
created_count = 0
|
||
for item in companies:
|
||
_, created = await self.enqueue_company(
|
||
source=source,
|
||
company_id=item.get("company_id", ""),
|
||
company_name=item.get("company_name", "") or "",
|
||
)
|
||
if created:
|
||
created_count += 1
|
||
return created_count
|
||
|
||
async def get_company_record(self, source: str, company_id: str) -> Optional[BaseCompanyModel]:
|
||
normalized_id = normalize_company_id(source, company_id)
|
||
model = self.company_model(source)
|
||
return await model.get_or_none(source_company_id=normalized_id)
|
||
|
||
async def upsert_company(
|
||
self,
|
||
source: str,
|
||
raw_data: dict[str, Any],
|
||
*,
|
||
company_id: Optional[str] = None,
|
||
) -> dict[str, Any]:
|
||
normalized_id = normalize_company_id(source, company_id or "")
|
||
fields = extract_company_fields(source, raw_data, normalized_id)
|
||
normalized_id = fields["source_company_id"]
|
||
if not normalized_id:
|
||
raise ValueError(f"missing normalized company id for source={source}")
|
||
if not fields["company_name"]:
|
||
raise ValueError(f"missing company name for source={source} company_id={normalized_id}")
|
||
|
||
model = self.company_model(source)
|
||
record = await model.get_or_none(source_company_id=normalized_id)
|
||
now = datetime.now()
|
||
payload = {
|
||
**fields,
|
||
"raw_json": raw_data,
|
||
"last_crawled_at": now,
|
||
}
|
||
|
||
if record:
|
||
for key, value in payload.items():
|
||
setattr(record, key, value)
|
||
await record.save()
|
||
created = False
|
||
else:
|
||
record = await model.create(
|
||
**payload,
|
||
first_crawled_at=now,
|
||
)
|
||
created = True
|
||
|
||
return {
|
||
"success": True,
|
||
"created": created,
|
||
"company_id": normalized_id,
|
||
"company_name": record.company_name,
|
||
"data_summary": {
|
||
"source": source,
|
||
"company_id": normalized_id,
|
||
"company_name": record.company_name,
|
||
"created": created,
|
||
},
|
||
"record": record,
|
||
}
|
||
|
||
async def mark_queue_processing(self, queue: CompanyCleaningQueue) -> None:
|
||
queue.status = "processing"
|
||
queue.error_msg = ""
|
||
queue.started_at = datetime.now()
|
||
queue.finished_at = None
|
||
queue.jobs_fetched = 0
|
||
queue.jobs_stored = 0
|
||
queue.jobs_duplicate = 0
|
||
queue.jobs_failed = 0
|
||
queue.jobs_error_msg = ""
|
||
await queue.save(
|
||
update_fields=[
|
||
"status",
|
||
"error_msg",
|
||
"started_at",
|
||
"finished_at",
|
||
"jobs_fetched",
|
||
"jobs_stored",
|
||
"jobs_duplicate",
|
||
"jobs_failed",
|
||
"jobs_error_msg",
|
||
"updated_at",
|
||
]
|
||
)
|
||
|
||
async def mark_queue_result(
|
||
self,
|
||
queue: CompanyCleaningQueue,
|
||
*,
|
||
status: str,
|
||
error_msg: str = "",
|
||
increment_retry: bool = False,
|
||
jobs_summary: Optional[dict[str, Any]] = None,
|
||
) -> None:
|
||
queue.status = status
|
||
queue.error_msg = error_msg or ""
|
||
queue.finished_at = datetime.now()
|
||
if jobs_summary:
|
||
queue.jobs_fetched = int(jobs_summary.get("jobs_fetched") or 0)
|
||
queue.jobs_stored = int(jobs_summary.get("stored_success") or 0)
|
||
queue.jobs_duplicate = int(jobs_summary.get("duplicate") or 0)
|
||
queue.jobs_failed = int(jobs_summary.get("failed") or 0)
|
||
queue.jobs_error_msg = jobs_summary.get("error") or ""
|
||
if increment_retry:
|
||
queue.retry_count += 1
|
||
await queue.save(
|
||
update_fields=[
|
||
"company_name",
|
||
"status",
|
||
"error_msg",
|
||
"retry_count",
|
||
"finished_at",
|
||
"jobs_fetched",
|
||
"jobs_stored",
|
||
"jobs_duplicate",
|
||
"jobs_failed",
|
||
"jobs_error_msg",
|
||
"updated_at",
|
||
]
|
||
)
|
||
|
||
|
||
company_storage = CompanyStorageService()
|