Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions caddytest/integration/reverseproxy_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package integration

import (
"bufio"
"crypto/sha1"
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
"net/textproto"
"os"
"runtime"
"strings"
Expand Down Expand Up @@ -562,3 +567,129 @@ func TestReverseProxyHealthCheckUnixSocketWithoutPort(t *testing.T) {

tester.AssertGetResponse("http://localhost:9080/", 200, "Hello, World!")
}

func TestReverseProxyWebSocketUpgradeUnixSocket(t *testing.T) {
if runtime.GOOS == "windows" {
t.SkipNow()
}

f, err := os.CreateTemp("", "*.sock")
if err != nil {
t.Fatalf("failed to create temporary socket file: %v", err)
}
_ = os.Remove(f.Name())
socketName := f.Name()

backend := http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.URL.Path != "/ws" {
http.NotFound(w, req)
return
}

if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") ||
!strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
http.Error(w, "missing websocket upgrade headers", http.StatusBadRequest)
return
}

wsKey := req.Header.Get("Sec-WebSocket-Key")
if wsKey == "" {
http.Error(w, "missing Sec-WebSocket-Key", http.StatusBadRequest)
return
}

hj, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "hijacker not supported", http.StatusInternalServerError)
return
}

conn, brw, err := hj.Hijack()
if err != nil {
return
}
defer conn.Close()

_, _ = brw.WriteString("HTTP/1.1 101 Switching Protocols\r\n")
_, _ = brw.WriteString("Upgrade: websocket\r\n")
_, _ = brw.WriteString("Connection: Upgrade\r\n")
_, _ = brw.WriteString("Sec-WebSocket-Accept: " + computeWebSocketAccept(wsKey) + "\r\n")
_, _ = brw.WriteString("\r\n")
_ = brw.Flush()
}),
}

unixListener, err := net.Listen("unix", socketName)
if err != nil {
t.Fatalf("failed to listen on unix socket: %v", err)
}
go backend.Serve(unixListener)
t.Cleanup(func() {
_ = backend.Close()
_ = unixListener.Close()
_ = os.Remove(socketName)
})
runtime.Gosched()

tester := caddytest.NewTester(t)
tester.InitServer(fmt.Sprintf(`
{
skip_install_trust
admin localhost:2999
http_port 9080
https_port 9443
grace_period 1ns
}
http://localhost:9080 {
reverse_proxy unix/%s
}
`, socketName), "caddyfile")

conn, err := net.Dial("tcp", "127.0.0.1:9080")
if err != nil {
t.Fatalf("failed to dial caddy listener: %v", err)
}
defer conn.Close()

wsKey := "dGhlIHNhbXBsZSBub25jZQ=="
request := strings.Join([]string{
"GET /ws HTTP/1.1",
"Host: localhost:9080",
"Connection: Upgrade",
"Upgrade: websocket",
"Sec-WebSocket-Version: 13",
"Sec-WebSocket-Key: " + wsKey,
"",
"",
}, "\r\n")

if _, err := io.WriteString(conn, request); err != nil {
t.Fatalf("failed to send websocket handshake request: %v", err)
}

tpr := textproto.NewReader(bufio.NewReader(conn))
statusLine, err := tpr.ReadLine()
if err != nil {
t.Fatalf("failed reading handshake status line: %v", err)
}
if !strings.Contains(statusLine, "101") || !strings.Contains(strings.ToLower(statusLine), "switching protocols") {
t.Fatalf("unexpected status line: %q", statusLine)
}

headers, err := tpr.ReadMIMEHeader()
if err != nil {
t.Fatalf("failed reading handshake headers: %v", err)
}
if !strings.EqualFold(headers.Get("Upgrade"), "websocket") {
t.Fatalf("unexpected Upgrade header: %q", headers.Get("Upgrade"))
}
if !strings.Contains(strings.ToLower(headers.Get("Connection")), "upgrade") {
t.Fatalf("unexpected Connection header: %q", headers.Get("Connection"))
}
}

func computeWebSocketAccept(wsKey string) string {
h := sha1.Sum([]byte(wsKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
return base64.StdEncoding.EncodeToString(h[:])
}
Loading