-
Notifications
You must be signed in to change notification settings - Fork 4.4k
feat(oauth): add stdio OAuth 2.1 login core library (1/4) #2704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
92db283
f707dba
27040d5
64afe52
5da25d8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,157 @@ | ||
| package oauth | ||
|
|
||
| import ( | ||
| "context" | ||
| "embed" | ||
| "fmt" | ||
| "html/template" | ||
| "net" | ||
| "net/http" | ||
| "time" | ||
| ) | ||
|
|
||
| //go:embed templates/*.html | ||
| var templateFS embed.FS | ||
|
|
||
| var ( | ||
| errorTemplate = template.Must(template.ParseFS(templateFS, "templates/error.html")) | ||
| successTemplate = template.Must(template.ParseFS(templateFS, "templates/success.html")) | ||
| ) | ||
|
|
||
| // callbackResult is delivered by the callback server once the browser redirect | ||
| // arrives. Exactly one of code or err is set. | ||
| type callbackResult struct { | ||
| code string | ||
| err error | ||
| } | ||
|
|
||
| // callbackServer is a short-lived local HTTP server that captures the | ||
| // authorization code from the OAuth redirect. | ||
| type callbackServer struct { | ||
| server *http.Server | ||
| listener net.Listener | ||
| redirect string | ||
| results chan callbackResult | ||
| } | ||
|
|
||
| // listenCallback binds the local callback listener. | ||
| // | ||
| // It binds to loopback (127.0.0.1) by default so the callback server is never | ||
| // exposed on other interfaces. bindAll is set only inside a container, where | ||
| // Docker's published-port DNAT delivers traffic to the container's eth0 rather | ||
| // than to loopback; host-side exposure is still constrained by the publish | ||
| // (e.g. -p 127.0.0.1:8085:8085). A native run — even with a fixed port — stays | ||
| // on loopback. | ||
| func listenCallback(port int, bindAll bool) (net.Listener, error) { | ||
| host := "127.0.0.1" | ||
| if bindAll { | ||
| host = "0.0.0.0" | ||
| } | ||
| addr := fmt.Sprintf("%s:%d", host, port) | ||
| listener, err := net.Listen("tcp", addr) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("starting callback listener on %s: %w", addr, err) | ||
| } | ||
| return listener, nil | ||
| } | ||
|
|
||
| // newCallbackServer starts a callback server on listener that validates state | ||
| // and reports the result on a buffered channel. The redirect URI always uses | ||
| // localhost so it matches the value registered on the OAuth/GitHub App. | ||
| func newCallbackServer(listener net.Listener, expectedState string) *callbackServer { | ||
| cs := &callbackServer{ | ||
| server: &http.Server{ReadHeaderTimeout: 10 * time.Second}, // ReadHeaderTimeout guards against Slowloris. | ||
| listener: listener, | ||
| redirect: fmt.Sprintf("http://localhost:%d/callback", listener.Addr().(*net.TCPAddr).Port), | ||
| results: make(chan callbackResult, 1), | ||
| } | ||
| cs.server.Handler = cs.handler(expectedState) | ||
|
|
||
| go func() { | ||
| if err := cs.server.Serve(listener); err != nil && err != http.ErrServerClosed { | ||
| cs.report(callbackResult{err: fmt.Errorf("callback server: %w", err)}) | ||
| } | ||
| }() | ||
|
|
||
| return cs | ||
| } | ||
|
|
||
| // handler renders the callback endpoint. It reports the outcome exactly once and | ||
| // always shows the user a friendly page. | ||
| func (cs *callbackServer) handler(expectedState string) http.Handler { | ||
| mux := http.NewServeMux() | ||
| mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { | ||
| q := r.URL.Query() | ||
|
|
||
| if errCode := q.Get("error"); errCode != "" { | ||
| msg := errCode | ||
| if desc := q.Get("error_description"); desc != "" { | ||
| msg = fmt.Sprintf("%s: %s", errCode, desc) | ||
| } | ||
| cs.report(callbackResult{err: fmt.Errorf("authorization failed: %s", msg)}) | ||
| renderError(w, msg) | ||
| return | ||
| } | ||
|
|
||
| if q.Get("state") != expectedState { | ||
| cs.report(callbackResult{err: fmt.Errorf("state mismatch (possible CSRF)")}) | ||
| renderError(w, "state mismatch") | ||
| return | ||
| } | ||
|
|
||
| code := q.Get("code") | ||
| if code == "" { | ||
| cs.report(callbackResult{err: fmt.Errorf("no authorization code in callback")}) | ||
| renderError(w, "no authorization code received") | ||
| return | ||
| } | ||
|
|
||
| cs.report(callbackResult{code: code}) | ||
| renderSuccess(w) | ||
| }) | ||
| return mux | ||
| } | ||
|
|
||
| // report delivers the first outcome and drops later ones (the channel is | ||
| // buffered for one; subsequent redirect retries must not block the handler). | ||
| func (cs *callbackServer) report(res callbackResult) { | ||
| select { | ||
| case cs.results <- res: | ||
| default: | ||
| } | ||
| } | ||
|
|
||
| // wait blocks for the callback outcome or ctx cancellation, then shuts the | ||
| // server down. It is safe to call once per server. | ||
| func (cs *callbackServer) wait(ctx context.Context) (string, error) { | ||
| defer cs.close() | ||
| select { | ||
| case res := <-cs.results: | ||
| return res.code, res.err | ||
| case <-ctx.Done(): | ||
| return "", ctx.Err() | ||
| } | ||
| } | ||
|
|
||
| func (cs *callbackServer) close() { | ||
| shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) | ||
| defer cancel() | ||
| _ = cs.server.Shutdown(shutdownCtx) | ||
| _ = cs.listener.Close() | ||
| } | ||
|
|
||
| func renderSuccess(w http.ResponseWriter) { | ||
| w.Header().Set("Content-Type", "text/html; charset=utf-8") | ||
| if err := successTemplate.Execute(w, nil); err != nil { | ||
| http.Error(w, "internal error", http.StatusInternalServerError) | ||
| } | ||
| } | ||
|
|
||
| // renderError shows the failure page. html/template auto-escapes msg, so a | ||
| // hostile error_description cannot inject markup. | ||
| func renderError(w http.ResponseWriter, msg string) { | ||
| w.Header().Set("Content-Type", "text/html; charset=utf-8") | ||
| if err := errorTemplate.Execute(w, struct{ ErrorMessage string }{ErrorMessage: msg}); err != nil { | ||
| http.Error(w, "internal error", http.StatusInternalServerError) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| package oauth | ||
|
|
||
| import ( | ||
| "net" | ||
| "net/http" | ||
| "net/http/httptest" | ||
| "testing" | ||
|
|
||
| "github.com/stretchr/testify/assert" | ||
| "github.com/stretchr/testify/require" | ||
| ) | ||
|
|
||
| // serveCallback drives the callback handler with the given query string and | ||
| // returns the recorded response and the single reported result. | ||
| func serveCallback(t *testing.T, expectedState, query string) (*httptest.ResponseRecorder, callbackResult) { | ||
| t.Helper() | ||
| cs := &callbackServer{results: make(chan callbackResult, 1)} | ||
| rec := httptest.NewRecorder() | ||
| req := httptest.NewRequest(http.MethodGet, "/callback?"+query, nil) | ||
|
|
||
| cs.handler(expectedState).ServeHTTP(rec, req) | ||
|
|
||
| select { | ||
| case res := <-cs.results: | ||
| return rec, res | ||
| default: | ||
| t.Fatal("handler did not report a result") | ||
| return nil, callbackResult{} | ||
| } | ||
| } | ||
|
|
||
| func TestCallbackHandlerSuccess(t *testing.T) { | ||
| rec, res := serveCallback(t, "state123", "code=the-code&state=state123") | ||
|
|
||
| require.NoError(t, res.err) | ||
| assert.Equal(t, "the-code", res.code) | ||
| assert.Equal(t, http.StatusOK, rec.Code) | ||
| assert.Contains(t, rec.Body.String(), "Authorization Successful") | ||
| } | ||
|
|
||
| func TestCallbackHandlerStateMismatch(t *testing.T) { | ||
| rec, res := serveCallback(t, "expected", "code=the-code&state=attacker") | ||
|
|
||
| require.Error(t, res.err) | ||
| assert.Empty(t, res.code) | ||
| assert.Contains(t, res.err.Error(), "state mismatch") | ||
| assert.Contains(t, rec.Body.String(), "state mismatch") | ||
| } | ||
|
|
||
| func TestCallbackHandlerMissingCode(t *testing.T) { | ||
| _, res := serveCallback(t, "state123", "state=state123") | ||
|
|
||
| require.Error(t, res.err) | ||
| assert.Contains(t, res.err.Error(), "no authorization code") | ||
| } | ||
|
|
||
| func TestCallbackHandlerOAuthError(t *testing.T) { | ||
| _, res := serveCallback(t, "state123", "error=access_denied&error_description=user+said+no") | ||
|
|
||
| require.Error(t, res.err) | ||
| assert.Contains(t, res.err.Error(), "access_denied") | ||
| assert.Contains(t, res.err.Error(), "user said no") | ||
| } | ||
|
|
||
| func TestCallbackHandlerEscapesError(t *testing.T) { | ||
| rec, _ := serveCallback(t, "state123", "error=evil&error_description=%3Cscript%3Ealert(1)%3C%2Fscript%3E") | ||
|
|
||
| body := rec.Body.String() | ||
| assert.NotContains(t, body, "<script>", "error message must be HTML-escaped") | ||
| assert.Contains(t, body, "<script>") | ||
| } | ||
|
|
||
| func TestListenCallbackRandomPortIsLoopback(t *testing.T) { | ||
| listener, err := listenCallback(0, false) | ||
| require.NoError(t, err) | ||
| defer listener.Close() | ||
|
Comment on lines
+73
to
+76
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated — the test helper now calls |
||
|
|
||
| addr, ok := listener.Addr().(*net.TCPAddr) | ||
| require.True(t, ok) | ||
| assert.True(t, addr.IP.IsLoopback(), "default bind must be loopback only, got %s", addr.IP) | ||
| assert.NotZero(t, addr.Port) | ||
| } | ||
|
|
||
| func TestListenCallbackBindAllForContainer(t *testing.T) { | ||
| listener, err := listenCallback(0, true) | ||
| require.NoError(t, err) | ||
| defer listener.Close() | ||
|
|
||
| addr, ok := listener.Addr().(*net.TCPAddr) | ||
| require.True(t, ok) | ||
| assert.True(t, addr.IP.IsUnspecified(), "bindAll must bind all interfaces, got %s", addr.IP) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| package oauth | ||
|
|
||
| import ( | ||
| "errors" | ||
| "fmt" | ||
| "io" | ||
| "os" | ||
| "os/exec" | ||
| "runtime" | ||
| "strings" | ||
| ) | ||
|
|
||
| // errNoDisplay reports that the host has no display server, so no browser can be | ||
| // launched. It is a definitive headless signal (unlike a generic launch error), | ||
| // which lets the flow prefer device authorization — the only channel reachable | ||
| // from a browser on another machine (e.g. a remote SSH session). | ||
| var errNoDisplay = errors.New("no display server detected") | ||
|
|
||
| // openBrowser tries to open url in the user's default browser. It returns an | ||
| // error when no browser can plausibly be launched so the caller can fall back | ||
| // to elicitation. On Linux it treats a headless session (no display server) as | ||
| // unopenable, which is the common case for SSH and containers. | ||
| func openBrowser(url string) error { | ||
| var cmd *exec.Cmd | ||
| switch runtime.GOOS { | ||
| case "linux": | ||
| if os.Getenv("DISPLAY") == "" && os.Getenv("WAYLAND_DISPLAY") == "" { | ||
| return errNoDisplay | ||
| } | ||
| cmd = exec.Command("xdg-open", url) | ||
| case "darwin": | ||
| cmd = exec.Command("open", url) | ||
| case "windows": | ||
| cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url) | ||
| default: | ||
| return fmt.Errorf("unsupported platform: %s", runtime.GOOS) | ||
| } | ||
|
|
||
| cmd.Stdout = io.Discard | ||
| cmd.Stderr = io.Discard | ||
| if err := cmd.Start(); err != nil { | ||
| return err | ||
| } | ||
| // The launcher (xdg-open/open/rundll32) exits as soon as it hands off to the | ||
| // browser. Reap it asynchronously so it does not linger as a zombie for the | ||
| // lifetime of this long-running server. | ||
| go func() { _ = cmd.Wait() }() | ||
| return nil | ||
| } | ||
|
Comment on lines
+39
to
+49
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch — fixed in f707dba. |
||
|
|
||
| // isRunningInDocker reports whether the process is running inside a Docker (or | ||
| // containerd) container. Detection relies on Linux-specific paths and is always | ||
| // false elsewhere. It is used only to skip a PKCE flow that cannot work: a | ||
| // random callback port inside a container cannot be reached from the host | ||
| // browser, so we go straight to device flow in that case. | ||
| func isRunningInDocker() bool { | ||
| if runtime.GOOS != "linux" { | ||
| return false | ||
| } | ||
| if _, err := os.Stat("/.dockerenv"); err == nil { | ||
| return true | ||
| } | ||
| if data, err := os.ReadFile("/proc/1/cgroup"); err == nil { | ||
| s := string(data) | ||
| if strings.Contains(s, "docker") || strings.Contains(s, "containerd") { | ||
| return true | ||
| } | ||
| } | ||
| return false | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in f707dba.
listenCallbacknow takes an explicitbindAllflag and binds to0.0.0.0only when running inside a container;beginPKCEpassesm.inDocker(). Native runs — even with a fixed callback port — stay on127.0.0.1. (PKCE in a container only happens with a fixed port, since a random port there falls back to device flow, so this is exactly the publish-via-eth0 case that needs all-interfaces.) Call site and test updated accordingly.