sub2api/backend/internal/pkg/windsurf/conversation_pool.go

209 lines
4.9 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package windsurf
import (
"crypto/sha256"
"encoding/json"
"fmt"
"sync"
"time"
)
const (
poolTTL = 30 * time.Minute
poolMax = 500
)
type ConversationEntry struct {
CascadeID string
SessionID string
LSPort int
APIKey string
CreatedAt time.Time
LastAccess time.Time
}
type ConversationPool struct {
mu sync.Mutex
pool map[string]*ConversationEntry
stats poolStats
}
type poolStats struct {
Hits int `json:"hits"`
Misses int `json:"misses"`
Stores int `json:"stores"`
Evictions int `json:"evictions"`
Expired int `json:"expired"`
}
func NewConversationPool() *ConversationPool {
cp := &ConversationPool{
pool: make(map[string]*ConversationEntry),
}
go cp.pruneLoop()
return cp
}
func (cp *ConversationPool) Checkout(fingerprint string) *ConversationEntry {
if fingerprint == "" {
cp.mu.Lock()
cp.stats.Misses++
cp.mu.Unlock()
return nil
}
cp.mu.Lock()
defer cp.mu.Unlock()
entry, ok := cp.pool[fingerprint]
if !ok {
cp.stats.Misses++
return nil
}
delete(cp.pool, fingerprint)
if time.Since(entry.LastAccess) > poolTTL {
cp.stats.Expired++
cp.stats.Misses++
return nil
}
cp.stats.Hits++
return entry
}
func (cp *ConversationPool) Checkin(fingerprint string, entry *ConversationEntry) {
if fingerprint == "" || entry == nil {
return
}
now := time.Now()
cp.mu.Lock()
defer cp.mu.Unlock()
if entry.CreatedAt.IsZero() {
entry.CreatedAt = now
}
entry.LastAccess = now
cp.pool[fingerprint] = entry
cp.stats.Stores++
cp.pruneLocked(now)
}
func (cp *ConversationPool) InvalidateFor(apiKey string, lsPort int) int {
cp.mu.Lock()
defer cp.mu.Unlock()
dropped := 0
for fp, e := range cp.pool {
if (apiKey != "" && e.APIKey == apiKey) || (lsPort > 0 && e.LSPort == lsPort) {
delete(cp.pool, fp)
dropped++
}
}
return dropped
}
func (cp *ConversationPool) pruneLocked(now time.Time) {
for fp, e := range cp.pool {
if now.Sub(e.LastAccess) > poolTTL {
delete(cp.pool, fp)
cp.stats.Expired++
}
}
if len(cp.pool) <= poolMax {
return
}
// LRU eviction: find oldest entries
type fpTime struct {
fp string
t time.Time
}
entries := make([]fpTime, 0, len(cp.pool))
for fp, e := range cp.pool {
entries = append(entries, fpTime{fp, e.LastAccess})
}
// Simple sort by time
for i := 0; i < len(entries)-1; i++ {
for j := i + 1; j < len(entries); j++ {
if entries[j].t.Before(entries[i].t) {
entries[i], entries[j] = entries[j], entries[i]
}
}
}
toDrop := len(entries) - poolMax
for i := 0; i < toDrop; i++ {
delete(cp.pool, entries[i].fp)
cp.stats.Evictions++
}
}
func (cp *ConversationPool) pruneLoop() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
cp.mu.Lock()
cp.pruneLocked(time.Now())
cp.mu.Unlock()
}
}
// FingerprintBefore computes the fingerprint for resuming a conversation.
// Hash only user/tool turns (excluding the last one) for lookup.
// apiKey 参与 hashcascade_id 绑定具体上游账号/LS不同账号即使消息一致也不能复用
// 否则 failover 切号后命中旧 cascade 会触发 "panel state not found"。
func FingerprintBefore(messages []ChatMessage, modelKey, apiKey string) string {
turns := stableTurns(messages)
if len(turns) < 2 {
return ""
}
return hashFingerprint(modelKey, apiKey, turns[:len(turns)-1])
}
// FingerprintAfter computes the fingerprint after a successful turn.
func FingerprintAfter(messages []ChatMessage, modelKey, apiKey string) string {
turns := stableTurns(messages)
if len(turns) == 0 {
return ""
}
return hashFingerprint(modelKey, apiKey, turns)
}
func stableTurns(messages []ChatMessage) []ChatMessage {
var turns []ChatMessage
for _, m := range messages {
if m.Role == "user" || m.Role == "tool" {
turns = append(turns, m)
}
}
return turns
}
func hashFingerprint(modelKey, apiKey string, turns []ChatMessage) string {
type canonicalImage struct {
MimeType string `json:"mime_type"`
SHA256 string `json:"sha256"`
ByteLen int `json:"byte_len"`
Caption string `json:"caption,omitempty"`
}
type canonical struct {
Role string `json:"role"`
Content string `json:"content"`
Images []canonicalImage `json:"images,omitempty"`
}
cans := make([]canonical, len(turns))
for i, t := range turns {
c := canonical{Role: t.Role, Content: t.Content}
// 指纹只使用 ImageDigests永不使用 base64。
// 若 ImageDigests 为空则 canonical.Images 也省略(保持向后兼容:无图 hash 与旧版本一致)。
if len(t.ImageDigests) > 0 {
c.Images = make([]canonicalImage, len(t.ImageDigests))
for j, d := range t.ImageDigests {
c.Images[j] = canonicalImage{
MimeType: d.MimeType,
SHA256: d.SHA256,
ByteLen: d.ByteLen,
Caption: d.Caption,
}
}
}
cans[i] = c
}
data, _ := json.Marshal(cans)
h := sha256.Sum256([]byte(fmt.Sprintf("%s\x00\x00%s\x00\x00%s", modelKey, apiKey, data)))
return fmt.Sprintf("%x", h)
}