diff --git a/src/server/redis_connection.h b/src/server/redis_connection.h index 4eba076e67a..7e339cec3c0 100644 --- a/src/server/redis_connection.h +++ b/src/server/redis_connection.h @@ -154,6 +154,7 @@ class Connection : public EvbufCallbackBase { bool IsAdmin() const { return is_admin_; } void BecomeAdmin() { is_admin_ = true; } void BecomeUser() { is_admin_ = false; } + void InitDefaultNamespace() { SetNamespace(kDefaultNamespace); } std::string GetNamespace() const { return ns_; } void SetNamespace(std::string ns) { ns_ = std::move(ns); } diff --git a/src/server/worker.cc b/src/server/worker.cc index 45eceb70577..c36c84ff9dd 100644 --- a/src/server/worker.cc +++ b/src/server/worker.cc @@ -176,9 +176,18 @@ void Worker::newTCPConnection(evconnlistener *listener, evutil_socket_t fd, [[ma } #endif auto conn = new redis::Connection(bev, this); + if (srv->GetConfig()->requirepass.empty()) { + conn->BecomeAdmin(); + conn->InitDefaultNamespace(); + } conn->SetCB(bev); bufferevent_enable(bev, EV_READ); + if (auto s = util::GetPeerAddr(fd)) { + auto [ip, port] = std::move(*s); + conn->SetAddr(ip, port); + } + s = AddConnection(conn); if (!s.IsOK()) { std::string err_msg = redis::Error({Status::NotOK, s.Msg()}); @@ -190,11 +199,6 @@ void Worker::newTCPConnection(evconnlistener *listener, evutil_socket_t fd, [[ma return; } - if (auto s = util::GetPeerAddr(fd)) { - auto [ip, port] = std::move(*s); - conn->SetAddr(ip, port); - } - if (rate_limit_group_) { bufferevent_add_to_rate_limit_group(bev, rate_limit_group_); } @@ -210,6 +214,10 @@ void Worker::newUnixSocketConnection(evconnlistener *listener, evutil_socket_t f bufferevent *bev = bufferevent_socket_new(base, fd, ev_thread_safe_flags); auto conn = new redis::Connection(bev, this); + if (srv->GetConfig()->requirepass.empty()) { + conn->BecomeAdmin(); + conn->InitDefaultNamespace(); + } conn->SetCB(bev); bufferevent_enable(bev, EV_READ); diff --git a/tests/gocase/integration/replication/replication_test.go b/tests/gocase/integration/replication/replication_test.go index 6e291ecc590..9b8f900eccb 100644 --- a/tests/gocase/integration/replication/replication_test.go +++ b/tests/gocase/integration/replication/replication_test.go @@ -329,7 +329,7 @@ func TestReplicationWithLimitSpeed(t *testing.T) { require.Eventually(t, func() bool { return slave.LogFileMatches(t, ".*skip count: 1.*") }, 50*time.Second, 1000*time.Millisecond) - util.WaitForSync(t, slaveClient) + util.WaitForOffsetSync(t, masterClient, slaveClient, 50*time.Second) require.Equal(t, "b", slaveClient.Get(ctx, "a").Val()) }) } diff --git a/tests/gocase/unit/auth/auth_test.go b/tests/gocase/unit/auth/auth_test.go index 562281487e9..ff5f5f37d05 100644 --- a/tests/gocase/unit/auth/auth_test.go +++ b/tests/gocase/unit/auth/auth_test.go @@ -21,7 +21,11 @@ package auth import ( "context" + "fmt" + "net" + "regexp" "testing" + "time" "github.com/apache/kvrocks/tests/gocase/util" "github.com/stretchr/testify/require" @@ -39,6 +43,30 @@ func TestNoAuth(t *testing.T) { r := rdb.Do(ctx, "AUTH", "foo") require.ErrorContains(t, r.Err(), "no password") }) + + t.Run("Connections accepted before requirepass is set remain usable", func(t *testing.T) { + idleConn := srv.NewTCPClient() + defer func() { require.NoError(t, idleConn.Close()) }() + + _, idlePort, err := net.SplitHostPort(idleConn.LocalAddr().String()) + require.NoError(t, err) + + idleConnPattern := regexp.MustCompile(fmt.Sprintf(`(?:^| )addr=[^ ]*:%s(?: |$)`, idlePort)) + require.Eventually(t, func() bool { + return idleConnPattern.MatchString(rdb.ClientList(ctx).Val()) + }, 5*time.Second, 10*time.Millisecond) + + require.NoError(t, rdb.ConfigSet(ctx, "requirepass", "foobar").Err()) + + require.NoError(t, idleConn.WriteArgs("PING")) + idleConn.MustRead(t, "+PONG") + + newConn := srv.NewTCPClient() + defer func() { require.NoError(t, newConn.Close()) }() + + require.NoError(t, newConn.WriteArgs("PING")) + newConn.MustRead(t, "-NOAUTH Authentication required.") + }) } func TestAuth(t *testing.T) { diff --git a/tests/gocase/util/tcp_client.go b/tests/gocase/util/tcp_client.go index 3cd9de2b1a9..dfc0f6cbcfd 100644 --- a/tests/gocase/util/tcp_client.go +++ b/tests/gocase/util/tcp_client.go @@ -59,6 +59,10 @@ func (c *TCPClient) Close() error { return c.c.Close() } +func (c *TCPClient) LocalAddr() net.Addr { + return c.c.LocalAddr() +} + func (c *TCPClient) ReadLine() (string, error) { r, err := c.r.ReadString('\n') if err != nil {