diff --git a/rt/rs/security/sso/oidc/pom.xml b/rt/rs/security/sso/oidc/pom.xml
index 5fcda3ae890..a808fdf37fe 100644
--- a/rt/rs/security/sso/oidc/pom.xml
+++ b/rt/rs/security/sso/oidc/pom.xml
@@ -38,7 +38,7 @@
- -javaagent:${org.apache.openjpa:openjpa:jar}
+ -javaagent:${org.apache.openjpa:openjpa:jar} -javaagent:${org.mockito:mockito-core:jar}
@@ -64,6 +64,12 @@
junit
test
+
+ org.mockito
+ mockito-core
+ ${cxf.mockito.version}
+ test
+
org.hsqldb
hsqldb
diff --git a/rt/rs/security/sso/oidc/src/main/java/org/apache/cxf/rs/security/oidc/rp/OidcRpAuthenticationFilter.java b/rt/rs/security/sso/oidc/src/main/java/org/apache/cxf/rs/security/oidc/rp/OidcRpAuthenticationFilter.java
index 547e91cd5c3..b08d907fe64 100644
--- a/rt/rs/security/sso/oidc/src/main/java/org/apache/cxf/rs/security/oidc/rp/OidcRpAuthenticationFilter.java
+++ b/rt/rs/security/sso/oidc/src/main/java/org/apache/cxf/rs/security/oidc/rp/OidcRpAuthenticationFilter.java
@@ -115,8 +115,33 @@ private MultivaluedMap toRequestState(ContainerRequestContext rc
rc.setEntityStream(new ByteArrayInputStream(StringUtils.toBytesUTF8(body)));
}
+ // The "state" carried here is read back by the sign-in completion service and returned
+ // as a redirect Location, so a caller-supplied value collides with the redirect query
+ // the filter itself writes. Anything that is not within this application's own origin is
+ // dropped, otherwise completion would become an open redirect.
+ String location = requestState.getFirst("state");
+ if (location != null && !isSameOrigin(rc, location)) {
+ requestState.remove("state");
+ }
return requestState;
}
+ private boolean isSameOrigin(ContainerRequestContext rc, String location) {
+ final URI uri;
+ try {
+ uri = URI.create(location);
+ } catch (IllegalArgumentException ex) {
+ return false;
+ }
+ if (uri.getScheme() == null && uri.getAuthority() == null) {
+ // a path-only reference is resolved by the browser against the current request
+ return true;
+ }
+ URI base = rc.getUriInfo().getAbsolutePath();
+ return uri.getScheme() != null
+ && uri.getScheme().equalsIgnoreCase(base.getScheme())
+ && uri.getAuthority() != null
+ && uri.getAuthority().equalsIgnoreCase(base.getAuthority());
+ }
public void setRedirectUri(String redirectUri) {
this.redirectUri = redirectUri;
}
diff --git a/rt/rs/security/sso/oidc/src/test/java/org/apache/cxf/rs/security/oidc/rp/OidcRpAuthenticationFilterTest.java b/rt/rs/security/sso/oidc/src/test/java/org/apache/cxf/rs/security/oidc/rp/OidcRpAuthenticationFilterTest.java
new file mode 100644
index 00000000000..8b2eeba7729
--- /dev/null
+++ b/rt/rs/security/sso/oidc/src/test/java/org/apache/cxf/rs/security/oidc/rp/OidcRpAuthenticationFilterTest.java
@@ -0,0 +1,106 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.cxf.rs.security.oidc.rp;
+
+import java.lang.reflect.Method;
+import java.net.URI;
+
+import jakarta.ws.rs.container.ContainerRequestContext;
+import jakarta.ws.rs.core.MultivaluedHashMap;
+import jakarta.ws.rs.core.MultivaluedMap;
+import jakarta.ws.rs.core.UriInfo;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class OidcRpAuthenticationFilterTest {
+
+ private static final URI ABSOLUTE_PATH = URI.create("https://app.example.com:8080/services/rp/complete");
+
+ @Test
+ public void testDropsCrossOriginState() {
+ MultivaluedMap state = requestState("https://evil.example.com/phish");
+ assertFalse(state.containsKey("state"));
+ }
+
+ @Test
+ public void testDropsProtocolRelativeState() {
+ MultivaluedMap state = requestState("//evil.example.com/phish");
+ assertFalse(state.containsKey("state"));
+ }
+
+ @Test
+ public void testDropsUserinfoHostConfusionState() {
+ MultivaluedMap state = requestState("https://app.example.com:8080@evil.example.com/phish");
+ assertFalse(state.containsKey("state"));
+ }
+
+ @Test
+ public void testDropsNoAuthority() {
+ MultivaluedMap state = requestState("http:/");
+ assertFalse(state.containsKey("state"));
+ }
+
+ @Test
+ public void testKeepsSameOriginState() {
+ MultivaluedMap state = requestState("https://app.example.com:8080/services/protected");
+ assertTrue(state.containsKey("state"));
+ assertEquals("https://app.example.com:8080/services/protected", state.getFirst("state"));
+ }
+
+ @Test
+ public void testKeepsRelativeState() {
+ MultivaluedMap state = requestState("/services/protected");
+ assertTrue(state.containsKey("state"));
+ assertEquals("/services/protected", state.getFirst("state"));
+ }
+
+ private MultivaluedMap requestState(String stateLocation) {
+ MultivaluedMap query = new MultivaluedHashMap<>();
+ query.putSingle("state", stateLocation);
+
+ UriInfo uriInfo = mock(UriInfo.class);
+ when(uriInfo.getQueryParameters(true)).thenReturn(query);
+ when(uriInfo.getAbsolutePath()).thenReturn(ABSOLUTE_PATH);
+
+ ContainerRequestContext rc = mock(ContainerRequestContext.class);
+ when(rc.getUriInfo()).thenReturn(uriInfo);
+ when(rc.getMediaType()).thenReturn(null);
+
+ return invokeToRequestState(new OidcRpAuthenticationFilter(), rc);
+ }
+
+ @SuppressWarnings("unchecked")
+ private static MultivaluedMap invokeToRequestState(OidcRpAuthenticationFilter filter,
+ ContainerRequestContext rc) {
+ try {
+ Method method = OidcRpAuthenticationFilter.class.getDeclaredMethod("toRequestState",
+ ContainerRequestContext.class);
+ method.setAccessible(true);
+ return (MultivaluedMap)method.invoke(filter, rc);
+ } catch (ReflectiveOperationException ex) {
+ throw new IllegalStateException(ex);
+ }
+ }
+}