209 lines
4.9 KiB
Go
209 lines
4.9 KiB
Go
package windsurf
|
||
|
||
import (
|
||
"crypto/sha256"
|
||
"encoding/json"
|
||
"fmt"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
const (
|
||
poolTTL = 30 * time.Minute
|
||
poolMax = 500
|
||
)
|
||
|
||
type ConversationEntry struct {
|
||
CascadeID string
|
||
SessionID string
|
||
LSPort int
|
||
APIKey string
|
||
CreatedAt time.Time
|
||
LastAccess time.Time
|
||
}
|
||
|
||
type ConversationPool struct {
|
||
mu sync.Mutex
|
||
pool map[string]*ConversationEntry
|
||
stats poolStats
|
||
}
|
||
|
||
type poolStats struct {
|
||
Hits int `json:"hits"`
|
||
Misses int `json:"misses"`
|
||
Stores int `json:"stores"`
|
||
Evictions int `json:"evictions"`
|
||
Expired int `json:"expired"`
|
||
}
|
||
|
||
func NewConversationPool() *ConversationPool {
|
||
cp := &ConversationPool{
|
||
pool: make(map[string]*ConversationEntry),
|
||
}
|
||
go cp.pruneLoop()
|
||
return cp
|
||
}
|
||
|
||
func (cp *ConversationPool) Checkout(fingerprint string) *ConversationEntry {
|
||
if fingerprint == "" {
|
||
cp.mu.Lock()
|
||
cp.stats.Misses++
|
||
cp.mu.Unlock()
|
||
return nil
|
||
}
|
||
cp.mu.Lock()
|
||
defer cp.mu.Unlock()
|
||
entry, ok := cp.pool[fingerprint]
|
||
if !ok {
|
||
cp.stats.Misses++
|
||
return nil
|
||
}
|
||
delete(cp.pool, fingerprint)
|
||
if time.Since(entry.LastAccess) > poolTTL {
|
||
cp.stats.Expired++
|
||
cp.stats.Misses++
|
||
return nil
|
||
}
|
||
cp.stats.Hits++
|
||
return entry
|
||
}
|
||
|
||
func (cp *ConversationPool) Checkin(fingerprint string, entry *ConversationEntry) {
|
||
if fingerprint == "" || entry == nil {
|
||
return
|
||
}
|
||
now := time.Now()
|
||
cp.mu.Lock()
|
||
defer cp.mu.Unlock()
|
||
if entry.CreatedAt.IsZero() {
|
||
entry.CreatedAt = now
|
||
}
|
||
entry.LastAccess = now
|
||
cp.pool[fingerprint] = entry
|
||
cp.stats.Stores++
|
||
cp.pruneLocked(now)
|
||
}
|
||
|
||
func (cp *ConversationPool) InvalidateFor(apiKey string, lsPort int) int {
|
||
cp.mu.Lock()
|
||
defer cp.mu.Unlock()
|
||
dropped := 0
|
||
for fp, e := range cp.pool {
|
||
if (apiKey != "" && e.APIKey == apiKey) || (lsPort > 0 && e.LSPort == lsPort) {
|
||
delete(cp.pool, fp)
|
||
dropped++
|
||
}
|
||
}
|
||
return dropped
|
||
}
|
||
|
||
func (cp *ConversationPool) pruneLocked(now time.Time) {
|
||
for fp, e := range cp.pool {
|
||
if now.Sub(e.LastAccess) > poolTTL {
|
||
delete(cp.pool, fp)
|
||
cp.stats.Expired++
|
||
}
|
||
}
|
||
if len(cp.pool) <= poolMax {
|
||
return
|
||
}
|
||
// LRU eviction: find oldest entries
|
||
type fpTime struct {
|
||
fp string
|
||
t time.Time
|
||
}
|
||
entries := make([]fpTime, 0, len(cp.pool))
|
||
for fp, e := range cp.pool {
|
||
entries = append(entries, fpTime{fp, e.LastAccess})
|
||
}
|
||
// Simple sort by time
|
||
for i := 0; i < len(entries)-1; i++ {
|
||
for j := i + 1; j < len(entries); j++ {
|
||
if entries[j].t.Before(entries[i].t) {
|
||
entries[i], entries[j] = entries[j], entries[i]
|
||
}
|
||
}
|
||
}
|
||
toDrop := len(entries) - poolMax
|
||
for i := 0; i < toDrop; i++ {
|
||
delete(cp.pool, entries[i].fp)
|
||
cp.stats.Evictions++
|
||
}
|
||
}
|
||
|
||
func (cp *ConversationPool) pruneLoop() {
|
||
ticker := time.NewTicker(5 * time.Minute)
|
||
defer ticker.Stop()
|
||
for range ticker.C {
|
||
cp.mu.Lock()
|
||
cp.pruneLocked(time.Now())
|
||
cp.mu.Unlock()
|
||
}
|
||
}
|
||
|
||
// FingerprintBefore computes the fingerprint for resuming a conversation.
|
||
// Hash only user/tool turns (excluding the last one) for lookup.
|
||
// apiKey 参与 hash:cascade_id 绑定具体上游账号/LS,不同账号即使消息一致也不能复用,
|
||
// 否则 failover 切号后命中旧 cascade 会触发 "panel state not found"。
|
||
func FingerprintBefore(messages []ChatMessage, modelKey, apiKey string) string {
|
||
turns := stableTurns(messages)
|
||
if len(turns) < 2 {
|
||
return ""
|
||
}
|
||
return hashFingerprint(modelKey, apiKey, turns[:len(turns)-1])
|
||
}
|
||
|
||
// FingerprintAfter computes the fingerprint after a successful turn.
|
||
func FingerprintAfter(messages []ChatMessage, modelKey, apiKey string) string {
|
||
turns := stableTurns(messages)
|
||
if len(turns) == 0 {
|
||
return ""
|
||
}
|
||
return hashFingerprint(modelKey, apiKey, turns)
|
||
}
|
||
|
||
func stableTurns(messages []ChatMessage) []ChatMessage {
|
||
var turns []ChatMessage
|
||
for _, m := range messages {
|
||
if m.Role == "user" || m.Role == "tool" {
|
||
turns = append(turns, m)
|
||
}
|
||
}
|
||
return turns
|
||
}
|
||
|
||
func hashFingerprint(modelKey, apiKey string, turns []ChatMessage) string {
|
||
type canonicalImage struct {
|
||
MimeType string `json:"mime_type"`
|
||
SHA256 string `json:"sha256"`
|
||
ByteLen int `json:"byte_len"`
|
||
Caption string `json:"caption,omitempty"`
|
||
}
|
||
type canonical struct {
|
||
Role string `json:"role"`
|
||
Content string `json:"content"`
|
||
Images []canonicalImage `json:"images,omitempty"`
|
||
}
|
||
cans := make([]canonical, len(turns))
|
||
for i, t := range turns {
|
||
c := canonical{Role: t.Role, Content: t.Content}
|
||
// 指纹只使用 ImageDigests,永不使用 base64。
|
||
// 若 ImageDigests 为空则 canonical.Images 也省略(保持向后兼容:无图 hash 与旧版本一致)。
|
||
if len(t.ImageDigests) > 0 {
|
||
c.Images = make([]canonicalImage, len(t.ImageDigests))
|
||
for j, d := range t.ImageDigests {
|
||
c.Images[j] = canonicalImage{
|
||
MimeType: d.MimeType,
|
||
SHA256: d.SHA256,
|
||
ByteLen: d.ByteLen,
|
||
Caption: d.Caption,
|
||
}
|
||
}
|
||
}
|
||
cans[i] = c
|
||
}
|
||
data, _ := json.Marshal(cans)
|
||
h := sha256.Sum256([]byte(fmt.Sprintf("%s\x00\x00%s\x00\x00%s", modelKey, apiKey, data)))
|
||
return fmt.Sprintf("%x", h)
|
||
}
|