sub2api/backend/internal/repository/transport_audit.go
win 435ae221bc
Some checks failed
CI / test (push) Failing after 1m32s
CI / golangci-lint (push) Failing after 31s
Security Scan / backend-security (push) Failing after 1m32s
Security Scan / frontend-security (push) Failing after 9s
x
2026-04-16 19:11:47 +08:00

160 lines
3.4 KiB
Go

package repository
import (
"crypto/tls"
"net/http"
"net/http/httptrace"
"os"
"strings"
"time"
"log/slog"
)
const transportAuditEnv = "SUB2API_DEBUG_TRANSPORT_AUDIT"
type transportAuditState struct {
start time.Time
reused bool
wasIdle bool
idleTime time.Duration
}
type transportAuditRoundTripper struct {
base http.RoundTripper
label string
}
func transportAuditEnabled() bool {
switch strings.ToLower(strings.TrimSpace(os.Getenv(transportAuditEnv))) {
case "1", "true", "yes", "on":
return true
default:
return false
}
}
func wrapTransportAuditIfEnabled(base http.RoundTripper, label string) http.RoundTripper {
if base == nil || !transportAuditEnabled() {
return base
}
return &transportAuditRoundTripper{
base: base,
label: strings.TrimSpace(label),
}
}
func (rt *transportAuditRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if rt == nil || rt.base == nil || req == nil {
return nil, http.ErrUseLastResponse
}
state := &transportAuditState{start: time.Now()}
trace := &httptrace.ClientTrace{
GotConn: func(info httptrace.GotConnInfo) {
state.reused = info.Reused
state.wasIdle = info.WasIdle
state.idleTime = info.IdleTime
},
}
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
resp, err := rt.base.RoundTrip(req)
if err != nil {
slog.Debug("transport_audit_error",
"label", rt.label,
"method", req.Method,
"url", safeAuditURL(req),
"elapsed_ms", time.Since(state.start).Milliseconds(),
"conn_reused", state.reused,
"conn_was_idle", state.wasIdle,
"conn_idle_ms", state.idleTime.Milliseconds(),
"error", err.Error(),
)
return nil, err
}
tlsState := resp.TLS
if tlsState == nil && req.TLS != nil {
tlsState = req.TLS
}
slog.Debug("transport_audit",
"label", rt.label,
"method", req.Method,
"url", safeAuditURL(req),
"status", resp.StatusCode,
"proto", strings.TrimSpace(resp.Proto),
"elapsed_ms", time.Since(state.start).Milliseconds(),
"conn_reused", state.reused,
"conn_was_idle", state.wasIdle,
"conn_idle_ms", state.idleTime.Milliseconds(),
"tls_version", tlsVersionString(tlsState),
"tls_cipher", tlsCipherSuiteString(tlsState),
"tls_alpn", tlsNegotiatedProtocol(tlsState),
"tls_resumed", tlsDidResume(tlsState),
"tls_server_name", tlsServerName(tlsState),
)
return resp, nil
}
func safeAuditURL(req *http.Request) string {
if req == nil || req.URL == nil {
return ""
}
u := *req.URL
u.RawQuery = ""
u.Fragment = ""
return u.String()
}
func tlsVersionString(state *tls.ConnectionState) string {
if state == nil {
return ""
}
switch state.Version {
case tls.VersionTLS10:
return "TLS1.0"
case tls.VersionTLS11:
return "TLS1.1"
case tls.VersionTLS12:
return "TLS1.2"
case tls.VersionTLS13:
return "TLS1.3"
default:
if state.Version == 0 {
return ""
}
return "unknown"
}
}
func tlsCipherSuiteString(state *tls.ConnectionState) string {
if state == nil || state.CipherSuite == 0 {
return ""
}
return tls.CipherSuiteName(state.CipherSuite)
}
func tlsNegotiatedProtocol(state *tls.ConnectionState) string {
if state == nil {
return ""
}
return strings.TrimSpace(state.NegotiatedProtocol)
}
func tlsDidResume(state *tls.ConnectionState) bool {
if state == nil {
return false
}
return state.DidResume
}
func tlsServerName(state *tls.ConnectionState) string {
if state == nil {
return ""
}
return strings.TrimSpace(state.ServerName)
}