160 lines
3.4 KiB
Go
160 lines
3.4 KiB
Go
package repository
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"net/http"
|
|
"net/http/httptrace"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"log/slog"
|
|
)
|
|
|
|
const transportAuditEnv = "SUB2API_DEBUG_TRANSPORT_AUDIT"
|
|
|
|
type transportAuditState struct {
|
|
start time.Time
|
|
reused bool
|
|
wasIdle bool
|
|
idleTime time.Duration
|
|
}
|
|
|
|
type transportAuditRoundTripper struct {
|
|
base http.RoundTripper
|
|
label string
|
|
}
|
|
|
|
func transportAuditEnabled() bool {
|
|
switch strings.ToLower(strings.TrimSpace(os.Getenv(transportAuditEnv))) {
|
|
case "1", "true", "yes", "on":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func wrapTransportAuditIfEnabled(base http.RoundTripper, label string) http.RoundTripper {
|
|
if base == nil || !transportAuditEnabled() {
|
|
return base
|
|
}
|
|
return &transportAuditRoundTripper{
|
|
base: base,
|
|
label: strings.TrimSpace(label),
|
|
}
|
|
}
|
|
|
|
func (rt *transportAuditRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
if rt == nil || rt.base == nil || req == nil {
|
|
return nil, http.ErrUseLastResponse
|
|
}
|
|
|
|
state := &transportAuditState{start: time.Now()}
|
|
trace := &httptrace.ClientTrace{
|
|
GotConn: func(info httptrace.GotConnInfo) {
|
|
state.reused = info.Reused
|
|
state.wasIdle = info.WasIdle
|
|
state.idleTime = info.IdleTime
|
|
},
|
|
}
|
|
|
|
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
|
|
resp, err := rt.base.RoundTrip(req)
|
|
if err != nil {
|
|
slog.Debug("transport_audit_error",
|
|
"label", rt.label,
|
|
"method", req.Method,
|
|
"url", safeAuditURL(req),
|
|
"elapsed_ms", time.Since(state.start).Milliseconds(),
|
|
"conn_reused", state.reused,
|
|
"conn_was_idle", state.wasIdle,
|
|
"conn_idle_ms", state.idleTime.Milliseconds(),
|
|
"error", err.Error(),
|
|
)
|
|
return nil, err
|
|
}
|
|
|
|
tlsState := resp.TLS
|
|
if tlsState == nil && req.TLS != nil {
|
|
tlsState = req.TLS
|
|
}
|
|
|
|
slog.Debug("transport_audit",
|
|
"label", rt.label,
|
|
"method", req.Method,
|
|
"url", safeAuditURL(req),
|
|
"status", resp.StatusCode,
|
|
"proto", strings.TrimSpace(resp.Proto),
|
|
"elapsed_ms", time.Since(state.start).Milliseconds(),
|
|
"conn_reused", state.reused,
|
|
"conn_was_idle", state.wasIdle,
|
|
"conn_idle_ms", state.idleTime.Milliseconds(),
|
|
"tls_version", tlsVersionString(tlsState),
|
|
"tls_cipher", tlsCipherSuiteString(tlsState),
|
|
"tls_alpn", tlsNegotiatedProtocol(tlsState),
|
|
"tls_resumed", tlsDidResume(tlsState),
|
|
"tls_server_name", tlsServerName(tlsState),
|
|
)
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func safeAuditURL(req *http.Request) string {
|
|
if req == nil || req.URL == nil {
|
|
return ""
|
|
}
|
|
u := *req.URL
|
|
u.RawQuery = ""
|
|
u.Fragment = ""
|
|
return u.String()
|
|
}
|
|
|
|
func tlsVersionString(state *tls.ConnectionState) string {
|
|
if state == nil {
|
|
return ""
|
|
}
|
|
switch state.Version {
|
|
case tls.VersionTLS10:
|
|
return "TLS1.0"
|
|
case tls.VersionTLS11:
|
|
return "TLS1.1"
|
|
case tls.VersionTLS12:
|
|
return "TLS1.2"
|
|
case tls.VersionTLS13:
|
|
return "TLS1.3"
|
|
default:
|
|
if state.Version == 0 {
|
|
return ""
|
|
}
|
|
return "unknown"
|
|
}
|
|
}
|
|
|
|
func tlsCipherSuiteString(state *tls.ConnectionState) string {
|
|
if state == nil || state.CipherSuite == 0 {
|
|
return ""
|
|
}
|
|
return tls.CipherSuiteName(state.CipherSuite)
|
|
}
|
|
|
|
func tlsNegotiatedProtocol(state *tls.ConnectionState) string {
|
|
if state == nil {
|
|
return ""
|
|
}
|
|
return strings.TrimSpace(state.NegotiatedProtocol)
|
|
}
|
|
|
|
func tlsDidResume(state *tls.ConnectionState) bool {
|
|
if state == nil {
|
|
return false
|
|
}
|
|
return state.DidResume
|
|
}
|
|
|
|
func tlsServerName(state *tls.ConnectionState) string {
|
|
if state == nil {
|
|
return ""
|
|
}
|
|
return strings.TrimSpace(state.ServerName)
|
|
}
|