96 lines
2.4 KiB
Go
96 lines
2.4 KiB
Go
package repository
|
|
|
|
import (
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptrace"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
func TestWrapTransportAuditIfEnabledDisabled(t *testing.T) {
|
|
t.Setenv(transportAuditEnv, "")
|
|
|
|
base := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
|
if httptrace.ContextClientTrace(r.Context()) != nil {
|
|
t.Fatalf("unexpected client trace when audit is disabled")
|
|
}
|
|
return &http.Response{
|
|
StatusCode: 200,
|
|
Proto: "HTTP/1.1",
|
|
Body: io.NopCloser(strings.NewReader("ok")),
|
|
Request: r,
|
|
}, nil
|
|
})
|
|
|
|
wrapped := wrapTransportAuditIfEnabled(base, "plain")
|
|
if _, ok := wrapped.(*transportAuditRoundTripper); ok {
|
|
t.Fatalf("expected base transport when audit disabled")
|
|
}
|
|
|
|
req, err := http.NewRequest(http.MethodGet, "https://api.anthropic.com/v1/messages?beta=true", nil)
|
|
if err != nil {
|
|
t.Fatalf("new request: %v", err)
|
|
}
|
|
if _, err := wrapped.RoundTrip(req); err != nil {
|
|
t.Fatalf("round trip: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestWrapTransportAuditIfEnabledEnabled(t *testing.T) {
|
|
t.Setenv(transportAuditEnv, "1")
|
|
|
|
base := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
|
if httptrace.ContextClientTrace(r.Context()) == nil {
|
|
t.Fatalf("expected client trace when audit is enabled")
|
|
}
|
|
return &http.Response{
|
|
StatusCode: 200,
|
|
Proto: "HTTP/1.1",
|
|
Body: io.NopCloser(strings.NewReader("ok")),
|
|
Request: r,
|
|
}, nil
|
|
})
|
|
|
|
wrapped := wrapTransportAuditIfEnabled(base, "tlsfp")
|
|
if _, ok := wrapped.(*transportAuditRoundTripper); !ok {
|
|
t.Fatalf("expected wrapped transport when audit enabled")
|
|
}
|
|
|
|
req, err := http.NewRequest(http.MethodGet, "https://api.anthropic.com/v1/messages?beta=true", nil)
|
|
if err != nil {
|
|
t.Fatalf("new request: %v", err)
|
|
}
|
|
if _, err := wrapped.RoundTrip(req); err != nil {
|
|
t.Fatalf("round trip: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestTransportAuditEnabled(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
raw string
|
|
want bool
|
|
}{
|
|
{name: "empty", raw: "", want: false},
|
|
{name: "one", raw: "1", want: true},
|
|
{name: "true", raw: "true", want: true},
|
|
{name: "on", raw: "on", want: true},
|
|
{name: "no", raw: "no", want: false},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
if tc.raw == "" {
|
|
_ = os.Unsetenv(transportAuditEnv)
|
|
} else {
|
|
t.Setenv(transportAuditEnv, tc.raw)
|
|
}
|
|
if got := transportAuditEnabled(); got != tc.want {
|
|
t.Fatalf("transportAuditEnabled() = %v, want %v", got, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|