JobData/app/services/company_storage.py
2026-03-22 23:22:30 +08:00

356 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
from datetime import datetime
from typing import Any, Dict, Iterable, Optional, Type
from app.models.company import (
BaseCompanyModel,
BossCompany,
CompanyCleaningQueue,
QcwyCompany,
ZhilianCompany,
)
COMPANY_SOURCES = {"boss", "qcwy", "zhilian"}
QUEUE_TERMINAL_STATUSES = {"done", "failed"}
def normalize_company_id(source: str, company_id: str) -> str:
value = str(company_id or "").strip()
if source == "qcwy" and value.lower().startswith("co") and value[2:].isdigit():
return value[2:]
return value
def _pick_first(data: dict[str, Any], *keys: str) -> Optional[Any]:
for key in keys:
value = data.get(key)
if value not in (None, ""):
return value
return None
def _nested_get(data: dict[str, Any], *path: str) -> Any:
current: Any = data
for key in path:
if not isinstance(current, dict):
return None
current = current.get(key)
return current
def _clean_text(value: Any) -> Optional[str]:
if value is None:
return None
text = str(value).strip()
return text or None
def _model_for_source(source: str) -> Type[BaseCompanyModel]:
mapping: dict[str, Type[BaseCompanyModel]] = {
"boss": BossCompany,
"qcwy": QcwyCompany,
"zhilian": ZhilianCompany,
}
if source not in mapping:
raise ValueError(f"unsupported source: {source}")
return mapping[source]
def _extract_boss_fields(raw: dict[str, Any], company_id: str) -> dict[str, Any]:
payload = raw.get("zpData") if isinstance(raw.get("zpData"), dict) else raw
brand = payload.get("brandComInfoVO") or {}
company_full = payload.get("companyFullInfoVO") or {}
return {
"source_company_id": normalize_company_id("boss", company_id or _pick_first(brand, "encryptBrandId", "brandId")),
"company_name": _clean_text(
_pick_first(payload, "name")
or _pick_first(company_full, "name", "brandName")
or _pick_first(brand, "brandName")
) or "",
"company_type": _clean_text(_pick_first(company_full, "typeName") or _pick_first(brand, "brandIndustry")),
"industry": _clean_text(_pick_first(brand, "industryName") or _pick_first(company_full, "industry")),
"company_size": _clean_text(_pick_first(brand, "scaleName") or _pick_first(company_full, "scaleName")),
"financing_stage": _clean_text(_pick_first(brand, "stageName") or _pick_first(company_full, "stageName")),
"city": _clean_text(_pick_first(company_full, "cityName", "city")),
"address": _clean_text(_pick_first(company_full, "address", "addressInfo")),
"website": _clean_text(_pick_first(company_full, "website")),
"logo_url": _clean_text(_pick_first(company_full, "logo", "brandLogo") or _pick_first(brand, "logo", "brandLogo")),
"description": _clean_text(
_pick_first(company_full, "introduce", "introduction", "companyDesc")
or _pick_first(brand, "introduce")
),
}
def _extract_qcwy_fields(raw: dict[str, Any], company_id: str) -> dict[str, Any]:
financing = raw.get("financingStage") or {}
coinfo = raw.get("coinfo") if isinstance(raw.get("coinfo"), dict) else {}
return {
"source_company_id": normalize_company_id(
"qcwy",
company_id or _pick_first(raw, "companyId", "coId") or _nested_get(raw, "coinfo", "coid"),
),
"company_name": _clean_text(
_pick_first(raw, "companyName", "fullCompanyName", "companyNameEn")
or _pick_first(coinfo, "coname", "brandName")
) or "",
"company_type": _clean_text(_pick_first(raw, "companyTypeString", "orgTypeName") or _pick_first(coinfo, "cotype")),
"industry": _clean_text(
_pick_first(raw, "industryName", "companyIndustryType1Str")
or _pick_first(coinfo, "indtype1", "indtype2", "coIndustryText")
),
"company_size": _clean_text(
_pick_first(raw, "companySizeString", "companySize", "orgSizeName")
or _pick_first(coinfo, "cosize")
),
"financing_stage": _clean_text(_pick_first(financing, "name") or _pick_first(raw, "financingStageName")),
"city": _clean_text(_pick_first(raw, "cityName", "jobAreaString", "workCity") or _pick_first(coinfo, "areaString")),
"address": _clean_text(
_pick_first(raw, "address", "location")
or _nested_get(raw, "workLocation", "workAddress")
or _pick_first(coinfo, "caddr")
),
"website": _clean_text(_pick_first(raw, "companyUrl", "companyHref") or _pick_first(coinfo, "webUrl")),
"logo_url": _clean_text(_pick_first(raw, "companyLogo") or _pick_first(coinfo, "logourl")),
"description": _clean_text(
_pick_first(raw, "companyDesc", "company_desc", "description")
or _nested_get(raw, "campusRootOrgInfo", "description")
or _pick_first(coinfo, "coinfo")
),
}
def _extract_zhilian_fields(raw: dict[str, Any], company_id: str) -> dict[str, Any]:
data = raw.get("data") if isinstance(raw.get("data"), dict) else raw
company_base = data.get("companyBase") or {}
detailed_company = data.get("detailedCompany") or {}
return {
"source_company_id": normalize_company_id(
"zhilian",
company_id
or _pick_first(company_base, "companyNumber", "number")
or _pick_first(detailed_company, "companyNumber", "number"),
),
"company_name": _clean_text(_pick_first(company_base, "companyName") or _pick_first(data, "companyName")) or "",
"company_type": _clean_text(
_pick_first(company_base, "companyTypeName", "companyType")
or _pick_first(detailed_company, "companyTypeName")
),
"industry": _clean_text(_pick_first(company_base, "industryName") or _pick_first(detailed_company, "industryName")),
"company_size": _clean_text(
_pick_first(company_base, "companySize", "companySizeString")
or _pick_first(detailed_company, "companySize")
),
"financing_stage": _clean_text(
_pick_first(company_base, "financingStageName")
or _nested_get(company_base, "financingStage", "name")
or _nested_get(detailed_company, "financingStage", "name")
),
"city": _clean_text(_pick_first(company_base, "cityName") or _pick_first(detailed_company, "cityName")),
"address": _clean_text(_pick_first(company_base, "address") or _pick_first(detailed_company, "address")),
"website": _clean_text(_pick_first(company_base, "companyUrl", "website")),
"logo_url": _clean_text(_pick_first(company_base, "logoUrl", "companyLogo")),
"description": _clean_text(
_pick_first(company_base, "companyDescWithHtml", "companyDesc")
or _pick_first(detailed_company, "companyDescription", "companyDesc")
),
}
def extract_company_fields(source: str, raw: dict[str, Any], company_id: str) -> dict[str, Any]:
if source == "boss":
return _extract_boss_fields(raw, company_id)
if source == "qcwy":
return _extract_qcwy_fields(raw, company_id)
if source == "zhilian":
return _extract_zhilian_fields(raw, company_id)
raise ValueError(f"unsupported source: {source}")
class CompanyStorageService:
@staticmethod
def company_model(source: str) -> Type[BaseCompanyModel]:
return _model_for_source(source)
async def get_existing_company_ids(self, source: str, company_ids: Iterable[str]) -> set[str]:
normalized_ids = [normalize_company_id(source, item) for item in company_ids if item]
if not normalized_ids:
return set()
model = self.company_model(source)
rows = await model.filter(source_company_id__in=normalized_ids).values_list("source_company_id", flat=True)
return set(rows)
async def get_all_company_ids(self, source: str) -> set[str]:
"""获取该平台所有已入库的公司 ID用于 ClickHouse 查询排除)"""
model = self.company_model(source)
rows = await model.all().values_list("source_company_id", flat=True)
return set(rows)
async def get_existing_queue_ids(self, source: str, company_ids: Iterable[str]) -> set[str]:
normalized_ids = [normalize_company_id(source, item) for item in company_ids if item]
if not normalized_ids:
return set()
rows = await CompanyCleaningQueue.filter(source=source, company_id__in=normalized_ids).values_list("company_id", flat=True)
return set(rows)
async def enqueue_company(self, source: str, company_id: str, company_name: str = "") -> tuple[CompanyCleaningQueue, bool]:
normalized_id = normalize_company_id(source, company_id)
defaults = {
"company_name": company_name or "",
"status": "pending",
"error_msg": "",
"retry_count": 0,
"started_at": None,
"finished_at": None,
"jobs_fetched": 0,
"jobs_stored": 0,
"jobs_duplicate": 0,
"jobs_failed": 0,
"jobs_error_msg": "",
}
queue, created = await CompanyCleaningQueue.get_or_create(
source=source,
company_id=normalized_id,
defaults=defaults,
)
if not created and company_name and queue.company_name != company_name:
queue.company_name = company_name
await queue.save(update_fields=["company_name", "updated_at"])
return queue, created
async def enqueue_companies(self, source: str, companies: Iterable[dict[str, str]]) -> int:
created_count = 0
for item in companies:
_, created = await self.enqueue_company(
source=source,
company_id=item.get("company_id", ""),
company_name=item.get("company_name", "") or "",
)
if created:
created_count += 1
return created_count
async def get_company_record(self, source: str, company_id: str) -> Optional[BaseCompanyModel]:
normalized_id = normalize_company_id(source, company_id)
model = self.company_model(source)
return await model.get_or_none(source_company_id=normalized_id)
async def upsert_company(
self,
source: str,
raw_data: dict[str, Any],
*,
company_id: Optional[str] = None,
) -> dict[str, Any]:
normalized_id = normalize_company_id(source, company_id or "")
fields = extract_company_fields(source, raw_data, normalized_id)
normalized_id = fields["source_company_id"]
if not normalized_id:
raise ValueError(f"missing normalized company id for source={source}")
if not fields["company_name"]:
raise ValueError(f"missing company name for source={source} company_id={normalized_id}")
model = self.company_model(source)
record = await model.get_or_none(source_company_id=normalized_id)
now = datetime.now()
payload = {
**fields,
"raw_json": raw_data,
"last_crawled_at": now,
}
if record:
for key, value in payload.items():
setattr(record, key, value)
await record.save()
created = False
else:
record = await model.create(
**payload,
first_crawled_at=now,
)
created = True
return {
"success": True,
"created": created,
"company_id": normalized_id,
"company_name": record.company_name,
"data_summary": {
"source": source,
"company_id": normalized_id,
"company_name": record.company_name,
"created": created,
},
"record": record,
}
async def mark_queue_processing(self, queue: CompanyCleaningQueue) -> None:
queue.status = "processing"
queue.error_msg = ""
queue.started_at = datetime.now()
queue.finished_at = None
queue.jobs_fetched = 0
queue.jobs_stored = 0
queue.jobs_duplicate = 0
queue.jobs_failed = 0
queue.jobs_error_msg = ""
await queue.save(
update_fields=[
"status",
"error_msg",
"started_at",
"finished_at",
"jobs_fetched",
"jobs_stored",
"jobs_duplicate",
"jobs_failed",
"jobs_error_msg",
"updated_at",
]
)
async def mark_queue_result(
self,
queue: CompanyCleaningQueue,
*,
status: str,
error_msg: str = "",
increment_retry: bool = False,
jobs_summary: Optional[dict[str, Any]] = None,
) -> None:
queue.status = status
queue.error_msg = error_msg or ""
queue.finished_at = datetime.now()
if jobs_summary:
queue.jobs_fetched = int(jobs_summary.get("jobs_fetched") or 0)
queue.jobs_stored = int(jobs_summary.get("stored_success") or 0)
queue.jobs_duplicate = int(jobs_summary.get("duplicate") or 0)
queue.jobs_failed = int(jobs_summary.get("failed") or 0)
queue.jobs_error_msg = jobs_summary.get("error") or ""
if increment_retry:
queue.retry_count += 1
await queue.save(
update_fields=[
"company_name",
"status",
"error_msg",
"retry_count",
"finished_at",
"jobs_fetched",
"jobs_stored",
"jobs_duplicate",
"jobs_failed",
"jobs_error_msg",
"updated_at",
]
)
company_storage = CompanyStorageService()