diff --git a/go.mod b/go.mod index ad236e09e..91b81bc54 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/github/gh-ost go 1.25.9 require ( + github.com/DataDog/datadog-go/v5 v5.8.3 github.com/go-ini/ini v1.67.0 github.com/go-mysql-org/go-mysql v1.11.0 github.com/go-sql-driver/mysql v1.8.1 diff --git a/go.sum b/go.sum index 39fd9cd48..e39b587e8 100644 --- a/go.sum +++ b/go.sum @@ -7,8 +7,11 @@ github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/DataDog/datadog-go/v5 v5.8.3 h1:s58CUJ9s8lezjhTNJO/SxkPBv2qZjS3ktpRSqGF5n0s= +github.com/DataDog/datadog-go/v5 v5.8.3/go.mod h1:K9kcYBlxkcPP8tvvjZZKs/m1edNAUFzBbdpTUKfCsuw= github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= +github.com/Microsoft/go-winio v0.5.0/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= @@ -54,6 +57,7 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= @@ -120,14 +124,21 @@ github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXY github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/siddontang/go-log v0.0.0-20180807004314-8d05993dda07 h1:oI+RNwuC9jF2g2lP0u0cVEEZrc/AYBCuFdvwrLWM/6Q= github.com/siddontang/go-log v0.0.0-20180807004314-8d05993dda07/go.mod h1:yFdBgwXP24JziuRl2NMUahT7nGLNOKi1SIiFxMttVD4= +github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/testcontainers/testcontainers-go v0.37.0 h1:L2Qc0vkTw2EHWQ08djon0D2uw7Z/PtHS/QzZZ5Ra/hg= @@ -140,6 +151,7 @@ github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+F github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= @@ -183,29 +195,38 @@ golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -221,6 +242,7 @@ golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/go/base/context.go b/go/base/context.go index 617e5bb13..0e46d4232 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -19,6 +19,7 @@ import ( uuid "github.com/google/uuid" + "github.com/github/gh-ost/go/metrics" "github.com/github/gh-ost/go/mysql" "github.com/github/gh-ost/go/sql" "github.com/openark/golib/log" @@ -176,6 +177,9 @@ type MigrationContext struct { CutOverType CutOver ReplicaServerId uint + // Number of workers used by the trx coordinator + NumWorkers int + Hostname string AssumeMasterHostname string ApplierTimeZone string @@ -237,6 +241,8 @@ type MigrationContext struct { AbortError error abortMutex *sync.Mutex + Metrics *metrics.Client + OriginalTableColumnsOnApplier *sql.ColumnList OriginalTableColumns *sql.ColumnList OriginalTableVirtualColumns *sql.ColumnList diff --git a/go/binlog/gomysql_reader.go b/go/binlog/gomysql_reader.go index 189a5f399..672f71263 100644 --- a/go/binlog/gomysql_reader.go +++ b/go/binlog/gomysql_reader.go @@ -1,6 +1,6 @@ /* Copyright 2022 GitHub Inc. - See https://github.com/github/gh-ost/blob/master/LICENSE + See https://github.com/github/gh-ost/blob/master/LICENSE */ package binlog @@ -11,7 +11,6 @@ import ( "github.com/github/gh-ost/go/base" "github.com/github/gh-ost/go/mysql" - "github.com/github/gh-ost/go/sql" "time" @@ -57,7 +56,7 @@ func NewGoMySQLReader(migrationContext *base.MigrationContext) *GoMySQLReader { // ConnectBinlogStreamer func (gmr *GoMySQLReader) ConnectBinlogStreamer(coordinates mysql.BinlogCoordinates) (err error) { if coordinates.IsEmpty() { - return gmr.migrationContext.Log.Errorf("empty coordinates at ConnectBinlogStreamer()") + return gmr.migrationContext.Log.Errorf("Empty coordinates at ConnectBinlogStreamer()") } gmr.currentCoordinatesMutex.Lock() @@ -85,53 +84,17 @@ func (gmr *GoMySQLReader) GetCurrentBinlogCoordinates() mysql.BinlogCoordinates return gmr.currentCoordinates.Clone() } -func (gmr *GoMySQLReader) handleRowsEvent(ev *replication.BinlogEvent, rowsEvent *replication.RowsEvent, entriesChannel chan<- *BinlogEntry) error { - currentCoords := gmr.GetCurrentBinlogCoordinates() - dml := ToEventDML(ev.Header.EventType.String()) - if dml == NotDML { - return fmt.Errorf("unknown DML type: %s", ev.Header.EventType.String()) - } - for i, row := range rowsEvent.Rows { - if dml == UpdateDML && i%2 == 1 { - // An update has two rows (WHERE+SET) - // We do both at the same time - continue +// StreamEvents reads binlog events and sends them to the given channel. +// It is blocking and should be executed in a goroutine. +func (gmr *GoMySQLReader) StreamEvents(ctx context.Context, canStopStreaming func() bool, eventChannel chan<- *replication.BinlogEvent) error { + for { + if canStopStreaming() { + return nil } - binlogEntry := NewBinlogEntryAt(currentCoords) - binlogEntry.DmlEvent = NewBinlogDMLEvent( - string(rowsEvent.Table.Schema), - string(rowsEvent.Table.Table), - dml, - ) - switch dml { - case InsertDML: - { - binlogEntry.DmlEvent.NewColumnValues = sql.ToColumnValues(row) - } - case UpdateDML: - { - binlogEntry.DmlEvent.WhereColumnValues = sql.ToColumnValues(row) - binlogEntry.DmlEvent.NewColumnValues = sql.ToColumnValues(rowsEvent.Rows[i+1]) - } - case DeleteDML: - { - binlogEntry.DmlEvent.WhereColumnValues = sql.ToColumnValues(row) - } + if err := ctx.Err(); err != nil { + return err } - - // The channel will do the throttling. Whoever is reading from the channel - // decides whether action is taken synchronously (meaning we wait before - // next iteration) or asynchronously (we keep pushing more events) - // In reality, reads will be synchronous - entriesChannel <- binlogEntry - } - return nil -} - -// StreamEvents -func (gmr *GoMySQLReader) StreamEvents(canStopStreaming func() bool, entriesChannel chan<- *BinlogEntry) error { - for !canStopStreaming() { - ev, err := gmr.binlogStreamer.GetEvent(context.Background()) + ev, err := gmr.binlogStreamer.GetEvent(ctx) if err != nil { return err } @@ -153,45 +116,38 @@ func (gmr *GoMySQLReader) StreamEvents(canStopStreaming func() bool, entriesChan switch event := ev.Event.(type) { case *replication.GTIDEvent: - if !gmr.migrationContext.UseGTIDs { - continue - } - sid, err := uuid.FromBytes(event.SID) - if err != nil { - return err - } - gmr.currentCoordinatesMutex.Lock() - if gmr.LastTrxCoords != nil { - gmr.currentCoordinates = gmr.LastTrxCoords.Clone() + if gmr.migrationContext.UseGTIDs { + sid, err := uuid.FromBytes(event.SID) + if err != nil { + return err + } + gmr.currentCoordinatesMutex.Lock() + if gmr.LastTrxCoords != nil { + gmr.currentCoordinates = gmr.LastTrxCoords.Clone() + } + coords := gmr.currentCoordinates.(*mysql.GTIDBinlogCoordinates) + trxGset := gomysql.NewUUIDSet(sid, gomysql.Interval{Start: event.GNO, Stop: event.GNO + 1}) + coords.GTIDSet.AddSet(trxGset) + gmr.currentCoordinatesMutex.Unlock() } - coords := gmr.currentCoordinates.(*mysql.GTIDBinlogCoordinates) - trxGset := gomysql.NewUUIDSet(sid, gomysql.Interval{Start: event.GNO, Stop: event.GNO + 1}) - coords.GTIDSet.AddSet(trxGset) - gmr.currentCoordinatesMutex.Unlock() case *replication.RotateEvent: - if gmr.migrationContext.UseGTIDs { - continue + if !gmr.migrationContext.UseGTIDs { + gmr.currentCoordinatesMutex.Lock() + coords := gmr.currentCoordinates.(*mysql.FileBinlogCoordinates) + coords.LogFile = string(event.NextLogName) + gmr.migrationContext.Log.Infof("rotate to next log from %s:%d to %s", coords.LogFile, int64(ev.Header.LogPos), event.NextLogName) + gmr.currentCoordinatesMutex.Unlock() } - gmr.currentCoordinatesMutex.Lock() - coords := gmr.currentCoordinates.(*mysql.FileBinlogCoordinates) - coords.LogFile = string(event.NextLogName) - gmr.migrationContext.Log.Infof("rotate to next log from %s:%d to %s", coords.LogFile, int64(ev.Header.LogPos), event.NextLogName) - gmr.currentCoordinatesMutex.Unlock() case *replication.XIDEvent: if gmr.migrationContext.UseGTIDs { gmr.LastTrxCoords = &mysql.GTIDBinlogCoordinates{GTIDSet: event.GSet.(*gomysql.MysqlGTIDSet)} } else { gmr.LastTrxCoords = gmr.currentCoordinates.Clone() } - case *replication.RowsEvent: - if err := gmr.handleRowsEvent(ev, event, entriesChannel); err != nil { - return err - } } - } - gmr.migrationContext.Log.Debugf("done streaming events") - return nil + eventChannel <- ev + } } func (gmr *GoMySQLReader) Close() error { diff --git a/go/cmd/gh-ost/main.go b/go/cmd/gh-ost/main.go index 567137fd5..0ea866bd0 100644 --- a/go/cmd/gh-ost/main.go +++ b/go/cmd/gh-ost/main.go @@ -13,9 +13,11 @@ import ( "os/signal" "regexp" "syscall" + "time" "github.com/github/gh-ost/go/base" "github.com/github/gh-ost/go/logic" + "github.com/github/gh-ost/go/metrics" "github.com/github/gh-ost/go/sql" _ "github.com/go-sql-driver/mysql" "github.com/openark/golib/log" @@ -25,6 +27,20 @@ import ( var AppVersion, GitCommit string +type statsdTagList []string + +func (s *statsdTagList) String() string { + if s == nil || len(*s) == 0 { + return "" + } + return fmt.Sprint([]string(*s)) +} + +func (s *statsdTagList) Set(value string) error { + *s = append(*s, value) + return nil +} + // acceptSignals registers for OS signals func acceptSignals(migrationContext *base.MigrationContext) { c := make(chan os.Signal, 1) @@ -112,6 +128,7 @@ func main() { flag.BoolVar(&migrationContext.PanicOnWarnings, "panic-on-warnings", false, "Panic when SQL warnings are encountered when copying a batch indicating data loss") cutOverLockTimeoutSeconds := flag.Int64("cut-over-lock-timeout-seconds", 3, "Max number of seconds to hold locks on tables while attempting to cut-over (retry attempted when lock exceeds timeout) or attempting instant DDL") niceRatio := flag.Float64("nice-ratio", 0, "force being 'nice', imply sleep time per chunk time; range: [0.0..100.0]. Example values: 0 is aggressive. 1: for every 1ms spent copying rows, sleep additional 1ms (effectively doubling runtime); 0.7: for every 10ms spend in a rowcopy chunk, spend 7ms sleeping immediately after") + flag.IntVar(&migrationContext.NumWorkers, "workers", 8, "Number of concurrent workers for applying DML events. Each worker uses one goroutine.") maxLagMillis := flag.Int64("max-lag-millis", 1500, "replication lag at which to throttle operation") replicationLagQuery := flag.String("replication-lag-query", "", "Deprecated. gh-ost uses an internal, subsecond resolution query") @@ -156,6 +173,10 @@ func main() { criticalLoad := flag.String("critical-load", "", "Comma delimited status-name=threshold, same format as --max-load. When status exceeds threshold, app panics and quits") flag.Int64Var(&migrationContext.CriticalLoadIntervalMilliseconds, "critical-load-interval-millis", 0, "When 0, migration immediately bails out upon meeting critical-load. When non-zero, a second check is done after given interval, and migration only bails out if 2nd check still meets critical load") flag.Int64Var(&migrationContext.CriticalLoadHibernateSeconds, "critical-load-hibernate-seconds", 0, "When non-zero, critical-load does not panic and bail out; instead, gh-ost goes into hibernation for the specified duration. It will not read/write anything from/to any server") + statsdAddr := flag.String("statsd-addr", "", "StatsD endpoint (host:port or unix socket); empty disables StatsD") + var statsdTags statsdTagList + flag.Var(&statsdTags, "statsd-tags", "global StatsD tags applied to every metric (repeatable), format key:value. Example: --statsd-tags 'env:prod,service:my-service'") + runtimeMetricsInterval := flag.Int("runtime-metrics-interval", 10, "Seconds between Go runtime memory/GC gauge samples (requires --statsd-addr); 0 disables") quiet := flag.Bool("quiet", false, "quiet") verbose := flag.Bool("verbose", false, "verbose") debug := flag.Bool("debug", false, "debug mode (very verbose)") @@ -375,6 +396,17 @@ func main() { log.Infof("starting gh-ost %+v (git commit: %s)", AppVersion, GitCommit) acceptSignals(migrationContext) + metricsClient, metricsErr := metrics.NewClient(*statsdAddr, []string(statsdTags), "gh_ost.") + if metricsErr != nil { + log.Fatalf("metrics: %v", metricsErr) + } + defer func() { _ = metricsClient.Close() }() + migrationContext.Metrics = metricsClient + metricsClient.Count("startup", 1) + if *runtimeMetricsInterval > 0 { + metrics.StartGoRuntimeReporter(migrationContext.GetContext(), metricsClient, time.Duration(*runtimeMetricsInterval)*time.Second) + } + migrator := logic.NewMigrator(migrationContext, AppVersion) var err error if migrationContext.Revert { diff --git a/go/logic/applier.go b/go/logic/applier.go index b49e131b8..06c7bfdbe 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -109,12 +109,14 @@ func (apl *Applier) compileMigrationKeyWarningRegex() (*regexp.Regexp, error) { return migrationKeyRegex, nil } -func (apl *Applier) InitDBConnections() (err error) { +func (apl *Applier) InitDBConnections(maxConns int) (err error) { applierUri := apl.connectionConfig.GetDBUri(apl.migrationContext.DatabaseName) uriWithMulti := fmt.Sprintf("%s&multiStatements=true", applierUri) if apl.db, _, err = mysql.GetDB(apl.migrationContext.Uuid, uriWithMulti); err != nil { return err } + apl.db.SetMaxOpenConns(maxConns) + apl.db.SetMaxIdleConns(maxConns) singletonApplierUri := fmt.Sprintf("%s&timeout=0", applierUri) if apl.singletonDB, _, err = mysql.GetDB(apl.migrationContext.Uuid, singletonApplierUri); err != nil { return err diff --git a/go/logic/applier_test.go b/go/logic/applier_test.go index 6d7ba42f4..7496fa5f3 100644 --- a/go/logic/applier_test.go +++ b/go/logic/applier_test.go @@ -333,7 +333,7 @@ func (suite *ApplierTestSuite) TestInitDBConnections() { applier := NewApplier(migrationContext) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(8) suite.Require().NoError(err) mysqlVersion, _ := strings.CutPrefix(testMysqlContainerImage, "mysql:") @@ -374,7 +374,7 @@ func (suite *ApplierTestSuite) TestApplyDMLEventQueries() { suite.Require().NoError(applier.prepareQueries()) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(8) suite.Require().NoError(err) dmlEvents := []*binlog.BinlogDMLEvent{ @@ -431,7 +431,7 @@ func (suite *ApplierTestSuite) TestValidateOrDropExistingTables() { applier := NewApplier(migrationContext) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(8) suite.Require().NoError(err) err = applier.ValidateOrDropExistingTables() @@ -463,7 +463,7 @@ func (suite *ApplierTestSuite) TestValidateOrDropExistingTablesWithGhostTableExi applier := NewApplier(migrationContext) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(8) suite.Require().NoError(err) err = applier.ValidateOrDropExistingTables() @@ -494,7 +494,7 @@ func (suite *ApplierTestSuite) TestValidateOrDropExistingTablesWithGhostTableExi applier := NewApplier(migrationContext) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(8) suite.Require().NoError(err) err = applier.ValidateOrDropExistingTables() @@ -531,7 +531,7 @@ func (suite *ApplierTestSuite) TestCreateGhostTable() { applier := NewApplier(migrationContext) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(8) suite.Require().NoError(err) err = applier.CreateGhostTable() @@ -583,7 +583,7 @@ func (suite *ApplierTestSuite) TestPanicOnWarningsInApplyIterationInsertQuerySuc suite.Require().NoError(applier.prepareQueries()) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(8) suite.Require().NoError(err) _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, item_id) VALUES (123456, 42);", getTestTableName())) @@ -673,7 +673,7 @@ func (suite *ApplierTestSuite) TestPanicOnWarningsInApplyIterationInsertQueryFai } applier := NewApplier(migrationContext) - err = applier.InitDBConnections() + err = applier.InitDBConnections(1) suite.Require().NoError(err) err = applier.CreateChangelogTable() @@ -740,7 +740,7 @@ func (suite *ApplierTestSuite) TestWriteCheckpoint() { applier := NewApplier(migrationContext) - err = applier.InitDBConnections() + err = applier.InitDBConnections(1) suite.Require().NoError(err) err = applier.CreateChangelogTable() @@ -822,7 +822,7 @@ func (suite *ApplierTestSuite) TestPanicOnWarningsWithDuplicateKeyOnNonMigration suite.Require().NoError(applier.prepareQueries()) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(1) suite.Require().NoError(err) // Insert initial rows into ghost table (simulating bulk copy phase) @@ -911,7 +911,7 @@ func (suite *ApplierTestSuite) TestPanicOnWarningsWithDuplicateCompositeUniqueKe suite.Require().NoError(applier.prepareQueries()) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(1) suite.Require().NoError(err) // Insert initial rows into ghost table (simulating bulk copy phase) @@ -1013,7 +1013,7 @@ func (suite *ApplierTestSuite) TestUpdateModifyingUniqueKeyWithDuplicateOnOtherI suite.Require().NoError(applier.prepareQueries()) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(1) suite.Require().NoError(err) // Setup: Insert initial rows into ghost table @@ -1108,7 +1108,7 @@ func (suite *ApplierTestSuite) TestNormalUpdateWithPanicOnWarnings() { suite.Require().NoError(applier.prepareQueries()) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(1) suite.Require().NoError(err) // Setup: Insert initial rows into ghost table @@ -1188,7 +1188,7 @@ func (suite *ApplierTestSuite) TestDuplicateOnMigrationKeyAllowedInBinlogReplay( suite.Require().NoError(applier.prepareQueries()) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(1) suite.Require().NoError(err) // Insert initial rows into ghost table (simulating bulk copy phase) @@ -1279,7 +1279,7 @@ func (suite *ApplierTestSuite) TestRegexMetacharactersInIndexName() { suite.Require().NoError(applier.prepareQueries()) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(1) suite.Require().NoError(err) // Insert initial rows @@ -1381,7 +1381,7 @@ func (suite *ApplierTestSuite) TestPanicOnWarningsDisabled() { suite.Require().NoError(applier.prepareQueries()) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(1) suite.Require().NoError(err) // Insert initial rows into ghost table @@ -1470,7 +1470,7 @@ func (suite *ApplierTestSuite) TestMultipleDMLEventsInBatch() { suite.Require().NoError(applier.prepareQueries()) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(1) suite.Require().NoError(err) // Insert initial rows into ghost table diff --git a/go/logic/coordinator.go b/go/logic/coordinator.go new file mode 100644 index 000000000..dc865b4cd --- /dev/null +++ b/go/logic/coordinator.go @@ -0,0 +1,669 @@ +package logic + +import ( + "bytes" + "context" + "fmt" + "math/rand" + "strings" + "sync" + "sync/atomic" + "time" + + "errors" + + "github.com/github/gh-ost/go/base" + "github.com/github/gh-ost/go/binlog" + "github.com/github/gh-ost/go/mysql" + "github.com/github/gh-ost/go/sql" + "github.com/go-mysql-org/go-mysql/replication" + drivermysql "github.com/go-sql-driver/mysql" +) + +type Coordinator struct { + migrationContext *base.MigrationContext + + binlogReader *binlog.GoMySQLReader + + onChangelogEvent func(dmlEvent *binlog.BinlogDMLEvent) error + + applier *Applier + + throttler *Throttler + + // Atomic counter for number of active workers (not in workerQueue) + busyWorkers atomic.Int64 + + // Mutex to protect the fields below + mu sync.Mutex + + // list of workers + workers []*Worker + + // The low water mark. We maintain that all transactions with + // sequence number <= lowWaterMark have been completed. + lowWaterMark int64 + + // This is a map of completed jobs by their sequence numbers. + // This is used when updating the low water mark. + // It records the binlog coordinates of the completed transaction. + completedJobs map[int64]struct{} + + // These are the jobs that are waiting for a previous job to complete. + // They are indexed by the sequence number of the job they are waiting for. + waitingJobs map[int64][]chan struct{} + + events chan *replication.BinlogEvent + + workerQueue chan *Worker + + // fatalErr stores the first fatal error from any worker goroutine. + fatalErr error + fatalErrMu sync.Mutex + // failedCh is closed on the first fatal worker error; all blocking + // coordinator and worker operations select on this to unblock. + failedCh chan struct{} + + finishedMigrating atomic.Bool +} + +// Worker takes jobs from the Coordinator and applies the job's DML events. +type Worker struct { + id int + coordinator *Coordinator + eventQueue chan *replication.BinlogEvent + + executedJobs atomic.Int64 + dmlEventsApplied atomic.Int64 + waitTimeNs atomic.Int64 + busyTimeNs atomic.Int64 +} + +type stats struct { + dmlRate float64 + trxRate float64 + + // Number of DML events applied + dmlEventsApplied int64 + + // Number of transactions processed + executedJobs int64 + + // Time spent applying DML events + busyTime time.Duration + + // Time spent waiting on transaction dependecies + // or waiting on events to arrive in queue. + waitTime time.Duration +} + +// isRetryableError returns true for MySQL errors that are safe to retry +// (deadlock and lock wait timeout). +func isRetryableError(err error) bool { + var mysqlErr *drivermysql.MySQLError + if errors.As(err, &mysqlErr) { + switch mysqlErr.Number { + case 1213, 1205: // deadlock, lock wait timeout + return true + } + } + return false +} + +// setFatalError records the first fatal error and closes failedCh so +// all blocking operations in the coordinator and workers unblock. +func (c *Coordinator) setFatalError(err error) { + c.fatalErrMu.Lock() + defer c.fatalErrMu.Unlock() + if c.fatalErr == nil { + c.fatalErr = err + close(c.failedCh) + } +} + +// getFatalError returns the first fatal error, or nil. +func (c *Coordinator) getFatalError() error { + c.fatalErrMu.Lock() + defer c.fatalErrMu.Unlock() + return c.fatalErr +} + +func (w *Worker) ProcessEvents() error { + databaseName := w.coordinator.migrationContext.DatabaseName + originalTableName := w.coordinator.migrationContext.OriginalTableName + changelogTableName := w.coordinator.migrationContext.GetChangelogTableName() + + for { + if w.coordinator.finishedMigrating.Load() { + return nil + } + + // Wait for first event (GTID), interruptible by fatal error + waitStart := time.Now() + var ev *replication.BinlogEvent + select { + case ev = <-w.eventQueue: + case <-w.coordinator.failedCh: + return fmt.Errorf("aborting: %w", w.coordinator.getFatalError()) + } + w.waitTimeNs.Add(time.Since(waitStart).Nanoseconds()) + + // Verify this is a GTID Event + gtidEvent, ok := ev.Event.(*replication.GTIDEvent) + if !ok { + w.coordinator.migrationContext.Log.Debugf("Received unexpected event: %v\n", ev) + } + + // Dependency wait is done by the coordinator before dispatch + // (coordinator-side scheduling, matching MySQL applier semantics). + + // Process the transaction + var changelogEvent *binlog.BinlogDMLEvent + var txErr error + dmlEvents := make([]*binlog.BinlogDMLEvent, 0, int(atomic.LoadInt64(&w.coordinator.migrationContext.DMLBatchSize))) + events: + for { + // wait for next event in the transaction + waitStart := time.Now() + var ev *replication.BinlogEvent + select { + case ev = <-w.eventQueue: + case <-w.coordinator.failedCh: + w.coordinator.busyWorkers.Add(-1) + return fmt.Errorf("aborting: %w", w.coordinator.getFatalError()) + } + w.waitTimeNs.Add(time.Since(waitStart).Nanoseconds()) + + if ev == nil { + break events + } + + switch binlogEvent := ev.Event.(type) { + case *replication.RowsEvent: + dml := binlog.ToEventDML(ev.Header.EventType.String()) + if dml == binlog.NotDML { + w.coordinator.busyWorkers.Add(-1) + return fmt.Errorf("unknown DML type: %s", ev.Header.EventType.String()) + } + + if !strings.EqualFold(databaseName, string(binlogEvent.Table.Schema)) { + continue + } + + if !strings.EqualFold(originalTableName, string(binlogEvent.Table.Table)) && !strings.EqualFold(changelogTableName, string(binlogEvent.Table.Table)) { + continue + } + + for i, row := range binlogEvent.Rows { + if dml == binlog.UpdateDML && i%2 == 1 { + // An update has two rows (WHERE+SET) + // We do both at the same time + continue + } + dmlEvent := binlog.NewBinlogDMLEvent( + string(binlogEvent.Table.Schema), + string(binlogEvent.Table.Table), + dml, + ) + switch dml { + case binlog.InsertDML: + { + dmlEvent.NewColumnValues = sql.ToColumnValues(row) + } + case binlog.UpdateDML: + { + dmlEvent.WhereColumnValues = sql.ToColumnValues(row) + dmlEvent.NewColumnValues = sql.ToColumnValues(binlogEvent.Rows[i+1]) + } + case binlog.DeleteDML: + { + dmlEvent.WhereColumnValues = sql.ToColumnValues(row) + } + } + + if strings.EqualFold(changelogTableName, string(binlogEvent.Table.Table)) { + changelogEvent = dmlEvent + } else { + dmlEvents = append(dmlEvents, dmlEvent) + + if len(dmlEvents) == cap(dmlEvents) { + if err := w.applyDMLEvents(dmlEvents); err != nil { + txErr = err + break events + } + dmlEvents = dmlEvents[:0] + } + } + } + case *replication.XIDEvent: + if len(dmlEvents) > 0 { + if err := w.applyDMLEvents(dmlEvents); err != nil { + txErr = err + break events + } + } + + w.executedJobs.Add(1) + break events + } + } + + if txErr != nil { + // Fatal: DML failed after retries. Decrement busyWorkers + // since we won't reach the normal cleanup path below. + w.coordinator.busyWorkers.Add(-1) + return txErr + } + + w.coordinator.MarkTransactionCompleted(gtidEvent.SequenceNumber, int64(ev.Header.LogPos), int64(ev.Header.EventSize)) + + // Did we see a changelog event? + // Handle it now + if changelogEvent != nil { + // wait for all transactions before this point + clWaitCh := w.coordinator.WaitForTransaction(gtidEvent.SequenceNumber - 1) + if clWaitCh != nil { + waitStart := time.Now() + select { + case <-clWaitCh: + case <-w.coordinator.failedCh: + w.coordinator.busyWorkers.Add(-1) + return fmt.Errorf("aborting: %w", w.coordinator.getFatalError()) + } + w.waitTimeNs.Add(time.Since(waitStart).Nanoseconds()) + } + w.coordinator.HandleChangeLogEvent(changelogEvent) + } + + w.coordinator.workerQueue <- w + w.coordinator.busyWorkers.Add(-1) + } +} + +func (w *Worker) applyDMLEvents(dmlEvents []*binlog.BinlogDMLEvent) error { + if w.coordinator.throttler != nil { + w.coordinator.throttler.throttle(nil) + } + // Deadlocks between parallel workers are expected due to InnoDB gap locks + // on secondary indexes. Use a generous retry limit with jittered backoff + // to handle contention between workers. + const maxDeadlockRetries = 100 + var err error + for attempt := 0; attempt < maxDeadlockRetries; attempt++ { + if attempt > 0 { + // Jittered exponential backoff: base * 2^min(attempt,7) + random jitter + base := time.Duration(10) * time.Millisecond + backoff := base * (1 << min(attempt, 7)) + jitter := time.Duration(rand.Int63n(int64(backoff))) + time.Sleep(backoff + jitter) + } + busyStart := time.Now() + err = w.coordinator.applier.ApplyDMLEventQueries(dmlEvents) + w.busyTimeNs.Add(time.Since(busyStart).Nanoseconds()) + if err == nil { + w.dmlEventsApplied.Add(int64(len(dmlEvents))) + return nil + } + if !isRetryableError(err) { + return err + } + if attempt > 0 && attempt%10 == 0 { + w.coordinator.migrationContext.Log.Infof("Worker %d: DML batch retry attempt %d after deadlock", w.id, attempt) + } + } + return fmt.Errorf("DML batch failed after %d deadlock retries: %w", maxDeadlockRetries, err) +} + +func NewCoordinator(migrationContext *base.MigrationContext, applier *Applier, throttler *Throttler, onChangelogEvent func(dmlEvent *binlog.BinlogDMLEvent) error) *Coordinator { + return &Coordinator{ + migrationContext: migrationContext, + + onChangelogEvent: onChangelogEvent, + + throttler: throttler, + + binlogReader: binlog.NewGoMySQLReader(migrationContext), + + lowWaterMark: -1, + completedJobs: make(map[int64]struct{}), + waitingJobs: make(map[int64][]chan struct{}), + + events: make(chan *replication.BinlogEvent, 1000), + failedCh: make(chan struct{}), + } +} + +func (c *Coordinator) StartStreaming(ctx context.Context, coords mysql.BinlogCoordinates, canStopStreaming func() bool) error { + err := c.binlogReader.ConnectBinlogStreamer(coords) + if err != nil { + return err + } + defer c.binlogReader.Close() + + var retries int64 + for { + if err := ctx.Err(); err != nil { + return err + } + if canStopStreaming() { + return nil + } + if err := c.binlogReader.StreamEvents(ctx, canStopStreaming, c.events); err != nil { + if errors.Is(err, context.Canceled) { + return err + } + + c.migrationContext.Log.Infof("StreamEvents encountered unexpected error: %+v", err) + c.migrationContext.MarkPointOfInterest() + + if retries >= c.migrationContext.MaxRetries() { + return fmt.Errorf("%d successive failures in streamer reconnect at coordinates %+v", retries, coords) + } + c.migrationContext.Log.Infof("Reconnecting... Will resume at %+v", coords) + + // We reconnect from the event that was last emitted to the stream. + // This ensures we don't miss any events, and we don't process any events twice. + // Processing events twice messes up the transaction tracking and + // will cause data corruption. + coords := c.binlogReader.GetCurrentBinlogCoordinates() + if err := c.binlogReader.ConnectBinlogStreamer(coords); err != nil { + return err + } + retries += 1 + } + } +} + +func (c *Coordinator) ProcessEventsUntilNextChangelogEvent() (*binlog.BinlogDMLEvent, error) { + databaseName := c.migrationContext.DatabaseName + changelogTableName := c.migrationContext.GetChangelogTableName() + + for ev := range c.events { + switch binlogEvent := ev.Event.(type) { + case *replication.RowsEvent: + dml := binlog.ToEventDML(ev.Header.EventType.String()) + if dml == binlog.NotDML { + return nil, fmt.Errorf("unknown DML type: %s", ev.Header.EventType.String()) + } + + if !strings.EqualFold(databaseName, string(binlogEvent.Table.Schema)) { + continue + } + + if !strings.EqualFold(changelogTableName, string(binlogEvent.Table.Table)) { + continue + } + + for i, row := range binlogEvent.Rows { + if dml == binlog.UpdateDML && i%2 == 1 { + // An update has two rows (WHERE+SET) + // We do both at the same time + continue + } + dmlEvent := binlog.NewBinlogDMLEvent( + string(binlogEvent.Table.Schema), + string(binlogEvent.Table.Table), + dml, + ) + switch dml { + case binlog.InsertDML: + { + dmlEvent.NewColumnValues = sql.ToColumnValues(row) + } + case binlog.UpdateDML: + { + dmlEvent.WhereColumnValues = sql.ToColumnValues(row) + dmlEvent.NewColumnValues = sql.ToColumnValues(binlogEvent.Rows[i+1]) + } + case binlog.DeleteDML: + { + dmlEvent.WhereColumnValues = sql.ToColumnValues(row) + } + } + + return dmlEvent, nil + } + } + } + + //nolint:nilnil + return nil, nil +} + +// ProcessEventsUntilDrained reads binlog events and sends them to the workers to process. +// It exits when the event queue is empty and all the workers are returned to the workerQueue. +func (c *Coordinator) ProcessEventsUntilDrained() error { + for { + // Check for fatal worker error first + select { + case <-c.failedCh: + return fmt.Errorf("worker error: %w", c.getFatalError()) + default: + } + + select { + // Read events from the binlog and submit them to the next worker + case ev := <-c.events: + { + if c.finishedMigrating.Load() { + return nil + } + + switch binlogEvent := ev.Event.(type) { + case *replication.GTIDEvent: + c.mu.Lock() + if c.lowWaterMark < 0 && binlogEvent.SequenceNumber > 0 { + c.lowWaterMark = binlogEvent.SequenceNumber - 1 + } + c.mu.Unlock() + + // Coordinator-side dependency wait: don't schedule this + // transaction until all its dependencies are complete. + // This matches MySQL's replication applier coordinator + // semantics (schedule iff lwm >= lastCommitted). + waitChannel := c.WaitForTransaction(binlogEvent.LastCommitted) + if waitChannel != nil { + select { + case <-waitChannel: + case <-c.failedCh: + return fmt.Errorf("worker error: %w", c.getFatalError()) + } + } + case *replication.RotateEvent: + c.migrationContext.Log.Infof("rotate to next log in %s", binlogEvent.NextLogName) + // Binlog rotation resets sequence numbers. We must + // drain all workers (old file) and reset the lwm so + // that dependency checking uses the new file's + // sequence number space. + c.mu.Lock() + needsReset := c.lowWaterMark >= 0 + c.mu.Unlock() + if needsReset { + for c.busyWorkers.Load() > 0 { + select { + case <-c.failedCh: + return fmt.Errorf("worker error: %w", c.getFatalError()) + default: + } + time.Sleep(time.Millisecond) + } + c.mu.Lock() + c.lowWaterMark = -1 + c.completedJobs = make(map[int64]struct{}) + c.waitingJobs = make(map[int64][]chan struct{}) + c.mu.Unlock() + } + continue + default: // ignore all other events + continue + } + + // Acquire a worker, interruptible by fatal error + var worker *Worker + select { + case worker = <-c.workerQueue: + case <-c.failedCh: + return fmt.Errorf("worker error: %w", c.getFatalError()) + } + c.busyWorkers.Add(1) + + // Send GTID to worker, interruptible + select { + case worker.eventQueue <- ev: + case <-c.failedCh: + c.busyWorkers.Add(-1) + return fmt.Errorf("worker error: %w", c.getFatalError()) + } + + ev = <-c.events + + switch binlogEvent := ev.Event.(type) { + case *replication.QueryEvent: + if bytes.Equal([]byte("BEGIN"), binlogEvent.Query) { + } else { + worker.eventQueue <- nil + continue + } + default: + worker.eventQueue <- nil + continue + } + + events: + for { + ev = <-c.events + switch ev.Event.(type) { + case *replication.RowsEvent: + select { + case worker.eventQueue <- ev: + case <-c.failedCh: + return fmt.Errorf("worker error: %w", c.getFatalError()) + } + case *replication.XIDEvent: + select { + case worker.eventQueue <- ev: + case <-c.failedCh: + return fmt.Errorf("worker error: %w", c.getFatalError()) + } + + // We're done with this transaction + break events + } + } + } + + // No events in the queue. Check if all workers are sleeping now + default: + { + select { + case <-c.failedCh: + return fmt.Errorf("worker error: %w", c.getFatalError()) + default: + } + if c.busyWorkers.Load() == 0 { + return nil + } + } + } + } +} + +func (c *Coordinator) InitializeWorkers(count int) { + c.workerQueue = make(chan *Worker, count) + for i := 0; i < count; i++ { + w := &Worker{id: i, coordinator: c, eventQueue: make(chan *replication.BinlogEvent, 1000)} + + c.mu.Lock() + c.workers = append(c.workers, w) + c.mu.Unlock() + + c.workerQueue <- w + go func() { + if err := w.ProcessEvents(); err != nil { + c.migrationContext.Log.Errorf("Worker %d fatal error: %v", w.id, err) + c.setFatalError(err) + } + }() + } +} + +// GetWorkerStats collects profiling stats for ProcessEvents from each worker. +func (c *Coordinator) GetWorkerStats() []stats { + c.mu.Lock() + defer c.mu.Unlock() + statSlice := make([]stats, 0, len(c.workers)) + for _, w := range c.workers { + stat := stats{} + stat.dmlEventsApplied = w.dmlEventsApplied.Load() + stat.executedJobs = w.executedJobs.Load() + stat.busyTime = time.Duration(w.busyTimeNs.Load()) + stat.waitTime = time.Duration(w.waitTimeNs.Load()) + if stat.busyTime.Milliseconds() > 0 { + stat.dmlRate = 1000.0 * float64(stat.dmlEventsApplied) / float64(stat.busyTime.Milliseconds()) + stat.trxRate = 1000.0 * float64(stat.executedJobs) / float64(stat.busyTime.Milliseconds()) + } + statSlice = append(statSlice, stat) + } + return statSlice +} + +func (c *Coordinator) WaitForTransaction(lastCommitted int64) chan struct{} { + c.mu.Lock() + defer c.mu.Unlock() + + if lastCommitted <= c.lowWaterMark { + return nil + } + + // Buffered so MarkTransactionCompleted never blocks if the waiter + // already exited (e.g. via failedCh). + waitChannel := make(chan struct{}, 1) + c.waitingJobs[lastCommitted] = append(c.waitingJobs[lastCommitted], waitChannel) + + return waitChannel +} + +func (c *Coordinator) HandleChangeLogEvent(event *binlog.BinlogDMLEvent) { + c.mu.Lock() + defer c.mu.Unlock() + c.onChangelogEvent(event) +} + +func (c *Coordinator) MarkTransactionCompleted(sequenceNumber, logPos, eventSize int64) { + var channelsToNotify []chan struct{} + + func() { + c.mu.Lock() + defer c.mu.Unlock() + + // Mark the job as completed + c.completedJobs[sequenceNumber] = struct{}{} + + // Then, update the low water mark if possible + for { + if _, ok := c.completedJobs[c.lowWaterMark+1]; ok { + c.lowWaterMark++ + delete(c.completedJobs, c.lowWaterMark) + } else { + break + } + } + channelsToNotify = make([]chan struct{}, 0) + + // Schedule any jobs that were waiting for this job to complete or for the low watermark + for waitingForSequenceNumber, channels := range c.waitingJobs { + if waitingForSequenceNumber <= c.lowWaterMark { + channelsToNotify = append(channelsToNotify, channels...) + delete(c.waitingJobs, waitingForSequenceNumber) + } + } + }() + + for _, waitChannel := range channelsToNotify { + waitChannel <- struct{}{} + } +} + +func (c *Coordinator) Teardown() { + c.finishedMigrating.Store(true) +} diff --git a/go/logic/coordinator_test.go b/go/logic/coordinator_test.go new file mode 100644 index 000000000..b38b7ff28 --- /dev/null +++ b/go/logic/coordinator_test.go @@ -0,0 +1,382 @@ +package logic + +import ( + "context" + gosql "database/sql" + "fmt" + "math/rand/v2" + "os" + "testing" + "time" + + "path/filepath" + "runtime" + + "github.com/github/gh-ost/go/base" + "github.com/github/gh-ost/go/binlog" + "github.com/github/gh-ost/go/mysql" + "github.com/github/gh-ost/go/sql" + "github.com/stretchr/testify/suite" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" + "golang.org/x/sync/errgroup" +) + +type CoordinatorTestSuite struct { + suite.Suite + + mysqlContainer testcontainers.Container + db *gosql.DB + concurrentTransactions int + transactionsPerWorker int + transactionSize int +} + +func (suite *CoordinatorTestSuite) SetupSuite() { + ctx := context.Background() + req := testcontainers.ContainerRequest{ + Image: "mysql:8.0.40", + Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root-password"}, + WaitingFor: wait.ForListeningPort("3306/tcp"), + } + + mysqlContainer, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + suite.Require().NoError(err) + + suite.mysqlContainer = mysqlContainer + + dsn, err := GetDSN(ctx, mysqlContainer) + suite.Require().NoError(err) + + db, err := gosql.Open("mysql", dsn) + suite.Require().NoError(err) + + suite.db = db + suite.concurrentTransactions = 8 + suite.transactionsPerWorker = 1000 + suite.transactionSize = 10 + + db.SetMaxOpenConns(suite.concurrentTransactions) +} + +func (suite *CoordinatorTestSuite) SetupTest() { + ctx := context.Background() + _, err := suite.db.ExecContext(ctx, "RESET MASTER") + suite.Require().NoError(err) + + _, err = suite.db.ExecContext(ctx, "SET @@GLOBAL.binlog_transaction_dependency_tracking = WRITESET") + suite.Require().NoError(err) + + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("SET @@GLOBAL.max_connections = %d", suite.concurrentTransactions*2)) + suite.Require().NoError(err) + + _, err = suite.db.ExecContext(ctx, "CREATE DATABASE test") + suite.Require().NoError(err) +} + +func (suite *CoordinatorTestSuite) TearDownTest() { + ctx := context.Background() + _, err := suite.db.ExecContext(ctx, "DROP DATABASE test") + suite.Require().NoError(err) +} + +func (suite *CoordinatorTestSuite) TeardownSuite() { + ctx := context.Background() + + suite.Assert().NoError(suite.db.Close()) + suite.Assert().NoError(suite.mysqlContainer.Terminate(ctx)) +} + +func (suite *CoordinatorTestSuite) TestApplyDML() { + ctx := context.Background() + + connectionConfig, err := GetConnectionConfig(ctx, suite.mysqlContainer) + suite.Require().NoError(err) + + _ = os.Remove("/tmp/gh-ost.sock") + + _, err = suite.db.Exec("CREATE TABLE test.gh_ost_test (id INT PRIMARY KEY AUTO_INCREMENT, name VARCHAR(255)) ENGINE=InnoDB") + suite.Require().NoError(err) + + _, err = suite.db.Exec("CREATE TABLE test._gh_ost_test_gho (id INT PRIMARY KEY AUTO_INCREMENT, name VARCHAR(255))") + suite.Require().NoError(err) + + migrationContext := base.NewMigrationContext() + migrationContext.DatabaseName = "test" + migrationContext.OriginalTableName = "gh_ost_test" + migrationContext.AlterStatement = "ALTER TABLE gh_ost_test ENGINE=InnoDB" + migrationContext.AllowedRunningOnMaster = true + migrationContext.ReplicaServerId = 99999 + migrationContext.HeartbeatIntervalMilliseconds = 100 + migrationContext.ThrottleHTTPIntervalMillis = 100 + migrationContext.DMLBatchSize = 10 + + migrationContext.ApplierConnectionConfig = connectionConfig + migrationContext.InspectorConnectionConfig = connectionConfig + + migrationContext.OriginalTableColumns = sql.NewColumnList([]string{"id", "name"}) + migrationContext.GhostTableColumns = sql.NewColumnList([]string{"id", "name"}) + migrationContext.SharedColumns = sql.NewColumnList([]string{"id", "name"}) + migrationContext.MappedSharedColumns = sql.NewColumnList([]string{"id", "name"}) + migrationContext.UniqueKey = &sql.UniqueKey{ + Name: "PRIMARY", + Columns: *sql.NewColumnList([]string{"id"}), + IsAutoIncrement: true, + } + + migrationContext.SetConnectionConfig("innodb") + migrationContext.SkipPortValidation = true + migrationContext.NumWorkers = 4 + + //nolint:dogsled + _, filename, _, _ := runtime.Caller(0) + migrationContext.ServeSocketFile = filepath.Join(filepath.Dir(filename), "../../tmp/gh-ost.sock") + + applier := NewApplier(migrationContext) + err = applier.InitDBConnections(migrationContext.NumWorkers) + suite.Require().NoError(err) + + err = applier.prepareQueries() + suite.Require().NoError(err) + + err = applier.CreateChangelogTable() + suite.Require().NoError(err) + + g, _ := errgroup.WithContext(ctx) + for i := range suite.concurrentTransactions { + g.Go(func() error { + r := rand.New(rand.NewPCG(uint64(0), uint64(i))) + maxID := int64(1) + for range suite.transactionsPerWorker { + tx, txErr := suite.db.Begin() + if txErr != nil { + return txErr + } + + // generate random write queries + for range r.IntN(suite.transactionSize) + 1 { + switch r.IntN(5) { + case 0: + _, txErr = tx.Exec(fmt.Sprintf("DELETE FROM test.gh_ost_test WHERE id=%d", r.Int64N(maxID))) + if txErr != nil { + return txErr + } + case 1, 2: + _, txErr = tx.Exec(fmt.Sprintf("UPDATE test.gh_ost_test SET name='test-%d' WHERE id=%d", r.Int(), r.Int64N(maxID))) + if txErr != nil { + return txErr + } + default: + res, txErr := tx.Exec(fmt.Sprintf("INSERT INTO test.gh_ost_test (name) VALUES ('test-%d')", r.Int())) + if txErr != nil { + return txErr + } + lastID, err := res.LastInsertId() + if err != nil { + return err + } + maxID = lastID + 1 + } + } + txErr = tx.Commit() + if txErr != nil { + return txErr + } + } + return nil + }) + } + + _, err = applier.WriteChangelogState("completed") + suite.Require().NoError(err) + + ctx, cancel := context.WithCancel(context.Background()) + + coord := NewCoordinator(migrationContext, applier, nil, + func(dmlEvent *binlog.BinlogDMLEvent) error { + fmt.Printf("Received Changelog DML event: %+v\n", dmlEvent) + fmt.Printf("Rowdata: %v - %v\n", dmlEvent.NewColumnValues, dmlEvent.WhereColumnValues) + + cancel() + + return nil + }) + coord.applier = applier + coord.InitializeWorkers(4) + + streamCtx, cancelStreaming := context.WithCancel(context.Background()) + canStopStreaming := func() bool { + return streamCtx.Err() != nil + } + go func() { + streamErr := coord.StartStreaming(streamCtx, &mysql.FileBinlogCoordinates{ + LogFile: "binlog.000001", + LogPos: int64(4), + }, canStopStreaming) + suite.Require().Equal(context.Canceled, streamErr) + }() + + // Give streamer some time to start + time.Sleep(1 * time.Second) + + startAt := time.Now() + + for { + if ctx.Err() != nil { + cancelStreaming() + break + } + + err = coord.ProcessEventsUntilDrained() + suite.Require().NoError(err) + } + + //err = g.Wait() + //suite.Require().NoError(err) + g.Wait() // there will be deadlock errors + + fmt.Printf("Time taken: %s\n", time.Since(startAt)) + + result, err := suite.db.Exec(`SELECT * FROM ( + SELECT t1.id, + CRC32(CONCAT_WS(';',t1.id,t1.name)) AS checksum1, + CRC32(CONCAT_WS(';',t2.id,t2.name)) AS checksum2 + FROM test.gh_ost_test t1 + LEFT JOIN test._gh_ost_test_gho t2 + ON t1.id = t2.id +) AS checksums +WHERE checksums.checksum1 != checksums.checksum2`) + suite.Require().NoError(err) + + count, err := result.RowsAffected() + suite.Require().NoError(err) + suite.Require().Zero(count) +} + +func TestCoordinator(t *testing.T) { + suite.Run(t, new(CoordinatorTestSuite)) +} + +// TestRotationResetsLowWaterMark is a deterministic unit test verifying that +// after a simulated binlog rotation the coordinator's lowWaterMark is reset +// so that transactions from the new file are properly ordered. +// This is the regression test for the root cause of the MTR data inconsistency: +// MySQL's logical clock (last_committed, sequence_number) resets per-binlog-file, +// but without resetting lwm, post-rotation transactions with small lastCommitted +// values would pass the WaitForTransaction check against the stale high lwm. +func TestRotationResetsLowWaterMark(t *testing.T) { + // Simulate a coordinator that has processed transactions from the first binlog file. + c := &Coordinator{ + lowWaterMark: -1, + completedJobs: make(map[int64]struct{}), + waitingJobs: make(map[int64][]chan struct{}), + failedCh: make(chan struct{}), + } + + // --- First binlog file: sequence numbers 1..5 --- + + // Initialize lwm (simulates first GTID event setting lwm = seqNo - 1 = 0) + c.mu.Lock() + c.lowWaterMark = 0 + c.mu.Unlock() + + // Complete transactions 1 through 5 + for seq := int64(1); seq <= 5; seq++ { + c.MarkTransactionCompleted(seq, 0, 0) + } + + // Verify lwm advanced to 5 + c.mu.Lock() + if c.lowWaterMark != 5 { + t.Fatalf("expected lwm=5 after completing seqs 1-5, got %d", c.lowWaterMark) + } + c.mu.Unlock() + + // A transaction with lastCommitted=3 should pass immediately (3 <= 5) + ch := c.WaitForTransaction(3) + if ch != nil { + t.Fatal("expected WaitForTransaction(3) to return nil when lwm=5") + } + + // --- Simulate binlog rotation: reset coordinator state --- + // This is what the RotateEvent handler does after draining workers. + c.mu.Lock() + c.lowWaterMark = -1 + c.completedJobs = make(map[int64]struct{}) + c.waitingJobs = make(map[int64][]chan struct{}) + c.mu.Unlock() + + // --- Second binlog file: sequence numbers restart at 1 --- + + // Initialize lwm for new file (first GTID sets lwm = seqNo - 1 = 0) + c.mu.Lock() + c.lowWaterMark = 0 + c.mu.Unlock() + + // BUG SCENARIO (before fix): if lwm was still 5 from the old file, + // WaitForTransaction(3) would return nil → tx executes out of order. + // After fix: lwm=0, so WaitForTransaction(3) must block. + ch = c.WaitForTransaction(3) + if ch == nil { + t.Fatal("expected WaitForTransaction(3) to block when lwm=0 in new binlog file, but it returned nil (stale lwm bug!)") + } + + // Complete transactions 1, 2, 3 in the new file + c.MarkTransactionCompleted(1, 0, 0) + c.MarkTransactionCompleted(2, 0, 0) + c.MarkTransactionCompleted(3, 0, 0) + + // Now the wait channel should be notified (lwm advances to 3) + select { + case <-ch: + // success + case <-time.After(time.Second): + t.Fatal("WaitForTransaction(3) was not notified after completing seqs 1-3") + } + + // Verify lwm is now 3 + c.mu.Lock() + if c.lowWaterMark != 3 { + t.Fatalf("expected lwm=3, got %d", c.lowWaterMark) + } + c.mu.Unlock() +} + +// TestBufferedWaitChannelNoDeadlock verifies that if a waiter exits early +// (e.g., via failedCh), MarkTransactionCompleted does not block forever. +func TestBufferedWaitChannelNoDeadlock(t *testing.T) { + c := &Coordinator{ + lowWaterMark: 0, + completedJobs: make(map[int64]struct{}), + waitingJobs: make(map[int64][]chan struct{}), + failedCh: make(chan struct{}), + } + + // Create a waiter for lastCommitted=3 + ch := c.WaitForTransaction(3) + if ch == nil { + t.Fatal("expected a wait channel") + } + + // Simulate the waiter exiting early (not reading from ch) + // This mimics what happens when a worker exits via failedCh. + + // MarkTransactionCompleted should NOT block even though nobody reads ch + done := make(chan struct{}) + go func() { + c.MarkTransactionCompleted(1, 0, 0) + c.MarkTransactionCompleted(2, 0, 0) + c.MarkTransactionCompleted(3, 0, 0) + close(done) + }() + + select { + case <-done: + // success — did not deadlock + case <-time.After(2 * time.Second): + t.Fatal("MarkTransactionCompleted deadlocked because wait channel is unbuffered") + } +} diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 96aadd672..f1f6bad61 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -19,6 +19,7 @@ import ( "github.com/github/gh-ost/go/mysql" "github.com/github/gh-ost/go/sql" + gomysql "github.com/go-mysql-org/go-mysql/mysql" "github.com/openark/golib/sqlutils" ) @@ -951,6 +952,37 @@ func (isp *Inspector) readChangelogState(hint string) (string, error) { return result, err } +// readCurrentBinlogCoordinates reads master status from hooked server +func (insp *Inspector) readCurrentBinlogCoordinates() (mysql.BinlogCoordinates, error) { + var coords mysql.BinlogCoordinates + query := fmt.Sprintf(`show /* gh-ost readCurrentBinlogCoordinates */ %s`, mysql.ReplicaTermFor(insp.migrationContext.InspectorMySQLVersion, "master status")) + foundMasterStatus := false + err := sqlutils.QueryRowsMap(insp.db, query, func(m sqlutils.RowMap) error { + if insp.migrationContext.UseGTIDs { + execGtidSet := m.GetString("Executed_Gtid_Set") + gtidSet, err := gomysql.ParseMysqlGTIDSet(execGtidSet) + if err != nil { + return err + } + coords = &mysql.GTIDBinlogCoordinates{GTIDSet: gtidSet.(*gomysql.MysqlGTIDSet)} + } else { + coords = &mysql.FileBinlogCoordinates{ + LogFile: m.GetString("File"), + LogPos: m.GetInt64("Position"), + } + } + foundMasterStatus = true + return nil + }) + if err != nil { + return nil, err + } + if !foundMasterStatus { + return nil, fmt.Errorf("got no results from SHOW MASTER STATUS, bailing out") + } + return coords, nil +} + func (isp *Inspector) getMasterConnectionConfig() (applierConfig *mysql.ConnectionConfig, err error) { isp.migrationContext.Log.Infof("Recursively searching for replication master") visitedKeys := mysql.NewInstanceKeyMap() diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 90fa8c509..04bcaa0dd 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -11,11 +11,12 @@ import ( "fmt" "io" "math" - "os" "strings" "sync/atomic" "time" + "os" + "github.com/github/gh-ost/go/base" "github.com/github/gh-ost/go/binlog" "github.com/github/gh-ost/go/mysql" @@ -48,23 +49,6 @@ type lockProcessedStruct struct { state string coords mysql.BinlogCoordinates } - -type applyEventStruct struct { - writeFunc *tableWriteFunc - dmlEvent *binlog.BinlogDMLEvent - coords mysql.BinlogCoordinates -} - -func newApplyEventStructByFunc(writeFunc *tableWriteFunc) *applyEventStruct { - result := &applyEventStruct{writeFunc: writeFunc} - return result -} - -func newApplyEventStructByDML(dmlEntry *binlog.BinlogEntry) *applyEventStruct { - result := &applyEventStruct{dmlEvent: dmlEntry.DmlEvent, coords: dmlEntry.Coordinates} - return result -} - type PrintStatusRule int const ( @@ -81,7 +65,6 @@ type Migrator struct { parser *sql.AlterTableParser inspector *Inspector applier *Applier - eventsStreamer *EventsStreamer server *Server throttler *Throttler hooksExecutor base.Hooks @@ -96,10 +79,10 @@ type Migrator struct { rowCopyCompleteFlag int64 // copyRowsQueue should not be buffered; if buffered some non-damaging but // excessive work happens at the end of the iteration as new copy-jobs arrive before realizing the copy is complete - copyRowsQueue chan tableWriteFunc - applyEventsQueue chan *applyEventStruct + copyRowsQueue chan tableWriteFunc finishedMigrating int64 + trxCoordinator *Coordinator } func NewMigrator(context *base.MigrationContext, appVersion string) *Migrator { @@ -108,20 +91,16 @@ func NewMigrator(context *base.MigrationContext, appVersion string) *Migrator { hooks = NewHooksExecutor(context) } migrator := &Migrator{ - appVersion: appVersion, - hooksExecutor: hooks, - migrationContext: context, - parser: sql.NewAlterTableParser(), - ghostTableMigrated: make(chan bool), - firstThrottlingCollected: make(chan bool, 3), - rowCopyComplete: make(chan error), - // Buffered with capacity 1; the send uses overwrite-oldest semantics - // to prevent both deadlock (see https://github.com/github/gh-ost/pull/1637) - // and OOM when MaxRetries() is extremely large. + appVersion: appVersion, + hooksExecutor: hooks, + migrationContext: context, + parser: sql.NewAlterTableParser(), + ghostTableMigrated: make(chan bool, 1), + firstThrottlingCollected: make(chan bool, 3), + rowCopyComplete: make(chan error), allEventsUpToLockProcessed: make(chan *lockProcessedStruct, 1), copyRowsQueue: make(chan tableWriteFunc), - applyEventsQueue: make(chan *applyEventStruct, base.MaxEventsBatchSize), finishedMigrating: 0, } return migrator @@ -129,10 +108,10 @@ func NewMigrator(context *base.MigrationContext, appVersion string) *Migrator { // sleepWhileTrue sleeps indefinitely until the given function returns 'false' // (or fails with error) -func (mgtr *Migrator) sleepWhileTrue(operation func() (bool, error)) error { +func (mig *Migrator) sleepWhileTrue(operation func() (bool, error)) error { for { // Check for abort before continuing - if err := mgtr.checkAbort(); err != nil { + if err := mig.checkAbort(); err != nil { return err } shouldSleep, err := operation() @@ -146,29 +125,29 @@ func (mgtr *Migrator) sleepWhileTrue(operation func() (bool, error)) error { } } -func (mgtr *Migrator) retryBatchCopyWithHooks(operation func() error, notFatalHint ...bool) (err error) { +func (mig *Migrator) retryBatchCopyWithHooks(operation func() error, notFatalHint ...bool) (err error) { wrappedOperation := func() error { if err := operation(); err != nil { - mgtr.hooksExecutor.OnBatchCopyRetry(err.Error()) + mig.hooksExecutor.OnBatchCopyRetry(err.Error()) return err } return nil } - return mgtr.retryOperation(wrappedOperation, notFatalHint...) + return mig.retryOperation(wrappedOperation, notFatalHint...) } // retryOperation attempts up to `count` attempts at running given function, // exiting as soon as it returns with non-error. -func (mgtr *Migrator) retryOperation(operation func() error, notFatalHint ...bool) (err error) { - maxRetries := int(mgtr.migrationContext.MaxRetries()) +func (mig *Migrator) retryOperation(operation func() error, notFatalHint ...bool) (err error) { + maxRetries := int(mig.migrationContext.MaxRetries()) for i := 0; i < maxRetries; i++ { if i != 0 { // sleep after previous iteration RetrySleepFn(1 * time.Second) } // Check for abort/context cancellation before each retry - if abortErr := mgtr.checkAbort(); abortErr != nil { + if abortErr := mig.checkAbort(); abortErr != nil { return abortErr } err = operation() @@ -178,7 +157,7 @@ func (mgtr *Migrator) retryOperation(operation func() error, notFatalHint ...boo // Check if this is an unrecoverable error (data consistency issues won't resolve on retry) if strings.Contains(err.Error(), "warnings detected") { if len(notFatalHint) == 0 { - _ = base.SendWithContext(mgtr.migrationContext.GetContext(), mgtr.migrationContext.PanicAbort, err) + _ = base.SendWithContext(mig.migrationContext.GetContext(), mig.migrationContext.PanicAbort, err) } return err } @@ -186,7 +165,7 @@ func (mgtr *Migrator) retryOperation(operation func() error, notFatalHint ...boo } if len(notFatalHint) == 0 { // Use helper to prevent deadlock if listenOnPanicAbort already exited - _ = base.SendWithContext(mgtr.migrationContext.GetContext(), mgtr.migrationContext.PanicAbort, err) + _ = base.SendWithContext(mig.migrationContext.GetContext(), mig.migrationContext.PanicAbort, err) } return err } @@ -196,9 +175,9 @@ func (mgtr *Migrator) retryOperation(operation func() error, notFatalHint ...boo // as soon as the function returns with non-error, or as soon as `MaxRetries` // attempts are reached. Wait intervals between attempts obey a maximum of // `ExponentialBackoffMaxInterval`. -func (mgtr *Migrator) retryOperationWithExponentialBackoff(operation func() error, notFatalHint ...bool) (err error) { - maxRetries := int(mgtr.migrationContext.MaxRetries()) - maxInterval := mgtr.migrationContext.ExponentialBackoffMaxInterval +func (mig *Migrator) retryOperationWithExponentialBackoff(operation func() error, notFatalHint ...bool) (err error) { + maxRetries := int(mig.migrationContext.MaxRetries()) + maxInterval := mig.migrationContext.ExponentialBackoffMaxInterval for i := 0; i < maxRetries; i++ { interval := math.Min( float64(maxInterval), @@ -209,7 +188,7 @@ func (mgtr *Migrator) retryOperationWithExponentialBackoff(operation func() erro RetrySleepFn(time.Duration(interval) * time.Second) } // Check for abort/context cancellation before each retry - if abortErr := mgtr.checkAbort(); abortErr != nil { + if abortErr := mig.checkAbort(); abortErr != nil { return abortErr } err = operation() @@ -219,128 +198,96 @@ func (mgtr *Migrator) retryOperationWithExponentialBackoff(operation func() erro // Check if this is an unrecoverable error (data consistency issues won't resolve on retry) if strings.Contains(err.Error(), "warnings detected") { if len(notFatalHint) == 0 { - _ = base.SendWithContext(mgtr.migrationContext.GetContext(), mgtr.migrationContext.PanicAbort, err) + _ = base.SendWithContext(mig.migrationContext.GetContext(), mig.migrationContext.PanicAbort, err) } return err } } if len(notFatalHint) == 0 { // Use helper to prevent deadlock if listenOnPanicAbort already exited - _ = base.SendWithContext(mgtr.migrationContext.GetContext(), mgtr.migrationContext.PanicAbort, err) + _ = base.SendWithContext(mig.migrationContext.GetContext(), mig.migrationContext.PanicAbort, err) } return err } // consumeRowCopyComplete blocks on the rowCopyComplete channel once, and then // consumes and drops any further incoming events that may be left hanging. -func (mgtr *Migrator) consumeRowCopyComplete() { - select { - case err := <-mgtr.rowCopyComplete: - if err != nil { - // Abort synchronously to ensure checkAbort() sees the error immediately - mgtr.abort(err) - // Don't mark row copy as complete if there was an error - return - } - case <-mgtr.migrationContext.GetContext().Done(): - // Abort cancelled the context +func (mig *Migrator) consumeRowCopyComplete() { + if err := <-mig.rowCopyComplete; err != nil { + // Abort synchronously to ensure checkAbort() sees the error immediately + mig.abort(err) + // Don't mark row copy as complete if there was an error return } - atomic.StoreInt64(&mgtr.rowCopyCompleteFlag, 1) - mgtr.migrationContext.MarkRowCopyEndTime() + atomic.StoreInt64(&mig.rowCopyCompleteFlag, 1) + mig.migrationContext.MarkRowCopyEndTime() go func() { - for err := range mgtr.rowCopyComplete { + for err := range mig.rowCopyComplete { if err != nil { // Abort synchronously to ensure the error is stored immediately - mgtr.abort(err) + mig.abort(err) return } } }() } -func (mgtr *Migrator) canStopStreaming() bool { - return atomic.LoadInt64(&mgtr.migrationContext.CutOverCompleteFlag) != 0 +func (mig *Migrator) canStopStreaming() bool { + return atomic.LoadInt64(&mig.migrationContext.CutOverCompleteFlag) != 0 } // onChangelogEvent is called when a binlog event operation on the changelog table is intercepted. -func (mgtr *Migrator) onChangelogEvent(dmlEntry *binlog.BinlogEntry) (err error) { +func (mig *Migrator) onChangelogEvent(dmlEvent *binlog.BinlogDMLEvent) (err error) { // Hey, I created the changelog table, I know the type of columns it has! - switch hint := dmlEntry.DmlEvent.NewColumnValues.StringColumn(2); hint { + switch hint := dmlEvent.NewColumnValues.StringColumn(2); hint { case "state": - return mgtr.onChangelogStateEvent(dmlEntry) + return mig.onChangelogStateEvent(dmlEvent) case "heartbeat": - return mgtr.onChangelogHeartbeatEvent(dmlEntry) + return mig.onChangelogHeartbeatEvent(dmlEvent) default: return nil } } -func (mgtr *Migrator) onChangelogStateEvent(dmlEntry *binlog.BinlogEntry) (err error) { - changelogStateString := dmlEntry.DmlEvent.NewColumnValues.StringColumn(3) +func (mig *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (err error) { + changelogStateString := dmlEvent.NewColumnValues.StringColumn(3) changelogState := ReadChangelogState(changelogStateString) - mgtr.migrationContext.Log.Infof("Intercepted changelog state %s", changelogState) + mig.migrationContext.Log.Infof("Intercepted changelog state %s", changelogState) switch changelogState { case Migrated, ReadMigrationRangeValues: // no-op event case GhostTableMigrated: // Use helper to prevent deadlock if migration aborts before receiver is ready - _ = base.SendWithContext(mgtr.migrationContext.GetContext(), mgtr.ghostTableMigrated, true) + _ = base.SendWithContext(mig.migrationContext.GetContext(), mig.ghostTableMigrated, true) case AllEventsUpToLockProcessed: - lps := &lockProcessedStruct{ - state: changelogStateString, - coords: dmlEntry.Coordinates.Clone(), - } - var applyEventFunc tableWriteFunc = func() error { - // Non-blocking send with overwrite-oldest semantics: if the buffer is - // full (receiver timed out on a previous attempt), drain the stale - // message first so the current sentinel is always delivered. This - // prevents both goroutine leaks (the original PR #1637 issue) and OOM - // when MaxRetries() is very large. - select { - case mgtr.allEventsUpToLockProcessed <- lps: - default: - // Buffer full — drain the stale value, then send the current one. - select { - case <-mgtr.allEventsUpToLockProcessed: - default: - } - select { - case mgtr.allEventsUpToLockProcessed <- lps: - default: - // Concurrent drain by another goroutine or receiver; the current - // value is no longer needed since a newer sentinel will follow. - } - } - return nil - } // at this point we know all events up to lock have been read from the streamer, // because the streamer works sequentially. So those events are either already handled, - // or have event functions in applyEventsQueue. - // So as not to create a potential deadlock, we write this func to applyEventsQueue - // asynchronously, understanding it doesn't really matter. + // or are being processed by the coordinator. + // So as not to create a potential deadlock, we send this asynchronously. go func() { - // Use helper to prevent deadlock if buffer fills and executeWriteFuncs exits - _ = base.SendWithContext(mgtr.migrationContext.GetContext(), mgtr.applyEventsQueue, newApplyEventStructByFunc(&applyEventFunc)) + _ = base.SendWithContext(mig.migrationContext.GetContext(), mig.allEventsUpToLockProcessed, &lockProcessedStruct{ + state: changelogStateString, + coords: mig.trxCoordinator.binlogReader.GetCurrentBinlogCoordinates(), + }) }() default: return fmt.Errorf("unknown changelog state: %+v", changelogState) } - mgtr.migrationContext.Log.Infof("Handled changelog state %s", changelogState) + mig.migrationContext.Log.Infof("Handled changelog state %s", changelogState) return nil } -func (mgtr *Migrator) onChangelogHeartbeatEvent(dmlEntry *binlog.BinlogEntry) (err error) { - changelogHeartbeatString := dmlEntry.DmlEvent.NewColumnValues.StringColumn(3) +func (mig *Migrator) onChangelogHeartbeatEvent(dmlEvent *binlog.BinlogDMLEvent) (err error) { + changelogHeartbeatString := dmlEvent.NewColumnValues.StringColumn(3) heartbeatTime, err := time.Parse(time.RFC3339Nano, changelogHeartbeatString) if err != nil { - return mgtr.migrationContext.Log.Errore(err) + return mig.migrationContext.Log.Errore(err) } else { - mgtr.migrationContext.SetLastHeartbeatOnChangelogTime(heartbeatTime) - mgtr.applier.CurrentCoordinatesMutex.Lock() - mgtr.applier.CurrentCoordinates = dmlEntry.Coordinates - mgtr.applier.CurrentCoordinatesMutex.Unlock() + mig.migrationContext.SetLastHeartbeatOnChangelogTime(heartbeatTime) + mig.applier.CurrentCoordinatesMutex.Lock() + mig.applier.CurrentCoordinates = mig.trxCoordinator.binlogReader.GetCurrentBinlogCoordinates() + mig.applier.CurrentCoordinatesMutex.Unlock() return nil } } @@ -348,68 +295,68 @@ func (mgtr *Migrator) onChangelogHeartbeatEvent(dmlEntry *binlog.BinlogEntry) (e // abort stores the error, cancels the context, and logs the abort. // This is the common abort logic used by both listenOnPanicAbort and // consumeRowCopyComplete to ensure consistent error handling. -func (mgtr *Migrator) abort(err error) { +func (mig *Migrator) abort(err error) { // Store the error for Migrate() to return - mgtr.migrationContext.SetAbortError(err) + mig.migrationContext.SetAbortError(err) // Cancel the context to signal all goroutines to stop - mgtr.migrationContext.CancelContext() + mig.migrationContext.CancelContext() // Log the error (but don't panic or exit) - mgtr.migrationContext.Log.Errorf("migration aborted: %v", err) + mig.migrationContext.Log.Errorf("Migration aborted: %v", err) } // listenOnPanicAbort listens for fatal errors and initiates graceful shutdown -func (mgtr *Migrator) listenOnPanicAbort() { - err := <-mgtr.migrationContext.PanicAbort - mgtr.abort(err) +func (mig *Migrator) listenOnPanicAbort() { + err := <-mig.migrationContext.PanicAbort + mig.abort(err) } // validateAlterStatement validates the `alter` statement meets criteria. // At this time this means: // - column renames are approved // - no table rename allowed -func (mgtr *Migrator) validateAlterStatement() (err error) { - if mgtr.parser.IsRenameTable() { +func (mig *Migrator) validateAlterStatement() (err error) { + if mig.parser.IsRenameTable() { return ErrMigratorUnsupportedRenameAlter } - if mgtr.parser.HasNonTrivialRenames() && !mgtr.migrationContext.SkipRenamedColumns { - mgtr.migrationContext.ColumnRenameMap = mgtr.parser.GetNonTrivialRenames() - if !mgtr.migrationContext.ApproveRenamedColumns { - return fmt.Errorf("gh-ost believes the ALTER statement renames columns, as follows: %v; as precaution, you are asked to confirm gh-ost is correct, and provide with `--approve-renamed-columns`, and we're all happy. Or you can skip renamed columns via `--skip-renamed-columns`, in which case column data may be lost", mgtr.parser.GetNonTrivialRenames()) + if mig.parser.HasNonTrivialRenames() && !mig.migrationContext.SkipRenamedColumns { + mig.migrationContext.ColumnRenameMap = mig.parser.GetNonTrivialRenames() + if !mig.migrationContext.ApproveRenamedColumns { + return fmt.Errorf("gh-ost believes the ALTER statement renames columns, as follows: %v; as precaution, you are asked to confirm gh-ost is correct, and provide with `--approve-renamed-columns`, and we're all happy. Or you can skip renamed columns via `--skip-renamed-columns`, in which case column data may be lost", mig.parser.GetNonTrivialRenames()) } - mgtr.migrationContext.Log.Infof("alter statement has column(s) renamed. gh-ost finds the following renames: %v; --approve-renamed-columns is given and so migration proceeds.", mgtr.parser.GetNonTrivialRenames()) + mig.migrationContext.Log.Infof("Alter statement has column(s) renamed. gh-ost finds the following renames: %v; --approve-renamed-columns is given and so migration proceeds.", mig.parser.GetNonTrivialRenames()) } - mgtr.migrationContext.DroppedColumnsMap = mgtr.parser.DroppedColumnsMap() + mig.migrationContext.DroppedColumnsMap = mig.parser.DroppedColumnsMap() return nil } -func (mgtr *Migrator) countTableRows() (err error) { - if !mgtr.migrationContext.CountTableRows { +func (mig *Migrator) countTableRows() (err error) { + if !mig.migrationContext.CountTableRows { // Not counting; we stay with an estimate return nil } - if mgtr.migrationContext.Noop { - mgtr.migrationContext.Log.Debugf("Noop operation; not really counting table rows") + if mig.migrationContext.Noop { + mig.migrationContext.Log.Debugf("Noop operation; not really counting table rows") return nil } countRowsFunc := func(ctx context.Context) error { - if err := mgtr.inspector.CountTableRows(ctx); err != nil { + if err := mig.inspector.CountTableRows(ctx); err != nil { return err } - if err := mgtr.hooksExecutor.OnRowCountComplete(); err != nil { + if err := mig.hooksExecutor.OnRowCountComplete(); err != nil { return err } return nil } - if mgtr.migrationContext.ConcurrentCountTableRows { + if mig.migrationContext.ConcurrentCountTableRows { // store a cancel func so we can stop this query before a cut over rowCountContext, rowCountCancel := context.WithCancel(context.Background()) - mgtr.migrationContext.SetCountTableRowsCancelFunc(rowCountCancel) + mig.migrationContext.SetCountTableRowsCancelFunc(rowCountCancel) - mgtr.migrationContext.Log.Infof("As instructed, counting rows in the background; meanwhile I will use an estimated count, and will update it later on") + mig.migrationContext.Log.Infof("As instructed, counting rows in the background; meanwhile I will use an estimated count, and will update it later on") go countRowsFunc(rowCountContext) // and we ignore errors, because this turns to be a background job @@ -418,30 +365,30 @@ func (mgtr *Migrator) countTableRows() (err error) { return countRowsFunc(context.Background()) } -func (mgtr *Migrator) createFlagFiles() (err error) { - if mgtr.migrationContext.PostponeCutOverFlagFile != "" { - if !base.FileExists(mgtr.migrationContext.PostponeCutOverFlagFile) { - if err := base.TouchFile(mgtr.migrationContext.PostponeCutOverFlagFile); err != nil { - return mgtr.migrationContext.Log.Errorf("--postpone-cut-over-flag-file indicated by gh-ost is unable to create said file: %s", err.Error()) +func (mig *Migrator) createFlagFiles() (err error) { + if mig.migrationContext.PostponeCutOverFlagFile != "" { + if !base.FileExists(mig.migrationContext.PostponeCutOverFlagFile) { + if err := base.TouchFile(mig.migrationContext.PostponeCutOverFlagFile); err != nil { + return mig.migrationContext.Log.Errorf("--postpone-cut-over-flag-file indicated by gh-ost is unable to create said file: %s", err.Error()) } - mgtr.migrationContext.Log.Infof("Created postpone-cut-over-flag-file: %s", mgtr.migrationContext.PostponeCutOverFlagFile) + mig.migrationContext.Log.Infof("Created postpone-cut-over-flag-file: %s", mig.migrationContext.PostponeCutOverFlagFile) } } return nil } // checkAbort returns abort error if migration was aborted -func (mgtr *Migrator) checkAbort() error { - if abortErr := mgtr.migrationContext.GetAbortError(); abortErr != nil { +func (mig *Migrator) checkAbort() error { + if abortErr := mig.migrationContext.GetAbortError(); abortErr != nil { return abortErr } - ctx := mgtr.migrationContext.GetContext() + ctx := mig.migrationContext.GetContext() if ctx != nil { select { case <-ctx.Done(): // Context cancelled but no abort error stored yet - if abortErr := mgtr.migrationContext.GetAbortError(); abortErr != nil { + if abortErr := mig.migrationContext.GetAbortError(); abortErr != nil { return abortErr } return ctx.Err() @@ -453,214 +400,236 @@ func (mgtr *Migrator) checkAbort() error { } // Migrate executes the complete migration logic. This is *the* major gh-ost function. -func (mgtr *Migrator) Migrate() (err error) { - mgtr.migrationContext.Log.Infof("Migrating %s.%s", sql.EscapeName(mgtr.migrationContext.DatabaseName), sql.EscapeName(mgtr.migrationContext.OriginalTableName)) - mgtr.migrationContext.StartTime = time.Now() - mgtr.migrationContext.SetLastHeartbeatOnChangelogTime(mgtr.migrationContext.StartTime) +func (mig *Migrator) Migrate() (err error) { + mig.migrationContext.Log.Infof("Migrating %s.%s", sql.EscapeName(mig.migrationContext.DatabaseName), sql.EscapeName(mig.migrationContext.OriginalTableName)) + mig.migrationContext.StartTime = time.Now() // Ensure context is cancelled on exit (cleanup) - defer mgtr.migrationContext.CancelContext() + defer mig.migrationContext.CancelContext() - if mgtr.migrationContext.Hostname, err = os.Hostname(); err != nil { + if mig.migrationContext.Hostname, err = os.Hostname(); err != nil { return err } - go mgtr.listenOnPanicAbort() + go mig.listenOnPanicAbort() - if err := mgtr.hooksExecutor.OnStartup(); err != nil { + if err := mig.hooksExecutor.OnStartup(); err != nil { return err } - if err := mgtr.parser.ParseAlterStatement(mgtr.migrationContext.AlterStatement); err != nil { + if err := mig.parser.ParseAlterStatement(mig.migrationContext.AlterStatement); err != nil { return err } - if err := mgtr.validateAlterStatement(); err != nil { + if err := mig.validateAlterStatement(); err != nil { return err } // After this point, we'll need to teardown anything that's been started // so we don't leave things hanging around - defer mgtr.teardown() + defer mig.teardown() - if err := mgtr.initiateInspector(); err != nil { + if err := mig.initiateInspector(); err != nil { return err } - if err := mgtr.checkAbort(); err != nil { + + mig.trxCoordinator = NewCoordinator(mig.migrationContext, mig.applier, mig.throttler, mig.onChangelogEvent) + + if err := mig.checkAbort(); err != nil { return err } // If we are resuming, we will initiateStreaming later when we know // the binlog coordinates to resume streaming from. // If not resuming, the streamer must be initiated before the applier, // so that the "GhostTableMigrated" event gets processed. - if !mgtr.migrationContext.Resume { - if err := mgtr.initiateStreaming(); err != nil { + if !mig.migrationContext.Resume { + if err := mig.initiateStreaming(); err != nil { return err } - if err := mgtr.checkAbort(); err != nil { + if err := mig.checkAbort(); err != nil { return err } } - if err := mgtr.initiateApplier(); err != nil { + if err := mig.initiateApplier(); err != nil { return err } - if err := mgtr.checkAbort(); err != nil { + if err := mig.checkAbort(); err != nil { return err } - if err := mgtr.createFlagFiles(); err != nil { + + mig.trxCoordinator.applier = mig.applier + + if err := mig.createFlagFiles(); err != nil { return err } // In MySQL 8.0 (and possibly earlier) some DDL statements can be applied instantly. // Attempt to do this if AttemptInstantDDL is set. - if mgtr.migrationContext.AttemptInstantDDL { - if mgtr.migrationContext.Noop { - mgtr.migrationContext.Log.Debugf("Noop operation; not really attempting instant DDL") + if mig.migrationContext.AttemptInstantDDL { + if mig.migrationContext.Noop { + mig.migrationContext.Log.Debugf("Noop operation; not really attempting instant DDL") } else { - mgtr.migrationContext.Log.Infof("Attempting to execute alter with ALGORITHM=INSTANT") - if err := mgtr.applier.AttemptInstantDDL(); err == nil { - if err := mgtr.finalCleanup(); err != nil { + mig.migrationContext.Log.Infof("Attempting to execute alter with ALGORITHM=INSTANT") + if err := mig.applier.AttemptInstantDDL(); err == nil { + if err := mig.finalCleanup(); err != nil { return nil } - if err := mgtr.hooksExecutor.OnSuccess(true); err != nil { + if err := mig.hooksExecutor.OnSuccess(true); err != nil { return err } - mgtr.migrationContext.Log.Infof("Success! table %s.%s migrated instantly", sql.EscapeName(mgtr.migrationContext.DatabaseName), sql.EscapeName(mgtr.migrationContext.OriginalTableName)) + mig.migrationContext.Log.Infof("Success! table %s.%s migrated instantly", sql.EscapeName(mig.migrationContext.DatabaseName), sql.EscapeName(mig.migrationContext.OriginalTableName)) return nil } else { - mgtr.migrationContext.Log.Infof("ALGORITHM=INSTANT not supported for this operation, proceeding with original algorithm: %s", err) + mig.migrationContext.Log.Infof("ALGORITHM=INSTANT not supported for this operation, proceeding with original algorithm: %s", err) } } } - initialLag, _ := mgtr.inspector.getReplicationLag() - if !mgtr.migrationContext.Resume { - mgtr.migrationContext.Log.Infof("Waiting for ghost table to be migrated. Current lag is %+v", initialLag) - <-mgtr.ghostTableMigrated - mgtr.migrationContext.Log.Debugf("ghost table migrated") + mig.migrationContext.Log.Infof("starting %d applier workers", mig.migrationContext.NumWorkers) + mig.trxCoordinator.InitializeWorkers(mig.migrationContext.NumWorkers) + + initialLag, _ := mig.inspector.getReplicationLag() + if !mig.migrationContext.Resume { + mig.migrationContext.Log.Infof("Waiting for ghost table to be migrated. Current lag is %+v", initialLag) + + waitForGhostTable: + for { + select { + case <-mig.ghostTableMigrated: + break waitForGhostTable + default: + dmlEvent, err := mig.trxCoordinator.ProcessEventsUntilNextChangelogEvent() + if err != nil { + return err + } + + mig.onChangelogEvent(dmlEvent) + } + } + + mig.migrationContext.Log.Debugf("ghost table migrated") } // Yay! We now know the Ghost and Changelog tables are good to examine! // When running on replica, this means the replica has those tables. When running // on master this is always true, of course, and yet it also implies this knowledge // is in the binlogs. - if err := mgtr.inspector.inspectOriginalAndGhostTables(); err != nil { + if err := mig.inspector.inspectOriginalAndGhostTables(); err != nil { return err } // We can prepare some of the queries on the applier - if err := mgtr.applier.prepareQueries(); err != nil { + if err := mig.applier.prepareQueries(); err != nil { return err } // inspectOriginalAndGhostTables must be called before creating checkpoint table. - if mgtr.migrationContext.Checkpoint && !mgtr.migrationContext.Resume { - if err := mgtr.applier.CreateCheckpointTable(); err != nil { - mgtr.migrationContext.Log.Errorf("unable to create checkpoint table, see further error details") + if mig.migrationContext.Checkpoint && !mig.migrationContext.Resume { + if err := mig.applier.CreateCheckpointTable(); err != nil { + mig.migrationContext.Log.Errorf("Unable to create checkpoint table, see further error details.") } } - if mgtr.migrationContext.Resume { - lastCheckpoint, err := mgtr.applier.ReadLastCheckpoint() + if mig.migrationContext.Resume { + lastCheckpoint, err := mig.applier.ReadLastCheckpoint() if err != nil { - return mgtr.migrationContext.Log.Errorf("no checkpoint found, unable to resume: %+v", err) + return mig.migrationContext.Log.Errorf("no checkpoint found, unable to resume: %+v", err) } - mgtr.migrationContext.Log.Infof("Resuming from checkpoint coords=%+v range_min=%+v range_max=%+v iteration=%d", + mig.migrationContext.Log.Infof("Resuming from checkpoint coords=%+v range_min=%+v range_max=%+v iteration=%d", lastCheckpoint.LastTrxCoords, lastCheckpoint.IterationRangeMin.String(), lastCheckpoint.IterationRangeMax.String(), lastCheckpoint.Iteration) - mgtr.migrationContext.MigrationIterationRangeMinValues = lastCheckpoint.IterationRangeMin - mgtr.migrationContext.MigrationIterationRangeMaxValues = lastCheckpoint.IterationRangeMax - mgtr.migrationContext.Iteration = lastCheckpoint.Iteration - mgtr.migrationContext.TotalRowsCopied = lastCheckpoint.RowsCopied - mgtr.migrationContext.TotalDMLEventsApplied = lastCheckpoint.DMLApplied - mgtr.migrationContext.InitialStreamerCoords = lastCheckpoint.LastTrxCoords - if err := mgtr.initiateStreaming(); err != nil { + mig.migrationContext.MigrationIterationRangeMinValues = lastCheckpoint.IterationRangeMin + mig.migrationContext.MigrationIterationRangeMaxValues = lastCheckpoint.IterationRangeMax + mig.migrationContext.Iteration = lastCheckpoint.Iteration + mig.migrationContext.TotalRowsCopied = lastCheckpoint.RowsCopied + mig.migrationContext.TotalDMLEventsApplied = lastCheckpoint.DMLApplied + mig.migrationContext.InitialStreamerCoords = lastCheckpoint.LastTrxCoords + if err := mig.initiateStreaming(); err != nil { return err } } // Validation complete! We're good to execute this migration - if err := mgtr.hooksExecutor.OnValidated(); err != nil { + if err := mig.hooksExecutor.OnValidated(); err != nil { return err } - if err := mgtr.initiateServer(); err != nil { + if err := mig.initiateServer(); err != nil { return err } - defer mgtr.server.RemoveSocketFile() + defer mig.server.RemoveSocketFile() - if err := mgtr.countTableRows(); err != nil { - return err - } - if err := mgtr.addDMLEventsListener(); err != nil { + if err := mig.countTableRows(); err != nil { return err } - if err := mgtr.applier.ReadMigrationRangeValues(); err != nil { + + if err := mig.applier.ReadMigrationRangeValues(); err != nil { return err } - mgtr.initiateThrottler() + mig.initiateThrottler() - if err := mgtr.hooksExecutor.OnBeforeRowCopy(); err != nil { + if err := mig.hooksExecutor.OnBeforeRowCopy(); err != nil { return err } go func() { - if err := mgtr.executeWriteFuncs(); err != nil { + if err := mig.executeWriteFuncs(); err != nil { // Send error to PanicAbort to trigger abort - _ = base.SendWithContext(mgtr.migrationContext.GetContext(), mgtr.migrationContext.PanicAbort, err) + _ = base.SendWithContext(mig.migrationContext.GetContext(), mig.migrationContext.PanicAbort, err) } }() - go mgtr.iterateChunks() - mgtr.migrationContext.MarkRowCopyStartTime() - go mgtr.initiateStatus() - if mgtr.migrationContext.Checkpoint { - go mgtr.checkpointLoop() + go mig.iterateChunks() + mig.migrationContext.MarkRowCopyStartTime() + go mig.initiateStatus() + if mig.migrationContext.Checkpoint { + go mig.checkpointLoop() } - mgtr.migrationContext.Log.Debugf("Operating until row copy is complete") - mgtr.consumeRowCopyComplete() - mgtr.migrationContext.Log.Infof("Row copy complete") + mig.migrationContext.Log.Debugf("Operating until row copy is complete") + mig.consumeRowCopyComplete() + mig.migrationContext.Log.Infof("Row copy complete") // Check if row copy was aborted due to error - if err := mgtr.checkAbort(); err != nil { + if err := mig.checkAbort(); err != nil { return err } - if err := mgtr.hooksExecutor.OnRowCopyComplete(); err != nil { + if err := mig.hooksExecutor.OnRowCopyComplete(); err != nil { return err } - mgtr.printStatus(ForcePrintStatusRule) + mig.printStatus(ForcePrintStatusRule) + mig.printWorkerStats() - if mgtr.migrationContext.IsCountingTableRows() { - mgtr.migrationContext.Log.Info("stopping query for exact row count, because that can accidentally lock out the cut over") - mgtr.migrationContext.CancelTableRowsCount() + if mig.migrationContext.IsCountingTableRows() { + mig.migrationContext.Log.Info("stopping query for exact row count, because that can accidentally lock out the cut over") + mig.migrationContext.CancelTableRowsCount() } - if err := mgtr.hooksExecutor.OnBeforeCutOver(); err != nil { + if err := mig.hooksExecutor.OnBeforeCutOver(); err != nil { return err } var retrier func(func() error, ...bool) error - if mgtr.migrationContext.CutOverExponentialBackoff { - retrier = mgtr.retryOperationWithExponentialBackoff + if mig.migrationContext.CutOverExponentialBackoff { + retrier = mig.retryOperationWithExponentialBackoff } else { - retrier = mgtr.retryOperation + retrier = mig.retryOperation } - if err := retrier(mgtr.cutOver); err != nil { + if err := retrier(mig.cutOver); err != nil { return err } - atomic.StoreInt64(&mgtr.migrationContext.CutOverCompleteFlag, 1) + atomic.StoreInt64(&mig.migrationContext.CutOverCompleteFlag, 1) - if mgtr.migrationContext.Checkpoint && !mgtr.migrationContext.Noop { - cutoverChk, err := mgtr.CheckpointAfterCutOver() + if mig.migrationContext.Checkpoint && !mig.migrationContext.Noop { + cutoverChk, err := mig.CheckpointAfterCutOver() if err != nil { - mgtr.migrationContext.Log.Warningf("failed to checkpoint after cutover: %+v", err) + mig.migrationContext.Log.Warningf("failed to checkpoint after cutover: %+v", err) } else { - mgtr.migrationContext.Log.Infof("checkpoint success after cutover at coords=%+v", cutoverChk.LastTrxCoords.DisplayString()) + mig.migrationContext.Log.Infof("checkpoint success after cutover at coords=%+v", cutoverChk.LastTrxCoords.DisplayString()) } } - if err := mgtr.finalCleanup(); err != nil { + if err := mig.finalCleanup(); err != nil { return nil } - if err := mgtr.hooksExecutor.OnSuccess(false); err != nil { + if err := mig.hooksExecutor.OnSuccess(false); err != nil { return err } - mgtr.migrationContext.Log.Infof("Done migrating %s.%s", sql.EscapeName(mgtr.migrationContext.DatabaseName), sql.EscapeName(mgtr.migrationContext.OriginalTableName)) + mig.migrationContext.Log.Infof("Done migrating %s.%s", sql.EscapeName(mig.migrationContext.DatabaseName), sql.EscapeName(mig.migrationContext.OriginalTableName)) // Final check for abort before declaring success - if err := mgtr.checkAbort(); err != nil { + if err := mig.checkAbort(); err != nil { return err } return nil @@ -669,144 +638,140 @@ func (mgtr *Migrator) Migrate() (err error) { // Revert reverts a migration that previously completed by applying all DML events that happened // after the original cutover, then doing another cutover to swap the tables back. // The steps are similar to Migrate(), but without row copying. -func (mgtr *Migrator) Revert() error { - mgtr.migrationContext.Log.Infof("Reverting %s.%s from %s.%s", - sql.EscapeName(mgtr.migrationContext.DatabaseName), sql.EscapeName(mgtr.migrationContext.OriginalTableName), - sql.EscapeName(mgtr.migrationContext.DatabaseName), sql.EscapeName(mgtr.migrationContext.OldTableName)) - mgtr.migrationContext.StartTime = time.Now() - mgtr.migrationContext.SetLastHeartbeatOnChangelogTime(mgtr.migrationContext.StartTime) +func (mig *Migrator) Revert() error { + mig.migrationContext.Log.Infof("Reverting %s.%s from %s.%s", + sql.EscapeName(mig.migrationContext.DatabaseName), sql.EscapeName(mig.migrationContext.OriginalTableName), + sql.EscapeName(mig.migrationContext.DatabaseName), sql.EscapeName(mig.migrationContext.OldTableName)) + mig.migrationContext.StartTime = time.Now() // Ensure context is cancelled on exit (cleanup) - defer mgtr.migrationContext.CancelContext() + defer mig.migrationContext.CancelContext() var err error - if mgtr.migrationContext.Hostname, err = os.Hostname(); err != nil { + if mig.migrationContext.Hostname, err = os.Hostname(); err != nil { return err } - go mgtr.listenOnPanicAbort() + go mig.listenOnPanicAbort() - if err := mgtr.hooksExecutor.OnStartup(); err != nil { + if err := mig.hooksExecutor.OnStartup(); err != nil { return err } - if err := mgtr.validateAlterStatement(); err != nil { + if err := mig.validateAlterStatement(); err != nil { return err } - defer mgtr.teardown() + defer mig.teardown() - if err := mgtr.initiateInspector(); err != nil { + if err := mig.initiateInspector(); err != nil { return err } - if err := mgtr.checkAbort(); err != nil { + if err := mig.checkAbort(); err != nil { return err } - if err := mgtr.initiateApplier(); err != nil { + if err := mig.initiateApplier(); err != nil { return err } - if err := mgtr.checkAbort(); err != nil { + if err := mig.checkAbort(); err != nil { return err } - if err := mgtr.createFlagFiles(); err != nil { + if err := mig.createFlagFiles(); err != nil { return err } - if err := mgtr.inspector.inspectOriginalAndGhostTables(); err != nil { + if err := mig.inspector.inspectOriginalAndGhostTables(); err != nil { return err } - if err := mgtr.applier.prepareQueries(); err != nil { + if err := mig.applier.prepareQueries(); err != nil { return err } - lastCheckpoint, err := mgtr.applier.ReadLastCheckpoint() + lastCheckpoint, err := mig.applier.ReadLastCheckpoint() if err != nil { - return mgtr.migrationContext.Log.Errorf("no checkpoint found, unable to revert: %+v", err) + return mig.migrationContext.Log.Errorf("no checkpoint found, unable to revert: %+v", err) } if !lastCheckpoint.IsCutover { - return mgtr.migrationContext.Log.Errorf("last checkpoint is not after cutover, unable to revert: coords=%+v time=%+v", lastCheckpoint.LastTrxCoords, lastCheckpoint.Timestamp) - } - mgtr.migrationContext.InitialStreamerCoords = lastCheckpoint.LastTrxCoords - mgtr.migrationContext.TotalRowsCopied = lastCheckpoint.RowsCopied - mgtr.migrationContext.MigrationIterationRangeMinValues = lastCheckpoint.IterationRangeMin - mgtr.migrationContext.MigrationIterationRangeMaxValues = lastCheckpoint.IterationRangeMax - if err := mgtr.initiateStreaming(); err != nil { - return err + return mig.migrationContext.Log.Errorf("Last checkpoint is not after cutover, unable to revert: coords=%+v time=%+v", lastCheckpoint.LastTrxCoords, lastCheckpoint.Timestamp) } - if err := mgtr.checkAbort(); err != nil { + mig.migrationContext.InitialStreamerCoords = lastCheckpoint.LastTrxCoords + mig.migrationContext.TotalRowsCopied = lastCheckpoint.RowsCopied + mig.migrationContext.MigrationIterationRangeMinValues = lastCheckpoint.IterationRangeMin + mig.migrationContext.MigrationIterationRangeMaxValues = lastCheckpoint.IterationRangeMax + if err := mig.initiateStreaming(); err != nil { return err } - if err := mgtr.hooksExecutor.OnValidated(); err != nil { + if err := mig.checkAbort(); err != nil { return err } - if err := mgtr.initiateServer(); err != nil { + if err := mig.hooksExecutor.OnValidated(); err != nil { return err } - defer mgtr.server.RemoveSocketFile() - if err := mgtr.addDMLEventsListener(); err != nil { + if err := mig.initiateServer(); err != nil { return err } + defer mig.server.RemoveSocketFile() - mgtr.initiateThrottler() - go mgtr.initiateStatus() + mig.initiateThrottler() + go mig.initiateStatus() go func() { - if err := mgtr.executeDMLWriteFuncs(); err != nil { + if err := mig.executeWriteFuncs(); err != nil { // Send error to PanicAbort to trigger abort - _ = base.SendWithContext(mgtr.migrationContext.GetContext(), mgtr.migrationContext.PanicAbort, err) + _ = base.SendWithContext(mig.migrationContext.GetContext(), mig.migrationContext.PanicAbort, err) } }() - mgtr.printStatus(ForcePrintStatusRule) + mig.printStatus(ForcePrintStatusRule) var retrier func(func() error, ...bool) error - if mgtr.migrationContext.CutOverExponentialBackoff { - retrier = mgtr.retryOperationWithExponentialBackoff + if mig.migrationContext.CutOverExponentialBackoff { + retrier = mig.retryOperationWithExponentialBackoff } else { - retrier = mgtr.retryOperation + retrier = mig.retryOperation } - if err := mgtr.hooksExecutor.OnBeforeCutOver(); err != nil { + if err := mig.hooksExecutor.OnBeforeCutOver(); err != nil { return err } - if err := retrier(mgtr.cutOver); err != nil { + if err := retrier(mig.cutOver); err != nil { return err } - atomic.StoreInt64(&mgtr.migrationContext.CutOverCompleteFlag, 1) - if err := mgtr.finalCleanup(); err != nil { + atomic.StoreInt64(&mig.migrationContext.CutOverCompleteFlag, 1) + if err := mig.finalCleanup(); err != nil { return nil } - if err := mgtr.hooksExecutor.OnSuccess(false); err != nil { + if err := mig.hooksExecutor.OnSuccess(false); err != nil { return err } - mgtr.migrationContext.Log.Infof("Done reverting %s.%s", sql.EscapeName(mgtr.migrationContext.DatabaseName), sql.EscapeName(mgtr.migrationContext.OriginalTableName)) + mig.migrationContext.Log.Infof("Done reverting %s.%s", sql.EscapeName(mig.migrationContext.DatabaseName), sql.EscapeName(mig.migrationContext.OriginalTableName)) return nil } // ExecOnFailureHook executes the onFailure hook, and this method is provided as the only external // hook access point -func (mgtr *Migrator) ExecOnFailureHook() (err error) { - return mgtr.hooksExecutor.OnFailure() +func (mig *Migrator) ExecOnFailureHook() (err error) { + return mig.hooksExecutor.OnFailure() } -func (mgtr *Migrator) handleCutOverResult(cutOverError error) (err error) { - if mgtr.migrationContext.TestOnReplica { +func (mig *Migrator) handleCutOverResult(cutOverError error) (err error) { + if mig.migrationContext.TestOnReplica { // We're merely testing, we don't want to keep this state. Rollback the renames as possible - mgtr.applier.RenameTablesRollback() + mig.applier.RenameTablesRollback() } if cutOverError == nil { return nil } // Only on error: - if mgtr.migrationContext.TestOnReplica { + if mig.migrationContext.TestOnReplica { // With `--test-on-replica` we stop replication thread, and then proceed to use // the same cut-over phase as the master would use. That means we take locks // and swap the tables. // The difference is that we will later swap the tables back. - if err := mgtr.hooksExecutor.OnStartReplication(); err != nil { - return mgtr.migrationContext.Log.Errore(err) + if err := mig.hooksExecutor.OnStartReplication(); err != nil { + return mig.migrationContext.Log.Errore(err) } - if mgtr.migrationContext.TestOnReplicaSkipReplicaStop { - mgtr.migrationContext.Log.Warningf("--test-on-replica-skip-replica-stop enabled, we are not starting replication.") + if mig.migrationContext.TestOnReplicaSkipReplicaStop { + mig.migrationContext.Log.Warningf("--test-on-replica-skip-replica-stop enabled, we are not starting replication.") } else { - mgtr.migrationContext.Log.Debugf("testing on replica. Starting replication IO thread after cut-over failure") - if err := mgtr.retryOperation(mgtr.applier.StartReplication); err != nil { - return mgtr.migrationContext.Log.Errore(err) + mig.migrationContext.Log.Debugf("testing on replica. Starting replication IO thread after cut-over failure") + if err := mig.retryOperation(mig.applier.StartReplication); err != nil { + return mig.migrationContext.Log.Errore(err) } } } @@ -815,42 +780,42 @@ func (mgtr *Migrator) handleCutOverResult(cutOverError error) (err error) { // cutOver performs the final step of migration, based on migration // type (on replica? atomic? safe?) -func (mgtr *Migrator) cutOver() (err error) { - if mgtr.migrationContext.Noop { - mgtr.migrationContext.Log.Debugf("Noop operation; not really swapping tables") +func (mig *Migrator) cutOver() (err error) { + if mig.migrationContext.Noop { + mig.migrationContext.Log.Debugf("Noop operation; not really swapping tables") return nil } - mgtr.migrationContext.MarkPointOfInterest() - mgtr.throttler.throttle(func() { - mgtr.migrationContext.Log.Debugf("throttling before swapping tables") + mig.migrationContext.MarkPointOfInterest() + mig.throttler.throttle(func() { + mig.migrationContext.Log.Debugf("throttling before swapping tables") }) - mgtr.migrationContext.MarkPointOfInterest() - mgtr.migrationContext.Log.Debugf("checking for cut-over postpone") - if err := mgtr.sleepWhileTrue( + mig.migrationContext.MarkPointOfInterest() + mig.migrationContext.Log.Debugf("checking for cut-over postpone") + if err := mig.sleepWhileTrue( func() (bool, error) { - heartbeatLag := mgtr.migrationContext.TimeSinceLastHeartbeatOnChangelog() - maxLagMillisecondsThrottle := time.Duration(atomic.LoadInt64(&mgtr.migrationContext.MaxLagMillisecondsThrottleThreshold)) * time.Millisecond - cutOverLockTimeout := time.Duration(mgtr.migrationContext.CutOverLockTimeoutSeconds) * time.Second + heartbeatLag := mig.migrationContext.TimeSinceLastHeartbeatOnChangelog() + maxLagMillisecondsThrottle := time.Duration(atomic.LoadInt64(&mig.migrationContext.MaxLagMillisecondsThrottleThreshold)) * time.Millisecond + cutOverLockTimeout := time.Duration(mig.migrationContext.CutOverLockTimeoutSeconds) * time.Second if heartbeatLag > maxLagMillisecondsThrottle || heartbeatLag > cutOverLockTimeout { - mgtr.migrationContext.Log.Debugf("current HeartbeatLag (%.2fs) is too high, it needs to be less than both --max-lag-millis (%.2fs) and --cut-over-lock-timeout-seconds (%.2fs) to continue", heartbeatLag.Seconds(), maxLagMillisecondsThrottle.Seconds(), cutOverLockTimeout.Seconds()) + mig.migrationContext.Log.Debugf("current HeartbeatLag (%.2fs) is too high, it needs to be less than both --max-lag-millis (%.2fs) and --cut-over-lock-timeout-seconds (%.2fs) to continue", heartbeatLag.Seconds(), maxLagMillisecondsThrottle.Seconds(), cutOverLockTimeout.Seconds()) return true, nil } - if mgtr.migrationContext.PostponeCutOverFlagFile == "" { + if mig.migrationContext.PostponeCutOverFlagFile == "" { return false, nil } - if atomic.LoadInt64(&mgtr.migrationContext.UserCommandedUnpostponeFlag) > 0 { - atomic.StoreInt64(&mgtr.migrationContext.UserCommandedUnpostponeFlag, 0) + if atomic.LoadInt64(&mig.migrationContext.UserCommandedUnpostponeFlag) > 0 { + atomic.StoreInt64(&mig.migrationContext.UserCommandedUnpostponeFlag, 0) return false, nil } - if base.FileExists(mgtr.migrationContext.PostponeCutOverFlagFile) { + if base.FileExists(mig.migrationContext.PostponeCutOverFlagFile) { // Postpone file defined and exists! - if atomic.LoadInt64(&mgtr.migrationContext.IsPostponingCutOver) == 0 { - if err := mgtr.hooksExecutor.OnBeginPostponed(); err != nil { + if atomic.LoadInt64(&mig.migrationContext.IsPostponingCutOver) == 0 { + if err := mig.hooksExecutor.OnBeginPostponed(); err != nil { return true, err } } - atomic.StoreInt64(&mgtr.migrationContext.IsPostponingCutOver, 1) + atomic.StoreInt64(&mig.migrationContext.IsPostponingCutOver, 1) return true, nil } return false, nil @@ -858,80 +823,80 @@ func (mgtr *Migrator) cutOver() (err error) { ); err != nil { return err } - atomic.StoreInt64(&mgtr.migrationContext.IsPostponingCutOver, 0) - mgtr.migrationContext.MarkPointOfInterest() - mgtr.migrationContext.Log.Debugf("checking for cut-over postpone: complete") + atomic.StoreInt64(&mig.migrationContext.IsPostponingCutOver, 0) + mig.migrationContext.MarkPointOfInterest() + mig.migrationContext.Log.Debugf("checking for cut-over postpone: complete") - if mgtr.migrationContext.TestOnReplica { + if mig.migrationContext.TestOnReplica { // With `--test-on-replica` we stop replication thread, and then proceed to use // the same cut-over phase as the master would use. That means we take locks // and swap the tables. // The difference is that we will later swap the tables back. - if err := mgtr.hooksExecutor.OnStopReplication(); err != nil { + if err := mig.hooksExecutor.OnStopReplication(); err != nil { return err } - if mgtr.migrationContext.TestOnReplicaSkipReplicaStop { - mgtr.migrationContext.Log.Warningf("--test-on-replica-skip-replica-stop enabled, we are not stopping replication.") + if mig.migrationContext.TestOnReplicaSkipReplicaStop { + mig.migrationContext.Log.Warningf("--test-on-replica-skip-replica-stop enabled, we are not stopping replication.") } else { - mgtr.migrationContext.Log.Debugf("testing on replica. Stopping replication IO thread") - if err := mgtr.retryOperation(mgtr.applier.StopReplication); err != nil { + mig.migrationContext.Log.Debugf("testing on replica. Stopping replication IO thread") + if err := mig.retryOperation(mig.applier.StopReplication); err != nil { return err } } } - switch mgtr.migrationContext.CutOverType { + switch mig.migrationContext.CutOverType { case base.CutOverAtomic: // Atomic solution: we use low timeout and multiple attempts. But for // each failed attempt, we throttle until replication lag is back to normal - err = mgtr.atomicCutOver() + err = mig.atomicCutOver() case base.CutOverTwoStep: - err = mgtr.cutOverTwoStep() + err = mig.cutOverTwoStep() default: - return mgtr.migrationContext.Log.Fatalf("Unknown cut-over type: %d; should never get here!", mgtr.migrationContext.CutOverType) + return mig.migrationContext.Log.Fatalf("Unknown cut-over type: %d; should never get here!", mig.migrationContext.CutOverType) } - mgtr.handleCutOverResult(err) + mig.handleCutOverResult(err) return err } // Inject the "AllEventsUpToLockProcessed" state hint, wait for it to appear in the binary logs, // make sure the queue is drained. -func (mgtr *Migrator) waitForEventsUpToLock() error { - timeout := time.NewTimer(time.Second * time.Duration(mgtr.migrationContext.CutOverLockTimeoutSeconds)) +func (mig *Migrator) waitForEventsUpToLock() error { + timeout := time.NewTimer(time.Second * time.Duration(mig.migrationContext.CutOverLockTimeoutSeconds)) - mgtr.migrationContext.MarkPointOfInterest() + mig.migrationContext.MarkPointOfInterest() waitForEventsUpToLockStartTime := time.Now() allEventsUpToLockProcessedChallenge := fmt.Sprintf("%s:%d", string(AllEventsUpToLockProcessed), waitForEventsUpToLockStartTime.UnixNano()) - mgtr.migrationContext.Log.Infof("Writing changelog state: %+v", allEventsUpToLockProcessedChallenge) - if _, err := mgtr.applier.WriteChangelogState(allEventsUpToLockProcessedChallenge); err != nil { + mig.migrationContext.Log.Infof("Writing changelog state: %+v", allEventsUpToLockProcessedChallenge) + if _, err := mig.applier.WriteChangelogState(allEventsUpToLockProcessedChallenge); err != nil { return err } - mgtr.migrationContext.Log.Infof("Waiting for events up to lock") - atomic.StoreInt64(&mgtr.migrationContext.AllEventsUpToLockProcessedInjectedFlag, 1) + mig.migrationContext.Log.Infof("Waiting for events up to lock") + atomic.StoreInt64(&mig.migrationContext.AllEventsUpToLockProcessedInjectedFlag, 1) var lockProcessed *lockProcessedStruct for found := false; !found; { select { case <-timeout.C: { - return mgtr.migrationContext.Log.Errorf("timeout while waiting for events up to lock") + return mig.migrationContext.Log.Errorf("Timeout while waiting for events up to lock") } - case lockProcessed = <-mgtr.allEventsUpToLockProcessed: + case lockProcessed = <-mig.allEventsUpToLockProcessed: { if lockProcessed.state == allEventsUpToLockProcessedChallenge { - mgtr.migrationContext.Log.Infof("Waiting for events up to lock: got %s", lockProcessed.state) + mig.migrationContext.Log.Infof("Waiting for events up to lock: got %s", lockProcessed.state) found = true - mgtr.lastLockProcessed = lockProcessed + mig.lastLockProcessed = lockProcessed } else { - mgtr.migrationContext.Log.Infof("Waiting for events up to lock: skipping %s", lockProcessed.state) + mig.migrationContext.Log.Infof("Waiting for events up to lock: skipping %s", lockProcessed.state) } } } } waitForEventsUpToLockDuration := time.Since(waitForEventsUpToLockStartTime) - mgtr.migrationContext.Log.Infof("Done waiting for events up to lock; duration=%+v", waitForEventsUpToLockDuration) - mgtr.printStatus(ForcePrintStatusAndHintRule) + mig.migrationContext.Log.Infof("Done waiting for events up to lock; duration=%+v", waitForEventsUpToLockDuration) + mig.printStatus(ForcePrintStatusAndHintRule) return nil } @@ -940,92 +905,92 @@ func (mgtr *Migrator) waitForEventsUpToLock() error { // what's left of last DML entries, and **non-atomically** swap original->old, then new->original. // There is a point in time where the "original" table does not exist and queries are non-blocked // and failing. -func (mgtr *Migrator) cutOverTwoStep() (err error) { - atomic.StoreInt64(&mgtr.migrationContext.InCutOverCriticalSectionFlag, 1) - defer atomic.StoreInt64(&mgtr.migrationContext.InCutOverCriticalSectionFlag, 0) - atomic.StoreInt64(&mgtr.migrationContext.AllEventsUpToLockProcessedInjectedFlag, 0) +func (mig *Migrator) cutOverTwoStep() (err error) { + atomic.StoreInt64(&mig.migrationContext.InCutOverCriticalSectionFlag, 1) + defer atomic.StoreInt64(&mig.migrationContext.InCutOverCriticalSectionFlag, 0) + atomic.StoreInt64(&mig.migrationContext.AllEventsUpToLockProcessedInjectedFlag, 0) - if err := mgtr.retryOperation(mgtr.applier.LockOriginalTable); err != nil { + if err := mig.retryOperation(mig.applier.LockOriginalTable); err != nil { return err } - if err := mgtr.retryOperation(mgtr.waitForEventsUpToLock); err != nil { + if err := mig.retryOperation(mig.waitForEventsUpToLock); err != nil { return err } // If we need to create triggers we need to do it here (only create part) - if mgtr.migrationContext.IncludeTriggers && len(mgtr.migrationContext.Triggers) > 0 { - if err := mgtr.retryOperation(mgtr.applier.CreateTriggersOnGhost); err != nil { + if mig.migrationContext.IncludeTriggers && len(mig.migrationContext.Triggers) > 0 { + if err := mig.retryOperation(mig.applier.CreateTriggersOnGhost); err != nil { return err } } - if err := mgtr.retryOperation(mgtr.applier.SwapTablesQuickAndBumpy); err != nil { + if err := mig.retryOperation(mig.applier.SwapTablesQuickAndBumpy); err != nil { return err } - if err := mgtr.retryOperation(mgtr.applier.UnlockTables); err != nil { + if err := mig.retryOperation(mig.applier.UnlockTables); err != nil { return err } - lockAndRenameDuration := mgtr.migrationContext.RenameTablesEndTime.Sub(mgtr.migrationContext.LockTablesStartTime) - renameDuration := mgtr.migrationContext.RenameTablesEndTime.Sub(mgtr.migrationContext.RenameTablesStartTime) - mgtr.migrationContext.Log.Debugf("Lock & rename duration: %s (rename only: %s). During mgtr time, queries on %s were locked or failing", lockAndRenameDuration, renameDuration, sql.EscapeName(mgtr.migrationContext.OriginalTableName)) + lockAndRenameDuration := mig.migrationContext.RenameTablesEndTime.Sub(mig.migrationContext.LockTablesStartTime) + renameDuration := mig.migrationContext.RenameTablesEndTime.Sub(mig.migrationContext.RenameTablesStartTime) + mig.migrationContext.Log.Debugf("Lock & rename duration: %s (rename only: %s). During this time, queries on %s were locked or failing", lockAndRenameDuration, renameDuration, sql.EscapeName(mig.migrationContext.OriginalTableName)) return nil } // atomicCutOver -func (mgtr *Migrator) atomicCutOver() (err error) { - atomic.StoreInt64(&mgtr.migrationContext.InCutOverCriticalSectionFlag, 1) - defer atomic.StoreInt64(&mgtr.migrationContext.InCutOverCriticalSectionFlag, 0) +func (mig *Migrator) atomicCutOver() (err error) { + atomic.StoreInt64(&mig.migrationContext.InCutOverCriticalSectionFlag, 1) + defer atomic.StoreInt64(&mig.migrationContext.InCutOverCriticalSectionFlag, 0) okToUnlockTable := make(chan bool, 4) defer func() { okToUnlockTable <- true }() - atomic.StoreInt64(&mgtr.migrationContext.AllEventsUpToLockProcessedInjectedFlag, 0) + atomic.StoreInt64(&mig.migrationContext.AllEventsUpToLockProcessedInjectedFlag, 0) lockOriginalSessionIdChan := make(chan int64, 2) tableLocked := make(chan error, 2) tableUnlocked := make(chan error, 2) var renameLockSessionId int64 go func() { - if err := mgtr.applier.AtomicCutOverMagicLock(lockOriginalSessionIdChan, tableLocked, okToUnlockTable, tableUnlocked, &renameLockSessionId); err != nil { - mgtr.migrationContext.Log.Errore(err) + if err := mig.applier.AtomicCutOverMagicLock(lockOriginalSessionIdChan, tableLocked, okToUnlockTable, tableUnlocked, &renameLockSessionId); err != nil { + mig.migrationContext.Log.Errore(err) } }() if err := <-tableLocked; err != nil { - return mgtr.migrationContext.Log.Errore(err) + return mig.migrationContext.Log.Errore(err) } lockOriginalSessionId := <-lockOriginalSessionIdChan - mgtr.migrationContext.Log.Infof("Session locking original & magic tables is %+v", lockOriginalSessionId) + mig.migrationContext.Log.Infof("Session locking original & magic tables is %+v", lockOriginalSessionId) // At this point we know the original table is locked. // We know any newly incoming DML on original table is blocked. - if err := mgtr.waitForEventsUpToLock(); err != nil { - return mgtr.migrationContext.Log.Errore(err) + if err := mig.waitForEventsUpToLock(); err != nil { + return mig.migrationContext.Log.Errore(err) } // If we need to create triggers we need to do it here (only create part) - if mgtr.migrationContext.IncludeTriggers && len(mgtr.migrationContext.Triggers) > 0 { - if err := mgtr.applier.CreateTriggersOnGhost(); err != nil { - return mgtr.migrationContext.Log.Errore(err) + if mig.migrationContext.IncludeTriggers && len(mig.migrationContext.Triggers) > 0 { + if err := mig.applier.CreateTriggersOnGhost(); err != nil { + return mig.migrationContext.Log.Errore(err) } } // Step 2 // We now attempt an atomic RENAME on original & ghost tables, and expect it to block. - mgtr.migrationContext.RenameTablesStartTime = time.Now() + mig.migrationContext.RenameTablesStartTime = time.Now() var tableRenameKnownToHaveFailed int64 renameSessionIdChan := make(chan int64, 2) tablesRenamed := make(chan error, 2) go func() { - if err := mgtr.applier.AtomicCutoverRename(renameSessionIdChan, tablesRenamed); err != nil { + if err := mig.applier.AtomicCutoverRename(renameSessionIdChan, tablesRenamed); err != nil { // Abort! Release the lock atomic.StoreInt64(&tableRenameKnownToHaveFailed, 1) okToUnlockTable <- true } }() renameSessionId := <-renameSessionIdChan - mgtr.migrationContext.Log.Infof("Session renaming tables is %+v", renameSessionId) + mig.migrationContext.Log.Infof("Session renaming tables is %+v", renameSessionId) waitForRename := func() error { if atomic.LoadInt64(&tableRenameKnownToHaveFailed) == 1 { @@ -1033,22 +998,22 @@ func (mgtr *Migrator) atomicCutOver() (err error) { // it won't show up in PROCESSLIST, no point in waiting return nil } - return mgtr.applier.ExpectProcess(renameSessionId, "metadata lock", "rename") + return mig.applier.ExpectProcess(renameSessionId, "metadata lock", "rename") } // Wait for the RENAME to appear in PROCESSLIST - if err := mgtr.retryOperation(waitForRename, true); err != nil { + if err := mig.retryOperation(waitForRename, true); err != nil { // Abort! Release the lock okToUnlockTable <- true return err } if atomic.LoadInt64(&tableRenameKnownToHaveFailed) == 0 { - mgtr.migrationContext.Log.Infof("Found atomic RENAME to be blocking, as expected. Double checking the lock is still in place (though I don't strictly have to)") + mig.migrationContext.Log.Infof("Found atomic RENAME to be blocking, as expected. Double checking the lock is still in place (though I don't strictly have to)") } - if err := mgtr.applier.ExpectUsedLock(lockOriginalSessionId); err != nil { + if err := mig.applier.ExpectUsedLock(lockOriginalSessionId); err != nil { // Abort operation. Just make sure to drop the magic table. - return mgtr.migrationContext.Log.Errore(err) + return mig.migrationContext.Log.Errore(err) } - mgtr.migrationContext.Log.Infof("Connection holding lock on original table still exists") + mig.migrationContext.Log.Infof("Connection holding lock on original table still exists") // Now that we've found the RENAME blocking, AND the locking connection still alive, // we know it is safe to proceed to release the lock @@ -1058,33 +1023,36 @@ func (mgtr *Migrator) atomicCutOver() (err error) { // BAM! magic table dropped, original table lock is released // -> RENAME released -> queries on original are unblocked. if err := <-tableUnlocked; err != nil { - return mgtr.migrationContext.Log.Errore(err) + return mig.migrationContext.Log.Errore(err) } if err := <-tablesRenamed; err != nil { - return mgtr.migrationContext.Log.Errore(err) + return mig.migrationContext.Log.Errore(err) } - mgtr.migrationContext.RenameTablesEndTime = time.Now() + mig.migrationContext.RenameTablesEndTime = time.Now() // ooh nice! We're actually truly and thankfully done - lockAndRenameDuration := mgtr.migrationContext.RenameTablesEndTime.Sub(mgtr.migrationContext.LockTablesStartTime) - mgtr.migrationContext.Log.Infof("Lock & rename duration: %s. During mgtr time, queries on %s were blocked", lockAndRenameDuration, sql.EscapeName(mgtr.migrationContext.OriginalTableName)) + lockAndRenameDuration := mig.migrationContext.RenameTablesEndTime.Sub(mig.migrationContext.LockTablesStartTime) + mig.migrationContext.Log.Infof("Lock & rename duration: %s. During this time, queries on %s were blocked", lockAndRenameDuration, sql.EscapeName(mig.migrationContext.OriginalTableName)) return nil } // initiateServer begins listening on unix socket/tcp for incoming interactive commands -func (mgtr *Migrator) initiateServer() (err error) { - var f printStatusFunc = func(rule PrintStatusRule, writer io.Writer) { - mgtr.printStatus(rule, writer) +func (mig *Migrator) initiateServer() (err error) { + var printStatus printStatusFunc = func(rule PrintStatusRule, writer io.Writer) { + mig.printStatus(rule, writer) } - mgtr.server = NewServer(mgtr.migrationContext, mgtr.hooksExecutor, f) - if err := mgtr.server.BindSocketFile(); err != nil { + var printWorkers printWorkersFunc = func(writer io.Writer) { + mig.printWorkerStats(writer) + } + mig.server = NewServer(mig.migrationContext, mig.hooksExecutor, printStatus, printWorkers) + if err := mig.server.BindSocketFile(); err != nil { return err } - if err := mgtr.server.BindTCPPort(); err != nil { + if err := mig.server.BindTCPPort(); err != nil { return err } - go mgtr.server.Serve() + go mig.server.Serve() return nil } @@ -1095,59 +1063,59 @@ func (mgtr *Migrator) initiateServer() (err error) { // - schema validation // - heartbeat // When `--allow-on-master` is supplied, the inspector is actually the master. -func (mgtr *Migrator) initiateInspector() (err error) { - mgtr.inspector = NewInspector(mgtr.migrationContext) - if err := mgtr.inspector.InitDBConnections(); err != nil { +func (mig *Migrator) initiateInspector() (err error) { + mig.inspector = NewInspector(mig.migrationContext) + if err := mig.inspector.InitDBConnections(); err != nil { return err } - if err := mgtr.inspector.ValidateOriginalTable(); err != nil { + if err := mig.inspector.ValidateOriginalTable(); err != nil { return err } - if err := mgtr.inspector.InspectOriginalTable(); err != nil { + if err := mig.inspector.InspectOriginalTable(); err != nil { return err } // So far so good, table is accessible and valid. // Let's get master connection config - if mgtr.migrationContext.AssumeMasterHostname == "" { + if mig.migrationContext.AssumeMasterHostname == "" { // No forced master host; detect master - if mgtr.migrationContext.ApplierConnectionConfig, err = mgtr.inspector.getMasterConnectionConfig(); err != nil { + if mig.migrationContext.ApplierConnectionConfig, err = mig.inspector.getMasterConnectionConfig(); err != nil { return err } - mgtr.migrationContext.Log.Infof("Master found to be %+v", *mgtr.migrationContext.ApplierConnectionConfig.ImpliedKey) + mig.migrationContext.Log.Infof("Master found to be %+v", *mig.migrationContext.ApplierConnectionConfig.ImpliedKey) } else { // Forced master host. - key, err := mysql.ParseInstanceKey(mgtr.migrationContext.AssumeMasterHostname) + key, err := mysql.ParseInstanceKey(mig.migrationContext.AssumeMasterHostname) if err != nil { return err } - mgtr.migrationContext.ApplierConnectionConfig = mgtr.migrationContext.InspectorConnectionConfig.DuplicateCredentials(*key) - if mgtr.migrationContext.CliMasterUser != "" { - mgtr.migrationContext.ApplierConnectionConfig.User = mgtr.migrationContext.CliMasterUser + mig.migrationContext.ApplierConnectionConfig = mig.migrationContext.InspectorConnectionConfig.DuplicateCredentials(*key) + if mig.migrationContext.CliMasterUser != "" { + mig.migrationContext.ApplierConnectionConfig.User = mig.migrationContext.CliMasterUser } - if mgtr.migrationContext.CliMasterPassword != "" { - mgtr.migrationContext.ApplierConnectionConfig.Password = mgtr.migrationContext.CliMasterPassword + if mig.migrationContext.CliMasterPassword != "" { + mig.migrationContext.ApplierConnectionConfig.Password = mig.migrationContext.CliMasterPassword } - if err := mgtr.migrationContext.ApplierConnectionConfig.RegisterTLSConfig(); err != nil { + if err := mig.migrationContext.ApplierConnectionConfig.RegisterTLSConfig(); err != nil { return err } - mgtr.migrationContext.Log.Infof("Master forced to be %+v", *mgtr.migrationContext.ApplierConnectionConfig.ImpliedKey) + mig.migrationContext.Log.Infof("Master forced to be %+v", *mig.migrationContext.ApplierConnectionConfig.ImpliedKey) } // validate configs - if mgtr.migrationContext.TestOnReplica || mgtr.migrationContext.MigrateOnReplica { - if mgtr.migrationContext.InspectorIsAlsoApplier() { + if mig.migrationContext.TestOnReplica || mig.migrationContext.MigrateOnReplica { + if mig.migrationContext.InspectorIsAlsoApplier() { return fmt.Errorf("instructed to --test-on-replica or --migrate-on-replica, but the server we connect to doesn't seem to be a replica") } - mgtr.migrationContext.Log.Infof("--test-on-replica or --migrate-on-replica given. Will not execute on master %+v but rather on replica %+v itself", - *mgtr.migrationContext.ApplierConnectionConfig.ImpliedKey, *mgtr.migrationContext.InspectorConnectionConfig.ImpliedKey, + mig.migrationContext.Log.Infof("--test-on-replica or --migrate-on-replica given. Will not execute on master %+v but rather on replica %+v itself", + *mig.migrationContext.ApplierConnectionConfig.ImpliedKey, *mig.migrationContext.InspectorConnectionConfig.ImpliedKey, ) - mgtr.migrationContext.ApplierConnectionConfig = mgtr.migrationContext.InspectorConnectionConfig.Duplicate() - if mgtr.migrationContext.GetThrottleControlReplicaKeys().Len() == 0 { - mgtr.migrationContext.AddThrottleControlReplicaKey(mgtr.migrationContext.InspectorConnectionConfig.Key) + mig.migrationContext.ApplierConnectionConfig = mig.migrationContext.InspectorConnectionConfig.Duplicate() + if mig.migrationContext.GetThrottleControlReplicaKeys().Len() == 0 { + mig.migrationContext.AddThrottleControlReplicaKey(mig.migrationContext.InspectorConnectionConfig.Key) } - } else if mgtr.migrationContext.InspectorIsAlsoApplier() && !mgtr.migrationContext.AllowedRunningOnMaster { + } else if mig.migrationContext.InspectorIsAlsoApplier() && !mig.migrationContext.AllowedRunningOnMaster { return ErrMigrationNotAllowedOnMaster } - if err := mgtr.inspector.validateLogSlaveUpdates(); err != nil { + if err := mig.inspector.validateLogSlaveUpdates(); err != nil { return err } @@ -1155,20 +1123,20 @@ func (mgtr *Migrator) initiateInspector() (err error) { } // initiateStatus sets and activates the printStatus() ticker -func (mgtr *Migrator) initiateStatus() { - mgtr.printStatus(ForcePrintStatusAndHintRule) +func (mig *Migrator) initiateStatus() { + mig.printStatus(ForcePrintStatusAndHintRule) ticker := time.NewTicker(time.Second) defer ticker.Stop() var previousCount int64 for range ticker.C { - if atomic.LoadInt64(&mgtr.finishedMigrating) > 0 { + if atomic.LoadInt64(&mig.finishedMigrating) > 0 { return } - go mgtr.printStatus(HeuristicPrintStatusRule) - totalCopied := atomic.LoadInt64(&mgtr.migrationContext.TotalRowsCopied) + go mig.printStatus(HeuristicPrintStatusRule) + totalCopied := atomic.LoadInt64(&mig.migrationContext.TotalRowsCopied) if previousCount > 0 { copiedThisLoop := totalCopied - previousCount - atomic.StoreInt64(&mgtr.migrationContext.EtaRowsPerSecond, copiedThisLoop) + atomic.StoreInt64(&mig.migrationContext.EtaRowsPerSecond, copiedThisLoop) } previousCount = totalCopied } @@ -1178,101 +1146,101 @@ func (mgtr *Migrator) initiateStatus() { // to keep in mind; such as the name of migrated table, throttle params etc. // This gets printed at beginning and end of migration, every 10 minutes throughout // migration, and as response to the "status" interactive command. -func (mgtr *Migrator) printMigrationStatusHint(writers ...io.Writer) { +func (mig *Migrator) printMigrationStatusHint(writers ...io.Writer) { w := io.MultiWriter(writers...) fmt.Fprintf(w, "# Migrating %s.%s; Ghost table is %s.%s\n", - sql.EscapeName(mgtr.migrationContext.DatabaseName), - sql.EscapeName(mgtr.migrationContext.OriginalTableName), - sql.EscapeName(mgtr.migrationContext.DatabaseName), - sql.EscapeName(mgtr.migrationContext.GetGhostTableName()), + sql.EscapeName(mig.migrationContext.DatabaseName), + sql.EscapeName(mig.migrationContext.OriginalTableName), + sql.EscapeName(mig.migrationContext.DatabaseName), + sql.EscapeName(mig.migrationContext.GetGhostTableName()), ) fmt.Fprintf(w, "# Migrating %+v; inspecting %+v; executing on %+v\n", - *mgtr.applier.connectionConfig.ImpliedKey, - *mgtr.inspector.connectionConfig.ImpliedKey, - mgtr.migrationContext.Hostname, + *mig.applier.connectionConfig.ImpliedKey, + *mig.inspector.connectionConfig.ImpliedKey, + mig.migrationContext.Hostname, ) fmt.Fprintf(w, "# Migration started at %+v\n", - mgtr.migrationContext.StartTime.Format(time.RubyDate), + mig.migrationContext.StartTime.Format(time.RubyDate), ) - maxLoad := mgtr.migrationContext.GetMaxLoad() - criticalLoad := mgtr.migrationContext.GetCriticalLoad() + maxLoad := mig.migrationContext.GetMaxLoad() + criticalLoad := mig.migrationContext.GetCriticalLoad() fmt.Fprintf(w, "# chunk-size: %+v; max-lag-millis: %+vms; dml-batch-size: %+v; max-load: %s; critical-load: %s; nice-ratio: %f\n", - atomic.LoadInt64(&mgtr.migrationContext.ChunkSize), - atomic.LoadInt64(&mgtr.migrationContext.MaxLagMillisecondsThrottleThreshold), - atomic.LoadInt64(&mgtr.migrationContext.DMLBatchSize), + atomic.LoadInt64(&mig.migrationContext.ChunkSize), + atomic.LoadInt64(&mig.migrationContext.MaxLagMillisecondsThrottleThreshold), + atomic.LoadInt64(&mig.migrationContext.DMLBatchSize), maxLoad.String(), criticalLoad.String(), - mgtr.migrationContext.GetNiceRatio(), + mig.migrationContext.GetNiceRatio(), ) - if mgtr.migrationContext.ThrottleFlagFile != "" { + if mig.migrationContext.ThrottleFlagFile != "" { setIndicator := "" - if base.FileExists(mgtr.migrationContext.ThrottleFlagFile) { + if base.FileExists(mig.migrationContext.ThrottleFlagFile) { setIndicator = "[set]" } fmt.Fprintf(w, "# throttle-flag-file: %+v %+v\n", - mgtr.migrationContext.ThrottleFlagFile, setIndicator, + mig.migrationContext.ThrottleFlagFile, setIndicator, ) } - if mgtr.migrationContext.ThrottleAdditionalFlagFile != "" { + if mig.migrationContext.ThrottleAdditionalFlagFile != "" { setIndicator := "" - if base.FileExists(mgtr.migrationContext.ThrottleAdditionalFlagFile) { + if base.FileExists(mig.migrationContext.ThrottleAdditionalFlagFile) { setIndicator = "[set]" } fmt.Fprintf(w, "# throttle-additional-flag-file: %+v %+v\n", - mgtr.migrationContext.ThrottleAdditionalFlagFile, setIndicator, + mig.migrationContext.ThrottleAdditionalFlagFile, setIndicator, ) } - if throttleQuery := mgtr.migrationContext.GetThrottleQuery(); throttleQuery != "" { + if throttleQuery := mig.migrationContext.GetThrottleQuery(); throttleQuery != "" { fmt.Fprintf(w, "# throttle-query: %+v\n", throttleQuery, ) } - if throttleControlReplicaKeys := mgtr.migrationContext.GetThrottleControlReplicaKeys(); throttleControlReplicaKeys.Len() > 0 { + if throttleControlReplicaKeys := mig.migrationContext.GetThrottleControlReplicaKeys(); throttleControlReplicaKeys.Len() > 0 { fmt.Fprintf(w, "# throttle-control-replicas count: %+v\n", throttleControlReplicaKeys.Len(), ) } - if mgtr.migrationContext.PostponeCutOverFlagFile != "" { + if mig.migrationContext.PostponeCutOverFlagFile != "" { setIndicator := "" - if base.FileExists(mgtr.migrationContext.PostponeCutOverFlagFile) { + if base.FileExists(mig.migrationContext.PostponeCutOverFlagFile) { setIndicator = "[set]" } fmt.Fprintf(w, "# postpone-cut-over-flag-file: %+v %+v\n", - mgtr.migrationContext.PostponeCutOverFlagFile, setIndicator, + mig.migrationContext.PostponeCutOverFlagFile, setIndicator, ) } - if mgtr.migrationContext.PanicFlagFile != "" { + if mig.migrationContext.PanicFlagFile != "" { fmt.Fprintf(w, "# panic-flag-file: %+v\n", - mgtr.migrationContext.PanicFlagFile, + mig.migrationContext.PanicFlagFile, ) } fmt.Fprintf(w, "# Serving on unix socket: %+v\n", - mgtr.migrationContext.ServeSocketFile, + mig.migrationContext.ServeSocketFile, ) - if mgtr.migrationContext.ServeTCPPort != 0 { - fmt.Fprintf(w, "# Serving on TCP port: %+v\n", mgtr.migrationContext.ServeTCPPort) + if mig.migrationContext.ServeTCPPort != 0 { + fmt.Fprintf(w, "# Serving on TCP port: %+v\n", mig.migrationContext.ServeTCPPort) } } // getProgressPercent returns an estimate of migration progess as a percent. -func (mgtr *Migrator) getProgressPercent(rowsEstimate int64) (progressPct float64) { +func (mig *Migrator) getProgressPercent(rowsEstimate int64) (progressPct float64) { progressPct = 100.0 if rowsEstimate > 0 { - progressPct *= float64(mgtr.migrationContext.GetTotalRowsCopied()) / float64(rowsEstimate) + progressPct *= float64(mig.migrationContext.GetTotalRowsCopied()) / float64(rowsEstimate) } return progressPct } // getMigrationETA returns the estimated duration of the migration -func (mgtr *Migrator) getMigrationETA(rowsEstimate int64) (eta string, duration time.Duration) { +func (mig *Migrator) getMigrationETA(rowsEstimate int64) (eta string, duration time.Duration) { duration = time.Duration(base.ETAUnknown) - progressPct := mgtr.getProgressPercent(rowsEstimate) + progressPct := mig.getProgressPercent(rowsEstimate) if progressPct >= 100.0 { duration = 0 } else if progressPct >= 0.1 { - totalRowsCopied := mgtr.migrationContext.GetTotalRowsCopied() - etaRowsPerSecond := atomic.LoadInt64(&mgtr.migrationContext.EtaRowsPerSecond) + totalRowsCopied := mig.migrationContext.GetTotalRowsCopied() + etaRowsPerSecond := atomic.LoadInt64(&mig.migrationContext.EtaRowsPerSecond) var etaSeconds float64 // If there is data available on our current row-copies-per-second rate, use it. // Otherwise we can fallback to the total elapsed time and extrapolate. @@ -1282,7 +1250,7 @@ func (mgtr *Migrator) getMigrationETA(rowsEstimate int64) (eta string, duration remainingRows := float64(rowsEstimate) - float64(totalRowsCopied) etaSeconds = remainingRows / float64(etaRowsPerSecond) } else { - elapsedRowCopySeconds := mgtr.migrationContext.ElapsedRowCopyTime().Seconds() + elapsedRowCopySeconds := mig.migrationContext.ElapsedRowCopyTime().Seconds() totalExpectedSeconds := elapsedRowCopySeconds * float64(rowsEstimate) / float64(totalRowsCopied) etaSeconds = totalExpectedSeconds - elapsedRowCopySeconds } @@ -1306,22 +1274,22 @@ func (mgtr *Migrator) getMigrationETA(rowsEstimate int64) (eta string, duration } // getMigrationStateAndETA returns the state and eta of the migration. -func (mgtr *Migrator) getMigrationStateAndETA(rowsEstimate int64) (state, eta string, etaDuration time.Duration) { - eta, etaDuration = mgtr.getMigrationETA(rowsEstimate) +func (mig *Migrator) getMigrationStateAndETA(rowsEstimate int64) (state, eta string, etaDuration time.Duration) { + eta, etaDuration = mig.getMigrationETA(rowsEstimate) state = "migrating" - if atomic.LoadInt64(&mgtr.migrationContext.CountingRowsFlag) > 0 && !mgtr.migrationContext.ConcurrentCountTableRows { + if atomic.LoadInt64(&mig.migrationContext.CountingRowsFlag) > 0 && !mig.migrationContext.ConcurrentCountTableRows { state = "counting rows" - } else if atomic.LoadInt64(&mgtr.migrationContext.IsPostponingCutOver) > 0 { + } else if atomic.LoadInt64(&mig.migrationContext.IsPostponingCutOver) > 0 { eta = "due" state = "postponing cut-over" - } else if isThrottled, throttleReason, _ := mgtr.migrationContext.IsThrottled(); isThrottled { + } else if isThrottled, throttleReason, _ := mig.migrationContext.IsThrottled(); isThrottled { state = fmt.Sprintf("throttled, %s", throttleReason) } return state, eta, etaDuration } // shouldPrintStatus returns true when the migrator is due to print status info. -func (mgtr *Migrator) shouldPrintStatus(rule PrintStatusRule, elapsedSeconds int64, etaDuration time.Duration) (shouldPrint bool) { +func (mig *Migrator) shouldPrintStatus(rule PrintStatusRule, elapsedSeconds int64, etaDuration time.Duration) (shouldPrint bool) { if rule != HeuristicPrintStatusRule { return true } @@ -1335,7 +1303,7 @@ func (mgtr *Migrator) shouldPrintStatus(rule PrintStatusRule, elapsedSeconds int shouldPrint = (elapsedSeconds%5 == 0) } else if elapsedSeconds <= 180 { shouldPrint = (elapsedSeconds%5 == 0) - } else if mgtr.migrationContext.TimeSincePointOfInterest().Seconds() <= 60 { + } else if mig.migrationContext.TimeSincePointOfInterest().Seconds() <= 60 { shouldPrint = (elapsedSeconds%5 == 0) } else { shouldPrint = (elapsedSeconds%30 == 0) @@ -1345,7 +1313,7 @@ func (mgtr *Migrator) shouldPrintStatus(rule PrintStatusRule, elapsedSeconds int } // shouldPrintMigrationStatus returns true when the migrator is due to print the migration status hint -func (mgtr *Migrator) shouldPrintMigrationStatusHint(rule PrintStatusRule, elapsedSeconds int64) (shouldPrint bool) { +func (mig *Migrator) shouldPrintMigrationStatusHint(rule PrintStatusRule, elapsedSeconds int64) (shouldPrint bool) { if elapsedSeconds%600 == 0 { shouldPrint = true } else if rule == ForcePrintStatusAndHintRule { @@ -1354,59 +1322,82 @@ func (mgtr *Migrator) shouldPrintMigrationStatusHint(rule PrintStatusRule, elaps return shouldPrint } +// printWorkerStats prints cumulative stats from the trxCoordinator workers. +func (mig *Migrator) printWorkerStats(writers ...io.Writer) { + writers = append(writers, os.Stdout) + mw := io.MultiWriter(writers...) + + busyWorkers := mig.trxCoordinator.busyWorkers.Load() + totalWorkers := cap(mig.trxCoordinator.workerQueue) + fmt.Fprintf(mw, "# %d/%d workers are busy\n", busyWorkers, totalWorkers) + + stats := mig.trxCoordinator.GetWorkerStats() + for id, stat := range stats { + fmt.Fprintf(mw, + "Worker %d; Waited: %s; Busy: %s; DML Applied: %d (%.2f/s), Trx Applied: %d (%.2f/s)\n", + id, + base.PrettifyDurationOutput(stat.waitTime), + base.PrettifyDurationOutput(stat.busyTime), + stat.dmlEventsApplied, + stat.dmlRate, + stat.executedJobs, + stat.trxRate) + } +} + // printStatus prints the progress status, and optionally additionally detailed // dump of configuration. // `rule` indicates the type of output expected. // By default the status is written to standard output, but other writers can // be used as well. -func (mgtr *Migrator) printStatus(rule PrintStatusRule, writers ...io.Writer) { +func (mig *Migrator) printStatus(rule PrintStatusRule, writers ...io.Writer) { if rule == NoPrintStatusRule { return } writers = append(writers, os.Stdout) - elapsedTime := mgtr.migrationContext.ElapsedTime() + elapsedTime := mig.migrationContext.ElapsedTime() elapsedSeconds := int64(elapsedTime.Seconds()) - totalRowsCopied := mgtr.migrationContext.GetTotalRowsCopied() - rowsEstimate := atomic.LoadInt64(&mgtr.migrationContext.RowsEstimate) + atomic.LoadInt64(&mgtr.migrationContext.RowsDeltaEstimate) - if atomic.LoadInt64(&mgtr.rowCopyCompleteFlag) == 1 { + totalRowsCopied := mig.migrationContext.GetTotalRowsCopied() + rowsEstimate := atomic.LoadInt64(&mig.migrationContext.RowsEstimate) + atomic.LoadInt64(&mig.migrationContext.RowsDeltaEstimate) + if atomic.LoadInt64(&mig.rowCopyCompleteFlag) == 1 { // Done copying rows. The totalRowsCopied value is the de-facto number of rows, // and there is no further need to keep updating the value. rowsEstimate = totalRowsCopied } // we take the opportunity to update migration context with progressPct - progressPct := mgtr.getProgressPercent(rowsEstimate) - mgtr.migrationContext.SetProgressPct(progressPct) + progressPct := mig.getProgressPercent(rowsEstimate) + mig.migrationContext.SetProgressPct(progressPct) // Before status, let's see if we should print a nice reminder for what exactly we're doing here. - if mgtr.shouldPrintMigrationStatusHint(rule, elapsedSeconds) { - mgtr.printMigrationStatusHint(writers...) + if mig.shouldPrintMigrationStatusHint(rule, elapsedSeconds) { + mig.printMigrationStatusHint(writers...) } // Get state + ETA - state, eta, etaDuration := mgtr.getMigrationStateAndETA(rowsEstimate) - mgtr.migrationContext.SetETADuration(etaDuration) + state, eta, etaDuration := mig.getMigrationStateAndETA(rowsEstimate) + mig.migrationContext.SetETADuration(etaDuration) - if !mgtr.shouldPrintStatus(rule, elapsedSeconds, etaDuration) { + if !mig.shouldPrintStatus(rule, elapsedSeconds, etaDuration) { return } - currentBinlogCoordinates := mgtr.eventsStreamer.GetCurrentBinlogCoordinates() + currentBinlogCoordinates := mig.trxCoordinator.binlogReader.GetCurrentBinlogCoordinates() status := fmt.Sprintf("Copy: %d/%d %.1f%%; Applied: %d; Backlog: %d/%d; Time: %+v(total), %+v(copy); streamer: %+v; Lag: %.2fs, HeartbeatLag: %.2fs, State: %s; ETA: %s", totalRowsCopied, rowsEstimate, progressPct, - atomic.LoadInt64(&mgtr.migrationContext.TotalDMLEventsApplied), - len(mgtr.applyEventsQueue), cap(mgtr.applyEventsQueue), - base.PrettifyDurationOutput(elapsedTime), base.PrettifyDurationOutput(mgtr.migrationContext.ElapsedRowCopyTime()), + atomic.LoadInt64(&mig.migrationContext.TotalDMLEventsApplied), + len(mig.trxCoordinator.events), cap(mig.trxCoordinator.events), + base.PrettifyDurationOutput(elapsedTime), base.PrettifyDurationOutput(mig.migrationContext.ElapsedRowCopyTime()), currentBinlogCoordinates.DisplayString(), - mgtr.migrationContext.GetCurrentLagDuration().Seconds(), - mgtr.migrationContext.TimeSinceLastHeartbeatOnChangelog().Seconds(), + mig.migrationContext.GetCurrentLagDuration().Seconds(), + mig.migrationContext.TimeSinceLastHeartbeatOnChangelog().Seconds(), state, eta, ) - mgtr.applier.WriteChangelog( - fmt.Sprintf("copy iteration %d at %d", mgtr.migrationContext.GetIteration(), time.Now().Unix()), + mig.applier.WriteChangelog( + fmt.Sprintf("copy iteration %d at %d", mig.migrationContext.GetIteration(), time.Now().Unix()), state, ) w := io.MultiWriter(writers...) @@ -1418,175 +1409,152 @@ func (mgtr *Migrator) printStatus(rule PrintStatusRule, writers ...io.Writer) { // fmt.Sprintf. So, the argument of every function called on the DefaultLogger object // migrationContext.Log will eventually pass through fmt.Sprintf, and thus the '%' character // needs to be escaped. - mgtr.migrationContext.Log.Info(strings.Replace(status, "%", "%%", 1)) + mig.migrationContext.Log.Info(strings.Replace(status, "%", "%%", 1)) - hooksStatusIntervalSec := mgtr.migrationContext.HooksStatusIntervalSec + hooksStatusIntervalSec := mig.migrationContext.HooksStatusIntervalSec if hooksStatusIntervalSec > 0 && elapsedSeconds%hooksStatusIntervalSec == 0 { - mgtr.hooksExecutor.OnStatus(status) + mig.hooksExecutor.OnStatus(status) } } // initiateStreaming begins streaming of binary log events and registers listeners for such events -func (mgtr *Migrator) initiateStreaming() error { - mgtr.eventsStreamer = NewEventsStreamer(mgtr.migrationContext) - if err := mgtr.eventsStreamer.InitDBConnections(); err != nil { +func (mig *Migrator) initiateStreaming() error { + initialCoords, err := mig.inspector.readCurrentBinlogCoordinates() + if err != nil { return err } - mgtr.eventsStreamer.AddListener( - false, - mgtr.migrationContext.DatabaseName, - mgtr.migrationContext.GetChangelogTableName(), - func(dmlEntry *binlog.BinlogEntry) error { - return mgtr.onChangelogEvent(dmlEntry) - }, - ) go func() { - mgtr.migrationContext.Log.Debugf("Beginning streaming") - err := mgtr.eventsStreamer.StreamEvents(mgtr.canStopStreaming) + mig.migrationContext.Log.Debugf("Beginning streaming at coordinates: %+v", initialCoords) + ctx := context.TODO() + err := mig.trxCoordinator.StartStreaming(ctx, initialCoords, mig.canStopStreaming) if err != nil { // Use helper to prevent deadlock if listenOnPanicAbort already exited - _ = base.SendWithContext(mgtr.migrationContext.GetContext(), mgtr.migrationContext.PanicAbort, err) + _ = base.SendWithContext(mig.migrationContext.GetContext(), mig.migrationContext.PanicAbort, err) } - mgtr.migrationContext.Log.Debugf("Done streaming") + mig.migrationContext.Log.Debugf("Done streaming") }() go func() { ticker := time.NewTicker(time.Second) defer ticker.Stop() for range ticker.C { - if atomic.LoadInt64(&mgtr.finishedMigrating) > 0 { + if atomic.LoadInt64(&mig.finishedMigrating) > 0 { return } - mgtr.migrationContext.SetRecentBinlogCoordinates(mgtr.eventsStreamer.GetCurrentBinlogCoordinates()) + mig.migrationContext.SetRecentBinlogCoordinates(mig.trxCoordinator.binlogReader.GetCurrentBinlogCoordinates()) } }() return nil } -// addDMLEventsListener begins listening for binlog events on the original table, -// and creates & enqueues a write task per such event. -func (mgtr *Migrator) addDMLEventsListener() error { - err := mgtr.eventsStreamer.AddListener( - false, - mgtr.migrationContext.DatabaseName, - mgtr.migrationContext.OriginalTableName, - func(dmlEntry *binlog.BinlogEntry) error { - // Use helper to prevent deadlock if buffer fills and executeWriteFuncs exits - // This is critical because this callback blocks the event streamer - return base.SendWithContext(mgtr.migrationContext.GetContext(), mgtr.applyEventsQueue, newApplyEventStructByDML(dmlEntry)) - }, - ) - return err -} - // initiateThrottler kicks in the throttling collection and the throttling checks. -func (mgtr *Migrator) initiateThrottler() { - mgtr.throttler = NewThrottler(mgtr.migrationContext, mgtr.applier, mgtr.inspector, mgtr.appVersion) - - go mgtr.throttler.initiateThrottlerCollection(mgtr.firstThrottlingCollected) - mgtr.migrationContext.Log.Infof("Waiting for first throttle metrics to be collected") - <-mgtr.firstThrottlingCollected // replication lag - <-mgtr.firstThrottlingCollected // HTTP status - <-mgtr.firstThrottlingCollected // other, general metrics - mgtr.migrationContext.Log.Infof("First throttle metrics collected") - go mgtr.throttler.initiateThrottlerChecks() +func (mig *Migrator) initiateThrottler() { + mig.throttler = NewThrottler(mig.migrationContext, mig.applier, mig.inspector, mig.appVersion) + + go mig.throttler.initiateThrottlerCollection(mig.firstThrottlingCollected) + mig.migrationContext.Log.Infof("Waiting for first throttle metrics to be collected") + <-mig.firstThrottlingCollected // replication lag + <-mig.firstThrottlingCollected // HTTP status + <-mig.firstThrottlingCollected // other, general metrics + mig.migrationContext.Log.Infof("First throttle metrics collected") + go mig.throttler.initiateThrottlerChecks() } -func (mgtr *Migrator) initiateApplier() error { - mgtr.applier = NewApplier(mgtr.migrationContext) - if err := mgtr.applier.InitDBConnections(); err != nil { +func (mig *Migrator) initiateApplier() error { + mig.applier = NewApplier(mig.migrationContext) + if err := mig.applier.InitDBConnections(mig.migrationContext.NumWorkers); err != nil { return err } - if mgtr.migrationContext.Revert { - if err := mgtr.applier.CreateChangelogTable(); err != nil { - mgtr.migrationContext.Log.Errorf("unable to create changelog table, see further error details. Perhaps a previous migration failed without dropping the table? OR is there a running migration? Bailing out") + if mig.migrationContext.Revert { + if err := mig.applier.CreateChangelogTable(); err != nil { + mig.migrationContext.Log.Errorf("Unable to create changelog table, see further error details. Perhaps a previous migration failed without dropping the table? OR is there a running migration? Bailing out") return err } - } else if !mgtr.migrationContext.Resume { - if err := mgtr.applier.ValidateOrDropExistingTables(); err != nil { + } else if !mig.migrationContext.Resume { + if err := mig.applier.ValidateOrDropExistingTables(); err != nil { return err } - if err := mgtr.applier.CreateChangelogTable(); err != nil { - mgtr.migrationContext.Log.Errorf("unable to create changelog table, see further error details. Perhaps a previous migration failed without dropping the table? OR is there a running migration? Bailing out") + if err := mig.applier.CreateChangelogTable(); err != nil { + mig.migrationContext.Log.Errorf("Unable to create changelog table, see further error details. Perhaps a previous migration failed without dropping the table? OR is there a running migration? Bailing out") return err } - if err := mgtr.applier.CreateGhostTable(); err != nil { - mgtr.migrationContext.Log.Errorf("unable to create ghost table, see further error details. Perhaps a previous migration failed without dropping the table? Bailing out") + if err := mig.applier.CreateGhostTable(); err != nil { + mig.migrationContext.Log.Errorf("Unable to create ghost table, see further error details. Perhaps a previous migration failed without dropping the table? Bailing out") return err } - if err := mgtr.applier.AlterGhost(); err != nil { - mgtr.migrationContext.Log.Errorf("unable to ALTER ghost table, see further error details. Bailing out") + if err := mig.applier.AlterGhost(); err != nil { + mig.migrationContext.Log.Errorf("Unable to ALTER ghost table, see further error details. Bailing out") return err } - if mgtr.migrationContext.OriginalTableAutoIncrement > 0 && !mgtr.parser.IsAutoIncrementDefined() { + if mig.migrationContext.OriginalTableAutoIncrement > 0 && !mig.parser.IsAutoIncrementDefined() { // Original table has AUTO_INCREMENT value and the -alter statement does not indicate any override, // so we should copy AUTO_INCREMENT value onto our ghost table. - if err := mgtr.applier.AlterGhostAutoIncrement(); err != nil { - mgtr.migrationContext.Log.Errorf("unable to ALTER ghost table AUTO_INCREMENT value, see further error details. Bailing out") + if err := mig.applier.AlterGhostAutoIncrement(); err != nil { + mig.migrationContext.Log.Errorf("Unable to ALTER ghost table AUTO_INCREMENT value, see further error details. Bailing out") return err } } - if _, err := mgtr.applier.WriteChangelogState(string(GhostTableMigrated)); err != nil { + if _, err := mig.applier.WriteChangelogState(string(GhostTableMigrated)); err != nil { return err } } // ensure performance_schema.metadata_locks is available. - if err := mgtr.applier.StateMetadataLockInstrument(); err != nil { - mgtr.migrationContext.Log.Warning("unable to enable metadata lock instrument, see further error details") + if err := mig.applier.StateMetadataLockInstrument(); err != nil { + mig.migrationContext.Log.Warning("Unable to enable metadata lock instrument, see further error details.") } - if !mgtr.migrationContext.IsOpenMetadataLockInstruments { - if !mgtr.migrationContext.SkipMetadataLockCheck { - return mgtr.migrationContext.Log.Errorf("bailing out because metadata lock instrument not enabled. Use --skip-metadata-lock-check if you wish to proceed without. See https://github.com/github/gh-ost/pull/1536 for details") + if !mig.migrationContext.IsOpenMetadataLockInstruments { + if !mig.migrationContext.SkipMetadataLockCheck { + return mig.migrationContext.Log.Errorf("Bailing out because metadata lock instrument not enabled. Use --skip-metadata-lock-check if you wish to proceed without. See https://github.com/github/gh-ost/pull/1536 for details.") } - mgtr.migrationContext.Log.Warning("proceeding without metadata lock check. There is a small chance of data loss if another session accesses the ghost table during cut-over. See https://github.com/github/gh-ost/pull/1536 for details") + mig.migrationContext.Log.Warning("Proceeding without metadata lock check. There is a small chance of data loss if another session accesses the ghost table during cut-over. See https://github.com/github/gh-ost/pull/1536 for details.") } - go mgtr.applier.InitiateHeartbeat() + go mig.applier.InitiateHeartbeat() return nil } // iterateChunks iterates the existing table rows, and generates a copy task of // a chunk of rows onto the ghost table. -func (mgtr *Migrator) iterateChunks() error { +func (mig *Migrator) iterateChunks() error { terminateRowIteration := func(err error) error { - _ = base.SendWithContext(mgtr.migrationContext.GetContext(), mgtr.rowCopyComplete, err) - return mgtr.migrationContext.Log.Errore(err) + _ = base.SendWithContext(mig.migrationContext.GetContext(), mig.rowCopyComplete, err) + return mig.migrationContext.Log.Errore(err) } - if mgtr.migrationContext.Noop { - mgtr.migrationContext.Log.Debugf("Noop operation; not really copying data") + if mig.migrationContext.Noop { + mig.migrationContext.Log.Debugf("Noop operation; not really copying data") return terminateRowIteration(nil) } - if mgtr.migrationContext.MigrationRangeMinValues == nil { - mgtr.migrationContext.Log.Debugf("No rows found in table. Rowcopy will be implicitly empty") + if mig.migrationContext.MigrationRangeMinValues == nil { + mig.migrationContext.Log.Debugf("No rows found in table. Rowcopy will be implicitly empty") return terminateRowIteration(nil) } var hasNoFurtherRangeFlag int64 // Iterate per chunk: for { - if err := mgtr.checkAbort(); err != nil { + if err := mig.checkAbort(); err != nil { return terminateRowIteration(err) } - if atomic.LoadInt64(&mgtr.rowCopyCompleteFlag) == 1 || atomic.LoadInt64(&hasNoFurtherRangeFlag) == 1 { + if atomic.LoadInt64(&mig.rowCopyCompleteFlag) == 1 || atomic.LoadInt64(&hasNoFurtherRangeFlag) == 1 { // Done // There's another such check down the line return nil } copyRowsFunc := func() error { - mgtr.migrationContext.SetNextIterationRangeMinValues() + mig.migrationContext.SetNextIterationRangeMinValues() // Copy task: applyCopyRowsFunc := func() error { - if atomic.LoadInt64(&mgtr.rowCopyCompleteFlag) == 1 || atomic.LoadInt64(&hasNoFurtherRangeFlag) == 1 { + if atomic.LoadInt64(&mig.rowCopyCompleteFlag) == 1 || atomic.LoadInt64(&hasNoFurtherRangeFlag) == 1 { // Done. // There's another such check down the line return nil } // When hasFurtherRange is false, original table might be write locked and CalculateNextIterationRangeEndValues would hangs forever - hasFurtherRange, err := mgtr.applier.CalculateNextIterationRangeEndValues() + hasFurtherRange, err := mig.applier.CalculateNextIterationRangeEndValues() if err != nil { return err // wrapping call will retry } @@ -1594,7 +1562,7 @@ func (mgtr *Migrator) iterateChunks() error { atomic.StoreInt64(&hasNoFurtherRangeFlag, 1) return terminateRowIteration(nil) } - if atomic.LoadInt64(&mgtr.rowCopyCompleteFlag) == 1 { + if atomic.LoadInt64(&mig.rowCopyCompleteFlag) == 1 { // No need for more writes. // This is the de-facto place where we avoid writing in the event of completed cut-over. // There could _still_ be a race condition, but that's as close as we can get. @@ -1605,44 +1573,44 @@ func (mgtr *Migrator) iterateChunks() error { // _ghost_ table, which no longer exists. So, bothering error messages and all, but no damage. return nil } - _, rowsAffected, _, err := mgtr.applier.ApplyIterationInsertQuery() + _, rowsAffected, _, err := mig.applier.ApplyIterationInsertQuery() if err != nil { return err // wrapping call will retry } - if mgtr.migrationContext.PanicOnWarnings { - if len(mgtr.migrationContext.MigrationLastInsertSQLWarnings) > 0 { - for _, warning := range mgtr.migrationContext.MigrationLastInsertSQLWarnings { - mgtr.migrationContext.Log.Infof("ApplyIterationInsertQuery has SQL warnings! %s", warning) + if mig.migrationContext.PanicOnWarnings { + if len(mig.migrationContext.MigrationLastInsertSQLWarnings) > 0 { + for _, warning := range mig.migrationContext.MigrationLastInsertSQLWarnings { + mig.migrationContext.Log.Infof("ApplyIterationInsertQuery has SQL warnings! %s", warning) } - joinedWarnings := strings.Join(mgtr.migrationContext.MigrationLastInsertSQLWarnings, "; ") + joinedWarnings := strings.Join(mig.migrationContext.MigrationLastInsertSQLWarnings, "; ") return terminateRowIteration(fmt.Errorf("ApplyIterationInsertQuery failed because of SQL warnings: [%s]", joinedWarnings)) } } - atomic.AddInt64(&mgtr.migrationContext.TotalRowsCopied, rowsAffected) - atomic.AddInt64(&mgtr.migrationContext.Iteration, 1) + atomic.AddInt64(&mig.migrationContext.TotalRowsCopied, rowsAffected) + atomic.AddInt64(&mig.migrationContext.Iteration, 1) return nil } - if err := mgtr.retryBatchCopyWithHooks(applyCopyRowsFunc); err != nil { + if err := mig.retryBatchCopyWithHooks(applyCopyRowsFunc); err != nil { return terminateRowIteration(err) } // record last successfully copied range - mgtr.applier.LastIterationRangeMutex.Lock() - if mgtr.migrationContext.MigrationIterationRangeMinValues != nil && mgtr.migrationContext.MigrationIterationRangeMaxValues != nil { - mgtr.applier.LastIterationRangeMinValues = mgtr.migrationContext.MigrationIterationRangeMinValues.Clone() - mgtr.applier.LastIterationRangeMaxValues = mgtr.migrationContext.MigrationIterationRangeMaxValues.Clone() + mig.applier.LastIterationRangeMutex.Lock() + if mig.migrationContext.MigrationIterationRangeMinValues != nil && mig.migrationContext.MigrationIterationRangeMaxValues != nil { + mig.applier.LastIterationRangeMinValues = mig.migrationContext.MigrationIterationRangeMinValues.Clone() + mig.applier.LastIterationRangeMaxValues = mig.migrationContext.MigrationIterationRangeMaxValues.Clone() } - mgtr.applier.LastIterationRangeMutex.Unlock() + mig.applier.LastIterationRangeMutex.Unlock() return nil } // Enqueue copy operation; to be executed by executeWriteFuncs() // Use helper to prevent deadlock if executeWriteFuncs exits - if err := base.SendWithContext(mgtr.migrationContext.GetContext(), mgtr.copyRowsQueue, copyRowsFunc); err != nil { + if err := base.SendWithContext(mig.migrationContext.GetContext(), mig.copyRowsQueue, copyRowsFunc); err != nil { // Context cancelled, check for abort and exit - if abortErr := mgtr.checkAbort(); abortErr != nil { + if abortErr := mig.checkAbort(); abortErr != nil { return terminateRowIteration(abortErr) } return terminateRowIteration(err) @@ -1650,152 +1618,96 @@ func (mgtr *Migrator) iterateChunks() error { } } -func (mgtr *Migrator) onApplyEventStruct(eventStruct *applyEventStruct) error { - handleNonDMLEventStruct := func(eventStruct *applyEventStruct) error { - if eventStruct.writeFunc != nil { - if err := mgtr.retryOperation(*eventStruct.writeFunc); err != nil { - return mgtr.migrationContext.Log.Errore(err) - } - } - return nil - } - if eventStruct.dmlEvent == nil { - return handleNonDMLEventStruct(eventStruct) - } - if eventStruct.dmlEvent != nil { - dmlEvents := [](*binlog.BinlogDMLEvent){} - dmlEvents = append(dmlEvents, eventStruct.dmlEvent) - var nonDmlStructToApply *applyEventStruct - - availableEvents := len(mgtr.applyEventsQueue) - batchSize := int(atomic.LoadInt64(&mgtr.migrationContext.DMLBatchSize)) - if availableEvents > batchSize-1 { - // The "- 1" is because we already consumed one event: the original event that led to this function getting called. - // So, if DMLBatchSize==1 we wish to not process any further events - availableEvents = batchSize - 1 - } - for i := 0; i < availableEvents; i++ { - additionalStruct := <-mgtr.applyEventsQueue - if additionalStruct.dmlEvent == nil { - // Not a DML. We don't group this, and we don't batch any further - nonDmlStructToApply = additionalStruct - break - } - dmlEvents = append(dmlEvents, additionalStruct.dmlEvent) - } - // Create a task to apply the DML event; this will be execute by executeWriteFuncs() - var applyEventFunc tableWriteFunc = func() error { - return mgtr.applier.ApplyDMLEventQueries(dmlEvents) - } - if err := mgtr.retryOperation(applyEventFunc); err != nil { - return mgtr.migrationContext.Log.Errore(err) - } - // update applier coordinates - mgtr.applier.CurrentCoordinatesMutex.Lock() - mgtr.applier.CurrentCoordinates = eventStruct.coords - mgtr.applier.CurrentCoordinatesMutex.Unlock() - - if nonDmlStructToApply != nil { - // We pulled DML events from the queue, and then we hit a non-DML event. Wait! - // We need to handle it! - if err := handleNonDMLEventStruct(nonDmlStructToApply); err != nil { - return mgtr.migrationContext.Log.Errore(err) - } - } - } - return nil -} - // Checkpoint attempts to write a checkpoint of the Migrator's current state. // It gets the binlog coordinates of the last received trx and waits until the // applier reaches that trx. At that point it's safe to resume from these coordinates. -func (mgtr *Migrator) Checkpoint(ctx context.Context) (*Checkpoint, error) { - coords := mgtr.eventsStreamer.GetCurrentBinlogCoordinates() - mgtr.applier.LastIterationRangeMutex.Lock() - if mgtr.applier.LastIterationRangeMaxValues == nil || mgtr.applier.LastIterationRangeMinValues == nil { - mgtr.applier.LastIterationRangeMutex.Unlock() +func (mig *Migrator) Checkpoint(ctx context.Context) (*Checkpoint, error) { + coords := mig.trxCoordinator.binlogReader.GetCurrentBinlogCoordinates() + mig.applier.LastIterationRangeMutex.Lock() + if mig.applier.LastIterationRangeMaxValues == nil || mig.applier.LastIterationRangeMinValues == nil { + mig.applier.LastIterationRangeMutex.Unlock() return nil, errors.New("iteration range is empty, not checkpointing") } chk := &Checkpoint{ - Iteration: mgtr.migrationContext.GetIteration(), - IterationRangeMin: mgtr.applier.LastIterationRangeMinValues.Clone(), - IterationRangeMax: mgtr.applier.LastIterationRangeMaxValues.Clone(), + Iteration: mig.migrationContext.GetIteration(), + IterationRangeMin: mig.applier.LastIterationRangeMinValues.Clone(), + IterationRangeMax: mig.applier.LastIterationRangeMaxValues.Clone(), LastTrxCoords: coords, - RowsCopied: atomic.LoadInt64(&mgtr.migrationContext.TotalRowsCopied), - DMLApplied: atomic.LoadInt64(&mgtr.migrationContext.TotalDMLEventsApplied), + RowsCopied: atomic.LoadInt64(&mig.migrationContext.TotalRowsCopied), + DMLApplied: atomic.LoadInt64(&mig.migrationContext.TotalDMLEventsApplied), } - mgtr.applier.LastIterationRangeMutex.Unlock() + mig.applier.LastIterationRangeMutex.Unlock() for { if err := ctx.Err(); err != nil { return nil, err } - mgtr.applier.CurrentCoordinatesMutex.Lock() - if coords.SmallerThanOrEquals(mgtr.applier.CurrentCoordinates) { - id, err := mgtr.applier.WriteCheckpoint(chk) + mig.applier.CurrentCoordinatesMutex.Lock() + if coords.SmallerThanOrEquals(mig.applier.CurrentCoordinates) { + id, err := mig.applier.WriteCheckpoint(chk) chk.Id = id - mgtr.applier.CurrentCoordinatesMutex.Unlock() + mig.applier.CurrentCoordinatesMutex.Unlock() return chk, err } - mgtr.applier.CurrentCoordinatesMutex.Unlock() + mig.applier.CurrentCoordinatesMutex.Unlock() time.Sleep(500 * time.Millisecond) } } // CheckpointAfterCutOver writes a final checkpoint after the cutover completes successfully. -func (mgtr *Migrator) CheckpointAfterCutOver() (*Checkpoint, error) { - if mgtr.lastLockProcessed == nil || mgtr.lastLockProcessed.coords.IsEmpty() { - return nil, mgtr.migrationContext.Log.Errorf("lastLockProcessed coords are empty") +func (mig *Migrator) CheckpointAfterCutOver() (*Checkpoint, error) { + if mig.lastLockProcessed == nil || mig.lastLockProcessed.coords.IsEmpty() { + return nil, mig.migrationContext.Log.Errorf("lastLockProcessed coords are empty") } chk := &Checkpoint{ IsCutover: true, - LastTrxCoords: mgtr.lastLockProcessed.coords, - IterationRangeMin: sql.NewColumnValues(mgtr.migrationContext.UniqueKey.Len()), - IterationRangeMax: sql.NewColumnValues(mgtr.migrationContext.UniqueKey.Len()), - Iteration: mgtr.migrationContext.GetIteration(), - RowsCopied: atomic.LoadInt64(&mgtr.migrationContext.TotalRowsCopied), - DMLApplied: atomic.LoadInt64(&mgtr.migrationContext.TotalDMLEventsApplied), + LastTrxCoords: mig.lastLockProcessed.coords, + IterationRangeMin: sql.NewColumnValues(mig.migrationContext.UniqueKey.Len()), + IterationRangeMax: sql.NewColumnValues(mig.migrationContext.UniqueKey.Len()), + Iteration: mig.migrationContext.GetIteration(), + RowsCopied: atomic.LoadInt64(&mig.migrationContext.TotalRowsCopied), + DMLApplied: atomic.LoadInt64(&mig.migrationContext.TotalDMLEventsApplied), } - mgtr.applier.LastIterationRangeMutex.Lock() - if mgtr.applier.LastIterationRangeMinValues != nil { - chk.IterationRangeMin = mgtr.applier.LastIterationRangeMinValues.Clone() + mig.applier.LastIterationRangeMutex.Lock() + if mig.applier.LastIterationRangeMinValues != nil { + chk.IterationRangeMin = mig.applier.LastIterationRangeMinValues.Clone() } - if mgtr.applier.LastIterationRangeMaxValues != nil { - chk.IterationRangeMax = mgtr.applier.LastIterationRangeMaxValues.Clone() + if mig.applier.LastIterationRangeMaxValues != nil { + chk.IterationRangeMax = mig.applier.LastIterationRangeMaxValues.Clone() } - mgtr.applier.LastIterationRangeMutex.Unlock() + mig.applier.LastIterationRangeMutex.Unlock() - id, err := mgtr.applier.WriteCheckpoint(chk) + id, err := mig.applier.WriteCheckpoint(chk) chk.Id = id return chk, err } -func (mgtr *Migrator) checkpointLoop() { - if mgtr.migrationContext.Noop { - mgtr.migrationContext.Log.Debugf("Noop operation; not really checkpointing") +func (mig *Migrator) checkpointLoop() { + if mig.migrationContext.Noop { + mig.migrationContext.Log.Debugf("Noop operation; not really checkpointing") return } - checkpointInterval := time.Duration(mgtr.migrationContext.CheckpointIntervalSeconds) * time.Second + checkpointInterval := time.Duration(mig.migrationContext.CheckpointIntervalSeconds) * time.Second ticker := time.NewTicker(checkpointInterval) for t := range ticker.C { - if atomic.LoadInt64(&mgtr.finishedMigrating) > 0 || atomic.LoadInt64(&mgtr.migrationContext.CutOverCompleteFlag) > 0 { + if atomic.LoadInt64(&mig.finishedMigrating) > 0 || atomic.LoadInt64(&mig.migrationContext.CutOverCompleteFlag) > 0 { return } - if atomic.LoadInt64(&mgtr.migrationContext.InCutOverCriticalSectionFlag) > 0 { + if atomic.LoadInt64(&mig.migrationContext.InCutOverCriticalSectionFlag) > 0 { continue } - mgtr.migrationContext.Log.Infof("starting checkpoint at %+v", t) + mig.migrationContext.Log.Infof("starting checkpoint at %+v", t) ctx, cancel := context.WithTimeout(context.Background(), checkpointTimeout) - chk, err := mgtr.Checkpoint(ctx) + chk, err := mig.Checkpoint(ctx) if err != nil { if errors.Is(err, context.DeadlineExceeded) { - mgtr.migrationContext.Log.Errorf("checkpoint attempt timed out after %+v", checkpointTimeout) + mig.migrationContext.Log.Errorf("checkpoint attempt timed out after %+v", checkpointTimeout) } else { - mgtr.migrationContext.Log.Errorf("error attempting checkpoint: %+v", err) + mig.migrationContext.Log.Errorf("error attempting checkpoint: %+v", err) } } else { - mgtr.migrationContext.Log.Infof("checkpoint success at coords=%+v range_min=%+v range_max=%+v iteration=%d", + mig.migrationContext.Log.Infof("checkpoint success at coords=%+v range_min=%+v range_max=%+v iteration=%d", chk.LastTrxCoords.DisplayString(), chk.IterationRangeMin.String(), chk.IterationRangeMax.String(), chk.Iteration) } cancel() @@ -1805,126 +1717,94 @@ func (mgtr *Migrator) checkpointLoop() { // executeWriteFuncs writes data via applier: both the rowcopy and the events backlog. // This is where the ghost table gets the data. The function fills the data single-threaded. // Both event backlog and rowcopy events are polled; the backlog events have precedence. -func (mgtr *Migrator) executeWriteFuncs() error { - if mgtr.migrationContext.Noop { - mgtr.migrationContext.Log.Debugf("Noop operation; not really executing write funcs") +func (mig *Migrator) executeWriteFuncs() error { + if mig.migrationContext.Noop { + mig.migrationContext.Log.Debugf("Noop operation; not really executing write funcs") return nil } + for { - if err := mgtr.checkAbort(); err != nil { + if err := mig.checkAbort(); err != nil { return err } - if atomic.LoadInt64(&mgtr.finishedMigrating) > 0 { + if atomic.LoadInt64(&mig.finishedMigrating) > 0 { return nil } - mgtr.throttler.throttle(nil) + mig.throttler.throttle(nil) - // We give higher priority to event processing, then secondary priority to - // rowcopy - select { - case eventStruct := <-mgtr.applyEventsQueue: - { - if err := mgtr.onApplyEventStruct(eventStruct); err != nil { - return err - } - } - default: - { - select { - case copyRowsFunc := <-mgtr.copyRowsQueue: - { - copyRowsStartTime := time.Now() - // Retries are handled within the copyRowsFunc - if err := copyRowsFunc(); err != nil { - return mgtr.migrationContext.Log.Errore(err) - } - if niceRatio := mgtr.migrationContext.GetNiceRatio(); niceRatio > 0 { - copyRowsDuration := time.Since(copyRowsStartTime) - sleepTimeNanosecondFloat64 := niceRatio * float64(copyRowsDuration.Nanoseconds()) - sleepTime := time.Duration(int64(sleepTimeNanosecondFloat64)) * time.Nanosecond - time.Sleep(sleepTime) - } - } - default: - { - // Hmmmmm... nothing in the queue; no events, but also no row copy. - // This is possible upon load. Let's just sleep it over. - mgtr.migrationContext.Log.Debugf("Getting nothing in the write queue. Sleeping...") - time.Sleep(time.Second) - } - } - } - } - } -} - -func (mgtr *Migrator) executeDMLWriteFuncs() error { - if mgtr.migrationContext.Noop { - mgtr.migrationContext.Log.Debugf("Noop operation; not really executing DML write funcs") - return nil - } - for { - if atomic.LoadInt64(&mgtr.finishedMigrating) > 0 { - return nil + // We give higher priority to event processing. + // ProcessEventsUntilDrained will process all events in the queue, and then return once no more events are available. + if err := mig.trxCoordinator.ProcessEventsUntilDrained(); err != nil { + return mig.migrationContext.Log.Errore(err) } - mgtr.throttler.throttle(nil) + mig.throttler.throttle(nil) + // And secondary priority to rowcopy select { - case eventStruct := <-mgtr.applyEventsQueue: + case copyRowsFunc := <-mig.copyRowsQueue: { - if err := mgtr.onApplyEventStruct(eventStruct); err != nil { - return err + copyRowsStartTime := time.Now() + // Retries are handled within the copyRowsFunc + if err := copyRowsFunc(); err != nil { + return mig.migrationContext.Log.Errore(err) + } + if niceRatio := mig.migrationContext.GetNiceRatio(); niceRatio > 0 { + copyRowsDuration := time.Since(copyRowsStartTime) + sleepTimeNanosecondFloat64 := niceRatio * float64(copyRowsDuration.Nanoseconds()) + sleepTime := time.Duration(int64(sleepTimeNanosecondFloat64)) * time.Nanosecond + time.Sleep(sleepTime) } } - case <-time.After(time.Second): - continue + default: + { + // Hmmmmm... nothing in the queue; no events, but also no row copy. + // This is possible upon load. Let's just sleep it over. + mig.migrationContext.Log.Debugf("Getting nothing in the write queue. Sleeping...") + time.Sleep(time.Second) + } } } } // finalCleanup takes actions at very end of migration, dropping tables etc. -func (mgtr *Migrator) finalCleanup() error { - atomic.StoreInt64(&mgtr.migrationContext.CleanupImminentFlag, 1) +func (mig *Migrator) finalCleanup() error { + atomic.StoreInt64(&mig.migrationContext.CleanupImminentFlag, 1) - mgtr.migrationContext.Log.Infof("Writing changelog state: %+v", Migrated) - if _, err := mgtr.applier.WriteChangelogState(string(Migrated)); err != nil { + mig.migrationContext.Log.Infof("Writing changelog state: %+v", Migrated) + if _, err := mig.applier.WriteChangelogState(string(Migrated)); err != nil { return err } - if mgtr.migrationContext.Noop { - if createTableStatement, err := mgtr.inspector.showCreateTable(mgtr.migrationContext.GetGhostTableName()); err == nil { - mgtr.migrationContext.Log.Infof("New table structure follows") + if mig.migrationContext.Noop { + if createTableStatement, err := mig.inspector.showCreateTable(mig.migrationContext.GetGhostTableName()); err == nil { + mig.migrationContext.Log.Infof("New table structure follows") fmt.Println(createTableStatement) } else { - mgtr.migrationContext.Log.Errore(err) + mig.migrationContext.Log.Errore(err) } } - if err := mgtr.eventsStreamer.Close(); err != nil { - mgtr.migrationContext.Log.Errore(err) - } - - if err := mgtr.retryOperation(mgtr.applier.DropChangelogTable); err != nil { + if err := mig.retryOperation(mig.applier.DropChangelogTable); err != nil { return err } - if mgtr.migrationContext.OkToDropTable && !mgtr.migrationContext.TestOnReplica { - if err := mgtr.retryOperation(mgtr.applier.DropOldTable); err != nil { + if mig.migrationContext.OkToDropTable && !mig.migrationContext.TestOnReplica { + if err := mig.retryOperation(mig.applier.DropOldTable); err != nil { return err } - if err := mgtr.retryOperation(mgtr.applier.DropCheckpointTable); err != nil { + if err := mig.retryOperation(mig.applier.DropCheckpointTable); err != nil { return err } - } else if !mgtr.migrationContext.Noop { - mgtr.migrationContext.Log.Infof("Am not dropping old table because I want this operation to be as live as possible. If you insist I should do it, please add `--ok-to-drop-table` next time. But I prefer you do not. To drop the old table, issue:") - mgtr.migrationContext.Log.Infof("-- drop table %s.%s", sql.EscapeName(mgtr.migrationContext.DatabaseName), sql.EscapeName(mgtr.migrationContext.GetOldTableName())) - if mgtr.migrationContext.Checkpoint { - mgtr.migrationContext.Log.Infof("Am not dropping checkpoint table without `--ok-to-drop-table`. To drop the checkpoint table, issue:") - mgtr.migrationContext.Log.Infof("-- drop table %s.%s", sql.EscapeName(mgtr.migrationContext.DatabaseName), sql.EscapeName(mgtr.migrationContext.GetCheckpointTableName())) + } else if !mig.migrationContext.Noop { + mig.migrationContext.Log.Infof("Am not dropping old table because I want this operation to be as live as possible. If you insist I should do it, please add `--ok-to-drop-table` next time. But I prefer you do not. To drop the old table, issue:") + mig.migrationContext.Log.Infof("-- drop table %s.%s", sql.EscapeName(mig.migrationContext.DatabaseName), sql.EscapeName(mig.migrationContext.GetOldTableName())) + if mig.migrationContext.Checkpoint { + mig.migrationContext.Log.Infof("Am not dropping checkpoint table without `--ok-to-drop-table`. To drop the checkpoint table, issue:") + mig.migrationContext.Log.Infof("-- drop table %s.%s", sql.EscapeName(mig.migrationContext.DatabaseName), sql.EscapeName(mig.migrationContext.GetCheckpointTableName())) } } - if mgtr.migrationContext.Noop { - if err := mgtr.retryOperation(mgtr.applier.DropGhostTable); err != nil { + if mig.migrationContext.Noop { + if err := mig.retryOperation(mig.applier.DropGhostTable); err != nil { return err } } @@ -1932,26 +1812,26 @@ func (mgtr *Migrator) finalCleanup() error { return nil } -func (mgtr *Migrator) teardown() { - atomic.StoreInt64(&mgtr.finishedMigrating, 1) +func (mig *Migrator) teardown() { + atomic.StoreInt64(&mig.finishedMigrating, 1) - if mgtr.inspector != nil { - mgtr.migrationContext.Log.Infof("Tearing down inspector") - mgtr.inspector.Teardown() + if mig.trxCoordinator != nil { + mig.migrationContext.Log.Infof("Tearing down coordinator") + mig.trxCoordinator.Teardown() } - if mgtr.applier != nil { - mgtr.migrationContext.Log.Infof("Tearing down applier") - mgtr.applier.Teardown() + if mig.throttler != nil { + mig.migrationContext.Log.Infof("Tearing down throttler") + mig.throttler.Teardown() } - if mgtr.eventsStreamer != nil { - mgtr.migrationContext.Log.Infof("Tearing down streamer") - mgtr.eventsStreamer.Teardown() + if mig.inspector != nil { + mig.migrationContext.Log.Infof("Tearing down inspector") + mig.inspector.Teardown() } - if mgtr.throttler != nil { - mgtr.migrationContext.Log.Infof("Tearing down throttler") - mgtr.throttler.Teardown() + if mig.applier != nil { + mig.migrationContext.Log.Infof("Tearing down applier") + mig.applier.Teardown() } } diff --git a/go/logic/migrator_test.go b/go/logic/migrator_test.go deleted file mode 100644 index 8fc48e326..000000000 --- a/go/logic/migrator_test.go +++ /dev/null @@ -1,1484 +0,0 @@ -/* - Copyright 2022 GitHub Inc. - See https://github.com/github/gh-ost/blob/master/LICENSE -*/ - -package logic - -import ( - "bytes" - "context" - gosql "database/sql" - "errors" - "fmt" - "io" - "os" - "path/filepath" - "strings" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" - testmysql "github.com/testcontainers/testcontainers-go/modules/mysql" - - "runtime" - - "github.com/github/gh-ost/go/base" - "github.com/github/gh-ost/go/binlog" - "github.com/github/gh-ost/go/mysql" - "github.com/github/gh-ost/go/sql" - "github.com/testcontainers/testcontainers-go" -) - -func TestMigratorOnChangelogEvent(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.2.3") - migrator.applier = NewApplier(migrationContext) - - t.Run("heartbeat", func(t *testing.T) { - columnValues := sql.ToColumnValues([]interface{}{ - 123, - time.Now().Unix(), - "heartbeat", - "2022-08-16T00:45:10.52Z", - }) - require.Nil(t, migrator.onChangelogEvent(&binlog.BinlogEntry{ - DmlEvent: &binlog.BinlogDMLEvent{ - DatabaseName: "test", - DML: binlog.InsertDML, - NewColumnValues: columnValues}, - Coordinates: mysql.NewFileBinlogCoordinates("mysql-bin.000004", int64(4)), - })) - }) - - t.Run("state-AllEventsUpToLockProcessed", func(t *testing.T) { - var wg sync.WaitGroup - wg.Add(1) - go func(wg *sync.WaitGroup) { - defer wg.Done() - es := <-migrator.applyEventsQueue - require.NotNil(t, es) - require.NotNil(t, es.writeFunc) - }(&wg) - - columnValues := sql.ToColumnValues([]interface{}{ - 123, - time.Now().Unix(), - "state", - AllEventsUpToLockProcessed, - }) - require.Nil(t, migrator.onChangelogEvent(&binlog.BinlogEntry{ - DmlEvent: &binlog.BinlogDMLEvent{ - DatabaseName: "test", - DML: binlog.InsertDML, - NewColumnValues: columnValues}, - Coordinates: mysql.NewFileBinlogCoordinates("mysql-bin.000004", int64(4)), - })) - wg.Wait() - }) - - t.Run("state-AllEventsUpToLockProcessed-overwrite-oldest", func(t *testing.T) { - // Simulate the scenario where the receiver (waitForEventsUpToLock) timed out - // and a stale message sits in the channel buffer. The next sentinel must - // overwrite the stale one so the current attempt's message is delivered. - m := NewMigrator(base.NewMigrationContext(), "test") - m.applier = NewApplier(m.migrationContext) - - sendChangelogEvent := func(challenge string) { - columnValues := sql.ToColumnValues([]interface{}{ - 123, - time.Now().Unix(), - "state", - challenge, - }) - require.NoError(t, m.onChangelogEvent(&binlog.BinlogEntry{ - DmlEvent: &binlog.BinlogDMLEvent{ - DatabaseName: "test", - DML: binlog.InsertDML, - NewColumnValues: columnValues}, - Coordinates: mysql.NewFileBinlogCoordinates("mysql-bin.000004", int64(4)), - })) - } - - executeWriteFunc := func() { - es := <-m.applyEventsQueue - require.NotNil(t, es.writeFunc) - require.NoError(t, (*es.writeFunc)()) - } - - // Attempt 1: send sentinel and execute the writeFunc to deliver it - sendChangelogEvent("AllEventsUpToLockProcessed:attempt1") - executeWriteFunc() - - // The message sits unconsumed in allEventsUpToLockProcessed (simulating a timeout) - require.Len(t, m.allEventsUpToLockProcessed, 1) - - // Attempt 2: send a new sentinel — must overwrite the stale one - sendChangelogEvent("AllEventsUpToLockProcessed:attempt2") - executeWriteFunc() - - // The channel should contain exactly the latest message - require.Len(t, m.allEventsUpToLockProcessed, 1) - msg := <-m.allEventsUpToLockProcessed - require.Equal(t, "AllEventsUpToLockProcessed:attempt2", msg.state) - }) - - t.Run("NewMigrator-with-extreme-MaxRetries", func(t *testing.T) { - // Regression test: an extremely large --default-retries value must not - // cause an OOM when creating the migrator. Before the fix, - // allEventsUpToLockProcessed was buffered to MaxRetries(), which tried - // to allocate a ~10 trillion element channel. - ctx := base.NewMigrationContext() - ctx.SetDefaultNumRetries(9999999999999) - require.Equal(t, int64(9999999999999), ctx.MaxRetries()) - - m := NewMigrator(ctx, "test") - require.NotNil(t, m) - require.Equal(t, 1, cap(m.allEventsUpToLockProcessed)) - }) - - t.Run("state-GhostTableMigrated", func(t *testing.T) { - go func() { - require.True(t, <-migrator.ghostTableMigrated) - }() - - columnValues := sql.ToColumnValues([]interface{}{ - 123, - time.Now().Unix(), - "state", - GhostTableMigrated, - }) - require.Nil(t, migrator.onChangelogEvent(&binlog.BinlogEntry{ - DmlEvent: &binlog.BinlogDMLEvent{ - DatabaseName: "test", - DML: binlog.InsertDML, - NewColumnValues: columnValues}, - Coordinates: mysql.NewFileBinlogCoordinates("mysql-bin.000004", int64(4)), - })) - }) - - t.Run("state-Migrated", func(t *testing.T) { - columnValues := sql.ToColumnValues([]interface{}{ - 123, - time.Now().Unix(), - "state", - Migrated, - }) - require.Nil(t, migrator.onChangelogEvent(&binlog.BinlogEntry{ - DmlEvent: &binlog.BinlogDMLEvent{ - DatabaseName: "test", - DML: binlog.InsertDML, - NewColumnValues: columnValues}, - Coordinates: mysql.NewFileBinlogCoordinates("mysql-bin.000004", int64(4)), - })) - }) - - t.Run("state-ReadMigrationRangeValues", func(t *testing.T) { - columnValues := sql.ToColumnValues([]interface{}{ - 123, - time.Now().Unix(), - "state", - ReadMigrationRangeValues, - }) - require.Nil(t, migrator.onChangelogEvent(&binlog.BinlogEntry{ - DmlEvent: &binlog.BinlogDMLEvent{ - DatabaseName: "test", - DML: binlog.InsertDML, - NewColumnValues: columnValues}, - Coordinates: mysql.NewFileBinlogCoordinates("mysql-bin.000004", int64(4)), - })) - }) -} - -func TestMigratorValidateStatement(t *testing.T) { - t.Run("add-column", func(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.2.3") - require.Nil(t, migrator.parser.ParseAlterStatement(`ALTER TABLE test ADD test_new VARCHAR(64) NOT NULL`)) - - require.Nil(t, migrator.validateAlterStatement()) - require.Len(t, migrator.migrationContext.DroppedColumnsMap, 0) - }) - - t.Run("drop-column", func(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.2.3") - require.Nil(t, migrator.parser.ParseAlterStatement(`ALTER TABLE test DROP abc`)) - - require.Nil(t, migrator.validateAlterStatement()) - require.Len(t, migrator.migrationContext.DroppedColumnsMap, 1) - _, exists := migrator.migrationContext.DroppedColumnsMap["abc"] - require.True(t, exists) - }) - - t.Run("rename-column", func(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.2.3") - require.Nil(t, migrator.parser.ParseAlterStatement(`ALTER TABLE test CHANGE test123 test1234 bigint unsigned`)) - - err := migrator.validateAlterStatement() - require.Error(t, err) - require.True(t, strings.HasPrefix(err.Error(), "gh-ost believes the ALTER statement renames columns")) - require.Len(t, migrator.migrationContext.DroppedColumnsMap, 0) - }) - - t.Run("rename-column-approved", func(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.2.3") - migrator.migrationContext.ApproveRenamedColumns = true - require.Nil(t, migrator.parser.ParseAlterStatement(`ALTER TABLE test CHANGE test123 test1234 bigint unsigned`)) - - require.Nil(t, migrator.validateAlterStatement()) - require.Len(t, migrator.migrationContext.DroppedColumnsMap, 0) - }) - - t.Run("rename-table", func(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.2.3") - require.Nil(t, migrator.parser.ParseAlterStatement(`ALTER TABLE test RENAME TO test_new`)) - - err := migrator.validateAlterStatement() - require.Error(t, err) - require.True(t, errors.Is(err, ErrMigratorUnsupportedRenameAlter)) - require.Len(t, migrator.migrationContext.DroppedColumnsMap, 0) - }) -} - -func TestMigratorCreateFlagFiles(t *testing.T) { - tmpdir, err := os.MkdirTemp("", t.Name()) - if err != nil { - panic(err) - } - defer os.RemoveAll(tmpdir) - - migrationContext := base.NewMigrationContext() - migrationContext.PostponeCutOverFlagFile = filepath.Join(tmpdir, "cut-over.flag") - migrator := NewMigrator(migrationContext, "1.2.3") - require.Nil(t, migrator.createFlagFiles()) - require.Nil(t, migrator.createFlagFiles()) // twice to test already-exists - - _, err = os.Stat(migrationContext.PostponeCutOverFlagFile) - require.NoError(t, err) -} - -func TestMigratorGetProgressPercent(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.2.3") - - { - require.Equal(t, float64(100.0), migrator.getProgressPercent(0)) - } - { - migrationContext.TotalRowsCopied = 250 - require.Equal(t, float64(25.0), migrator.getProgressPercent(1000)) - } -} - -func TestMigratorGetMigrationStateAndETA(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.2.3") - now := time.Now() - migrationContext.RowCopyStartTime = now.Add(-time.Minute) - migrationContext.RowCopyEndTime = now - - { - migrationContext.TotalRowsCopied = 456 - state, eta, etaDuration := migrator.getMigrationStateAndETA(123456) - require.Equal(t, "migrating", state) - require.Equal(t, "4h29m44s", eta) - require.Equal(t, "4h29m44s", etaDuration.String()) - } - { - // Test using rows-per-second added data. - migrationContext.TotalRowsCopied = 456 - migrationContext.EtaRowsPerSecond = 100 - state, eta, etaDuration := migrator.getMigrationStateAndETA(123456) - require.Equal(t, "migrating", state) - require.Equal(t, "20m30s", eta) - require.Equal(t, "20m30s", etaDuration.String()) - } - { - migrationContext.TotalRowsCopied = 456 - state, eta, etaDuration := migrator.getMigrationStateAndETA(456) - require.Equal(t, "migrating", state) - require.Equal(t, "due", eta) - require.Equal(t, "0s", etaDuration.String()) - } - { - migrationContext.TotalRowsCopied = 123456 - state, eta, etaDuration := migrator.getMigrationStateAndETA(456) - require.Equal(t, "migrating", state) - require.Equal(t, "due", eta) - require.Equal(t, "0s", etaDuration.String()) - } - { - atomic.StoreInt64(&migrationContext.CountingRowsFlag, 1) - state, eta, etaDuration := migrator.getMigrationStateAndETA(123456) - require.Equal(t, "counting rows", state) - require.Equal(t, "due", eta) - require.Equal(t, "0s", etaDuration.String()) - } - { - atomic.StoreInt64(&migrationContext.CountingRowsFlag, 0) - atomic.StoreInt64(&migrationContext.IsPostponingCutOver, 1) - state, eta, etaDuration := migrator.getMigrationStateAndETA(123456) - require.Equal(t, "postponing cut-over", state) - require.Equal(t, "due", eta) - require.Equal(t, "0s", etaDuration.String()) - } -} - -func TestMigratorShouldPrintStatus(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.2.3") - - require.True(t, migrator.shouldPrintStatus(NoPrintStatusRule, 10, time.Second)) // test 'rule != HeuristicPrintStatusRule' return - require.True(t, migrator.shouldPrintStatus(HeuristicPrintStatusRule, 10, time.Second)) // test 'etaDuration.Seconds() <= 60' - require.True(t, migrator.shouldPrintStatus(HeuristicPrintStatusRule, 90, time.Second)) // test 'etaDuration.Seconds() <= 60' again - require.True(t, migrator.shouldPrintStatus(HeuristicPrintStatusRule, 90, time.Minute)) // test 'etaDuration.Seconds() <= 180' - require.True(t, migrator.shouldPrintStatus(HeuristicPrintStatusRule, 60, 90*time.Second)) // test 'elapsedSeconds <= 180' - require.False(t, migrator.shouldPrintStatus(HeuristicPrintStatusRule, 61, 90*time.Second)) // test 'elapsedSeconds <= 180' - require.False(t, migrator.shouldPrintStatus(HeuristicPrintStatusRule, 99, 210*time.Second)) // test 'elapsedSeconds <= 180' - require.False(t, migrator.shouldPrintStatus(HeuristicPrintStatusRule, 12345, 86400*time.Second)) // test 'else' - require.True(t, migrator.shouldPrintStatus(HeuristicPrintStatusRule, 30030, 86400*time.Second)) // test 'else' again -} - -type MigratorTestSuite struct { - suite.Suite - - mysqlContainer testcontainers.Container - db *gosql.DB -} - -func (suite *MigratorTestSuite) SetupSuite() { - ctx := context.Background() - mysqlContainer, err := testmysql.Run(ctx, - testMysqlContainerImage, - testmysql.WithDatabase(testMysqlDatabase), - testmysql.WithUsername(testMysqlUser), - testmysql.WithPassword(testMysqlPass), - testmysql.WithConfigFile("my.cnf.test"), - ) - suite.Require().NoError(err) - - suite.mysqlContainer = mysqlContainer - dsn, err := mysqlContainer.ConnectionString(ctx) - suite.Require().NoError(err) - - db, err := gosql.Open("mysql", dsn) - suite.Require().NoError(err) - - suite.db = db -} - -func (suite *MigratorTestSuite) TeardownSuite() { - suite.Assert().NoError(suite.db.Close()) - suite.Assert().NoError(testcontainers.TerminateContainer(suite.mysqlContainer)) -} - -func (suite *MigratorTestSuite) SetupTest() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, "CREATE DATABASE IF NOT EXISTS "+testMysqlDatabase) - suite.Require().NoError(err) - - os.Remove("/tmp/gh-ost.sock") -} - -func (suite *MigratorTestSuite) TearDownTest() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, "DROP TABLE IF EXISTS "+getTestTableName()) - suite.Require().NoError(err) - _, err = suite.db.ExecContext(ctx, "DROP TABLE IF EXISTS "+getTestGhostTableName()) - suite.Require().NoError(err) - _, err = suite.db.ExecContext(ctx, "DROP TABLE IF EXISTS "+getTestRevertedTableName()) - suite.Require().NoError(err) - _, err = suite.db.ExecContext(ctx, "DROP TABLE IF EXISTS "+getTestOldTableName()) - suite.Require().NoError(err) -} - -func (suite *MigratorTestSuite) TestMigrateEmpty() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, name VARCHAR(64))", getTestTableName())) - suite.Require().NoError(err) - - connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) - suite.Require().NoError(err) - - migrationContext := newTestMigrationContext() - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.SetConnectionConfig("innodb") - migrationContext.InitiallyDropOldTable = true - - migrationContext.AlterStatementOptions = "ADD COLUMN foobar varchar(255), ENGINE=InnoDB" - - migrator := NewMigrator(migrationContext, "0.0.0") - - err = migrator.Migrate() - suite.Require().NoError(err) - - // Verify the new column was added - var tableName, createTableSQL string - err = suite.db.QueryRow("SHOW CREATE TABLE "+getTestTableName()).Scan(&tableName, &createTableSQL) - suite.Require().NoError(err) - - suite.Require().Equal("testing", tableName) - suite.Require().Equal("CREATE TABLE `testing` (\n `id` int NOT NULL,\n `name` varchar(64) DEFAULT NULL,\n `foobar` varchar(255) DEFAULT NULL,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci", createTableSQL) - - // Verify the changelog table was claned up - err = suite.db.QueryRow("SHOW TABLES IN test LIKE '_testing_ghc'").Scan(&tableName) - suite.Require().Error(err) - suite.Require().Equal(gosql.ErrNoRows, err) - - // Verify the old table was renamed - err = suite.db.QueryRow("SHOW TABLES IN test LIKE '_testing_del'").Scan(&tableName) - suite.Require().NoError(err) - suite.Require().Equal("_testing_del", tableName) -} - -func (suite *MigratorTestSuite) TestRetryBatchCopyWithHooks() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, "CREATE TABLE test.test_retry_batch (id INT PRIMARY KEY AUTO_INCREMENT, name TEXT)") - suite.Require().NoError(err) - - const initStride = 1000 - const totalBatches = 3 - for i := 0; i < totalBatches; i++ { - dataSize := 50 * i - for j := 0; j < initStride; j++ { - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO test.test_retry_batch (name) VALUES ('%s')", strings.Repeat("a", dataSize))) - suite.Require().NoError(err) - } - } - - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("SET GLOBAL max_binlog_cache_size = %d", 1024*8)) - suite.Require().NoError(err) - defer func() { - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("SET GLOBAL max_binlog_cache_size = %d", 1024*1024*1024)) - suite.Require().NoError(err) - }() - - tmpDir, err := os.MkdirTemp("", "gh-ost-hooks") - suite.Require().NoError(err) - defer os.RemoveAll(tmpDir) - - hookScript := filepath.Join(tmpDir, "gh-ost-on-batch-copy-retry") - hookContent := `#!/bin/bash -# Mock hook that reduces chunk size on binlog cache error -ERROR_MSG="$GH_OST_LAST_BATCH_COPY_ERROR" -SOCKET_PATH="/tmp/gh-ost.sock" - -if ! [[ "$ERROR_MSG" =~ "max_binlog_cache_size" ]]; then - echo "Nothing to do for error: $ERROR_MSG" - exit 0 -fi - -CHUNK_SIZE=$(echo "chunk-size=?" | nc -U $SOCKET_PATH | tr -d '\n') - -MIN_CHUNK_SIZE=10 -NEW_CHUNK_SIZE=$(( CHUNK_SIZE * 8 / 10 )) -if [ $NEW_CHUNK_SIZE -lt $MIN_CHUNK_SIZE ]; then - NEW_CHUNK_SIZE=$MIN_CHUNK_SIZE -fi - -if [ $CHUNK_SIZE -eq $NEW_CHUNK_SIZE ]; then - echo "Chunk size unchanged: $CHUNK_SIZE" - exit 0 -fi - -echo "[gh-ost-on-batch-copy-retry]: Changing chunk size from $CHUNK_SIZE to $NEW_CHUNK_SIZE" -echo "chunk-size=$NEW_CHUNK_SIZE" | nc -U $SOCKET_PATH -echo "[gh-ost-on-batch-copy-retry]: Done, exiting..." -` - err = os.WriteFile(hookScript, []byte(hookContent), 0755) - suite.Require().NoError(err) - - origStdout := os.Stdout - origStderr := os.Stderr - - rOut, wOut, _ := os.Pipe() - rErr, wErr, _ := os.Pipe() - os.Stdout = wOut - os.Stderr = wErr - - connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) - suite.Require().NoError(err) - - migrationContext := base.NewMigrationContext() - migrationContext.AllowedRunningOnMaster = true - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.DatabaseName = "test" - migrationContext.SkipPortValidation = true - migrationContext.OriginalTableName = "test_retry_batch" - migrationContext.SetConnectionConfig("innodb") - migrationContext.AlterStatementOptions = "MODIFY name LONGTEXT, ENGINE=InnoDB" - migrationContext.ReplicaServerId = 99999 - migrationContext.HeartbeatIntervalMilliseconds = 100 - migrationContext.ThrottleHTTPIntervalMillis = 100 - migrationContext.ThrottleHTTPTimeoutMillis = 1000 - migrationContext.HooksPath = tmpDir - migrationContext.ChunkSize = 1000 - migrationContext.SetDefaultNumRetries(10) - migrationContext.ServeSocketFile = "/tmp/gh-ost.sock" - - migrator := NewMigrator(migrationContext, "0.0.0") - - err = migrator.Migrate() - suite.Require().NoError(err) - - wOut.Close() - wErr.Close() - os.Stdout = origStdout - os.Stderr = origStderr - - var bufOut, bufErr bytes.Buffer - io.Copy(&bufOut, rOut) - io.Copy(&bufErr, rErr) - - outStr := bufOut.String() - errStr := bufErr.String() - - suite.Assert().Contains(outStr, "chunk-size: 1000") - suite.Assert().Contains(errStr, "[gh-ost-on-batch-copy-retry]: Changing chunk size from 1000 to 800") - suite.Assert().Contains(outStr, "chunk-size: 800") - - suite.Assert().Contains(errStr, "[gh-ost-on-batch-copy-retry]: Changing chunk size from 800 to 640") - suite.Assert().Contains(outStr, "chunk-size: 640") - - suite.Assert().Contains(errStr, "[gh-ost-on-batch-copy-retry]: Changing chunk size from 640 to 512") - suite.Assert().Contains(outStr, "chunk-size: 512") - - var count int - err = suite.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM test.test_retry_batch").Scan(&count) - suite.Require().NoError(err) - suite.Assert().Equal(3000, count) -} - -func (suite *MigratorTestSuite) TestCopierIntPK() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, name VARCHAR(64), age INT);", getTestTableName())) - suite.Require().NoError(err) - - connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) - suite.Require().NoError(err) - - migrationContext := newTestMigrationContext() - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.SetConnectionConfig("innodb") - - migrationContext.AlterStatementOptions = "ENGINE=InnoDB" - migrationContext.OriginalTableColumns = sql.NewColumnList([]string{"id", "name", "age"}) - migrationContext.SharedColumns = sql.NewColumnList([]string{"id", "name", "age"}) - migrationContext.MappedSharedColumns = sql.NewColumnList([]string{"id", "name", "age"}) - migrationContext.UniqueKey = &sql.UniqueKey{ - Name: "PRIMARY", - NameInGhostTable: "PRIMARY", - Columns: *sql.NewColumnList([]string{"id"}), - } - - chunkSize := int64(73) - migrationContext.ChunkSize = chunkSize - - // fill with some rows - numRows := int64(3421) - for i := range numRows { - _, err = suite.db.ExecContext(ctx, - fmt.Sprintf("INSERT INTO %s (id, name, age) VALUES (%d, 'user-%d', %d)", getTestTableName(), i, i, i%99)) - suite.Require().NoError(err) - } - - migrator := NewMigrator(migrationContext, "0.0.0") - suite.Require().NoError(migrator.initiateApplier()) - suite.Require().NoError(migrator.applier.prepareQueries()) - suite.Require().NoError(migrator.applier.ReadMigrationRangeValues()) - - go migrator.iterateChunks() - go func() { - if err := <-migrator.rowCopyComplete; err != nil { - migrator.migrationContext.PanicAbort <- err - } - atomic.StoreInt64(&migrator.rowCopyCompleteFlag, 1) - }() - - for { - if atomic.LoadInt64(&migrator.rowCopyCompleteFlag) == 1 { - suite.Assert().Equal((numRows/chunkSize)+1, migrator.migrationContext.GetIteration()) - return - } - select { - case copyRowsFunc := <-migrator.copyRowsQueue: - { - suite.Require().NoError(copyRowsFunc()) - - // check ghost table has expected number of rows - var ghostRows int64 - suite.db.QueryRowContext(ctx, - fmt.Sprintf(`SELECT COUNT(*) FROM %s`, getTestGhostTableName()), - ).Scan(&ghostRows) - suite.Assert().Equal(migrator.migrationContext.TotalRowsCopied, ghostRows) - } - default: - time.Sleep(time.Second) - } - } -} - -func (suite *MigratorTestSuite) TestCopierCompositePK() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT UNSIGNED, t CHAR(32), PRIMARY KEY (t, id));", getTestTableName())) - suite.Require().NoError(err) - - connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) - suite.Require().NoError(err) - - migrationContext := newTestMigrationContext() - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.SetConnectionConfig("innodb") - - migrationContext.AlterStatementOptions = "ENGINE=InnoDB" - migrationContext.OriginalTableColumns = sql.NewColumnList([]string{"id", "t"}) - migrationContext.SharedColumns = sql.NewColumnList([]string{"id", "t"}) - migrationContext.MappedSharedColumns = sql.NewColumnList([]string{"id", "t"}) - migrationContext.UniqueKey = &sql.UniqueKey{ - Name: "PRIMARY", - NameInGhostTable: "PRIMARY", - Columns: *sql.NewColumnList([]string{"t", "id"}), - } - - chunkSize := int64(100) - migrationContext.ChunkSize = chunkSize - - // fill with some rows - numRows := int64(2049) - for i := range numRows { - query := fmt.Sprintf(`INSERT INTO %s (id, t) VALUES (FLOOR(100000000 * RAND(%d)), MD5(RAND(%d)))`, getTestTableName(), i, i) - _, err = suite.db.ExecContext(ctx, query) - suite.Require().NoError(err) - } - - migrator := NewMigrator(migrationContext, "0.0.0") - suite.Require().NoError(migrator.initiateApplier()) - suite.Require().NoError(migrator.applier.prepareQueries()) - suite.Require().NoError(migrator.applier.ReadMigrationRangeValues()) - - go migrator.iterateChunks() - go func() { - if err := <-migrator.rowCopyComplete; err != nil { - migrator.migrationContext.PanicAbort <- err - } - atomic.StoreInt64(&migrator.rowCopyCompleteFlag, 1) - }() - - for { - if atomic.LoadInt64(&migrator.rowCopyCompleteFlag) == 1 { - suite.Assert().Equal((numRows/chunkSize)+1, migrator.migrationContext.GetIteration()) - return - } - select { - case copyRowsFunc := <-migrator.copyRowsQueue: - { - suite.Require().NoError(copyRowsFunc()) - - // check ghost table has expected number of rows - var ghostRows int64 - suite.db.QueryRowContext(ctx, - fmt.Sprintf(`SELECT COUNT(*) FROM %s`, getTestGhostTableName()), - ).Scan(&ghostRows) - suite.Assert().Equal(migrator.migrationContext.TotalRowsCopied, ghostRows) - } - default: - time.Sleep(time.Second) - } - } -} - -func TestMigratorRetry(t *testing.T) { - oldRetrySleepFn := RetrySleepFn - defer func() { RetrySleepFn = oldRetrySleepFn }() - - migrationContext := base.NewMigrationContext() - migrationContext.SetDefaultNumRetries(100) - migrator := NewMigrator(migrationContext, "1.2.3") - - var sleeps = 0 - RetrySleepFn = func(duration time.Duration) { - assert.Equal(t, 1*time.Second, duration) - sleeps++ - } - - var tries = 0 - retryable := func() error { - tries++ - if tries < int(migrationContext.MaxRetries()) { - return errors.New("Backoff") - } - return nil - } - - result := migrator.retryOperation(retryable, false) - assert.NoError(t, result) - assert.Equal(t, sleeps, 99) - assert.Equal(t, tries, 100) -} - -func TestMigratorRetryWithExponentialBackoff(t *testing.T) { - oldRetrySleepFn := RetrySleepFn - defer func() { RetrySleepFn = oldRetrySleepFn }() - - migrationContext := base.NewMigrationContext() - migrationContext.SetDefaultNumRetries(100) - migrationContext.SetExponentialBackoffMaxInterval(42) - migrator := NewMigrator(migrationContext, "1.2.3") - - var sleeps = 0 - expected := []int{ - 1, 2, 4, 8, 16, 32, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, - 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, - 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, - 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, - 42, 42, 42, 42, 42, 42, - } - RetrySleepFn = func(duration time.Duration) { - assert.Equal(t, time.Duration(expected[sleeps])*time.Second, duration) - sleeps++ - } - - var tries = 0 - retryable := func() error { - tries++ - if tries < int(migrationContext.MaxRetries()) { - return errors.New("Backoff") - } - return nil - } - - result := migrator.retryOperationWithExponentialBackoff(retryable, false) - assert.NoError(t, result) - assert.Equal(t, sleeps, 99) - assert.Equal(t, tries, 100) -} - -func TestMigratorRetryAbortsOnContextCancellation(t *testing.T) { - oldRetrySleepFn := RetrySleepFn - defer func() { RetrySleepFn = oldRetrySleepFn }() - - migrationContext := base.NewMigrationContext() - migrationContext.SetDefaultNumRetries(100) - migrator := NewMigrator(migrationContext, "1.2.3") - - RetrySleepFn = func(duration time.Duration) { - // No sleep needed for this test - } - - var tries = 0 - retryable := func() error { - tries++ - if tries == 5 { - // Cancel context on 5th try - migrationContext.CancelContext() - } - return errors.New("Simulated error") - } - - result := migrator.retryOperation(retryable, false) - assert.Error(t, result) - // Should abort after 6 tries: 5 failures + 1 checkAbort detection - assert.True(t, tries <= 6, "Expected tries <= 6, got %d", tries) - // Verify we got context cancellation error - assert.Contains(t, result.Error(), "context canceled") -} - -func TestMigratorRetryWithExponentialBackoffAbortsOnContextCancellation(t *testing.T) { - oldRetrySleepFn := RetrySleepFn - defer func() { RetrySleepFn = oldRetrySleepFn }() - - migrationContext := base.NewMigrationContext() - migrationContext.SetDefaultNumRetries(100) - migrationContext.SetExponentialBackoffMaxInterval(42) - migrator := NewMigrator(migrationContext, "1.2.3") - - RetrySleepFn = func(duration time.Duration) { - // No sleep needed for this test - } - - var tries = 0 - retryable := func() error { - tries++ - if tries == 5 { - // Cancel context on 5th try - migrationContext.CancelContext() - } - return errors.New("Simulated error") - } - - result := migrator.retryOperationWithExponentialBackoff(retryable, false) - assert.Error(t, result) - // Should abort after 6 tries: 5 failures + 1 checkAbort detection - assert.True(t, tries <= 6, "Expected tries <= 6, got %d", tries) - // Verify we got context cancellation error - assert.Contains(t, result.Error(), "context canceled") -} - -func TestMigratorRetrySkipsRetriesForWarnings(t *testing.T) { - oldRetrySleepFn := RetrySleepFn - defer func() { RetrySleepFn = oldRetrySleepFn }() - - migrationContext := base.NewMigrationContext() - migrationContext.SetDefaultNumRetries(100) - migrator := NewMigrator(migrationContext, "1.2.3") - - RetrySleepFn = func(duration time.Duration) { - t.Fatal("Should not sleep/retry for warning errors") - } - - var tries = 0 - retryable := func() error { - tries++ - return errors.New("warnings detected in statement 1 of 1: [Warning: Duplicate entry 'test' for key 'idx' (1062)]") - } - - result := migrator.retryOperation(retryable, false) - assert.Error(t, result) - // Should only try once - no retries for warnings - assert.Equal(t, 1, tries, "Expected exactly 1 try (no retries) for warning error") - assert.Contains(t, result.Error(), "warnings detected") -} - -func TestMigratorRetryWithExponentialBackoffSkipsRetriesForWarnings(t *testing.T) { - oldRetrySleepFn := RetrySleepFn - defer func() { RetrySleepFn = oldRetrySleepFn }() - - migrationContext := base.NewMigrationContext() - migrationContext.SetDefaultNumRetries(100) - migrationContext.SetExponentialBackoffMaxInterval(42) - migrator := NewMigrator(migrationContext, "1.2.3") - - RetrySleepFn = func(duration time.Duration) { - t.Fatal("Should not sleep/retry for warning errors") - } - - var tries = 0 - retryable := func() error { - tries++ - return errors.New("warnings detected in statement 1 of 1: [Warning: Duplicate entry 'test' for key 'idx' (1062)]") - } - - result := migrator.retryOperationWithExponentialBackoff(retryable, false) - assert.Error(t, result) - // Should only try once - no retries for warnings - assert.Equal(t, 1, tries, "Expected exactly 1 try (no retries) for warning error") - assert.Contains(t, result.Error(), "warnings detected") -} - -func (suite *MigratorTestSuite) TestCutOverLossDataCaseLockGhostBeforeRename() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, name VARCHAR(64))", getTestTableName())) - suite.Require().NoError(err) - - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("insert into %s values(1,'a')", getTestTableName())) - suite.Require().NoError(err) - - done := make(chan error, 1) - go func() { - connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) - if err != nil { - done <- err - return - } - migrationContext := newTestMigrationContext() - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.SetConnectionConfig("innodb") - migrationContext.AllowSetupMetadataLockInstruments = true - migrationContext.AlterStatementOptions = "ADD COLUMN foobar varchar(255)" - migrationContext.HeartbeatIntervalMilliseconds = 100 - migrationContext.CutOverLockTimeoutSeconds = 4 - - _, filename, _, _ := runtime.Caller(0) - migrationContext.PostponeCutOverFlagFile = filepath.Join(filepath.Dir(filename), "../../tmp/ghost.postpone.flag") - - migrator := NewMigrator(migrationContext, "0.0.0") - - //nolint:contextcheck - done <- migrator.Migrate() - }() - - time.Sleep(2 * time.Second) - //nolint:dogsled - _, filename, _, _ := runtime.Caller(0) - err = os.Remove(filepath.Join(filepath.Dir(filename), "../../tmp/ghost.postpone.flag")) - if err != nil { - suite.Require().NoError(err) - } - time.Sleep(1 * time.Second) - go func() { - holdConn, err := suite.db.Conn(ctx) - suite.Require().NoError(err) - _, err = holdConn.ExecContext(ctx, "SELECT *, sleep(2) FROM test._testing_gho WHERE id = 1") - suite.Require().NoError(err) - }() - - dmlConn, err := suite.db.Conn(ctx) - suite.Require().NoError(err) - - _, err = dmlConn.ExecContext(ctx, fmt.Sprintf("insert into %s (id, name) values(2,'b')", getTestTableName())) - fmt.Println("insert into table original table") - suite.Require().NoError(err) - - migrateErr := <-done - suite.Require().NoError(migrateErr) - - // Verify the new column was added - var delValue, OriginalValue int64 - err = suite.db.QueryRow( - fmt.Sprintf("select count(*) from %s._%s_del", testMysqlDatabase, testMysqlTableName), - ).Scan(&delValue) - suite.Require().NoError(err) - - err = suite.db.QueryRow("select count(*) from " + getTestTableName()).Scan(&OriginalValue) - suite.Require().NoError(err) - - suite.Require().LessOrEqual(delValue, OriginalValue) - - var tableName, createTableSQL string - err = suite.db.QueryRow("SHOW CREATE TABLE "+getTestTableName()).Scan(&tableName, &createTableSQL) - suite.Require().NoError(err) - - suite.Require().Equal(testMysqlTableName, tableName) - suite.Require().Equal("CREATE TABLE `testing` (\n `id` int NOT NULL,\n `name` varchar(64) DEFAULT NULL,\n `foobar` varchar(255) DEFAULT NULL,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci", createTableSQL) -} - -func (suite *MigratorTestSuite) TestRevertEmpty() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, s CHAR(32))", getTestTableName())) - suite.Require().NoError(err) - - var oldTableName string - - // perform original migration - connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) - suite.Require().NoError(err) - { - migrationContext := newTestMigrationContext() - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.SetConnectionConfig("innodb") - migrationContext.AlterStatement = "ADD COLUMN newcol CHAR(32)" - migrationContext.Checkpoint = true - migrationContext.CheckpointIntervalSeconds = 10 - migrationContext.DropServeSocket = true - migrationContext.InitiallyDropOldTable = true - migrationContext.UseGTIDs = true - - migrator := NewMigrator(migrationContext, "0.0.0") - - err = migrator.Migrate() - oldTableName = migrationContext.GetOldTableName() - suite.Require().NoError(err) - suite.Require().Less(migrationContext.TimeSinceLastHeartbeatOnChangelog(), 24*time.Hour) - } - - // revert the original migration - { - migrationContext := newTestMigrationContext() - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.SetConnectionConfig("innodb") - migrationContext.DropServeSocket = true - migrationContext.UseGTIDs = true - migrationContext.Revert = true - migrationContext.OkToDropTable = true - migrationContext.OldTableName = oldTableName - - migrator := NewMigrator(migrationContext, "0.0.0") - - err = migrator.Revert() - suite.Require().NoError(err) - suite.Require().Less(migrationContext.TimeSinceLastHeartbeatOnChangelog(), 24*time.Hour) - } -} - -func (suite *MigratorTestSuite) TestRevert() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, s CHAR(32))", getTestTableName())) - suite.Require().NoError(err) - - numRows := 0 - for range 100 { - _, err = suite.db.ExecContext(ctx, - fmt.Sprintf("INSERT INTO %s (id, s) VALUES (%d, MD5('%d'))", getTestTableName(), numRows, numRows)) - suite.Require().NoError(err) - numRows += 1 - } - - var oldTableName string - - connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) - suite.Require().NoError(err) - // perform original migration - { - migrationContext := newTestMigrationContext() - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.SetConnectionConfig("innodb") - migrationContext.AlterStatement = "ADD INDEX idx1 (s)" - migrationContext.Checkpoint = true - migrationContext.CheckpointIntervalSeconds = 10 - migrationContext.DropServeSocket = true - migrationContext.InitiallyDropOldTable = true - migrationContext.UseGTIDs = true - - migrator := NewMigrator(migrationContext, "0.0.0") - - err = migrator.Migrate() - oldTableName = migrationContext.GetOldTableName() - suite.Require().NoError(err) - } - - // do some writes - for range 100 { - _, err = suite.db.ExecContext(ctx, - fmt.Sprintf("INSERT INTO %s (id, s) VALUES (%d, MD5('%d'))", getTestTableName(), numRows, numRows)) - suite.Require().NoError(err) - numRows += 1 - } - for i := 0; i < numRows; i += 7 { - _, err = suite.db.ExecContext(ctx, - fmt.Sprintf("UPDATE %s SET s=MD5('%d') where id=%d", getTestTableName(), 2*i, i)) - suite.Require().NoError(err) - } - - // revert the original migration - { - migrationContext := newTestMigrationContext() - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.SetConnectionConfig("innodb") - migrationContext.DropServeSocket = true - migrationContext.UseGTIDs = true - migrationContext.Revert = true - migrationContext.OldTableName = oldTableName - - migrator := NewMigrator(migrationContext, "0.0.0") - - err = migrator.Revert() - oldTableName = migrationContext.GetOldTableName() - suite.Require().NoError(err) - } - - // checksum original and reverted table - var _tableName, checksum1, checksum2 string - rows, err := suite.db.Query(fmt.Sprintf("CHECKSUM TABLE %s, %s", testMysqlTableName, oldTableName)) - suite.Require().NoError(err) - defer rows.Close() - suite.Require().True(rows.Next()) - suite.Require().NoError(rows.Scan(&_tableName, &checksum1)) - suite.Require().True(rows.Next()) - suite.Require().NoError(rows.Scan(&_tableName, &checksum2)) - suite.Require().NoError(rows.Err()) - - suite.Require().Equal(checksum1, checksum2) -} - -func TestMigrator(t *testing.T) { - if testing.Short() { - t.Skip("skipping migrator test suite in short mode") - } - suite.Run(t, new(MigratorTestSuite)) -} - -func TestPanicAbort_PropagatesError(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.0.0") - - // Start listenOnPanicAbort - go migrator.listenOnPanicAbort() - - // Send an error to PanicAbort - testErr := errors.New("test abort error") - go func() { - migrationContext.PanicAbort <- testErr - }() - - // Wait a bit for error to be processed - time.Sleep(100 * time.Millisecond) - - // Verify error was stored - got := migrationContext.GetAbortError() - if got != testErr { //nolint:errorlint // Testing pointer equality for sentinel error - t.Errorf("Expected error %v, got %v", testErr, got) - } - - // Verify context was cancelled - ctx := migrationContext.GetContext() - select { - case <-ctx.Done(): - // Success - context was cancelled - default: - t.Error("Expected context to be cancelled") - } -} - -func TestPanicAbort_FirstErrorWins(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.0.0") - - // Start listenOnPanicAbort - go migrator.listenOnPanicAbort() - - // Send first error - err1 := errors.New("first error") - go func() { - migrationContext.PanicAbort <- err1 - }() - - // Wait for first error to be processed - time.Sleep(50 * time.Millisecond) - - // Try to send second error (should be ignored) - err2 := errors.New("second error") - migrationContext.SetAbortError(err2) - - // Verify only first error is stored - got := migrationContext.GetAbortError() - if got != err1 { //nolint:errorlint // Testing pointer equality for sentinel error - t.Errorf("Expected first error %v, got %v", err1, got) - } -} - -func TestAbort_AfterRowCopy(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.0.0") - - // Start listenOnPanicAbort - go migrator.listenOnPanicAbort() - - // Give listenOnPanicAbort time to start - time.Sleep(20 * time.Millisecond) - - // Simulate row copy error by sending to rowCopyComplete in a goroutine - // (unbuffered channel, so send must be async) - testErr := errors.New("row copy failed") - go func() { - migrator.rowCopyComplete <- testErr - }() - - // Consume the error (simulating what Migrate() does) - // This is a blocking call that waits for the error - migrator.consumeRowCopyComplete() - - // Wait for the error to be processed by listenOnPanicAbort - time.Sleep(50 * time.Millisecond) - - // Check that error was stored - if got := migrationContext.GetAbortError(); got == nil { - t.Fatal("Expected abort error to be stored after row copy error") - } else if got.Error() != "row copy failed" { - t.Errorf("Expected 'row copy failed', got %v", got) - } - - // Verify context was cancelled - ctx := migrationContext.GetContext() - select { - case <-ctx.Done(): - // Success - case <-time.After(1 * time.Second): - t.Error("Expected context to be cancelled after row copy error") - } -} - -func TestAbort_DuringInspection(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.0.0") - - // Start listenOnPanicAbort - go migrator.listenOnPanicAbort() - - // Simulate error during inspection phase - testErr := errors.New("inspection failed") - go func() { - time.Sleep(10 * time.Millisecond) - select { - case migrationContext.PanicAbort <- testErr: - case <-migrationContext.GetContext().Done(): - } - }() - - // Wait for abort to be processed - time.Sleep(50 * time.Millisecond) - - // Call checkAbort (simulating what Migrate() does after initiateInspector) - err := migrator.checkAbort() - if err == nil { - t.Fatal("Expected checkAbort to return error after abort during inspection") - } - - if err.Error() != "inspection failed" { - t.Errorf("Expected 'inspection failed', got %v", err) - } -} - -func TestAbort_DuringStreaming(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.0.0") - - // Start listenOnPanicAbort - go migrator.listenOnPanicAbort() - - // Simulate error from streaming goroutine - testErr := errors.New("streaming error") - go func() { - time.Sleep(10 * time.Millisecond) - // Use select pattern like actual code does - select { - case migrationContext.PanicAbort <- testErr: - case <-migrationContext.GetContext().Done(): - } - }() - - // Wait for abort to be processed - time.Sleep(50 * time.Millisecond) - - // Verify error stored and context cancelled - if got := migrationContext.GetAbortError(); got == nil { - t.Fatal("Expected abort error to be stored") - } else if got.Error() != "streaming error" { - t.Errorf("Expected 'streaming error', got %v", got) - } - - // Verify checkAbort catches it - err := migrator.checkAbort() - if err == nil { - t.Fatal("Expected checkAbort to return error after streaming abort") - } -} - -func TestRetryExhaustion_TriggersAbort(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrationContext.SetDefaultNumRetries(2) // Only 2 retries - migrator := NewMigrator(migrationContext, "1.0.0") - - // Start listenOnPanicAbort - go migrator.listenOnPanicAbort() - - // Operation that always fails - callCount := 0 - operation := func() error { - callCount++ - return errors.New("persistent failure") - } - - // Call retryOperation (with notFatalHint=false so it sends to PanicAbort) - err := migrator.retryOperation(operation) - - // Should have called operation MaxRetries times - if callCount != 2 { - t.Errorf("Expected 2 retry attempts, got %d", callCount) - } - - // Should return the error - if err == nil { - t.Fatal("Expected retryOperation to return error") - } - - // Wait for abort to be processed - time.Sleep(100 * time.Millisecond) - - // Verify error was sent to PanicAbort and stored - if got := migrationContext.GetAbortError(); got == nil { - t.Error("Expected abort error to be stored after retry exhaustion") - } - - // Verify context was cancelled - ctx := migrationContext.GetContext() - select { - case <-ctx.Done(): - // Success - default: - t.Error("Expected context to be cancelled after retry exhaustion") - } -} - -func TestRevert_AbortsOnError(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrationContext.Revert = true - migrationContext.OldTableName = "_test_del" - migrationContext.OriginalTableName = "test" - migrationContext.DatabaseName = "testdb" - migrator := NewMigrator(migrationContext, "1.0.0") - - // Start listenOnPanicAbort - go migrator.listenOnPanicAbort() - - // Simulate error during revert - testErr := errors.New("revert failed") - go func() { - time.Sleep(10 * time.Millisecond) - select { - case migrationContext.PanicAbort <- testErr: - case <-migrationContext.GetContext().Done(): - } - }() - - // Wait for abort to be processed - time.Sleep(50 * time.Millisecond) - - // Verify checkAbort catches it - err := migrator.checkAbort() - if err == nil { - t.Fatal("Expected checkAbort to return error during revert") - } - - if err.Error() != "revert failed" { - t.Errorf("Expected 'revert failed', got %v", err) - } - - // Verify context was cancelled - ctx := migrationContext.GetContext() - select { - case <-ctx.Done(): - // Success - default: - t.Error("Expected context to be cancelled during revert abort") - } -} - -func TestCheckAbort_ReturnsNilWhenNoError(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.0.0") - - // No error has occurred - err := migrator.checkAbort() - if err != nil { - t.Errorf("Expected no error, got %v", err) - } -} - -func TestCheckAbort_DetectsContextCancellation(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.0.0") - - // Cancel context directly (without going through PanicAbort) - migrationContext.CancelContext() - - // checkAbort should detect the cancellation - err := migrator.checkAbort() - if err == nil { - t.Fatal("Expected checkAbort to return error when context is cancelled") - } -} - -func (suite *MigratorTestSuite) TestPanicOnWarningsDuplicateDuringCutoverWithHighRetries() { - ctx := context.Background() - - // Create table with email column (no unique constraint initially) - _, err := suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY AUTO_INCREMENT, email VARCHAR(100))", getTestTableName())) - suite.Require().NoError(err) - - // Insert initial rows with unique email values - passes pre-flight validation - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (email) VALUES ('user1@example.com')", getTestTableName())) - suite.Require().NoError(err) - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (email) VALUES ('user2@example.com')", getTestTableName())) - suite.Require().NoError(err) - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (email) VALUES ('user3@example.com')", getTestTableName())) - suite.Require().NoError(err) - - // Verify we have 3 rows - var count int - err = suite.db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", getTestTableName())).Scan(&count) - suite.Require().NoError(err) - suite.Require().Equal(3, count) - - // Create postpone flag file - tmpDir, err := os.MkdirTemp("", "gh-ost-postpone-test") - suite.Require().NoError(err) - defer os.RemoveAll(tmpDir) - postponeFlagFile := filepath.Join(tmpDir, "postpone.flag") - err = os.WriteFile(postponeFlagFile, []byte{}, 0644) - suite.Require().NoError(err) - - // Start migration in goroutine - done := make(chan error, 1) - go func() { - connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) - if err != nil { - done <- err - return - } - - migrationContext := newTestMigrationContext() - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.SetConnectionConfig("innodb") - migrationContext.AlterStatementOptions = "ADD UNIQUE KEY unique_email_idx (email)" - migrationContext.HeartbeatIntervalMilliseconds = 100 - migrationContext.PostponeCutOverFlagFile = postponeFlagFile - migrationContext.PanicOnWarnings = true - - // High retry count + exponential backoff means retries will take a long time and fail the test if not properly aborted - migrationContext.SetDefaultNumRetries(30) - migrationContext.CutOverExponentialBackoff = true - migrationContext.SetExponentialBackoffMaxInterval(128) - - migrator := NewMigrator(migrationContext, "0.0.0") - - //nolint:contextcheck - done <- migrator.Migrate() - }() - - // Wait for migration to reach postponed state - // TODO replace this with an actual check for postponed state - time.Sleep(3 * time.Second) - - // Now insert a duplicate email value while migration is postponed - // This simulates data arriving during migration that would violate the unique constraint - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (email) VALUES ('user1@example.com')", getTestTableName())) - suite.Require().NoError(err) - - // Verify we now have 4 rows (including the duplicate) - err = suite.db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", getTestTableName())).Scan(&count) - suite.Require().NoError(err) - suite.Require().Equal(4, count) - - // Unpostpone the migration - gh-ost will now try to apply binlog events with the duplicate - err = os.Remove(postponeFlagFile) - suite.Require().NoError(err) - - // Wait for Migrate() to return - with timeout to detect if it hangs - select { - case migrateErr := <-done: - // Success - Migrate() returned - // It should return an error due to the duplicate - suite.Require().Error(migrateErr, "Expected migration to fail due to duplicate key violation") - suite.Require().Contains(migrateErr.Error(), "Duplicate entry", "Error should mention duplicate entry") - case <-time.After(5 * time.Minute): - suite.FailNow("Migrate() hung and did not return within 5 minutes - failure to abort on warnings in retry loop") - } - - // Verify all 4 rows are still in the original table (no silent data loss) - err = suite.db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", getTestTableName())).Scan(&count) - suite.Require().NoError(err) - suite.Require().Equal(4, count, "Original table should still have all 4 rows") - - // Verify both user1@example.com entries still exist - var duplicateCount int - err = suite.db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE email = 'user1@example.com'", getTestTableName())).Scan(&duplicateCount) - suite.Require().NoError(err) - suite.Require().Equal(2, duplicateCount, "Should have 2 duplicate email entries") -} diff --git a/go/logic/server.go b/go/logic/server.go index 4705ba9b9..f3fac0d09 100644 --- a/go/logic/server.go +++ b/go/logic/server.go @@ -32,6 +32,7 @@ var ( ) type printStatusFunc func(PrintStatusRule, io.Writer) +type printWorkersFunc func(io.Writer) // Server listens for requests on a socket file or via TCP type Server struct { @@ -40,14 +41,16 @@ type Server struct { tcpListener net.Listener hooksExecutor base.Hooks printStatus printStatusFunc + printWorkers printWorkersFunc isCPUProfiling int64 } -func NewServer(migrationContext *base.MigrationContext, hooksExecutor base.Hooks, printStatus printStatusFunc) *Server { +func NewServer(migrationContext *base.MigrationContext, hooksExecutor base.Hooks, printStatus printStatusFunc, printWorkers printWorkersFunc) *Server { return &Server{ migrationContext: migrationContext, hooksExecutor: hooksExecutor, printStatus: printStatus, + printWorkers: printWorkers, } } @@ -243,6 +246,11 @@ help # This message return ForcePrintStatusOnlyRule, nil case "info", "status": return ForcePrintStatusAndHintRule, nil + case "worker-stats": + if srv.printWorkers != nil { + srv.printWorkers(writer) + } + return NoPrintStatusRule, nil case "cpu-profile": cpuProfile, err := srv.runCPUProfile(arg) if err == nil { diff --git a/go/logic/streamer.go b/go/logic/streamer.go index ecb936069..86fe3d985 100644 --- a/go/logic/streamer.go +++ b/go/logic/streamer.go @@ -1,244 +1,10 @@ /* Copyright 2022 GitHub Inc. - See https://github.com/github/gh-ost/blob/master/LICENSE + See https://github.com/github/gh-ost/blob/master/LICENSE */ package logic -import ( - gosql "database/sql" - "fmt" - "strings" - "sync" - "time" - - "github.com/github/gh-ost/go/base" - "github.com/github/gh-ost/go/binlog" - "github.com/github/gh-ost/go/mysql" - - gomysql "github.com/go-mysql-org/go-mysql/mysql" - "github.com/openark/golib/sqlutils" -) - -type BinlogEventListener struct { - async bool - databaseName string - tableName string - onDmlEvent func(event *binlog.BinlogEntry) error -} - -const ( - EventsChannelBufferSize = 1 - ReconnectStreamerSleepSeconds = 1 -) - -// EventsStreamer reads data from binary logs and streams it on. It acts as a publisher, -// and interested parties may subscribe for per-table events. -type EventsStreamer struct { - connectionConfig *mysql.ConnectionConfig - db *gosql.DB - dbVersion string - migrationContext *base.MigrationContext - initialBinlogCoordinates mysql.BinlogCoordinates - listeners [](*BinlogEventListener) - listenersMutex *sync.Mutex - eventsChannel chan *binlog.BinlogEntry - binlogReader *binlog.GoMySQLReader - name string -} - -func NewEventsStreamer(migrationContext *base.MigrationContext) *EventsStreamer { - return &EventsStreamer{ - connectionConfig: migrationContext.InspectorConnectionConfig, - migrationContext: migrationContext, - listeners: [](*BinlogEventListener){}, - listenersMutex: &sync.Mutex{}, - eventsChannel: make(chan *binlog.BinlogEntry, EventsChannelBufferSize), - name: "streamer", - initialBinlogCoordinates: migrationContext.InitialStreamerCoords, - } -} - -// AddListener registers a new listener for binlog events, on a per-table basis -func (es *EventsStreamer) AddListener( - async bool, databaseName string, tableName string, onDmlEvent func(event *binlog.BinlogEntry) error) (err error) { - es.listenersMutex.Lock() - defer es.listenersMutex.Unlock() - - if databaseName == "" { - return fmt.Errorf("empty database name in AddListener") - } - if tableName == "" { - return fmt.Errorf("empty table name in AddListener") - } - listener := &BinlogEventListener{ - async: async, - databaseName: databaseName, - tableName: tableName, - onDmlEvent: onDmlEvent, - } - es.listeners = append(es.listeners, listener) - return nil -} - -// notifyListeners will notify relevant listeners with given DML event. Only -// listeners registered for changes on the table on which the DML operates are notified. -func (es *EventsStreamer) notifyListeners(binlogEntry *binlog.BinlogEntry) { - es.listenersMutex.Lock() - defer es.listenersMutex.Unlock() - - for _, listener := range es.listeners { - listener := listener - if !strings.EqualFold(listener.databaseName, binlogEntry.DmlEvent.DatabaseName) { - continue - } - if !strings.EqualFold(listener.tableName, binlogEntry.DmlEvent.TableName) { - continue - } - if listener.async { - go func() { - listener.onDmlEvent(binlogEntry) - }() - } else { - listener.onDmlEvent(binlogEntry) - } - } -} - -func (es *EventsStreamer) InitDBConnections() (err error) { - EventsStreamerUri := es.connectionConfig.GetDBUri(es.migrationContext.DatabaseName) - if es.db, _, err = mysql.GetDB(es.migrationContext.Uuid, EventsStreamerUri); err != nil { - return err - } - version, err := base.ValidateConnection(es.db, es.connectionConfig, es.migrationContext, es.name) - if err != nil { - return err - } - es.dbVersion = version - if es.initialBinlogCoordinates == nil || es.initialBinlogCoordinates.IsEmpty() { - if err := es.readCurrentBinlogCoordinates(); err != nil { - return err - } - } - if err := es.initBinlogReader(es.initialBinlogCoordinates); err != nil { - return err - } - - return nil -} - -// initBinlogReader creates and connects the reader: we hook up to a MySQL server as a replica -func (es *EventsStreamer) initBinlogReader(binlogCoordinates mysql.BinlogCoordinates) error { - goMySQLReader := binlog.NewGoMySQLReader(es.migrationContext) - if err := goMySQLReader.ConnectBinlogStreamer(binlogCoordinates); err != nil { - return err - } - es.binlogReader = goMySQLReader - return nil -} - -func (es *EventsStreamer) GetCurrentBinlogCoordinates() mysql.BinlogCoordinates { - return es.binlogReader.GetCurrentBinlogCoordinates() -} - -// readCurrentBinlogCoordinates reads master status from hooked server -func (es *EventsStreamer) readCurrentBinlogCoordinates() error { - binaryLogStatusTerm := mysql.ReplicaTermFor(es.dbVersion, "master status") - query := fmt.Sprintf("show /* gh-ost readCurrentBinlogCoordinates */ %s", binaryLogStatusTerm) - foundMasterStatus := false - err := sqlutils.QueryRowsMap(es.db, query, func(m sqlutils.RowMap) error { - if es.migrationContext.UseGTIDs { - execGtidSet := m.GetString("Executed_Gtid_Set") - gtidSet, err := gomysql.ParseMysqlGTIDSet(execGtidSet) - if err != nil { - return err - } - es.initialBinlogCoordinates = &mysql.GTIDBinlogCoordinates{GTIDSet: gtidSet.(*gomysql.MysqlGTIDSet)} - } else { - es.initialBinlogCoordinates = &mysql.FileBinlogCoordinates{ - LogFile: m.GetString("File"), - LogPos: m.GetInt64("Position"), - } - } - foundMasterStatus = true - return nil - }) - if err != nil { - return err - } - if !foundMasterStatus { - return fmt.Errorf("got no results from SHOW %s. Bailing out", strings.ToUpper(binaryLogStatusTerm)) - } - es.migrationContext.Log.Debugf("Streamer binlog coordinates: %+v", es.initialBinlogCoordinates) - return nil -} - -// StreamEvents will begin streaming events. It will be blocking, so should be -// executed by a goroutine -func (es *EventsStreamer) StreamEvents(canStopStreaming func() bool) error { - go func() { - for binlogEntry := range es.eventsChannel { - if binlogEntry.DmlEvent != nil { - es.notifyListeners(binlogEntry) - } - } - }() - // The next should block and execute forever, unless there's a serious error. - var successiveFailures int - var reconnectCoords mysql.BinlogCoordinates - ctx := es.migrationContext.GetContext() - for { - // Check for context cancellation each iteration - if err := ctx.Err(); err != nil { - return err - } - if canStopStreaming() { - return nil - } - // We will reconnect the binlog streamer at the coordinates - // of the last trx that was read completely from the streamer. - // Since row event application is idempotent, it's OK if we reapply some events. - if err := es.binlogReader.StreamEvents(canStopStreaming, es.eventsChannel); err != nil { - if canStopStreaming() { - return nil - } - - es.migrationContext.Log.Infof("StreamEvents encountered unexpected error: %+v", err) - es.migrationContext.MarkPointOfInterest() - time.Sleep(ReconnectStreamerSleepSeconds * time.Second) - - // See if there's retry overflow - if es.migrationContext.BinlogSyncerMaxReconnectAttempts > 0 && successiveFailures >= es.migrationContext.BinlogSyncerMaxReconnectAttempts { - return fmt.Errorf("%d successive failures in streamer reconnect at coordinates %+v", successiveFailures, reconnectCoords) - } - - // Reposition at same coordinates - if es.binlogReader.LastTrxCoords != nil { - reconnectCoords = es.binlogReader.LastTrxCoords.Clone() - } else { - reconnectCoords = es.initialBinlogCoordinates.Clone() - } - if !reconnectCoords.SmallerThan(es.GetCurrentBinlogCoordinates()) { - successiveFailures += 1 - } else { - successiveFailures = 0 - } - - es.migrationContext.Log.Infof("Reconnecting EventsStreamer... Will resume at %+v", reconnectCoords) - _ = es.binlogReader.Close() - if err := es.initBinlogReader(reconnectCoords); err != nil { - return err - } - } - } -} - -func (es *EventsStreamer) Close() (err error) { - err = es.binlogReader.Close() - es.migrationContext.Log.Infof("Closed streamer connection. err=%+v", err) - return err -} - -func (es *EventsStreamer) Teardown() { - es.db.Close() -} +// EventsStreamer has been replaced by the transaction-aware streaming +// in gomysql_reader.go (StreamTransactions / handleTransactionEvent). +// The coordinator now handles event dispatching and parallel application. diff --git a/go/logic/streamer_test.go b/go/logic/streamer_test.go index e8c0812d2..c48054717 100644 --- a/go/logic/streamer_test.go +++ b/go/logic/streamer_test.go @@ -1,267 +1,4 @@ package logic -import ( - "context" - gosql "database/sql" - "fmt" - "testing" - "time" - - "github.com/github/gh-ost/go/binlog" - "github.com/stretchr/testify/suite" - "github.com/testcontainers/testcontainers-go" - "github.com/testcontainers/testcontainers-go/modules/mysql" - - "golang.org/x/sync/errgroup" -) - -type EventsStreamerTestSuite struct { - suite.Suite - - mysqlContainer testcontainers.Container - db *gosql.DB -} - -func (suite *EventsStreamerTestSuite) SetupSuite() { - ctx := context.Background() - mysqlContainer, err := mysql.Run(ctx, - testMysqlContainerImage, - mysql.WithDatabase(testMysqlDatabase), - mysql.WithUsername(testMysqlUser), - mysql.WithPassword(testMysqlPass), - ) - suite.Require().NoError(err) - - suite.mysqlContainer = mysqlContainer - dsn, err := mysqlContainer.ConnectionString(ctx) - suite.Require().NoError(err) - - db, err := gosql.Open("mysql", dsn) - suite.Require().NoError(err) - - suite.db = db -} - -func (suite *EventsStreamerTestSuite) TeardownSuite() { - suite.Assert().NoError(suite.db.Close()) - suite.Assert().NoError(testcontainers.TerminateContainer(suite.mysqlContainer)) -} - -func (suite *EventsStreamerTestSuite) SetupTest() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, "CREATE DATABASE IF NOT EXISTS "+testMysqlDatabase) - suite.Require().NoError(err) -} - -func (suite *EventsStreamerTestSuite) TearDownTest() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, "DROP TABLE IF EXISTS "+getTestTableName()) - suite.Require().NoError(err) - _, err = suite.db.ExecContext(ctx, "DROP TABLE IF EXISTS "+getTestGhostTableName()) - suite.Require().NoError(err) -} - -func (suite *EventsStreamerTestSuite) TestStreamEvents() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, name VARCHAR(255))", getTestTableName())) - suite.Require().NoError(err) - - connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) - suite.Require().NoError(err) - - migrationContext := newTestMigrationContext() - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.SetConnectionConfig("innodb") - - streamer := NewEventsStreamer(migrationContext) - - err = streamer.InitDBConnections() - suite.Require().NoError(err) - defer streamer.Close() - defer streamer.Teardown() - - streamCtx, cancel := context.WithCancel(context.Background()) - - dmlEvents := make([]*binlog.BinlogDMLEvent, 0) - err = streamer.AddListener(false, testMysqlDatabase, testMysqlTableName, func(event *binlog.BinlogEntry) error { - dmlEvents = append(dmlEvents, event.DmlEvent) - - // Stop once we've collected three events - if len(dmlEvents) == 3 { - cancel() - } - - return nil - }) - suite.Require().NoError(err) - - group := errgroup.Group{} - group.Go(func() error { - //nolint:contextcheck - return streamer.StreamEvents(func() bool { - return streamCtx.Err() != nil - }) - }) - - group.Go(func() error { - var err error - - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, name) VALUES (1, 'foo')", getTestTableName())) - if err != nil { - return err - } - - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, name) VALUES (2, 'bar')", getTestTableName())) - if err != nil { - return err - } - - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, name) VALUES (3, 'baz')", getTestTableName())) - if err != nil { - return err - } - - // Bug: Need to write fourth event to hit the canStopStreaming function again - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, name) VALUES (4, 'qux')", getTestTableName())) - if err != nil { - return err - } - - return nil - }) - - err = group.Wait() - suite.Require().NoError(err) - - suite.Require().Len(dmlEvents, 3) -} - -func (suite *EventsStreamerTestSuite) TestStreamEventsAutomaticallyReconnects() { - ctx := context.Background() - _, err := suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, name VARCHAR(255))", getTestTableName())) - suite.Require().NoError(err) - - connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) - suite.Require().NoError(err) - - migrationContext := newTestMigrationContext() - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.SetConnectionConfig("innodb") - - streamer := NewEventsStreamer(migrationContext) - - err = streamer.InitDBConnections() - suite.Require().NoError(err) - defer streamer.Close() - defer streamer.Teardown() - - streamCtx, cancel := context.WithCancel(context.Background()) - - dmlEvents := make([]*binlog.BinlogDMLEvent, 0) - err = streamer.AddListener(false, testMysqlDatabase, testMysqlTableName, func(event *binlog.BinlogEntry) error { - dmlEvents = append(dmlEvents, event.DmlEvent) - - // Stop once we've collected three events - if len(dmlEvents) == 3 { - cancel() - } - - return nil - }) - suite.Require().NoError(err) - - group := errgroup.Group{} - group.Go(func() error { - //nolint:contextcheck - return streamer.StreamEvents(func() bool { - return streamCtx.Err() != nil - }) - }) - - group.Go(func() error { - var err error - - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, name) VALUES (1, 'foo')", getTestTableName())) - if err != nil { - return err - } - - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, name) VALUES (2, 'bar')", getTestTableName())) - if err != nil { - return err - } - - var currentConnectionId int - err = suite.db.QueryRowContext(ctx, "SELECT CONNECTION_ID()").Scan(¤tConnectionId) - if err != nil { - return err - } - - rows, err := suite.db.Query("SHOW FULL PROCESSLIST") - if err != nil { - return err - } - defer rows.Close() - - connectionIdsToKill := make([]int, 0) - - var id, stateTime int - var user, host, dbName, command, state, info gosql.NullString - for rows.Next() { - err = rows.Scan(&id, &user, &host, &dbName, &command, &stateTime, &state, &info) - if err != nil { - return err - } - - fmt.Printf("id: %d, user: %s, host: %s, dbName: %s, command: %s, time: %d, state: %s, info: %s\n", id, user.String, host.String, dbName.String, command.String, stateTime, state.String, info.String) - - if id != currentConnectionId && user.String == testMysqlUser { - connectionIdsToKill = append(connectionIdsToKill, id) - } - } - - if err := rows.Err(); err != nil { - return err - } - - for _, connectionIdToKill := range connectionIdsToKill { - _, err = suite.db.ExecContext(ctx, "KILL ?", connectionIdToKill) - if err != nil { - return err - } - } - - // Bug: We need to wait here for the streamer to reconnect - time.Sleep(time.Second * 2) - - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, name) VALUES (3, 'baz')", getTestTableName())) - if err != nil { - return err - } - - // Bug: Need to write fourth event to hit the canStopStreaming function again - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, name) VALUES (4, 'qux')", getTestTableName())) - if err != nil { - return err - } - - return nil - }) - - err = group.Wait() - suite.Require().NoError(err) - - suite.Require().Len(dmlEvents, 3) -} - -func TestEventsStreamer(t *testing.T) { - if testing.Short() { - t.Skip("skipping events streamer test suite in short mode") - } - suite.Run(t, new(EventsStreamerTestSuite)) -} +// Legacy EventsStreamer tests removed. +// See coordinator_test.go for the replacement test suite. diff --git a/go/logic/test_helpers_test.go b/go/logic/test_helpers_test.go new file mode 100644 index 000000000..3b067fcfb --- /dev/null +++ b/go/logic/test_helpers_test.go @@ -0,0 +1,38 @@ +package logic + +import ( + "context" + "fmt" + + "github.com/github/gh-ost/go/mysql" + "github.com/testcontainers/testcontainers-go" +) + +func GetDSN(ctx context.Context, container testcontainers.Container) (string, error) { + host, err := container.Host(ctx) + if err != nil { + return "", err + } + port, err := container.MappedPort(ctx, "3306/tcp") + if err != nil { + return "", err + } + return fmt.Sprintf("root:root-password@tcp(%s:%s)/", host, port.Port()), nil +} + +func GetConnectionConfig(ctx context.Context, container testcontainers.Container) (*mysql.ConnectionConfig, error) { + host, err := container.Host(ctx) + if err != nil { + return nil, err + } + port, err := container.MappedPort(ctx, "3306/tcp") + if err != nil { + return nil, err + } + config := mysql.NewConnectionConfig() + config.Key.Hostname = host + config.Key.Port = port.Int() + config.User = "root" + config.Password = "root-password" + return config, nil +} diff --git a/go/logic/test_utils.go b/go/logic/test_utils.go index f552cfc76..d532e0920 100644 --- a/go/logic/test_utils.go +++ b/go/logic/test_utils.go @@ -28,14 +28,6 @@ func getTestGhostTableName() string { return fmt.Sprintf("`%s`.`_%s_gho`", testMysqlDatabase, testMysqlTableName) } -func getTestRevertedTableName() string { - return fmt.Sprintf("`%s`.`_%s_rev_del`", testMysqlDatabase, testMysqlTableName) -} - -func getTestOldTableName() string { - return fmt.Sprintf("`%s`.`_%s_del`", testMysqlDatabase, testMysqlTableName) -} - func getTestConnectionConfig(ctx context.Context, container testcontainers.Container) (*mysql.ConnectionConfig, error) { host, err := container.Host(ctx) if err != nil { diff --git a/go/metrics/client.go b/go/metrics/client.go new file mode 100644 index 000000000..ed6acc096 --- /dev/null +++ b/go/metrics/client.go @@ -0,0 +1,69 @@ +/* + Copyright 2022 GitHub Inc. + See https://github.com/github/gh-ost/blob/master/LICENSE +*/ + +package metrics + +import ( + "time" + + "github.com/DataDog/datadog-go/v5/statsd" + "github.com/openark/golib/log" +) + +// Noop is a StatsD client that discards all metrics. NewClient("", ...) returns +// this exact pointer so callers can use `client == metrics.Noop`. +var Noop = &Client{} + +// Client wraps a StatsD client with namespace and global tags (from --statsd-tags). +type Client struct { + sd *statsd.Client +} + +// NewClient connects to addr for StatsD. If addr is empty, returns Noop and nil error. +// namespace is typically "gh_ost." (metrics are named namespace + short name, e.g. gh_ost.startup). +// tags are global tags applied to every metric (repeatable --statsd-tags). +func NewClient(addr string, tags []string, namespace string) (*Client, error) { + if addr == "" { + return Noop, nil + } + sd, err := statsd.New(addr, + statsd.WithNamespace(namespace), + statsd.WithTags(tags), + statsd.WithoutTelemetry(), + statsd.WithoutOriginDetection(), + statsd.WithClientSideAggregation(), + statsd.WithExtendedClientSideAggregation(), + statsd.WithMaxSamplesPerContext(1_000), + statsd.WithMaxBytesPerPayload(8_172), + statsd.WithAggregationInterval(5*time.Second), + ) + if err != nil { + return nil, err + } + log.Infof("metrics: DogStatsD client connected to %s (namespace: %s)", addr, namespace) + return &Client{sd: sd}, nil +} + +func (c *Client) Gauge(name string, value float64, tags ...string) { + if c.sd == nil { + return + } + _ = c.sd.Gauge(name, value, tags, 1.0) +} + +func (c *Client) Count(name string, value int64, tags ...string) { + if c.sd == nil { + return + } + _ = c.sd.Count(name, value, tags, 1.0) +} + +// Close flushes buffered metrics; safe for Noop. +func (c *Client) Close() error { + if c.sd == nil { + return nil + } + return c.sd.Close() +} diff --git a/go/metrics/client_test.go b/go/metrics/client_test.go new file mode 100644 index 000000000..a2fb81261 --- /dev/null +++ b/go/metrics/client_test.go @@ -0,0 +1,57 @@ +/* + Copyright 2022 GitHub Inc. + See https://github.com/github/gh-ost/blob/master/LICENSE +*/ + +package metrics + +import ( + "slices" + "testing" +) + +func TestNewClient_NoAddr_ReturnsNoopSingleton(t *testing.T) { + c, err := NewClient("", []string{"env:test"}, "gh_ost.") + if err != nil { + t.Fatal(err) + } + if c != Noop || c.sd != nil { + t.Fatalf("expected Noop singleton without statsd connection, got %p sd=%v", c, c.sd) + } + if err := c.Close(); err != nil { + t.Fatal(err) + } +} + +func TestMergeTagSlices(t *testing.T) { + tests := []struct { + name string + global []string + perCall []string + want []string + }{ + {"nil_global", nil, []string{"k:v"}, []string{"k:v"}}, + {"empty_extra", []string{"env:prod"}, nil, []string{"env:prod"}}, + {"combined", []string{"env:prod"}, []string{"shard:1"}, []string{"env:prod", "shard:1"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := mergeTagSlices(tt.global, tt.perCall) + if !slices.Equal(got, tt.want) { + t.Fatalf("got %#v want %#v", got, tt.want) + } + }) + } +} + +func mergeTagSlices(global, perCall []string) []string { + if len(global) == 0 { + return perCall + } + if len(perCall) == 0 { + return global + } + out := make([]string, 0, len(global)+len(perCall)) + return append(append(out, global...), perCall...) +} diff --git a/go/metrics/go_runtime.go b/go/metrics/go_runtime.go new file mode 100644 index 000000000..24ae2c6b5 --- /dev/null +++ b/go/metrics/go_runtime.go @@ -0,0 +1,61 @@ +/* + Copyright 2022 GitHub Inc. + See https://github.com/github/gh-ost/blob/master/LICENSE +*/ + +package metrics + +import ( + "context" + "runtime" + "time" +) + +// MemStatsGaugeEmitter is implemented by *Client; used for tests without UDP. +type MemStatsGaugeEmitter interface { + Gauge(name string, value float64, tags ...string) +} + +// EmitGoRuntimeGauges emits gh_ost.go_runtime.* gauges (namespace is applied by the client). +// m and numGoroutine are typically from runtime.ReadMemStats and runtime.NumGoroutine. +func EmitGoRuntimeGauges(emit MemStatsGaugeEmitter, m *runtime.MemStats, numGoroutine int) { + if emit == nil || m == nil { + return + } + emit.Gauge("go_runtime.alloc_bytes", float64(m.Alloc)) + emit.Gauge("go_runtime.sys_bytes", float64(m.Sys)) + emit.Gauge("go_runtime.heap_inuse_bytes", float64(m.HeapInuse)) + emit.Gauge("go_runtime.num_gc", float64(m.NumGC)) + emit.Gauge("go_runtime.gc_pause_total_ns", float64(m.PauseTotalNs)) + emit.Gauge("go_runtime.goroutines", float64(numGoroutine)) +} + +// StartGoRuntimeReporter periodically samples runtime memory and goroutines and emits gauges +// until ctx is cancelled. It is a no-op when interval <= 0, client is nil, or StatsD is disabled +// (noop client). +func StartGoRuntimeReporter(ctx context.Context, client *Client, interval time.Duration) { + if ctx == nil || client == nil || interval <= 0 || client.sd == nil { + return + } + + emit := func() { + var m runtime.MemStats + runtime.ReadMemStats(&m) + EmitGoRuntimeGauges(client, &m, runtime.NumGoroutine()) + } + + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + emit() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + emit() + } + } + }() +} diff --git a/go/metrics/go_runtime_test.go b/go/metrics/go_runtime_test.go new file mode 100644 index 000000000..24811206b --- /dev/null +++ b/go/metrics/go_runtime_test.go @@ -0,0 +1,67 @@ +/* + Copyright 2022 GitHub Inc. + See https://github.com/github/gh-ost/blob/master/LICENSE +*/ + +package metrics + +import ( + "context" + "runtime" + "testing" + "time" +) + +type gaugeSpy struct { + names []string + values []float64 +} + +func (g *gaugeSpy) Gauge(name string, value float64, _ ...string) { + g.names = append(g.names, name) + g.values = append(g.values, value) +} + +func TestEmitGoRuntimeGauges(t *testing.T) { + spy := &gaugeSpy{} + m := &runtime.MemStats{ + Alloc: 100, + Sys: 200, + HeapInuse: 300, + NumGC: 7, + PauseTotalNs: 42, + } + EmitGoRuntimeGauges(spy, m, 123) + + wantNames := []string{ + "go_runtime.alloc_bytes", + "go_runtime.sys_bytes", + "go_runtime.heap_inuse_bytes", + "go_runtime.num_gc", + "go_runtime.gc_pause_total_ns", + "go_runtime.goroutines", + } + wantVals := []float64{100, 200, 300, 7, 42, 123} + + if len(spy.names) != len(wantNames) { + t.Fatalf("got %d gauges, want %d", len(spy.names), len(wantNames)) + } + for i := range wantNames { + if spy.names[i] != wantNames[i] || spy.values[i] != wantVals[i] { + t.Fatalf("[%d] got %s=%v want %s=%v", i, spy.names[i], spy.values[i], wantNames[i], wantVals[i]) + } + } +} + +func TestEmitGoRuntimeGauges_nilSafe(t *testing.T) { + EmitGoRuntimeGauges(nil, &runtime.MemStats{}, 1) + EmitGoRuntimeGauges(&gaugeSpy{}, nil, 1) +} + +func TestStartGoRuntimeReporter_stopsOnCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := &Client{} // sd nil — should not start + StartGoRuntimeReporter(ctx, c, time.Millisecond) + cancel() + time.Sleep(20 * time.Millisecond) +} diff --git a/localtests/sysbench/generate_load b/localtests/sysbench/generate_load new file mode 100755 index 000000000..f1f641af1 --- /dev/null +++ b/localtests/sysbench/generate_load @@ -0,0 +1 @@ +#!/usr/bin/env bash diff --git a/localtests/test.sh b/localtests/test.sh index d918d473b..575cf6a72 100755 --- a/localtests/test.sh +++ b/localtests/test.sh @@ -16,6 +16,7 @@ toxiproxy=false gtid=false storage_engine=innodb exec_command_file=/tmp/gh-ost-test.bash +generate_load_file=/tmp/gh-ost-generate-load.bash ghost_structure_output_file=/tmp/gh-ost-test.ghost.structure.sql orig_content_output_file=/tmp/gh-ost-test.orig.content.csv ghost_content_output_file=/tmp/gh-ost-test.ghost.content.csv diff --git a/script/test b/script/test index 7c58e6520..3f66288d1 100755 --- a/script/test +++ b/script/test @@ -6,7 +6,7 @@ set -e echo "Verifying code is formatted via 'gofmt -s -w go/'" gofmt -s -w go/ -git diff --exit-code --quiet go/ +git diff --exit-code --quiet echo "Building" script/build @@ -14,4 +14,4 @@ script/build cd .gopath/src/github.com/github/gh-ost echo "Running unit tests" -go test "$@" -v -covermode=atomic ./go/... +go test -v -p 1 -covermode=atomic -race ./go/... diff --git a/vendor/github.com/DataDog/datadog-go/v5/LICENSE.txt b/vendor/github.com/DataDog/datadog-go/v5/LICENSE.txt new file mode 100644 index 000000000..97cd06d7f --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/LICENSE.txt @@ -0,0 +1,19 @@ +Copyright (c) 2015 Datadog, Inc + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/README.md b/vendor/github.com/DataDog/datadog-go/v5/statsd/README.md new file mode 100644 index 000000000..2fc899687 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/README.md @@ -0,0 +1,4 @@ +## Overview + +Package `statsd` provides a Go [dogstatsd](http://docs.datadoghq.com/guides/dogstatsd/) client. Dogstatsd extends Statsd, adding tags +and histograms. diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/aggregator.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/aggregator.go new file mode 100644 index 000000000..ed18f8f5c --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/aggregator.go @@ -0,0 +1,349 @@ +package statsd + +import ( + "strings" + "sync" + "sync/atomic" + "time" +) + +type ( + countsMap map[string]*countMetric + gaugesMap map[string]*gaugeMetric + setsMap map[string]*setMetric + bufferedMetricMap map[string]*bufferedMetric +) + +type countShard struct { + sync.RWMutex + counts countsMap +} + +type gaugeShard struct { + sync.RWMutex + gauges gaugesMap +} + +type setShard struct { + sync.RWMutex + sets setsMap +} + +type aggregator struct { + nbContextGauge uint64 + nbContextCount uint64 + nbContextSet uint64 + + shardsCount int + countShards []*countShard + gaugeShards []*gaugeShard + setShards []*setShard + + histograms bufferedMetricContexts + distributions bufferedMetricContexts + timings bufferedMetricContexts + + closed chan struct{} + + client *ClientEx + + // aggregator implements channelMode mechanism to receive histograms, + // distributions and timings. Since they need sampling they need to + // lock for random. When using both channelMode and ExtendedAggregation + // we don't want goroutine to fight over the lock. + inputMetrics chan metric + stopChannelMode chan struct{} + wg sync.WaitGroup +} + +func newAggregator(c *ClientEx, maxSamplesPerContext int64, shardsCount int) *aggregator { + agg := &aggregator{ + client: c, + shardsCount: shardsCount, + countShards: make([]*countShard, shardsCount), + gaugeShards: make([]*gaugeShard, shardsCount), + setShards: make([]*setShard, shardsCount), + histograms: newBufferedContexts(newHistogramMetric, maxSamplesPerContext), + distributions: newBufferedContexts(newDistributionMetric, maxSamplesPerContext), + timings: newBufferedContexts(newTimingMetric, maxSamplesPerContext), + closed: make(chan struct{}), + stopChannelMode: make(chan struct{}), + } + for i := 0; i < shardsCount; i++ { + agg.countShards[i] = &countShard{counts: countsMap{}} + agg.gaugeShards[i] = &gaugeShard{gauges: gaugesMap{}} + agg.setShards[i] = &setShard{sets: setsMap{}} + } + return agg +} + +func (a *aggregator) start(flushInterval time.Duration) { + ticker := time.NewTicker(flushInterval) + + go func() { + for { + select { + case <-ticker.C: + a.flush() + case <-a.closed: + ticker.Stop() + return + } + } + }() +} + +func (a *aggregator) startReceivingMetric(bufferSize int, nbWorkers int) { + a.inputMetrics = make(chan metric, bufferSize) + for i := 0; i < nbWorkers; i++ { + a.wg.Add(1) + go a.pullMetric() + } +} + +func (a *aggregator) stopReceivingMetric() { + close(a.stopChannelMode) + a.wg.Wait() +} + +func (a *aggregator) stop() { + a.closed <- struct{}{} +} + +func (a *aggregator) pullMetric() { + for { + select { + case m := <-a.inputMetrics: + switch m.metricType { + case histogram: + a.histogram(m.name, m.fvalue, m.tags, m.rate, m.cardinality) + case distribution: + a.distribution(m.name, m.fvalue, m.tags, m.rate, m.cardinality) + case timing: + a.timing(m.name, m.fvalue, m.tags, m.rate, m.cardinality) + } + case <-a.stopChannelMode: + a.wg.Done() + return + } + } +} + +func (a *aggregator) flush() { + for _, m := range a.flushMetrics() { + a.client.sendBlocking(m) + } +} + +func (a *aggregator) flushTelemetryMetrics(t *Telemetry) { + if a == nil { + // aggregation is disabled + return + } + + t.AggregationNbContextGauge = atomic.LoadUint64(&a.nbContextGauge) + t.AggregationNbContextCount = atomic.LoadUint64(&a.nbContextCount) + t.AggregationNbContextSet = atomic.LoadUint64(&a.nbContextSet) + t.AggregationNbContextHistogram = a.histograms.getNbContext() + t.AggregationNbContextDistribution = a.distributions.getNbContext() + t.AggregationNbContextTiming = a.timings.getNbContext() +} + +func (a *aggregator) flushMetrics() []metric { + metrics := []metric{} + + // We reset the values to avoid sending 'zero' values for metrics not + // sampled during this flush interval + + for _, shard := range a.setShards { + shard.Lock() + sets := shard.sets + shard.sets = setsMap{} + shard.Unlock() + for _, s := range sets { + metrics = append(metrics, s.flushUnsafe()...) + } + atomic.AddUint64(&a.nbContextSet, uint64(len(sets))) + } + + for _, shard := range a.gaugeShards { + shard.Lock() + gauges := shard.gauges + shard.gauges = gaugesMap{} + shard.Unlock() + for _, g := range gauges { + metrics = append(metrics, g.flushUnsafe()) + } + atomic.AddUint64(&a.nbContextGauge, uint64(len(gauges))) + } + + for _, shard := range a.countShards { + shard.Lock() + counts := shard.counts + shard.counts = countsMap{} + shard.Unlock() + for _, c := range counts { + metrics = append(metrics, c.flushUnsafe()) + } + atomic.AddUint64(&a.nbContextCount, uint64(len(counts))) + } + + metrics = a.histograms.flush(metrics) + metrics = a.distributions.flush(metrics) + metrics = a.timings.flush(metrics) + + return metrics +} + +// getContext returns the context for a metric name, tags, and cardinality. +// +// The context is the metric name, tags, and cardinality separated by separator symbols. +// It is not intended to be used as a metric name but as a unique key to aggregate +func getContext(name string, tags []string, cardinality Cardinality) string { + c, _ := getContextAndTags(name, tags, cardinality) + return c +} + +// getContextAndTags returns the context and tags for a metric name, tags, and cardinality. +// +// See getContext for usage for context +// The tags are the tags separated by a separator symbol and can be re-used to pass down to the writer +func getContextAndTags(name string, tags []string, cardinality Cardinality) (string, string) { + cardString := cardinality.String() + if len(tags) == 0 { + if cardString == "" { + return name, "" + } + return name + nameSeparatorSymbol + cardString, "" + } + + n := len(name) + len(nameSeparatorSymbol) + len(tagSeparatorSymbol)*(len(tags)-1) + for _, s := range tags { + n += len(s) + } + var cardStringLen = 0 + if cardString != "" { + n += len(cardString) + len(cardSeparatorSymbol) + cardStringLen = len(cardString) + len(cardSeparatorSymbol) + } + + var sb strings.Builder + sb.Grow(n) + sb.WriteString(name) + sb.WriteString(nameSeparatorSymbol) + if cardString != "" { + sb.WriteString(cardString) + sb.WriteString(cardSeparatorSymbol) + } + sb.WriteString(tags[0]) + for _, s := range tags[1:] { + sb.WriteString(tagSeparatorSymbol) + sb.WriteString(s) + } + + s := sb.String() + + return s, s[len(name)+len(nameSeparatorSymbol)+cardStringLen:] +} + +func getShardIndex(shardsCount int, context string) int { + if shardsCount <= 1 { + return 0 + } + return int(hashString32(context) % uint32(shardsCount)) +} + +func (a *aggregator) count(name string, value int64, tags []string, cardinality Cardinality) error { + context := getContext(name, tags, cardinality) + shard := a.countShards[getShardIndex(a.shardsCount, context)] + shard.RLock() + if count, found := shard.counts[context]; found { + count.sample(value) + shard.RUnlock() + return nil + } + shard.RUnlock() + + metric := newCountMetric(name, value, tags, cardinality) + + shard.Lock() + // Check if another goroutines hasn't created the value between the RUnlock and 'Lock' + if count, found := shard.counts[context]; found { + count.sample(value) + shard.Unlock() + return nil + } + + shard.counts[context] = metric + shard.Unlock() + return nil +} + +func (a *aggregator) gauge(name string, value float64, tags []string, cardinality Cardinality) error { + context := getContext(name, tags, cardinality) + shard := a.gaugeShards[getShardIndex(a.shardsCount, context)] + shard.RLock() + if gauge, found := shard.gauges[context]; found { + gauge.sample(value) + shard.RUnlock() + return nil + } + shard.RUnlock() + + gauge := newGaugeMetric(name, value, tags, cardinality) + + shard.Lock() + // Check if another goroutines hasn't created the value between the 'RUnlock' and 'Lock' + if gauge, found := shard.gauges[context]; found { + gauge.sample(value) + shard.Unlock() + return nil + } + shard.gauges[context] = gauge + shard.Unlock() + return nil +} + +func (a *aggregator) set(name string, value string, tags []string, cardinality Cardinality) error { + context := getContext(name, tags, cardinality) + shard := a.setShards[getShardIndex(a.shardsCount, context)] + shard.RLock() + if set, found := shard.sets[context]; found { + set.sample(value) + shard.RUnlock() + return nil + } + shard.RUnlock() + + metric := newSetMetric(name, value, tags, cardinality) + + shard.Lock() + // Check if another goroutines hasn't created the value between the 'RUnlock' and 'Lock' + if set, found := shard.sets[context]; found { + set.sample(value) + shard.Unlock() + return nil + } + shard.sets[context] = metric + shard.Unlock() + return nil +} + +// Only histograms, distributions and timings are sampled with a rate since we +// only pack them in on message instead of aggregating them. Discarding the +// sample rate will have impacts on the CPU and memory usage of the Agent. + +// type alias for Client.sendToAggregator +type bufferedMetricSampleFunc func(name string, value float64, tags []string, rate float64, cardinality Cardinality) error + +func (a *aggregator) histogram(name string, value float64, tags []string, rate float64, cardinality Cardinality) error { + return a.histograms.sample(name, value, tags, rate, cardinality) +} + +func (a *aggregator) distribution(name string, value float64, tags []string, rate float64, cardinality Cardinality) error { + return a.distributions.sample(name, value, tags, rate, cardinality) +} + +func (a *aggregator) timing(name string, value float64, tags []string, rate float64, cardinality Cardinality) error { + return a.timings.sample(name, value, tags, rate, cardinality) +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/buffer.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/buffer.go new file mode 100644 index 000000000..2b604090c --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/buffer.go @@ -0,0 +1,208 @@ +package statsd + +import ( + "strconv" +) + +// MessageTooLongError is an error returned when a sample, event or service check is too large once serialized. See +// WithMaxBytesPerPayload option for more details. +type MessageTooLongError struct{} + +func (e MessageTooLongError) Error() string { + return "message too long. See 'WithMaxBytesPerPayload' documentation." +} + +var errBufferFull = MessageTooLongError{} + +type partialWriteError string + +func (e partialWriteError) Error() string { return string(e) } + +const errPartialWrite = partialWriteError("value partially written") + +const metricOverhead = 512 + +// statsdBuffer is a buffer containing statsd messages +// this struct methods are NOT safe for concurrent use +type statsdBuffer struct { + buffer []byte + maxSize int + maxElements int + elementCount int +} + +func newStatsdBuffer(maxSize, maxElements int) *statsdBuffer { + return &statsdBuffer{ + buffer: make([]byte, 0, maxSize+metricOverhead), // pre-allocate the needed size + metricOverhead to avoid having Go re-allocate on it's own if an element does not fit + maxSize: maxSize, + maxElements: maxElements, + } +} + +func (b *statsdBuffer) writeGauge(namespace string, globalTags []string, name string, value float64, tags []string, rate float64, timestamp int64, originDetection bool, cardinality Cardinality) error { + if b.elementCount >= b.maxElements { + return errBufferFull + } + originalBuffer := b.buffer + b.buffer = appendGauge(b.buffer, namespace, globalTags, name, value, tags, rate, originDetection) + b.buffer = appendTimestamp(b.buffer, timestamp) + b.buffer = appendTagCardinality(b.buffer, cardinality) + b.writeSeparator() + return b.validateNewElement(originalBuffer) +} + +func (b *statsdBuffer) writeCount(namespace string, globalTags []string, name string, value int64, tags []string, rate float64, timestamp int64, originDetection bool, cardinality Cardinality) error { + if b.elementCount >= b.maxElements { + return errBufferFull + } + originalBuffer := b.buffer + b.buffer = appendCount(b.buffer, namespace, globalTags, name, value, tags, rate, originDetection) + b.buffer = appendTimestamp(b.buffer, timestamp) + b.buffer = appendTagCardinality(b.buffer, cardinality) + b.writeSeparator() + return b.validateNewElement(originalBuffer) +} + +func (b *statsdBuffer) writeHistogram(namespace string, globalTags []string, name string, value float64, tags []string, rate float64, originDetection bool, cardinality Cardinality) error { + if b.elementCount >= b.maxElements { + return errBufferFull + } + originalBuffer := b.buffer + b.buffer = appendHistogram(b.buffer, namespace, globalTags, name, value, tags, rate, originDetection) + b.buffer = appendTagCardinality(b.buffer, cardinality) + b.writeSeparator() + return b.validateNewElement(originalBuffer) +} + +// writeAggregated serialized as many values as possible in the current buffer and return the position in values where it stopped. +func (b *statsdBuffer) writeAggregated(metricSymbol []byte, namespace string, globalTags []string, name string, values []float64, tags string, tagSize int, precision int, rate float64, originDetection bool, cardinality Cardinality) (int, error) { + if b.elementCount >= b.maxElements { + return 0, errBufferFull + } + + originalBuffer := b.buffer + b.buffer = appendHeader(b.buffer, namespace, name) + + // buffer already full + if len(b.buffer)+tagSize > b.maxSize { + b.buffer = originalBuffer + return 0, errBufferFull + } + + // We add as many value as possible + var position int + for idx, v := range values { + previousBuffer := b.buffer + if idx != 0 { + b.buffer = append(b.buffer, ':') + } + + b.buffer = strconv.AppendFloat(b.buffer, v, 'f', precision, 64) + + // Should we stop serializing and switch to another buffer + if len(b.buffer)+tagSize > b.maxSize { + b.buffer = previousBuffer + break + } + position = idx + 1 + } + + // we could not add a single value + if position == 0 { + b.buffer = originalBuffer + return 0, errBufferFull + } + + b.buffer = append(b.buffer, '|') + b.buffer = append(b.buffer, metricSymbol...) + b.buffer = appendRate(b.buffer, rate) + b.buffer = appendTagsAggregated(b.buffer, globalTags, tags) + b.buffer = appendContainerID(b.buffer) + b.buffer = appendExternalEnv(b.buffer, originDetection) + b.buffer = appendTagCardinality(b.buffer, cardinality) + b.writeSeparator() + b.elementCount++ + + if position != len(values) { + return position, errPartialWrite + } + return position, nil + +} + +func (b *statsdBuffer) writeDistribution(namespace string, globalTags []string, name string, value float64, tags []string, rate float64, originDetection bool, cardinality Cardinality) error { + if b.elementCount >= b.maxElements { + return errBufferFull + } + originalBuffer := b.buffer + b.buffer = appendDistribution(b.buffer, namespace, globalTags, name, value, tags, rate, originDetection) + b.buffer = appendTagCardinality(b.buffer, cardinality) + b.writeSeparator() + return b.validateNewElement(originalBuffer) +} + +func (b *statsdBuffer) writeSet(namespace string, globalTags []string, name string, value string, tags []string, rate float64, originDetection bool, cardinality Cardinality) error { + if b.elementCount >= b.maxElements { + return errBufferFull + } + originalBuffer := b.buffer + b.buffer = appendSet(b.buffer, namespace, globalTags, name, value, tags, rate, originDetection) + b.buffer = appendTagCardinality(b.buffer, cardinality) + b.writeSeparator() + return b.validateNewElement(originalBuffer) +} + +func (b *statsdBuffer) writeTiming(namespace string, globalTags []string, name string, value float64, tags []string, rate float64, originDetection bool, cardinality Cardinality) error { + if b.elementCount >= b.maxElements { + return errBufferFull + } + originalBuffer := b.buffer + b.buffer = appendTiming(b.buffer, namespace, globalTags, name, value, tags, rate, originDetection) + b.buffer = appendTagCardinality(b.buffer, cardinality) + b.writeSeparator() + return b.validateNewElement(originalBuffer) +} + +func (b *statsdBuffer) writeEvent(event *Event, globalTags []string, originDetection bool, cardinality Cardinality) error { + if b.elementCount >= b.maxElements { + return errBufferFull + } + originalBuffer := b.buffer + b.buffer = appendEvent(b.buffer, event, globalTags, originDetection) + b.buffer = appendTagCardinality(b.buffer, cardinality) + b.writeSeparator() + return b.validateNewElement(originalBuffer) +} + +func (b *statsdBuffer) writeServiceCheck(serviceCheck *ServiceCheck, globalTags []string, originDetection bool, cardinality Cardinality) error { + if b.elementCount >= b.maxElements { + return errBufferFull + } + originalBuffer := b.buffer + b.buffer = appendServiceCheck(b.buffer, serviceCheck, globalTags, originDetection) + b.buffer = appendTagCardinality(b.buffer, cardinality) + b.writeSeparator() + return b.validateNewElement(originalBuffer) +} + +func (b *statsdBuffer) validateNewElement(originalBuffer []byte) error { + if len(b.buffer) > b.maxSize { + b.buffer = originalBuffer + return errBufferFull + } + b.elementCount++ + return nil +} + +func (b *statsdBuffer) writeSeparator() { + b.buffer = append(b.buffer, '\n') +} + +func (b *statsdBuffer) reset() { + b.buffer = b.buffer[:0] + b.elementCount = 0 +} + +func (b *statsdBuffer) bytes() []byte { + return b.buffer +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/buffer_pool.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/buffer_pool.go new file mode 100644 index 000000000..7a3e3c9d2 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/buffer_pool.go @@ -0,0 +1,40 @@ +package statsd + +type bufferPool struct { + pool chan *statsdBuffer + bufferMaxSize int + bufferMaxElements int +} + +func newBufferPool(poolSize, bufferMaxSize, bufferMaxElements int) *bufferPool { + p := &bufferPool{ + pool: make(chan *statsdBuffer, poolSize), + bufferMaxSize: bufferMaxSize, + bufferMaxElements: bufferMaxElements, + } + for i := 0; i < poolSize; i++ { + p.addNewBuffer() + } + return p +} + +func (p *bufferPool) addNewBuffer() { + p.pool <- newStatsdBuffer(p.bufferMaxSize, p.bufferMaxElements) +} + +func (p *bufferPool) borrowBuffer() *statsdBuffer { + select { + case b := <-p.pool: + return b + default: + return newStatsdBuffer(p.bufferMaxSize, p.bufferMaxElements) + } +} + +func (p *bufferPool) returnBuffer(buffer *statsdBuffer) { + buffer.reset() + select { + case p.pool <- buffer: + default: + } +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/buffered_metric_context.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/buffered_metric_context.go new file mode 100644 index 000000000..85cab2a17 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/buffered_metric_context.go @@ -0,0 +1,104 @@ +package statsd + +import ( + "math/rand" + "sync" + "sync/atomic" + "time" +) + +// bufferedMetricContexts represent the contexts for Histograms, Distributions +// and Timing. Since those 3 metric types behave the same way and are sampled +// with the same type they're represented by the same class. +type bufferedMetricContexts struct { + nbContext uint64 + mutex sync.RWMutex + values bufferedMetricMap + newMetric func(string, float64, string, float64, Cardinality) *bufferedMetric + + // Each bufferedMetricContexts uses its own random source and random + // lock to prevent goroutines from contending for the lock on the + // "math/rand" package-global random source (e.g. calls like + // "rand.Float64()" must acquire a shared lock to get the next + // pseudorandom number). + random *rand.Rand + randomLock sync.Mutex +} + +func newBufferedContexts(newMetric func(string, float64, string, int64, float64, Cardinality) *bufferedMetric, maxSamples int64) bufferedMetricContexts { + return bufferedMetricContexts{ + values: bufferedMetricMap{}, + newMetric: func(name string, value float64, stringTags string, rate float64, cardinality Cardinality) *bufferedMetric { + return newMetric(name, value, stringTags, maxSamples, rate, cardinality) + }, + // Note that calling "time.Now().UnixNano()" repeatedly quickly may return + // very similar values. That's fine for seeding the worker-specific random + // source because we just need an evenly distributed stream of float values. + // Do not use this random source for cryptographic randomness. + random: rand.New(rand.NewSource(time.Now().UnixNano())), + } +} + +func (bc *bufferedMetricContexts) flush(metrics []metric) []metric { + bc.mutex.Lock() + values := bc.values + bc.values = bufferedMetricMap{} + bc.mutex.Unlock() + + for _, d := range values { + d.Lock() + metrics = append(metrics, d.flushUnsafe()) + d.Unlock() + } + atomic.AddUint64(&bc.nbContext, uint64(len(values))) + return metrics +} + +func (bc *bufferedMetricContexts) sample(name string, value float64, tags []string, rate float64, cardinality Cardinality) error { + keepingSample := shouldSample(rate, bc.random, &bc.randomLock) + + // If we don't keep the sample, return early. If we do keep the sample + // we end up storing the *first* observed sampling rate in the metric. + // This is the *wrong* behavior but it's the one we had before and the alternative would increase lock contention too + // much with the current code. + // TODO: change this behavior in the future, probably by introducing thread-local storage and lockless stuctures. + // If this code is removed, also remove the observed sampling rate in the metric and fix `bufferedMetric.flushUnsafe()` + if !keepingSample { + return nil + } + + context, stringTags := getContextAndTags(name, tags, cardinality) + var v *bufferedMetric + + bc.mutex.RLock() + v, _ = bc.values[context] + bc.mutex.RUnlock() + + // Create it if it wasn't found + if v == nil { + bc.mutex.Lock() + // It might have been created by another goroutine since last call + v, _ = bc.values[context] + if v == nil { + // If we might keep a sample that we should have skipped, but that should not drastically affect performances. + bc.values[context] = bc.newMetric(name, value, stringTags, rate, cardinality) + // We added a new value, we need to unlock the mutex and quit + bc.mutex.Unlock() + return nil + } + bc.mutex.Unlock() + } + + // Now we can keep the sample or skip it + if keepingSample { + v.maybeKeepSample(value, bc.random, &bc.randomLock) + } else { + v.skipSample() + } + + return nil +} + +func (bc *bufferedMetricContexts) getNbContext() uint64 { + return atomic.LoadUint64(&bc.nbContext) +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/container.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/container.go new file mode 100644 index 000000000..20d69ef63 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/container.go @@ -0,0 +1,19 @@ +package statsd + +import ( + "sync" +) + +var ( + // containerID holds the container ID. + containerID = "" + + initOnce sync.Once +) + +// getContainerID returns the container ID configured at the client creation +// It can either be auto-discovered with origin detection or provided by the user. +// User-defined container ID is prioritized. +func getContainerID() string { + return containerID +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/container_linux.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/container_linux.go new file mode 100644 index 000000000..125132349 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/container_linux.go @@ -0,0 +1,219 @@ +//go:build linux +// +build linux + +package statsd + +import ( + "bufio" + "fmt" + "io" + "os" + "path" + "regexp" + "strings" + "syscall" +) + +const ( + // cgroupPath is the path to the cgroup file where we can find the container id if one exists. + cgroupPath = "/proc/self/cgroup" + + // selfMountinfo is the path to the mountinfo path where we can find the container id in case cgroup namespace is preventing the use of /proc/self/cgroup + selfMountInfoPath = "/proc/self/mountinfo" + + // defaultCgroupMountPath is the default path to the cgroup mount point. + defaultCgroupMountPath = "/sys/fs/cgroup" + + // cgroupV1BaseController is the controller used to identify the container-id for cgroup v1 + cgroupV1BaseController = "memory" + + uuidSource = "[0-9a-f]{8}[-_][0-9a-f]{4}[-_][0-9a-f]{4}[-_][0-9a-f]{4}[-_][0-9a-f]{12}" + containerSource = "[0-9a-f]{64}" + taskSource = "[0-9a-f]{32}-\\d+" + + containerdSandboxPrefix = "sandboxes" + + // ContainerRegexpStr defines the regexp used to match container IDs + // ([0-9a-f]{64}) is standard container id used pretty much everywhere + // ([0-9a-f]{32}-\d+) is container id used by AWS ECS + // ([0-9a-f]{8}(-[0-9a-f]{4}){4}$) is container id used by Garden + containerRegexpStr = "([0-9a-f]{64})|([0-9a-f]{32}-\\d+)|([0-9a-f]{8}(-[0-9a-f]{4}){4}$)" + // cIDRegexpStr defines the regexp used to match container IDs in /proc/self/mountinfo + cIDRegexpStr = `.*/([^\s/]+)/(` + containerRegexpStr + `)/[\S]*hostname` + + // From https://github.com/torvalds/linux/blob/5859a2b1991101d6b978f3feb5325dad39421f29/include/linux/proc_ns.h#L41-L49 + // Currently, host namespace inode number are hardcoded, which can be used to detect + // if we're running in host namespace or not (does not work when running in DinD) + hostCgroupNamespaceInode = 0xEFFFFFFB +) + +var ( + // expLine matches a line in the /proc/self/cgroup file. It has a submatch for the last element (path), which contains the container ID. + expLine = regexp.MustCompile(`^\d+:[^:]*:(.+)$`) + + // expContainerID matches contained IDs and sources. Source: https://github.com/Qard/container-info/blob/master/index.js + expContainerID = regexp.MustCompile(fmt.Sprintf(`(%s|%s|%s)(?:.scope)?$`, uuidSource, containerSource, taskSource)) + + cIDMountInfoRegexp = regexp.MustCompile(cIDRegexpStr) + + // initContainerID initializes the container ID. + initContainerID = internalInitContainerID +) + +// parseContainerID finds the first container ID reading from r and returns it. +func parseContainerID(r io.Reader) string { + scn := bufio.NewScanner(r) + for scn.Scan() { + path := expLine.FindStringSubmatch(scn.Text()) + if len(path) != 2 { + // invalid entry, continue + continue + } + if parts := expContainerID.FindStringSubmatch(path[1]); len(parts) == 2 { + return parts[1] + } + } + return "" +} + +// readContainerID attempts to return the container ID from the provided file path or empty on failure. +func readContainerID(fpath string) string { + f, err := os.Open(fpath) + if err != nil { + return "" + } + defer f.Close() + return parseContainerID(f) +} + +// Parsing /proc/self/mountinfo is not always reliable in Kubernetes+containerd (at least) +// We're still trying to use it as it may help in some cgroupv2 configurations (Docker, ECS, raw containerd) +func parseMountinfo(r io.Reader) string { + scn := bufio.NewScanner(r) + for scn.Scan() { + line := scn.Text() + allMatches := cIDMountInfoRegexp.FindAllStringSubmatch(line, -1) + if len(allMatches) == 0 { + continue + } + + // We're interest in rightmost match + matches := allMatches[len(allMatches)-1] + if len(matches) > 0 && matches[1] != containerdSandboxPrefix { + return matches[2] + } + } + + return "" +} + +func readMountinfo(path string) string { + f, err := os.Open(path) + if err != nil { + return "" + } + defer f.Close() + return parseMountinfo(f) +} + +func isHostCgroupNamespace() bool { + fi, err := os.Stat("/proc/self/ns/cgroup") + if err != nil { + return false + } + + inode := fi.Sys().(*syscall.Stat_t).Ino + + return inode == hostCgroupNamespaceInode +} + +// parseCgroupNodePath parses /proc/self/cgroup and returns a map of controller to its associated cgroup node path. +func parseCgroupNodePath(r io.Reader) map[string]string { + res := make(map[string]string) + scn := bufio.NewScanner(r) + for scn.Scan() { + line := scn.Text() + tokens := strings.Split(line, ":") + if len(tokens) != 3 { + continue + } + if tokens[1] == cgroupV1BaseController || tokens[1] == "" { + res[tokens[1]] = tokens[2] + } + } + return res +} + +// getCgroupInode returns the cgroup controller inode if it exists otherwise an empty string. +// The inode is prefixed by "in-" and is used by the agent to retrieve the container ID. +// For cgroup v1, we use the memory controller. +func getCgroupInode(cgroupMountPath, procSelfCgroupPath string) string { + // Parse /proc/self/cgroup to retrieve the paths to the memory controller (cgroupv1) and the cgroup node (cgroupv2) + f, err := os.Open(procSelfCgroupPath) + if err != nil { + return "" + } + defer f.Close() + cgroupControllersPaths := parseCgroupNodePath(f) + // Retrieve the cgroup inode from /sys/fs/cgroup+controller+cgroupNodePath + for _, controller := range []string{cgroupV1BaseController, ""} { + cgroupNodePath, ok := cgroupControllersPaths[controller] + if !ok { + continue + } + inode := inodeForPath(path.Join(cgroupMountPath, controller, cgroupNodePath)) + if inode != "" { + return inode + } + } + return "" +} + +// inodeForPath returns the inode for the provided path or empty on failure. +func inodeForPath(path string) string { + fi, err := os.Stat(path) + if err != nil { + return "" + } + stats, ok := fi.Sys().(*syscall.Stat_t) + if !ok { + return "" + } + return fmt.Sprintf("in-%d", stats.Ino) +} + +// internalInitContainerID initializes the container ID. +// It can either be provided by the user or read from cgroups. +func internalInitContainerID(userProvidedID string, cgroupFallback, isHostCgroupNs bool) { + initOnce.Do(func() { + readCIDOrInode(userProvidedID, cgroupPath, selfMountInfoPath, defaultCgroupMountPath, cgroupFallback, isHostCgroupNs) + }) +} + +// readCIDOrInode reads the container ID from the user provided ID, cgroups or mountinfo. +func readCIDOrInode(userProvidedID, cgroupPath, selfMountInfoPath, defaultCgroupMountPath string, cgroupFallback, isHostCgroupNs bool) { + if userProvidedID != "" { + containerID = userProvidedID + return + } + + if cgroupFallback { + containerID = readContainerID(cgroupPath) + if containerID != "" { + return + } + + containerID = readMountinfo(selfMountInfoPath) + if containerID != "" { + return + } + + // If we're in the host cgroup namespace, the cid should be retrievable in /proc/self/cgroup + // In private cgroup namespace, we can retrieve the cgroup controller inode. + if containerID == "" && isHostCgroupNs { + return + } + + containerID = getCgroupInode(defaultCgroupMountPath, cgroupPath) + } +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/container_stub.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/container_stub.go new file mode 100644 index 000000000..29ab7f2c9 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/container_stub.go @@ -0,0 +1,17 @@ +//go:build !linux +// +build !linux + +package statsd + +func isHostCgroupNamespace() bool { + return false +} + +var initContainerID = func(userProvidedID string, _, _ bool) { + initOnce.Do(func() { + if userProvidedID != "" { + containerID = userProvidedID + return + } + }) +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/error_handler.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/error_handler.go new file mode 100644 index 000000000..007626273 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/error_handler.go @@ -0,0 +1,22 @@ +package statsd + +import ( + "log" +) + +func LoggingErrorHandler(err error) { + if e, ok := err.(*ErrorInputChannelFull); ok { + log.Printf( + "Input Queue is full (%d elements): %s %s dropped - %s - increase channel buffer size with `WithChannelModeBufferSize()`", + e.ChannelSize, e.Metric.name, e.Metric.tags, e.Msg, + ) + return + } else if e, ok := err.(*ErrorSenderChannelFull); ok { + log.Printf( + "Sender Queue is full (%d elements): %d metrics dropped - %s - increase sender queue size with `WithSenderQueueSize()`", + e.ChannelSize, e.LostElements, e.Msg, + ) + } else { + log.Printf("Error: %v", err) + } +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/event.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/event.go new file mode 100644 index 000000000..a2ca4faf7 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/event.go @@ -0,0 +1,75 @@ +package statsd + +import ( + "fmt" + "time" +) + +// Events support +// EventAlertType and EventAlertPriority became exported types after this issue was submitted: https://github.com/DataDog/datadog-go/issues/41 +// The reason why they got exported is so that client code can directly use the types. + +// EventAlertType is the alert type for events +type EventAlertType string + +const ( + // Info is the "info" AlertType for events + Info EventAlertType = "info" + // Error is the "error" AlertType for events + Error EventAlertType = "error" + // Warning is the "warning" AlertType for events + Warning EventAlertType = "warning" + // Success is the "success" AlertType for events + Success EventAlertType = "success" +) + +// EventPriority is the event priority for events +type EventPriority string + +const ( + // Normal is the "normal" Priority for events + Normal EventPriority = "normal" + // Low is the "low" Priority for events + Low EventPriority = "low" +) + +// An Event is an object that can be posted to your DataDog event stream. +type Event struct { + // Title of the event. Required. + Title string + // Text is the description of the event. + Text string + // Timestamp is a timestamp for the event. If not provided, the dogstatsd + // server will set this to the current time. + Timestamp time.Time + // Hostname for the event. + Hostname string + // AggregationKey groups this event with others of the same key. + AggregationKey string + // Priority of the event. Can be statsd.Low or statsd.Normal. + Priority EventPriority + // SourceTypeName is a source type for the event. + SourceTypeName string + // AlertType can be statsd.Info, statsd.Error, statsd.Warning, or statsd.Success. + // If absent, the default value applied by the dogstatsd server is Info. + AlertType EventAlertType + // Tags for the event. + Tags []string +} + +// NewEvent creates a new event with the given title and text. Error checking +// against these values is done at send-time, or upon running e.Check. +func NewEvent(title, text string) *Event { + return &Event{ + Title: title, + Text: text, + } +} + +// Check verifies that an event is valid. +func (e *Event) Check() error { + if len(e.Title) == 0 { + return fmt.Errorf("statsd.Event title is required") + } + return nil +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/external_env.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/external_env.go new file mode 100644 index 000000000..2c9b13a4c --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/external_env.go @@ -0,0 +1,46 @@ +package statsd + +import ( + "os" + "sync" + "unicode" +) + +// ddExternalEnvVarName specifies the env var to inject the environment name. +const ddExternalEnvVarName = "DD_EXTERNAL_ENV" + +var ( + externalEnv = "" + externalEnvMu sync.RWMutex // Protects concurrent access to externalEnv +) + +// initExternalEnv initializes the external environment name. +func initExternalEnv() { + var value = os.Getenv(ddExternalEnvVarName) + if value != "" { + externalEnvMu.Lock() + externalEnv = sanitizeExternalEnv(value) + externalEnvMu.Unlock() + } +} + +// sanitizeExternalEnv removes non-printable characters and pipe characters from the external environment name. +func sanitizeExternalEnv(externalEnv string) string { + if externalEnv == "" { + return "" + } + var output string + for _, r := range externalEnv { + if unicode.IsPrint(r) && r != '|' { + output += string(r) + } + } + + return output +} + +func getExternalEnv() string { + externalEnvMu.RLock() + defer externalEnvMu.RUnlock() + return externalEnv +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/fnv1a.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/fnv1a.go new file mode 100644 index 000000000..03dc8a07c --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/fnv1a.go @@ -0,0 +1,39 @@ +package statsd + +const ( + // FNV-1a + offset32 = uint32(2166136261) + prime32 = uint32(16777619) + + // init32 is what 32 bits hash values should be initialized with. + init32 = offset32 +) + +// HashString32 returns the hash of s. +func hashString32(s string) uint32 { + return addString32(init32, s) +} + +// AddString32 adds the hash of s to the precomputed hash value h. +func addString32(h uint32, s string) uint32 { + i := 0 + n := (len(s) / 8) * 8 + + for i != n { + h = (h ^ uint32(s[i])) * prime32 + h = (h ^ uint32(s[i+1])) * prime32 + h = (h ^ uint32(s[i+2])) * prime32 + h = (h ^ uint32(s[i+3])) * prime32 + h = (h ^ uint32(s[i+4])) * prime32 + h = (h ^ uint32(s[i+5])) * prime32 + h = (h ^ uint32(s[i+6])) * prime32 + h = (h ^ uint32(s[i+7])) * prime32 + i += 8 + } + + for _, c := range s[i:] { + h = (h ^ uint32(c)) * prime32 + } + + return h +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/format.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/format.go new file mode 100644 index 000000000..52f906355 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/format.go @@ -0,0 +1,306 @@ +package statsd + +import ( + "strconv" + "strings" +) + +var ( + gaugeSymbol = []byte("g") + countSymbol = []byte("c") + histogramSymbol = []byte("h") + distributionSymbol = []byte("d") + setSymbol = []byte("s") + timingSymbol = []byte("ms") +) + +const ( + tagSeparatorSymbol = "," + nameSeparatorSymbol = ":" + cardSeparatorSymbol = "|" +) + +func appendHeader(buffer []byte, namespace string, name string) []byte { + if namespace != "" { + buffer = append(buffer, namespace...) + } + buffer = append(buffer, name...) + buffer = append(buffer, ':') + return buffer +} + +func appendRate(buffer []byte, rate float64) []byte { + if rate < 1 { + buffer = append(buffer, "|@"...) + buffer = strconv.AppendFloat(buffer, rate, 'f', -1, 64) + } + return buffer +} + +func appendWithoutNewlines(buffer []byte, s string) []byte { + // fastpath for strings without newlines + if strings.IndexByte(s, '\n') == -1 { + return append(buffer, s...) + } + + for _, b := range []byte(s) { + if b != '\n' { + buffer = append(buffer, b) + } + } + return buffer +} + +func appendTags(buffer []byte, globalTags []string, tags []string) []byte { + if len(globalTags) == 0 && len(tags) == 0 { + return buffer + } + buffer = append(buffer, "|#"...) + firstTag := true + + for _, tag := range globalTags { + if !firstTag { + buffer = append(buffer, tagSeparatorSymbol...) + } + buffer = appendWithoutNewlines(buffer, tag) + firstTag = false + } + for _, tag := range tags { + if !firstTag { + buffer = append(buffer, tagSeparatorSymbol...) + } + buffer = appendWithoutNewlines(buffer, tag) + firstTag = false + } + return buffer +} + +func appendTagsAggregated(buffer []byte, globalTags []string, tags string) []byte { + if len(globalTags) == 0 && tags == "" { + return buffer + } + + buffer = append(buffer, "|#"...) + firstTag := true + + for _, tag := range globalTags { + if !firstTag { + buffer = append(buffer, tagSeparatorSymbol...) + } + buffer = appendWithoutNewlines(buffer, tag) + firstTag = false + } + if tags != "" { + if !firstTag { + buffer = append(buffer, tagSeparatorSymbol...) + } + buffer = appendWithoutNewlines(buffer, tags) + } + return buffer +} + +func appendFloatMetric(buffer []byte, typeSymbol []byte, namespace string, globalTags []string, name string, value float64, tags []string, rate float64, precision int, originDetection bool) []byte { + buffer = appendHeader(buffer, namespace, name) + buffer = strconv.AppendFloat(buffer, value, 'f', precision, 64) + buffer = append(buffer, '|') + buffer = append(buffer, typeSymbol...) + buffer = appendRate(buffer, rate) + buffer = appendTags(buffer, globalTags, tags) + buffer = appendContainerID(buffer) + buffer = appendExternalEnv(buffer, originDetection) + return buffer +} + +func appendIntegerMetric(buffer []byte, typeSymbol []byte, namespace string, globalTags []string, name string, value int64, tags []string, rate float64, originDetection bool) []byte { + buffer = appendHeader(buffer, namespace, name) + buffer = strconv.AppendInt(buffer, value, 10) + buffer = append(buffer, '|') + buffer = append(buffer, typeSymbol...) + buffer = appendRate(buffer, rate) + buffer = appendTags(buffer, globalTags, tags) + buffer = appendContainerID(buffer) + buffer = appendExternalEnv(buffer, originDetection) + return buffer +} + +func appendStringMetric(buffer []byte, typeSymbol []byte, namespace string, globalTags []string, name string, value string, tags []string, rate float64, originDetection bool) []byte { + buffer = appendHeader(buffer, namespace, name) + buffer = append(buffer, value...) + buffer = append(buffer, '|') + buffer = append(buffer, typeSymbol...) + buffer = appendRate(buffer, rate) + buffer = appendTags(buffer, globalTags, tags) + buffer = appendContainerID(buffer) + buffer = appendExternalEnv(buffer, originDetection) + return buffer +} + +func appendGauge(buffer []byte, namespace string, globalTags []string, name string, value float64, tags []string, rate float64, originDetection bool) []byte { + return appendFloatMetric(buffer, gaugeSymbol, namespace, globalTags, name, value, tags, rate, -1, originDetection) +} + +func appendCount(buffer []byte, namespace string, globalTags []string, name string, value int64, tags []string, rate float64, originDetection bool) []byte { + return appendIntegerMetric(buffer, countSymbol, namespace, globalTags, name, value, tags, rate, originDetection) +} + +func appendHistogram(buffer []byte, namespace string, globalTags []string, name string, value float64, tags []string, rate float64, originDetection bool) []byte { + return appendFloatMetric(buffer, histogramSymbol, namespace, globalTags, name, value, tags, rate, -1, originDetection) +} + +func appendDistribution(buffer []byte, namespace string, globalTags []string, name string, value float64, tags []string, rate float64, originDetection bool) []byte { + return appendFloatMetric(buffer, distributionSymbol, namespace, globalTags, name, value, tags, rate, -1, originDetection) +} + +func appendSet(buffer []byte, namespace string, globalTags []string, name string, value string, tags []string, rate float64, originDetection bool) []byte { + return appendStringMetric(buffer, setSymbol, namespace, globalTags, name, value, tags, rate, originDetection) +} + +func appendTiming(buffer []byte, namespace string, globalTags []string, name string, value float64, tags []string, rate float64, originDetection bool) []byte { + return appendFloatMetric(buffer, timingSymbol, namespace, globalTags, name, value, tags, rate, 6, originDetection) +} + +func escapedEventTextLen(text string) int { + return len(text) + strings.Count(text, "\n") +} + +func appendEscapedEventText(buffer []byte, text string) []byte { + for _, b := range []byte(text) { + if b != '\n' { + buffer = append(buffer, b) + } else { + buffer = append(buffer, "\\n"...) + } + } + return buffer +} + +func appendEvent(buffer []byte, event *Event, globalTags []string, originDetection bool) []byte { + escapedTextLen := escapedEventTextLen(event.Text) + + buffer = append(buffer, "_e{"...) + buffer = strconv.AppendInt(buffer, int64(len(event.Title)), 10) + buffer = append(buffer, tagSeparatorSymbol...) + buffer = strconv.AppendInt(buffer, int64(escapedTextLen), 10) + buffer = append(buffer, "}:"...) + buffer = append(buffer, event.Title...) + buffer = append(buffer, '|') + if escapedTextLen != len(event.Text) { + buffer = appendEscapedEventText(buffer, event.Text) + } else { + buffer = append(buffer, event.Text...) + } + + if !event.Timestamp.IsZero() { + buffer = append(buffer, "|d:"...) + buffer = strconv.AppendInt(buffer, int64(event.Timestamp.Unix()), 10) + } + + if len(event.Hostname) != 0 { + buffer = append(buffer, "|h:"...) + buffer = append(buffer, event.Hostname...) + } + + if len(event.AggregationKey) != 0 { + buffer = append(buffer, "|k:"...) + buffer = append(buffer, event.AggregationKey...) + } + + if len(event.Priority) != 0 { + buffer = append(buffer, "|p:"...) + buffer = append(buffer, event.Priority...) + } + + if len(event.SourceTypeName) != 0 { + buffer = append(buffer, "|s:"...) + buffer = append(buffer, event.SourceTypeName...) + } + + if len(event.AlertType) != 0 { + buffer = append(buffer, "|t:"...) + buffer = append(buffer, string(event.AlertType)...) + } + + buffer = appendTags(buffer, globalTags, event.Tags) + buffer = appendContainerID(buffer) + buffer = appendExternalEnv(buffer, originDetection) + return buffer +} + +func appendEscapedServiceCheckText(buffer []byte, text string) []byte { + for i := 0; i < len(text); i++ { + if text[i] == '\n' { + buffer = append(buffer, "\\n"...) + } else if text[i] == 'm' && i+1 < len(text) && text[i+1] == ':' { + buffer = append(buffer, "m\\:"...) + i++ + } else { + buffer = append(buffer, text[i]) + } + } + return buffer +} + +func appendServiceCheck(buffer []byte, serviceCheck *ServiceCheck, globalTags []string, originDetection bool) []byte { + buffer = append(buffer, "_sc|"...) + buffer = append(buffer, serviceCheck.Name...) + buffer = append(buffer, '|') + buffer = strconv.AppendInt(buffer, int64(serviceCheck.Status), 10) + + if !serviceCheck.Timestamp.IsZero() { + buffer = append(buffer, "|d:"...) + buffer = strconv.AppendInt(buffer, int64(serviceCheck.Timestamp.Unix()), 10) + } + + if len(serviceCheck.Hostname) != 0 { + buffer = append(buffer, "|h:"...) + buffer = append(buffer, serviceCheck.Hostname...) + } + + buffer = appendTags(buffer, globalTags, serviceCheck.Tags) + + if len(serviceCheck.Message) != 0 { + buffer = append(buffer, "|m:"...) + buffer = appendEscapedServiceCheckText(buffer, serviceCheck.Message) + } + + buffer = appendContainerID(buffer) + buffer = appendExternalEnv(buffer, originDetection) + return buffer +} + +func appendSeparator(buffer []byte) []byte { + return append(buffer, '\n') +} + +func appendContainerID(buffer []byte) []byte { + if containerID := getContainerID(); len(containerID) > 0 { + buffer = append(buffer, "|c:"...) + buffer = append(buffer, containerID...) + } + return buffer +} + +func appendTimestamp(buffer []byte, timestamp int64) []byte { + if timestamp > noTimestamp { + buffer = append(buffer, "|T"...) + buffer = strconv.AppendInt(buffer, timestamp, 10) + } + return buffer +} + +func appendExternalEnv(buffer []byte, originDetection bool) []byte { + if externalEnv := getExternalEnv(); externalEnv != "" && originDetection { + buffer = append(buffer, "|e:"...) + buffer = append(buffer, externalEnv...) + } + return buffer +} + +func appendTagCardinality(buffer []byte, cardinality Cardinality) []byte { + cardString := cardinality.String() + if cardString != "" { + buffer = append(buffer, "|card:"...) + buffer = append(buffer, cardString...) + } + return buffer +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/metrics.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/metrics.go new file mode 100644 index 000000000..ea78730ea --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/metrics.go @@ -0,0 +1,283 @@ +package statsd + +import ( + "math" + "math/rand" + "sync" + "sync/atomic" +) + +/* +Those are metrics type that can be aggregated on the client side: + - Gauge + - Count + - Set +*/ + +type countMetric struct { + value int64 + name string + tags []string + cardinality Cardinality +} + +func newCountMetric(name string, value int64, tags []string, cardinality Cardinality) *countMetric { + return &countMetric{ + value: value, + name: name, + tags: copySlice(tags), + cardinality: cardinality, + } +} + +func (c *countMetric) sample(v int64) { + atomic.AddInt64(&c.value, v) +} + +func (c *countMetric) flushUnsafe() metric { + return metric{ + metricType: count, + name: c.name, + tags: c.tags, + rate: 1, + ivalue: c.value, + cardinality: c.cardinality, + } +} + +// Gauge + +type gaugeMetric struct { + value uint64 + name string + tags []string + cardinality Cardinality +} + +func newGaugeMetric(name string, value float64, tags []string, cardinality Cardinality) *gaugeMetric { + return &gaugeMetric{ + value: math.Float64bits(value), + name: name, + tags: copySlice(tags), + cardinality: cardinality, + } +} + +func (g *gaugeMetric) sample(v float64) { + atomic.StoreUint64(&g.value, math.Float64bits(v)) +} + +func (g *gaugeMetric) flushUnsafe() metric { + return metric{ + metricType: gauge, + name: g.name, + tags: g.tags, + rate: 1, + fvalue: math.Float64frombits(g.value), + cardinality: g.cardinality, + } +} + +// Set + +type setMetric struct { + data map[string]struct{} + name string + tags []string + cardinality Cardinality + sync.Mutex +} + +func newSetMetric(name string, value string, tags []string, cardinality Cardinality) *setMetric { + set := &setMetric{ + data: map[string]struct{}{}, + name: name, + tags: copySlice(tags), + cardinality: cardinality, + } + set.data[value] = struct{}{} + return set +} + +func (s *setMetric) sample(v string) { + s.Lock() + defer s.Unlock() + s.data[v] = struct{}{} +} + +// Sets are aggregated on the agent side too. We flush the keys so a set from +// multiple application can be correctly aggregated on the agent side. +func (s *setMetric) flushUnsafe() []metric { + if len(s.data) == 0 { + return nil + } + + metrics := make([]metric, len(s.data)) + i := 0 + for value := range s.data { + metrics[i] = metric{ + metricType: set, + name: s.name, + tags: s.tags, + rate: 1, + svalue: value, + cardinality: s.cardinality, + } + i++ + } + return metrics +} + +// Histograms, Distributions and Timings + +type bufferedMetric struct { + sync.Mutex + + // Kept samples (after sampling) + data []float64 + // Total stored samples (after sampling) + storedSamples int64 + // Total number of observed samples (before sampling). This is used to keep + // the sampling rate correct. + totalSamples int64 + + name string + // Histograms and Distributions store tags as one string since we need + // to compute its size multiple time when serializing. + tags string + mtype metricType + + // maxSamples is the maximum number of samples we keep in memory + maxSamples int64 + + // The first observed user-specified sample rate. When specified + // it is used because we don't know better. + specifiedRate float64 + + cardinality Cardinality +} + +func (s *bufferedMetric) sample(v float64) { + s.Lock() + defer s.Unlock() + s.sampleUnsafe(v) +} + +func (s *bufferedMetric) sampleUnsafe(v float64) { + s.data = append(s.data, v) + s.storedSamples++ + // Total samples needs to be incremented though an atomic because it can be accessed without the lock. + atomic.AddInt64(&s.totalSamples, 1) +} + +func (s *bufferedMetric) maybeKeepSample(v float64, rand *rand.Rand, randLock *sync.Mutex) { + s.Lock() + defer s.Unlock() + if s.maxSamples > 0 { + if s.storedSamples >= s.maxSamples { + // We reached the maximum number of samples we can keep in memory, so we randomly + // replace a sample. + randLock.Lock() + i := rand.Int63n(atomic.LoadInt64(&s.totalSamples)) + randLock.Unlock() + if i < s.maxSamples { + s.data[i] = v + } + } else { + s.data[s.storedSamples] = v + s.storedSamples++ + } + s.totalSamples++ + } else { + // This code path appends to the slice since we did not pre-allocate memory in this case. + s.sampleUnsafe(v) + } +} + +func (s *bufferedMetric) skipSample() { + atomic.AddInt64(&s.totalSamples, 1) +} + +func (s *bufferedMetric) flushUnsafe() metric { + totalSamples := atomic.LoadInt64(&s.totalSamples) + var rate float64 + + // If the user had a specified rate send it because we don't know better. + // This code should be removed once we can also remove the early return at the top of + // `bufferedMetricContexts.sample` + if s.specifiedRate != 1.0 { + rate = s.specifiedRate + } else { + rate = float64(s.storedSamples) / float64(totalSamples) + } + + return metric{ + metricType: s.mtype, + name: s.name, + stags: s.tags, + rate: rate, + fvalues: s.data[:s.storedSamples], + cardinality: s.cardinality, + } +} + +type histogramMetric = bufferedMetric + +func newHistogramMetric(name string, value float64, stringTags string, maxSamples int64, rate float64, cardinality Cardinality) *histogramMetric { + return &histogramMetric{ + data: newData(value, maxSamples), + totalSamples: 1, + storedSamples: 1, + name: name, + tags: stringTags, + mtype: histogramAggregated, + maxSamples: maxSamples, + specifiedRate: rate, + cardinality: cardinality, + } +} + +type distributionMetric = bufferedMetric + +func newDistributionMetric(name string, value float64, stringTags string, maxSamples int64, rate float64, cardinality Cardinality) *distributionMetric { + return &distributionMetric{ + data: newData(value, maxSamples), + totalSamples: 1, + storedSamples: 1, + name: name, + tags: stringTags, + mtype: distributionAggregated, + maxSamples: maxSamples, + specifiedRate: rate, + cardinality: cardinality, + } +} + +type timingMetric = bufferedMetric + +func newTimingMetric(name string, value float64, stringTags string, maxSamples int64, rate float64, cardinality Cardinality) *timingMetric { + return &timingMetric{ + data: newData(value, maxSamples), + totalSamples: 1, + storedSamples: 1, + name: name, + tags: stringTags, + mtype: timingAggregated, + maxSamples: maxSamples, + specifiedRate: rate, + cardinality: cardinality, + } +} + +// newData creates a new slice of float64 with the given capacity. If maxSample +// is less than or equal to 0, it returns a slice with the given value as the +// only element. +func newData(value float64, maxSample int64) []float64 { + if maxSample <= 0 { + return []float64{value} + } else { + data := make([]float64, maxSample) + data[0] = value + return data + } +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/noop.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/noop.go new file mode 100644 index 000000000..6500cde9a --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/noop.go @@ -0,0 +1,118 @@ +package statsd + +import "time" + +// NoOpClient is a statsd client that does nothing. Can be useful in testing +// situations for library users. +type NoOpClient struct{} + +// Gauge does nothing and returns nil +func (n *NoOpClient) Gauge(name string, value float64, tags []string, rate float64) error { + return nil +} + +// GaugeWithTimestamp does nothing and returns nil +func (n *NoOpClient) GaugeWithTimestamp(name string, value float64, tags []string, rate float64, timestamp time.Time) error { + return nil +} + +// Count does nothing and returns nil +func (n *NoOpClient) Count(name string, value int64, tags []string, rate float64) error { + return nil +} + +// CountWithTimestamp does nothing and returns nil +func (n *NoOpClient) CountWithTimestamp(name string, value int64, tags []string, rate float64, timestamp time.Time) error { + return nil +} + +// Histogram does nothing and returns nil +func (n *NoOpClient) Histogram(name string, value float64, tags []string, rate float64) error { + return nil +} + +// Distribution does nothing and returns nil +func (n *NoOpClient) Distribution(name string, value float64, tags []string, rate float64) error { + return nil +} + +// Decr does nothing and returns nil +func (n *NoOpClient) Decr(name string, tags []string, rate float64) error { + return nil +} + +// Incr does nothing and returns nil +func (n *NoOpClient) Incr(name string, tags []string, rate float64) error { + return nil +} + +// Set does nothing and returns nil +func (n *NoOpClient) Set(name string, value string, tags []string, rate float64) error { + return nil +} + +// Timing does nothing and returns nil +func (n *NoOpClient) Timing(name string, value time.Duration, tags []string, rate float64) error { + return nil +} + +// TimeInMilliseconds does nothing and returns nil +func (n *NoOpClient) TimeInMilliseconds(name string, value float64, tags []string, rate float64) error { + return nil +} + +// Event does nothing and returns nil +func (n *NoOpClient) Event(e *Event) error { + return nil +} + +// SimpleEvent does nothing and returns nil +func (n *NoOpClient) SimpleEvent(title, text string) error { + return nil +} + +// ServiceCheck does nothing and returns nil +func (n *NoOpClient) ServiceCheck(sc *ServiceCheck) error { + return nil +} + +// SimpleServiceCheck does nothing and returns nil +func (n *NoOpClient) SimpleServiceCheck(name string, status ServiceCheckStatus) error { + return nil +} + +// Close does nothing and returns nil +func (n *NoOpClient) Close() error { + return nil +} + +// Flush does nothing and returns nil +func (n *NoOpClient) Flush() error { + return nil +} + +// IsClosed does nothing and return false +func (n *NoOpClient) IsClosed() bool { + return false +} + +// GetTelemetry does nothing and returns an empty Telemetry +func (n *NoOpClient) GetTelemetry() Telemetry { + return Telemetry{} +} + +// Verify that NoOpClient implements the ClientInterface. +// https://golang.org/doc/faq#guarantee_satisfies_interface +var _ ClientInterface = &NoOpClient{} + +// NoOpClientDirect implements ClientDirectInterface and does nothing. +type NoOpClientDirect struct { + NoOpClient +} + +// DistributionSamples does nothing and returns nil +func (n *NoOpClientDirect) DistributionSamples(name string, values []float64, tags []string, rate float64) error { + return nil +} + +var _ ClientDirectInterface = &NoOpClientDirect{} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/options.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/options.go new file mode 100644 index 000000000..225a5aea5 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/options.go @@ -0,0 +1,443 @@ +package statsd + +import ( + "fmt" + "math" + "strings" + "time" +) + +var ( + defaultNamespace = "" + defaultTags = []string{} + defaultMaxBytesPerPayload = 0 + defaultMaxMessagesPerPayload = math.MaxInt32 + defaultBufferPoolSize = 0 + defaultBufferFlushInterval = 100 * time.Millisecond + defaultWorkerCount = 32 + defaultSenderQueueSize = 0 + defaultWriteTimeout = 100 * time.Millisecond + defaultConnectTimeout = 1000 * time.Millisecond + defaultTelemetry = true + defaultReceivingMode = mutexMode + defaultChannelModeBufferSize = 4096 + defaultAggregationFlushInterval = 2 * time.Second + defaultAggregation = true + defaultExtendedAggregation = false + defaultMaxBufferedSamplesPerContext = -1 + defaultOriginDetection = true + defaultChannelModeErrorsWhenFull = false + defaultErrorHandler = func(error) {} + defaultAggregatorShardCount = 1 +) + +// Options contains the configuration options for a client. +type Options struct { + namespace string + tags []string + maxBytesPerPayload int + maxMessagesPerPayload int + bufferPoolSize int + bufferFlushInterval time.Duration + workersCount int + senderQueueSize int + writeTimeout time.Duration + connectTimeout time.Duration + telemetry bool + receiveMode receivingMode + channelModeBufferSize int + aggregationFlushInterval time.Duration + aggregation bool + extendedAggregation bool + maxBufferedSamplesPerContext int + aggregatorShardCount int + telemetryAddr string + originDetection bool + containerID string + channelModeErrorsWhenFull bool + errorHandler ErrorHandler + tagCardinality *Cardinality +} + +func resolveOptions(options []Option) (*Options, error) { + o := &Options{ + namespace: defaultNamespace, + tags: defaultTags, + maxBytesPerPayload: defaultMaxBytesPerPayload, + maxMessagesPerPayload: defaultMaxMessagesPerPayload, + bufferPoolSize: defaultBufferPoolSize, + bufferFlushInterval: defaultBufferFlushInterval, + workersCount: defaultWorkerCount, + senderQueueSize: defaultSenderQueueSize, + writeTimeout: defaultWriteTimeout, + connectTimeout: defaultConnectTimeout, + telemetry: defaultTelemetry, + receiveMode: defaultReceivingMode, + channelModeBufferSize: defaultChannelModeBufferSize, + aggregationFlushInterval: defaultAggregationFlushInterval, + aggregation: defaultAggregation, + extendedAggregation: defaultExtendedAggregation, + maxBufferedSamplesPerContext: defaultMaxBufferedSamplesPerContext, + originDetection: defaultOriginDetection, + channelModeErrorsWhenFull: defaultChannelModeErrorsWhenFull, + errorHandler: defaultErrorHandler, + aggregatorShardCount: defaultAggregatorShardCount, + } + + for _, option := range options { + err := option(o) + if err != nil { + return nil, err + } + } + + return o, nil +} + +// Option is a client option. Can return an error if validation fails. +type Option func(*Options) error + +// WithNamespace sets a string to be prepend to all metrics, events and service checks name. +// +// A '.' will automatically be added after the namespace if needed. For example a metrics 'test' with a namespace 'prod' +// will produce a final metric named 'prod.test'. +func WithNamespace(namespace string) Option { + return func(o *Options) error { + if strings.HasSuffix(namespace, ".") { + o.namespace = namespace + } else { + o.namespace = namespace + "." + } + return nil + } +} + +// WithTags sets global tags to be applied to every metrics, events and service checks. +func WithTags(tags []string) Option { + return func(o *Options) error { + o.tags = tags + return nil + } +} + +// WithMaxMessagesPerPayload sets the maximum number of metrics, events and/or service checks that a single payload can +// contain. +// +// The default is 'math.MaxInt32' which will most likely let the WithMaxBytesPerPayload option take precedence. This +// option can be set to `1` to create an unbuffered client (each metrics/event/service check will be send in its own +// payload to the agent). +func WithMaxMessagesPerPayload(maxMessagesPerPayload int) Option { + return func(o *Options) error { + o.maxMessagesPerPayload = maxMessagesPerPayload + return nil + } +} + +// WithMaxBytesPerPayload sets the maximum number of bytes a single payload can contain. Each sample, even and service +// check must be lower than this value once serialized or an `MessageTooLongError` is returned. +// +// The default value 0 which will set the option to the optimal size for the transport protocol used: 1432 for UDP and +// named pipe and 8192 for UDS. Those values offer the best performances. +// Be careful when changing this option, see +// https://docs.datadoghq.com/developers/dogstatsd/high_throughput/#ensure-proper-packet-sizes. +func WithMaxBytesPerPayload(MaxBytesPerPayload int) Option { + return func(o *Options) error { + o.maxBytesPerPayload = MaxBytesPerPayload + return nil + } +} + +// WithBufferPoolSize sets the size of the pool of buffers used to serialized metrics, events and service_checks. +// +// The default, 0, will set the option to the optimal size for the transport protocol used: 2048 for UDP and named pipe +// and 512 for UDS. +func WithBufferPoolSize(bufferPoolSize int) Option { + return func(o *Options) error { + o.bufferPoolSize = bufferPoolSize + return nil + } +} + +// WithBufferFlushInterval sets the interval after which the current buffer is flushed. +// +// A buffers are used to serialized data, they're flushed either when full (see WithMaxBytesPerPayload) or when it's +// been open for longer than this interval. +// +// With apps sending a high number of metrics/events/service_checks the interval rarely timeout. But with slow sending +// apps increasing this value will reduce the number of payload sent on the wire as more data is serialized in the same +// payload. +// +// Default is 100ms +func WithBufferFlushInterval(bufferFlushInterval time.Duration) Option { + return func(o *Options) error { + o.bufferFlushInterval = bufferFlushInterval + return nil + } +} + +// WithWorkersCount sets the number of workers that will be used to serialized data. +// +// Those workers allow the use of multiple buffers at the same time (see WithBufferPoolSize) to reduce lock contention. +// +// Default is 32. +func WithWorkersCount(workersCount int) Option { + return func(o *Options) error { + if workersCount < 1 { + return fmt.Errorf("workersCount must be a positive integer") + } + o.workersCount = workersCount + return nil + } +} + +// WithSenderQueueSize sets the size of the sender queue in number of buffers. +// +// After data has been serialized in a buffer they're pushed to a queue that the sender will consume and then each one +// ot the agent. +// +// The default value 0 will set the option to the optimal size for the transport protocol used: 2048 for UDP and named +// pipe and 512 for UDS. +func WithSenderQueueSize(senderQueueSize int) Option { + return func(o *Options) error { + o.senderQueueSize = senderQueueSize + return nil + } +} + +// WithWriteTimeout sets the timeout for network communication with the Agent, after this interval a payload is +// dropped. This is only used for UDS and named pipes connection. +func WithWriteTimeout(writeTimeout time.Duration) Option { + return func(o *Options) error { + o.writeTimeout = writeTimeout + return nil + } +} + +// WithConnectTimeout sets the timeout for network connection with the Agent, after this interval the connection +// attempt is aborted. This is only used for UDS connection. This will also reset the connection if nothing can be +// written to it for this duration. +func WithConnectTimeout(connectTimeout time.Duration) Option { + return func(o *Options) error { + o.connectTimeout = connectTimeout + return nil + } +} + +// WithChannelMode make the client use channels to receive metrics +// +// This determines how the client receive metrics from the app (for example when calling the `Gauge()` method). +// The client will either drop the metrics if its buffers are full (WithChannelMode option) or block the caller until the +// metric can be handled (WithMutexMode option). By default, the client use mutexes. +// +// WithChannelMode uses a channel (see WithChannelModeBufferSize to configure its size) to receive metrics and drops metrics if +// the channel is full. Sending metrics in this mode is much slower that WithMutexMode (because of the channel), but will not +// block the application. This mode is made for application using statsd directly into the application code instead of +// a separated periodic reporter. The goal is to not slow down the application at the cost of dropping metrics and having a lower max +// throughput. +func WithChannelMode() Option { + return func(o *Options) error { + o.receiveMode = channelMode + return nil + } +} + +// WithMutexMode will use mutex to receive metrics from the app through the API. +// +// This determines how the client receive metrics from the app (for example when calling the `Gauge()` method). +// The client will either drop the metrics if its buffers are full (WithChannelMode option) or block the caller until the +// metric can be handled (WithMutexMode option). By default the client use mutexes. +// +// WithMutexMode uses mutexes to receive metrics which is much faster than channels but can cause some lock contention +// when used with a high number of goroutines sending the same metrics. Mutexes are sharded based on the metrics name +// which limit mutex contention when multiple goroutines send different metrics (see WithWorkersCount). This is the +// default behavior which will produce the best throughput. +func WithMutexMode() Option { + return func(o *Options) error { + o.receiveMode = mutexMode + return nil + } +} + +// WithChannelModeBufferSize sets the size of the channel holding incoming metrics when WithChannelMode is used. +func WithChannelModeBufferSize(bufferSize int) Option { + return func(o *Options) error { + o.channelModeBufferSize = bufferSize + return nil + } +} + +// WithChannelModeErrorsWhenFull makes the client return an error when the channel is full. +// This should be enabled if you want to be notified when the client is dropping metrics. You +// will also need to set `WithErrorHandler` to be notified of sender error. This might have +// a small performance impact. +func WithChannelModeErrorsWhenFull() Option { + return func(o *Options) error { + o.channelModeErrorsWhenFull = true + return nil + } +} + +// WithoutChannelModeErrorsWhenFull makes the client not return an error when the channel is full. +func WithoutChannelModeErrorsWhenFull() Option { + return func(o *Options) error { + o.channelModeErrorsWhenFull = false + return nil + } +} + +// WithErrorHandler sets a function that will be called when an error occurs. +func WithErrorHandler(errorHandler ErrorHandler) Option { + return func(o *Options) error { + o.errorHandler = errorHandler + return nil + } +} + +// WithAggregationInterval sets the interval at which aggregated metrics are flushed. See WithClientSideAggregation and +// WithExtendedClientSideAggregation for more. +// +// The default interval is 2s. The interval must divide the Agent reporting period (default=10s) evenly to reduce "aliasing" +// that can cause values to appear irregular/spiky. +// +// For example a 3s aggregation interval will create spikes in the final graph: a application sending a count metric +// that increments at a constant 1000 time per second will appear noisy with an interval of 3s. This is because +// client-side aggregation would report every 3 seconds, while the agent is reporting every 10 seconds. This means in +// each agent bucket, the values are: 9000, 9000, 12000. +func WithAggregationInterval(interval time.Duration) Option { + return func(o *Options) error { + o.aggregationFlushInterval = interval + return nil + } +} + +// WithClientSideAggregation enables client side aggregation for Gauges, Counts and Sets. +func WithClientSideAggregation() Option { + return func(o *Options) error { + o.aggregation = true + return nil + } +} + +// WithoutClientSideAggregation disables client side aggregation. +func WithoutClientSideAggregation() Option { + return func(o *Options) error { + o.aggregation = false + o.extendedAggregation = false + return nil + } +} + +// WithExtendedClientSideAggregation enables client side aggregation for all types. This feature is only compatible with +// Agent's version >=6.25.0 && <7.0.0 or Agent's versions >=7.25.0. +// When enabled, the use of `rate` with distribution is discouraged and `WithMaxSamplesPerContext()` should be used. +// If `rate` is used with different values of `rate` the resulting rate is not guaranteed to be correct. +func WithExtendedClientSideAggregation() Option { + return func(o *Options) error { + o.aggregation = true + o.extendedAggregation = true + return nil + } +} + +// WithMaxSamplesPerContext limits the number of sample for metric types that require multiple samples to be send +// over statsd to the agent, such as distributions or timings. This limits the number of sample per +// context for a distribution to a given number. Gauges and counts will not be affected as a single sample per context +// is sent with client side aggregation. +// - This will enable client side aggregation for all metrics. +// - This feature should be used with `WithExtendedClientSideAggregation` for optimal results. +func WithMaxSamplesPerContext(maxSamplesPerDistribution int) Option { + return func(o *Options) error { + o.aggregation = true + o.maxBufferedSamplesPerContext = maxSamplesPerDistribution + return nil + } +} + +// WithoutTelemetry disables the client telemetry. +// +// More on this here: https://docs.datadoghq.com/developers/dogstatsd/high_throughput/#client-side-telemetry +func WithoutTelemetry() Option { + return func(o *Options) error { + o.telemetry = false + return nil + } +} + +// WithTelemetryAddr sets a different address for telemetry metrics. By default the same address as the client is used +// for telemetry. +// +// More on this here: https://docs.datadoghq.com/developers/dogstatsd/high_throughput/#client-side-telemetry +func WithTelemetryAddr(addr string) Option { + return func(o *Options) error { + o.telemetryAddr = addr + return nil + } +} + +// WithoutOriginDetection disables the client origin detection. +// When enabled, the client tries to discover its container ID and sends it to the Agent +// to enrich the metrics with container tags. +// If the container id is not found and the client is running in a private cgroup namespace, the client +// sends the base cgroup controller inode. +// Origin detection can also be disabled by configuring the environment variabe DD_ORIGIN_DETECTION_ENABLED=false +// The client tries to read the container ID by parsing the file /proc/self/cgroup, this is not supported on Windows. +// +// More on this here: https://docs.datadoghq.com/developers/dogstatsd/?tab=kubernetes#origin-detection-over-udp +func WithoutOriginDetection() Option { + return func(o *Options) error { + o.originDetection = false + return nil + } +} + +// WithOriginDetection enables the client origin detection. +// This feature requires Datadog Agent version >=6.35.0 && <7.0.0 or Agent versions >=7.35.0. +// When enabled, the client tries to discover its container ID and sends it to the Agent +// to enrich the metrics with container tags. +// If the container id is not found and the client is running in a private cgroup namespace, the client +// sends the base cgroup controller inode. +// Origin detection can be disabled by configuring the environment variable DD_ORIGIN_DETECTION_ENABLED=false +// +// More on this here: https://docs.datadoghq.com/developers/dogstatsd/?tab=kubernetes#origin-detection-over-udp +func WithOriginDetection() Option { + return func(o *Options) error { + o.originDetection = true + return nil + } +} + +// WithContainerID allows passing the container ID, this will be used by the Agent to enrich metrics with container tags. +// This feature requires Datadog Agent version >=6.35.0 && <7.0.0 or Agent versions >=7.35.0. +// When configured, the provided container ID is prioritized over the container ID discovered via Origin Detection. +// The client prioritizes the value passed via DD_ENTITY_ID (if set) over the container ID. +func WithContainerID(id string) Option { + return func(o *Options) error { + o.containerID = id + return nil + } +} + +// WithCardinality sets the tag cardinality of the metric. +func WithCardinality(card Cardinality) Option { + return func(o *Options) error { + if !card.isValid() { + return fmt.Errorf("invalid cardinality %d", card) + } + o.tagCardinality = &card + return nil + } +} + +// WithAggregatorShardCount sets the number of shards used for the aggregator. +// Higher values reduce lock contention but increase memory usage. +// +// The default is 1 as to mimic current behavior. +func WithAggregatorShardCount(shardCount int) Option { + return func(o *Options) error { + if shardCount < 1 { + return fmt.Errorf("shardCount must be a positive integer") + } + o.aggregatorShardCount = shardCount + return nil + } +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/pipe.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/pipe.go new file mode 100644 index 000000000..1188b00f3 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/pipe.go @@ -0,0 +1,13 @@ +//go:build !windows +// +build !windows + +package statsd + +import ( + "errors" + "time" +) + +func newWindowsPipeWriter(pipepath string, writeTimeout time.Duration) (Transport, error) { + return nil, errors.New("Windows Named Pipes are only supported on Windows") +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/pipe_windows.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/pipe_windows.go new file mode 100644 index 000000000..c27434ccf --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/pipe_windows.go @@ -0,0 +1,81 @@ +//go:build windows +// +build windows + +package statsd + +import ( + "net" + "sync" + "time" + + "github.com/Microsoft/go-winio" +) + +type pipeWriter struct { + mu sync.RWMutex + conn net.Conn + timeout time.Duration + pipepath string +} + +func (p *pipeWriter) Write(data []byte) (n int, err error) { + conn, err := p.ensureConnection() + if err != nil { + return 0, err + } + + p.mu.RLock() + conn.SetWriteDeadline(time.Now().Add(p.timeout)) + p.mu.RUnlock() + + n, err = conn.Write(data) + if err != nil { + if e, ok := err.(net.Error); !ok || !e.Temporary() { + // disconnected; retry again on next attempt + p.mu.Lock() + p.conn = nil + p.mu.Unlock() + } + } + return n, err +} + +func (p *pipeWriter) ensureConnection() (net.Conn, error) { + p.mu.RLock() + conn := p.conn + p.mu.RUnlock() + if conn != nil { + return conn, nil + } + + // looks like we might need to connect - try again with write locking. + p.mu.Lock() + defer p.mu.Unlock() + if p.conn != nil { + return p.conn, nil + } + newconn, err := winio.DialPipe(p.pipepath, nil) + if err != nil { + return nil, err + } + p.conn = newconn + return newconn, nil +} + +func (p *pipeWriter) Close() error { + return p.conn.Close() +} + +// GetTransportName returns the name of the transport +func (p *pipeWriter) GetTransportName() string { + return writerWindowsPipe +} + +func newWindowsPipeWriter(pipepath string, writeTimeout time.Duration) (*pipeWriter, error) { + // Defer connection establishment to first write + return &pipeWriter{ + conn: nil, + timeout: writeTimeout, + pipepath: pipepath, + }, nil +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/sender.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/sender.go new file mode 100644 index 000000000..fc80395c3 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/sender.go @@ -0,0 +1,145 @@ +package statsd + +import ( + "io" + "sync/atomic" +) + +// senderTelemetry contains telemetry about the health of the sender +type senderTelemetry struct { + totalPayloadsSent uint64 + totalPayloadsDroppedQueueFull uint64 + totalPayloadsDroppedWriter uint64 + totalBytesSent uint64 + totalBytesDroppedQueueFull uint64 + totalBytesDroppedWriter uint64 +} + +type Transport interface { + io.WriteCloser + + // GetTransportName returns the name of the transport + GetTransportName() string +} + +type sender struct { + transport Transport + pool *bufferPool + queue chan *statsdBuffer + telemetry *senderTelemetry + stop chan struct{} + flushSignal chan struct{} + errorHandler ErrorHandler +} + +type ErrorSenderChannelFull struct { + LostElements int + ChannelSize int + Msg string +} + +func (e *ErrorSenderChannelFull) Error() string { + return e.Msg +} + +func newSender(transport Transport, queueSize int, pool *bufferPool, errorHandler ErrorHandler) *sender { + sender := &sender{ + transport: transport, + pool: pool, + queue: make(chan *statsdBuffer, queueSize), + telemetry: &senderTelemetry{}, + stop: make(chan struct{}), + flushSignal: make(chan struct{}), + errorHandler: errorHandler, + } + + go sender.sendLoop() + return sender +} + +func (s *sender) send(buffer *statsdBuffer) { + select { + case s.queue <- buffer: + default: + if s.errorHandler != nil { + err := &ErrorSenderChannelFull{ + LostElements: buffer.elementCount, + ChannelSize: len(s.queue), + Msg: "Sender queue is full", + } + s.errorHandler(err) + } + atomic.AddUint64(&s.telemetry.totalPayloadsDroppedQueueFull, 1) + atomic.AddUint64(&s.telemetry.totalBytesDroppedQueueFull, uint64(len(buffer.bytes()))) + s.pool.returnBuffer(buffer) + } +} + +func (s *sender) write(buffer *statsdBuffer) { + _, err := s.transport.Write(buffer.bytes()) + if err != nil { + atomic.AddUint64(&s.telemetry.totalPayloadsDroppedWriter, 1) + atomic.AddUint64(&s.telemetry.totalBytesDroppedWriter, uint64(len(buffer.bytes()))) + if s.errorHandler != nil { + s.errorHandler(err) + } + } else { + atomic.AddUint64(&s.telemetry.totalPayloadsSent, 1) + atomic.AddUint64(&s.telemetry.totalBytesSent, uint64(len(buffer.bytes()))) + } + s.pool.returnBuffer(buffer) +} + +func (s *sender) flushTelemetryMetrics(t *Telemetry) { + t.TotalPayloadsSent = atomic.LoadUint64(&s.telemetry.totalPayloadsSent) + t.TotalPayloadsDroppedQueueFull = atomic.LoadUint64(&s.telemetry.totalPayloadsDroppedQueueFull) + t.TotalPayloadsDroppedWriter = atomic.LoadUint64(&s.telemetry.totalPayloadsDroppedWriter) + + t.TotalBytesSent = atomic.LoadUint64(&s.telemetry.totalBytesSent) + t.TotalBytesDroppedQueueFull = atomic.LoadUint64(&s.telemetry.totalBytesDroppedQueueFull) + t.TotalBytesDroppedWriter = atomic.LoadUint64(&s.telemetry.totalBytesDroppedWriter) +} + +func (s *sender) sendLoop() { + defer close(s.stop) + for { + select { + case buffer := <-s.queue: + s.write(buffer) + case <-s.stop: + return + case <-s.flushSignal: + // At that point we know that the workers are paused (the statsd client + // will pause them before calling sender.flush()). + // So we can fully flush the input queue + s.flushInputQueue() + s.flushSignal <- struct{}{} + } + } +} + +func (s *sender) flushInputQueue() { + for { + select { + case buffer := <-s.queue: + s.write(buffer) + default: + return + } + } +} +func (s *sender) flush() { + s.flushSignal <- struct{}{} + <-s.flushSignal +} + +func (s *sender) close() error { + s.stop <- struct{}{} + <-s.stop + s.flushInputQueue() + return s.transport.Close() +} + +func (s *sender) getTransportName() string { + return s.transport.GetTransportName() +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/service_check.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/service_check.go new file mode 100644 index 000000000..e2850465c --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/service_check.go @@ -0,0 +1,57 @@ +package statsd + +import ( + "fmt" + "time" +) + +// ServiceCheckStatus support +type ServiceCheckStatus byte + +const ( + // Ok is the "ok" ServiceCheck status + Ok ServiceCheckStatus = 0 + // Warn is the "warning" ServiceCheck status + Warn ServiceCheckStatus = 1 + // Critical is the "critical" ServiceCheck status + Critical ServiceCheckStatus = 2 + // Unknown is the "unknown" ServiceCheck status + Unknown ServiceCheckStatus = 3 +) + +// A ServiceCheck is an object that contains status of DataDog service check. +type ServiceCheck struct { + // Name of the service check. Required. + Name string + // Status of service check. Required. + Status ServiceCheckStatus + // Timestamp is a timestamp for the serviceCheck. If not provided, the dogstatsd + // server will set this to the current time. + Timestamp time.Time + // Hostname for the serviceCheck. + Hostname string + // A message describing the current state of the serviceCheck. + Message string + // Tags for the serviceCheck. + Tags []string +} + +// NewServiceCheck creates a new serviceCheck with the given name and status. Error checking +// against these values is done at send-time, or upon running sc.Check. +func NewServiceCheck(name string, status ServiceCheckStatus) *ServiceCheck { + return &ServiceCheck{ + Name: name, + Status: status, + } +} + +// Check verifies that a service check is valid. +func (sc *ServiceCheck) Check() error { + if len(sc.Name) == 0 { + return fmt.Errorf("statsd.ServiceCheck name is required") + } + if byte(sc.Status) < 0 || byte(sc.Status) > 3 { + return fmt.Errorf("statsd.ServiceCheck status has invalid value") + } + return nil +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/statsd.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/statsd.go new file mode 100644 index 000000000..1f09ec79a --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/statsd.go @@ -0,0 +1,318 @@ +// Copyright 2013 Ooyala, Inc. + +/* +Package statsd provides a Go dogstatsd client. Dogstatsd extends the popular statsd, +adding tags and histograms and pushing upstream to Datadog. + +Refer to http://docs.datadoghq.com/guides/dogstatsd/ for information about DogStatsD. + +statsd is based on go-statsd-client. +*/ +package statsd + +//go:generate mockgen -source=statsd.go -destination=mocks/statsd.go + +import ( + "io" + "time" +) + +// ClientInterface is an interface that exposes the common client functions for the +// purpose of being able to provide a no-op client or even mocking. This can aid +// downstream users' with their testing. +type ClientInterface interface { + // Gauge measures the value of a metric at a particular time. + Gauge(name string, value float64, tags []string, rate float64) error + + // GaugeWithTimestamp measures the value of a metric at a given time. + // BETA - Please contact our support team for more information to use this feature: https://www.datadoghq.com/support/ + // The value will bypass any aggregation on the client side and agent side, this is + // useful when sending points in the past. + // + // Minimum Datadog Agent version: 7.40.0 + GaugeWithTimestamp(name string, value float64, tags []string, rate float64, timestamp time.Time) error + + // Count tracks how many times something happened per second. + Count(name string, value int64, tags []string, rate float64) error + + // CountWithTimestamp tracks how many times something happened at the given second. + // BETA - Please contact our support team for more information to use this feature: https://www.datadoghq.com/support/ + // The value will bypass any aggregation on the client side and agent side, this is + // useful when sending points in the past. + // + // Minimum Datadog Agent version: 7.40.0 + CountWithTimestamp(name string, value int64, tags []string, rate float64, timestamp time.Time) error + + // Histogram tracks the statistical distribution of a set of values on each host. + Histogram(name string, value float64, tags []string, rate float64) error + + // Distribution tracks the statistical distribution of a set of values across your infrastructure. + // + // It is recommended to use `WithMaxBufferedMetricsPerContext` to avoid dropping metrics at high throughput, `rate` can + // also be used to limit the load. Both options can *not* be used together. + Distribution(name string, value float64, tags []string, rate float64) error + + // Decr is just Count of -1 + Decr(name string, tags []string, rate float64) error + + // Incr is just Count of 1 + Incr(name string, tags []string, rate float64) error + + // Set counts the number of unique elements in a group. + Set(name string, value string, tags []string, rate float64) error + + // Timing sends timing information, it is an alias for TimeInMilliseconds + Timing(name string, value time.Duration, tags []string, rate float64) error + + // TimeInMilliseconds sends timing information in milliseconds. + // It is flushed by statsd with percentiles, mean and other info (https://github.com/etsy/statsd/blob/master/docs/metric_types.md#timing) + TimeInMilliseconds(name string, value float64, tags []string, rate float64) error + + // Event sends the provided Event. + Event(e *Event) error + + // SimpleEvent sends an event with the provided title and text. + SimpleEvent(title, text string) error + + // ServiceCheck sends the provided ServiceCheck. + ServiceCheck(sc *ServiceCheck) error + + // SimpleServiceCheck sends an serviceCheck with the provided name and status. + SimpleServiceCheck(name string, status ServiceCheckStatus) error + + // Close the client connection. + Close() error + + // Flush forces a flush of all the queued dogstatsd payloads. + Flush() error + + // IsClosed returns if the client has been closed. + IsClosed() bool + + // GetTelemetry return the telemetry metrics for the client since it started. + GetTelemetry() Telemetry +} + +// A Client is a handle for sending messages to dogstatsd. It is safe to +// use one Client from multiple goroutines simultaneously. +type Client struct { + clientEx *ClientEx +} + +// Verify that Client implements the ClientInterface. +// https://golang.org/doc/faq#guarantee_satisfies_interface +var _ ClientInterface = &Client{} + +// New returns a pointer to a new Client given an addr in the format "hostname:port" for UDP, +// "unix:///path/to/socket" for UDS or "\\.\pipe\path\to\pipe" for Windows Named Pipes. +func New(addr string, options ...Option) (*Client, error) { + clientEx, err := NewEx(addr, options...) + if err != nil { + return nil, err + } + + return &Client{ + clientEx: clientEx, + }, nil +} + +// NewWithWriter creates a new Client with given writer. Writer is a +// io.WriteCloser +func NewWithWriter(w io.WriteCloser, options ...Option) (*Client, error) { + clientEx, err := NewWithWriterEx(w, options...) + if err != nil { + return nil, err + } + + return &Client{ + clientEx: clientEx, + }, nil +} + +// CloneWithExtraOptions create a new Client with extra options +func CloneWithExtraOptions(c *Client, options ...Option) (*Client, error) { + if c == nil { + return nil, ErrNoClient + } + + clientEx, err := CloneWithExtraOptionsEx(c.clientEx, options...) + if err != nil { + return nil, err + } + + return &Client{ + clientEx: clientEx, + }, nil +} + +// Flush forces a flush of all the queued dogstatsd payloads This method is +// blocking and will not return until everything is sent through the network. +// In mutexMode, this will also block sampling new data to the client while the +// workers and sender are flushed. +func (c *Client) Flush() error { + if c == nil { + return ErrNoClient + } + return c.clientEx.Flush() +} + +// IsClosed returns if the client has been closed. +func (c *Client) IsClosed() bool { + return c.clientEx.IsClosed() +} + +// GetTelemetry return the telemetry metrics for the client since it started. +func (c *Client) GetTelemetry() Telemetry { + return c.clientEx.GetTelemetry() +} + +// GetTransport return the name of the transport used. +func (c *Client) GetTransport() string { + return c.clientEx.GetTransport() +} + +// Gauge measures the value of a metric at a particular time. +func (c *Client) Gauge(name string, value float64, tags []string, rate float64) error { + if c == nil { + return ErrNoClient + } + return c.clientEx.Gauge(name, value, tags, rate) +} + +// GaugeWithTimestamp measures the value of a metric at a given time. +// BETA - Please contact our support team for more information to use this feature: https://www.datadoghq.com/support/ +// The value will bypass any aggregation on the client side and agent side, this is +// useful when sending points in the past. +// +// Minimum Datadog Agent version: 7.40.0 +func (c *Client) GaugeWithTimestamp(name string, value float64, tags []string, rate float64, timestamp time.Time) error { + if c == nil { + return ErrNoClient + } + return c.clientEx.GaugeWithTimestamp(name, value, tags, rate, timestamp) +} + +// Count tracks how many times something happened per second. +func (c *Client) Count(name string, value int64, tags []string, rate float64) error { + if c == nil { + return ErrNoClient + } + return c.clientEx.Count(name, value, tags, rate) +} + +// CountWithTimestamp tracks how many times something happened at the given second. +// BETA - Please contact our support team for more information to use this feature: https://www.datadoghq.com/support/ +// The value will bypass any aggregation on the client side and agent side, this is +// useful when sending points in the past. +// +// Minimum Datadog Agent version: 7.40.0 +func (c *Client) CountWithTimestamp(name string, value int64, tags []string, rate float64, timestamp time.Time) error { + if c == nil { + return ErrNoClient + } + return c.clientEx.CountWithTimestamp(name, value, tags, rate, timestamp) +} + +// Histogram tracks the statistical distribution of a set of values on each host. +func (c *Client) Histogram(name string, value float64, tags []string, rate float64) error { + if c == nil { + return ErrNoClient + } + return c.clientEx.Histogram(name, value, tags, rate) +} + +// Distribution tracks the statistical distribution of a set of values across your infrastructure. +func (c *Client) Distribution(name string, value float64, tags []string, rate float64) error { + if c == nil { + return ErrNoClient + } + return c.clientEx.Distribution(name, value, tags, rate) +} + +// Decr is just Count of -1 +func (c *Client) Decr(name string, tags []string, rate float64) error { + if c == nil { + return ErrNoClient + } + return c.clientEx.Decr(name, tags, rate) +} + +// Incr is just Count of 1 +func (c *Client) Incr(name string, tags []string, rate float64) error { + if c == nil { + return ErrNoClient + } + return c.clientEx.Incr(name, tags, rate) +} + +// Set counts the number of unique elements in a group. +func (c *Client) Set(name string, value string, tags []string, rate float64) error { + if c == nil { + return ErrNoClient + } + return c.clientEx.Set(name, value, tags, rate) + +} + +// Timing sends timing information, it is an alias for TimeInMilliseconds +func (c *Client) Timing(name string, value time.Duration, tags []string, rate float64) error { + if c == nil { + return ErrNoClient + } + return c.clientEx.Timing(name, value, tags, rate) +} + +// TimeInMilliseconds sends timing information in milliseconds. +// It is flushed by statsd with percentiles, mean and other info (https://github.com/etsy/statsd/blob/master/docs/metric_types.md#timing) +func (c *Client) TimeInMilliseconds(name string, value float64, tags []string, rate float64) error { + if c == nil { + return ErrNoClient + } + return c.clientEx.TimeInMilliseconds(name, value, tags, rate) +} + +// Event sends the provided Event. +func (c *Client) Event(e *Event) error { + if c == nil { + return ErrNoClient + } + return c.clientEx.Event(e) +} + +// SimpleEvent sends an event with the provided title and text. +func (c *Client) SimpleEvent(title, text string) error { + if c == nil { + return ErrNoClient + } + return c.clientEx.SimpleEvent(title, text) +} + +// ServiceCheck sends the provided ServiceCheck. +func (c *Client) ServiceCheck(sc *ServiceCheck) error { + if c == nil { + return ErrNoClient + } + return c.clientEx.ServiceCheck(sc) +} + +// SimpleServiceCheck sends an serviceCheck with the provided name and status. +func (c *Client) SimpleServiceCheck(name string, status ServiceCheckStatus) error { + if c == nil { + return ErrNoClient + } + return c.clientEx.SimpleServiceCheck(name, status) + +} + +// Close the client connection. +func (c *Client) Close() error { + if c == nil { + return ErrNoClient + } + return c.clientEx.Close() +} + +// sendBlocking is used by the aggregator to inject aggregated metrics. +func (c *Client) sendBlocking(m metric) error { + return c.clientEx.sendBlocking(m) +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/statsd_direct.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/statsd_direct.go new file mode 100644 index 000000000..150ee2c81 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/statsd_direct.go @@ -0,0 +1,69 @@ +package statsd + +import ( + "io" + "strings" + "sync/atomic" +) + +type ClientDirectInterface interface { + DistributionSamples(name string, values []float64, tags []string, rate float64) error +} + +// ClientDirect is an *experimental* statsd client that gives direct access to some dogstatsd features. +// +// It is not recommended to use this client in production. This client might allow you to take advantage of +// new features in the agent before they are released, but it might also break your application. +type ClientDirect struct { + *Client +} + +// NewDirect returns a pointer to a new ClientDirect given an addr in the format "hostname:port" for UDP, +// "unix:///path/to/socket" for UDS or "\\.\pipe\path\to\pipe" for Windows Named Pipes. +func NewDirect(addr string, options ...Option) (*ClientDirect, error) { + client, err := New(addr, options...) + if err != nil { + return nil, err + } + return &ClientDirect{ + client, + }, nil +} + +func NewDirectWithWriter(writer io.WriteCloser, options ...Option) (*ClientDirect, error) { + client, err := NewWithWriter(writer, options...) + if err != nil { + return nil, err + } + return &ClientDirect{ + client, + }, nil +} + +// DistributionSamples is similar to Distribution, but it lets the client deals with the sampling. +// +// The provided `rate` is the sampling rate applied by the client and will *not* be used to apply further +// sampling. This is recommended in high performance cases were the overhead of the statsd library might be +// significant and the sampling is already done by the client. +// +// `WithMaxBufferedMetricsPerContext` is ignored when using this method. +func (c *ClientDirect) DistributionSamples(name string, values []float64, tags []string, rate float64) error { + if c == nil { + return ErrNoClient + } + atomic.AddUint64(&c.clientEx.telemetry.totalMetricsDistribution, uint64(len(values))) + return c.clientEx.send(metric{ + metricType: distributionAggregated, + name: name, + fvalues: values, + tags: tags, + stags: strings.Join(tags, tagSeparatorSymbol), + rate: rate, + globalTags: c.clientEx.tags, + namespace: c.clientEx.namespace, + }) +} + +// Validate that ClientDirect implements ClientDirectInterface and ClientInterface. +var _ ClientDirectInterface = (*ClientDirect)(nil) +var _ ClientInterface = (*ClientDirect)(nil) diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/statsdex.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/statsdex.go new file mode 100644 index 000000000..faa3a1947 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/statsdex.go @@ -0,0 +1,953 @@ +// Copyright 2013 Ooyala, Inc. + +/* +Package statsd provides a Go dogstatsd client. Dogstatsd extends the popular statsd, +adding tags and histograms and pushing upstream to Datadog. + +Refer to http://docs.datadoghq.com/guides/dogstatsd/ for information about DogStatsD. + +statsd is based on go-statsd-client. +*/ +package statsd + +//go:generate mockgen -source=statsd.go -destination=mocks/statsd.go + +import ( + "errors" + "fmt" + "io" + "net/url" + "os" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +/* +OptimalUDPPayloadSize defines the optimal payload size for a UDP datagram, 1432 bytes +is optimal for regular networks with an MTU of 1500 so datagrams don't get +fragmented. It's generally recommended not to fragment UDP datagrams as losing +a single fragment will cause the entire datagram to be lost. +*/ +const OptimalUDPPayloadSize = 1432 + +/* +MaxUDPPayloadSize defines the maximum payload size for a UDP datagram. +Its value comes from the calculation: 65535 bytes Max UDP datagram size - +8byte UDP header - 60byte max IP headers +any number greater than that will see frames being cut out. +*/ +const MaxUDPPayloadSize = 65467 + +// DefaultUDPBufferPoolSize is the default size of the buffer pool for UDP clients. +const DefaultUDPBufferPoolSize = 2048 + +// DefaultUDSBufferPoolSize is the default size of the buffer pool for UDS clients. +const DefaultUDSBufferPoolSize = 512 + +/* +DefaultMaxAgentPayloadSize is the default maximum payload size the agent +can receive. This can be adjusted by changing dogstatsd_buffer_size in the +agent configuration file datadog.yaml. This is also used as the optimal payload size +for UDS datagrams. +*/ +const DefaultMaxAgentPayloadSize = 8192 + +/* +UnixAddressPrefix holds the prefix to use to enable Unix Domain Socket +traffic instead of UDP. The type of the socket will be guessed. +*/ +const UnixAddressPrefix = "unix://" + +/* +UnixDatagramAddressPrefix holds the prefix to use to enable Unix Domain Socket +datagram traffic instead of UDP. +*/ +const UnixAddressDatagramPrefix = "unixgram://" + +/* +UnixAddressStreamPrefix holds the prefix to use to enable Unix Domain Socket +stream traffic instead of UDP. +*/ +const UnixAddressStreamPrefix = "unixstream://" + +/* +WindowsPipeAddressPrefix holds the prefix to use to enable Windows Named Pipes +traffic instead of UDP. +*/ +const WindowsPipeAddressPrefix = `\\.\pipe\` + +var ( + AddressPrefixes = []string{UnixAddressPrefix, UnixAddressDatagramPrefix, UnixAddressStreamPrefix, WindowsPipeAddressPrefix} +) + +const ( + agentHostEnvVarName = "DD_AGENT_HOST" + agentPortEnvVarName = "DD_DOGSTATSD_PORT" + agentURLEnvVarName = "DD_DOGSTATSD_URL" + defaultUDPPort = "8125" +) + +const ( + // ddEntityID specifies client-side user-specified entity ID injection. + // This env var can be set to the Pod UID on Kubernetes via the downward API. + // Docs: https://docs.datadoghq.com/developers/dogstatsd/?tab=kubernetes#origin-detection-over-udp + ddEntityID = "DD_ENTITY_ID" + + // ddEntityIDTag specifies the tag name for the client-side entity ID injection + // The Agent expects this tag to contain a non-prefixed Kubernetes Pod UID. + ddEntityIDTag = "dd.internal.entity_id" + + // originDetectionEnabled specifies the env var to enable/disable sending the container ID field. + originDetectionEnabled = "DD_ORIGIN_DETECTION_ENABLED" +) + +/* +ddEnvTagsMapping is a mapping of each "DD_" prefixed environment variable +to a specific tag name. We use a slice to keep the order and simplify tests. +*/ +var ddEnvTagsMapping = []struct{ envName, tagName string }{ + {ddEntityID, ddEntityIDTag}, // Client-side entity ID injection for container tagging. + {"DD_ENV", "env"}, // The name of the env in which the service runs. + {"DD_SERVICE", "service"}, // The name of the running service. + {"DD_VERSION", "version"}, // The current version of the running service. +} + +type metricType int + +const ( + gauge metricType = iota + count + histogram + histogramAggregated + distribution + distributionAggregated + set + timing + timingAggregated + event + serviceCheck +) + +type receivingMode int + +const ( + mutexMode receivingMode = iota + channelMode +) + +const ( + writerNameUDP string = "udp" + writerNameUDS string = "uds" + writerNameUDSStream string = "uds-stream" + writerWindowsPipe string = "pipe" + writerNameCustom string = "custom" +) + +// noTimestamp is used as a value for metric without a given timestamp. +const noTimestamp = int64(0) + +type metric struct { + metricType metricType + namespace string + globalTags []string + name string + fvalue float64 + fvalues []float64 + ivalue int64 + svalue string + evalue *Event + scvalue *ServiceCheck + tags []string + stags string + rate float64 + timestamp int64 + originDetection bool + cardinality Cardinality +} + +type noClientErr string + +// ErrNoClient is returned if statsd reporting methods are invoked on +// a nil client. +const ErrNoClient = noClientErr("statsd client is nil") + +func (e noClientErr) Error() string { + return string(e) +} + +type invalidTimestampErr string + +// InvalidTimestamp is returned if a provided timestamp is invalid. +const InvalidTimestamp = invalidTimestampErr("invalid timestamp") + +func (e invalidTimestampErr) Error() string { + return string(e) +} + +// ClientInterfaceEx is an temporary interface that is similar to ClientInterface +// but with the addition of a `...Parameter` for the telemetry functions. This is currently +// just used to specify the tag cardinality. We want to avoid changing ClientInterface +// at present as that would require a new major release. +// Users should avoid implementing this interface as it will be deprecated in the next version. +type ClientInterfaceEx interface { + // Gauge measures the value of a metric at a particular time. + Gauge(name string, value float64, tags []string, rate float64, parameters ...Parameter) error + + // GaugeWithTimestamp measures the value of a metric at a given time. + // BETA - Please contact our support team for more information to use this feature: https://www.datadoghq.com/support/ + // The value will bypass any aggregation on the client side and agent side, this is + // useful when sending points in the past. + // + // Minimum Datadog Agent version: 7.40.0 + GaugeWithTimestamp(name string, value float64, tags []string, rate float64, timestamp time.Time, parameters ...Parameter) error + + // Count tracks how many times something happened per second. + Count(name string, value int64, tags []string, rate float64, parameters ...Parameter) error + + // CountWithTimestamp tracks how many times something happened at the given second. + // BETA - Please contact our support team for more information to use this feature: https://www.datadoghq.com/support/ + // The value will bypass any aggregation on the client side and agent side, this is + // useful when sending points in the past. + // + // Minimum Datadog Agent version: 7.40.0 + CountWithTimestamp(name string, value int64, tags []string, rate float64, timestamp time.Time, parameters ...Parameter) error + + // Histogram tracks the statistical distribution of a set of values on each host. + Histogram(name string, value float64, tags []string, rate float64, parameters ...Parameter) error + + // Distribution tracks the statistical distribution of a set of values across your infrastructure. + // + // It is recommended to use `WithMaxBufferedMetricsPerContext` to avoid dropping metrics at high throughput, `rate` can + // also be used to limit the load. Both options can *not* be used together. + Distribution(name string, value float64, tags []string, rate float64, parameters ...Parameter) error + + // Decr is just Count of -1 + Decr(name string, tags []string, rate float64, parameters ...Parameter) error + + // Incr is just Count of 1 + Incr(name string, tags []string, rate float64, parameters ...Parameter) error + + // Set counts the number of unique elements in a group. + Set(name string, value string, tags []string, rate float64, parameters ...Parameter) error + + // Timing sends timing information, it is an alias for TimeInMilliseconds + Timing(name string, value time.Duration, tags []string, rate float64, parameters ...Parameter) error + + // TimeInMilliseconds sends timing information in milliseconds. + // It is flushed by statsd with percentiles, mean and other info (https://github.com/etsy/statsd/blob/master/docs/metric_types.md#timing) + TimeInMilliseconds(name string, value float64, tags []string, rate float64, parameters ...Parameter) error + + // Event sends the provided Event. + Event(e *Event, parameters ...Parameter) error + + // SimpleEvent sends an event with the provided title and text. + SimpleEvent(title, text string, parameters ...Parameter) error + + // ServiceCheck sends the provided ServiceCheck. + ServiceCheck(sc *ServiceCheck, parameters ...Parameter) error + + // SimpleServiceCheck sends an serviceCheck with the provided name and status. + SimpleServiceCheck(name string, status ServiceCheckStatus, parameters ...Parameter) error + + // Close the client connection. + Close() error + + // Flush forces a flush of all the queued dogstatsd payloads. + Flush() error + + // IsClosed returns if the client has been closed. + IsClosed() bool + + // GetTelemetry return the telemetry metrics for the client since it started. + GetTelemetry() Telemetry + + // Ensure this interface can't be implemented outside of this package. + // ClientInterfaceEx is a temporary measure to allow us to release a version of the library with the + // extra `...Parameter` parameter (currently used to specify the tag cardinality) in the metric functions + // without having to release a new major version. + // This interface will be deprecated with the next release. + private() +} + +type ErrorHandler func(error) + +// A Client is a handle for sending messages to dogstatsd. It is safe to +// use one Client from multiple goroutines simultaneously. +type ClientEx struct { + // Sender handles the underlying networking protocol + sender *sender + // namespace to prepend to all statsd calls + namespace string + // tags are global tags to be added to every statsd call + tags []string + flushTime time.Duration + telemetry *statsdTelemetry + telemetryClient *telemetryClient + stop chan struct{} + wg sync.WaitGroup + workers []*worker + closerLock sync.Mutex + workersMode receivingMode + aggregatorMode receivingMode + agg *aggregator + aggExtended *aggregator + options []Option + addrOption string + isClosed bool + errorOnBlockedChannel bool + errorHandler ErrorHandler + originDetection bool + defaultCardinality Cardinality +} + +// statsdTelemetry contains telemetry metrics about the client +type statsdTelemetry struct { + totalMetricsGauge uint64 + totalMetricsCount uint64 + totalMetricsHistogram uint64 + totalMetricsDistribution uint64 + totalMetricsSet uint64 + totalMetricsTiming uint64 + totalEvents uint64 + totalServiceChecks uint64 + totalDroppedOnReceive uint64 +} + +// Verify that ClientEx implements the ClientInterfaceEx interface. +// https://golang.org/doc/faq#guarantee_satisfies_interface +var _ ClientInterfaceEx = &ClientEx{} + +func resolveAddr(addr string) string { + envPort := "" + + if addr == "" { + addr = os.Getenv(agentHostEnvVarName) + envPort = os.Getenv(agentPortEnvVarName) + agentURL, _ := os.LookupEnv(agentURLEnvVarName) + agentURL = parseAgentURL(agentURL) + + // agentURLEnvVarName has priority over agentHostEnvVarName + if agentURL != "" { + return agentURL + } + } + + if addr == "" { + return "" + } + + for _, prefix := range AddressPrefixes { + if strings.HasPrefix(addr, prefix) { + return addr + } + } + // TODO: How does this work for IPv6? + if strings.Contains(addr, ":") { + return addr + } + if envPort != "" { + addr = fmt.Sprintf("%s:%s", addr, envPort) + } else { + addr = fmt.Sprintf("%s:%s", addr, defaultUDPPort) + } + return addr +} + +func parseAgentURL(agentURL string) string { + if agentURL != "" { + if strings.HasPrefix(agentURL, WindowsPipeAddressPrefix) { + return agentURL + } + + parsedURL, err := url.Parse(agentURL) + if err != nil { + return "" + } + + if parsedURL.Scheme == "udp" { + if strings.Contains(parsedURL.Host, ":") { + return parsedURL.Host + } + return fmt.Sprintf("%s:%s", parsedURL.Host, defaultUDPPort) + } + + if parsedURL.Scheme == "unix" { + return agentURL + } + } + return "" +} + +func createWriter(addr string, writeTimeout time.Duration, connectTimeout time.Duration) (Transport, string, error) { + if addr == "" { + return nil, "", errors.New("No address passed and autodetection from environment failed") + } + + switch { + case strings.HasPrefix(addr, WindowsPipeAddressPrefix): + w, err := newWindowsPipeWriter(addr, writeTimeout) + return w, writerWindowsPipe, err + case strings.HasPrefix(addr, UnixAddressPrefix): + w, err := newUDSWriter(addr[len(UnixAddressPrefix):], writeTimeout, connectTimeout, "") + return w, writerNameUDS, err + case strings.HasPrefix(addr, UnixAddressDatagramPrefix): + w, err := newUDSWriter(addr[len(UnixAddressDatagramPrefix):], writeTimeout, connectTimeout, "unixgram") + return w, writerNameUDS, err + case strings.HasPrefix(addr, UnixAddressStreamPrefix): + w, err := newUDSWriter(addr[len(UnixAddressStreamPrefix):], writeTimeout, connectTimeout, "unix") + return w, writerNameUDS, err + default: + w, err := newUDPWriter(addr, writeTimeout) + return w, writerNameUDP, err + } +} + +// New returns a pointer to a new Client given an addr in the format "hostname:port" for UDP, +// "unix:///path/to/socket" for UDS or "\\.\pipe\path\to\pipe" for Windows Named Pipes. +func NewEx(addr string, options ...Option) (*ClientEx, error) { + o, err := resolveOptions(options) + if err != nil { + return nil, err + } + + addr = resolveAddr(addr) + w, writerType, err := createWriter(addr, o.writeTimeout, o.connectTimeout) + if err != nil { + return nil, err + } + + client, err := newWithWriter(w, o, writerType) + if err == nil { + client.options = append(client.options, options...) + client.addrOption = addr + } + return client, err +} + +type customWriter struct { + io.WriteCloser +} + +func (w *customWriter) GetTransportName() string { + return writerNameCustom +} + +// NewWithWriter creates a new ClientEx with given writer. Writer is a +// io.WriteCloser +func NewWithWriterEx(w io.WriteCloser, options ...Option) (*ClientEx, error) { + o, err := resolveOptions(options) + if err != nil { + return nil, err + } + return newWithWriter(&customWriter{w}, o, writerNameCustom) +} + +// CloneWithExtraOptions create a new ClientEx with extra options +func CloneWithExtraOptionsEx(c *ClientEx, options ...Option) (*ClientEx, error) { + if c == nil { + return nil, ErrNoClient + } + + if c.addrOption == "" { + return nil, fmt.Errorf("can't clone client with no addrOption") + } + opt := append(c.options, options...) + return NewEx(c.addrOption, opt...) +} + +func newWithWriter(w Transport, o *Options, writerName string) (*ClientEx, error) { + c := ClientEx{ + namespace: o.namespace, + tags: o.tags, + telemetry: &statsdTelemetry{}, + errorOnBlockedChannel: o.channelModeErrorsWhenFull, + errorHandler: o.errorHandler, + originDetection: isOriginDetectionEnabled(o), + } + + // Inject values of DD_* environment variables as global tags. + for _, mapping := range ddEnvTagsMapping { + if value := os.Getenv(mapping.envName); value != "" { + c.tags = append(c.tags, fmt.Sprintf("%s:%s", mapping.tagName, value)) + } + } + // Whether origin detection is enabled or not for this client, we need to initialize the global + // external environment variable in case another client has enabled it and needs to access it. + initExternalEnv() + + if o.tagCardinality != nil { + c.defaultCardinality = *o.tagCardinality + } else if card, ok := envTagCardinality(); ok { + c.defaultCardinality = card + } else { + c.defaultCardinality = CardinalityNotSet + } + + initContainerID(o.containerID, fillInContainerID(o), isHostCgroupNamespace()) + isUDS := writerName == writerNameUDS + + if o.maxBytesPerPayload == 0 { + if isUDS { + o.maxBytesPerPayload = DefaultMaxAgentPayloadSize + } else { + o.maxBytesPerPayload = OptimalUDPPayloadSize + } + } + if o.bufferPoolSize == 0 { + if isUDS { + o.bufferPoolSize = DefaultUDSBufferPoolSize + } else { + o.bufferPoolSize = DefaultUDPBufferPoolSize + } + } + if o.senderQueueSize == 0 { + if isUDS { + o.senderQueueSize = DefaultUDSBufferPoolSize + } else { + o.senderQueueSize = DefaultUDPBufferPoolSize + } + } + + bufferPool := newBufferPool(o.bufferPoolSize, o.maxBytesPerPayload, o.maxMessagesPerPayload) + c.sender = newSender(w, o.senderQueueSize, bufferPool, o.errorHandler) + c.aggregatorMode = o.receiveMode + + c.workersMode = o.receiveMode + // channelMode mode at the worker level is not enabled when + // ExtendedAggregation is since the user app will not directly + // use the worker (the aggregator sit between the app and the + // workers). + if o.extendedAggregation { + c.workersMode = mutexMode + } + + if o.aggregation || o.extendedAggregation || o.maxBufferedSamplesPerContext > 0 { + c.agg = newAggregator(&c, int64(o.maxBufferedSamplesPerContext), o.aggregatorShardCount) + c.agg.start(o.aggregationFlushInterval) + + if o.extendedAggregation { + c.aggExtended = c.agg + + if c.aggregatorMode == channelMode { + c.agg.startReceivingMetric(o.channelModeBufferSize, o.workersCount) + } + } + } + + for i := 0; i < o.workersCount; i++ { + w := newWorker(bufferPool, c.sender) + c.workers = append(c.workers, w) + + if c.workersMode == channelMode { + w.startReceivingMetric(o.channelModeBufferSize) + } + } + + c.flushTime = o.bufferFlushInterval + c.stop = make(chan struct{}, 1) + + c.wg.Add(1) + go func() { + defer c.wg.Done() + c.watch() + }() + + if o.telemetry { + if o.telemetryAddr == "" { + c.telemetryClient = newTelemetryClient(&c, c.agg != nil) + } else { + var err error + c.telemetryClient, err = newTelemetryClientWithCustomAddr(&c, o.telemetryAddr, c.agg != nil, bufferPool, o.writeTimeout, o.connectTimeout) + if err != nil { + return nil, err + } + } + c.telemetryClient.run(&c.wg, c.stop) + } + + return &c, nil +} + +func (c *ClientEx) watch() { + ticker := time.NewTicker(c.flushTime) + + for { + select { + case <-ticker.C: + for _, w := range c.workers { + w.flush() + } + case <-c.stop: + ticker.Stop() + return + } + } +} + +// Flush forces a flush of all the queued dogstatsd payloads This method is +// blocking and will not return until everything is sent through the network. +// In mutexMode, this will also block sampling new data to the client while the +// workers and sender are flushed. +func (c *ClientEx) Flush() error { + if c == nil { + return ErrNoClient + } + if c.agg != nil { + c.agg.flush() + } + for _, w := range c.workers { + w.pause() + defer w.unpause() + w.flushUnsafe() + } + // Now that the worker are pause the sender can flush the queue between + // worker and senders + c.sender.flush() + return nil +} + +// IsClosed returns if the client has been closed. +func (c *ClientEx) IsClosed() bool { + c.closerLock.Lock() + defer c.closerLock.Unlock() + return c.isClosed +} + +func (c *ClientEx) flushTelemetryMetrics(t *Telemetry) { + t.TotalMetricsGauge = atomic.LoadUint64(&c.telemetry.totalMetricsGauge) + t.TotalMetricsCount = atomic.LoadUint64(&c.telemetry.totalMetricsCount) + t.TotalMetricsSet = atomic.LoadUint64(&c.telemetry.totalMetricsSet) + t.TotalMetricsHistogram = atomic.LoadUint64(&c.telemetry.totalMetricsHistogram) + t.TotalMetricsDistribution = atomic.LoadUint64(&c.telemetry.totalMetricsDistribution) + t.TotalMetricsTiming = atomic.LoadUint64(&c.telemetry.totalMetricsTiming) + t.TotalEvents = atomic.LoadUint64(&c.telemetry.totalEvents) + t.TotalServiceChecks = atomic.LoadUint64(&c.telemetry.totalServiceChecks) + t.TotalDroppedOnReceive = atomic.LoadUint64(&c.telemetry.totalDroppedOnReceive) +} + +// GetTelemetry return the telemetry metrics for the client since it started. +func (c *ClientEx) GetTelemetry() Telemetry { + return c.telemetryClient.getTelemetry() +} + +// GetTransport return the name of the transport used. +func (c *ClientEx) GetTransport() string { + if c.sender == nil { + return "" + } + return c.sender.getTransportName() +} + +type ErrorInputChannelFull struct { + Metric metric + ChannelSize int + Msg string +} + +func (e ErrorInputChannelFull) Error() string { + return e.Msg +} + +func (c *ClientEx) send(m metric) error { + h := hashString32(m.name) + worker := c.workers[h%uint32(len(c.workers))] + + if c.workersMode == channelMode { + select { + case worker.inputMetrics <- m: + default: + atomic.AddUint64(&c.telemetry.totalDroppedOnReceive, 1) + err := &ErrorInputChannelFull{m, len(worker.inputMetrics), "Worker input channel full"} + if c.errorHandler != nil { + c.errorHandler(err) + } + if c.errorOnBlockedChannel { + return err + } + } + return nil + } + return worker.processMetric(m) +} + +// sendBlocking is used by the aggregator to inject aggregated metrics. +func (c *ClientEx) sendBlocking(m metric) error { + m.globalTags = c.tags + m.namespace = c.namespace + + h := hashString32(m.name) + worker := c.workers[h%uint32(len(c.workers))] + return worker.processMetric(m) +} + +func (c *ClientEx) sendToAggregator(mType metricType, name string, value float64, tags []string, rate float64, f bufferedMetricSampleFunc, cardinality Cardinality) error { + if c.aggregatorMode == channelMode { + m := metric{metricType: mType, name: name, fvalue: value, tags: tags, rate: rate, cardinality: cardinality} + select { + case c.aggExtended.inputMetrics <- m: + default: + atomic.AddUint64(&c.telemetry.totalDroppedOnReceive, 1) + err := &ErrorInputChannelFull{m, len(c.aggExtended.inputMetrics), "Aggregator input channel full"} + if c.errorHandler != nil { + c.errorHandler(err) + } + if c.errorOnBlockedChannel { + return err + } + } + return nil + } + return f(name, value, tags, rate, cardinality) +} + +// Gauge measures the value of a metric at a particular time. +func (c *ClientEx) Gauge(name string, value float64, tags []string, rate float64, parameters ...Parameter) error { + if c == nil { + return ErrNoClient + } + atomic.AddUint64(&c.telemetry.totalMetricsGauge, 1) + cardinality := parameterCardinality(parameters, c.defaultCardinality) + if c.agg != nil { + return c.agg.gauge(name, value, tags, cardinality) + } + return c.send(metric{metricType: gauge, name: name, fvalue: value, tags: tags, rate: rate, globalTags: c.tags, namespace: c.namespace, originDetection: c.originDetection, cardinality: cardinality}) +} + +// GaugeWithTimestamp measures the value of a metric at a given time. +// BETA - Please contact our support team for more information to use this feature: https://www.datadoghq.com/support/ +// The value will bypass any aggregation on the client side and agent side, this is +// useful when sending points in the past. +// +// Minimum Datadog Agent version: 7.40.0 +func (c *ClientEx) GaugeWithTimestamp(name string, value float64, tags []string, rate float64, timestamp time.Time, parameters ...Parameter) error { + if c == nil { + return ErrNoClient + } + + if timestamp.IsZero() || timestamp.Unix() <= noTimestamp { + return InvalidTimestamp + } + + atomic.AddUint64(&c.telemetry.totalMetricsGauge, 1) + cardinality := parameterCardinality(parameters, c.defaultCardinality) + return c.send(metric{metricType: gauge, name: name, fvalue: value, tags: tags, rate: rate, globalTags: c.tags, namespace: c.namespace, timestamp: timestamp.Unix(), originDetection: c.originDetection, cardinality: cardinality}) +} + +// Count tracks how many times something happened per second. +func (c *ClientEx) Count(name string, value int64, tags []string, rate float64, parameters ...Parameter) error { + if c == nil { + return ErrNoClient + } + atomic.AddUint64(&c.telemetry.totalMetricsCount, 1) + cardinality := parameterCardinality(parameters, c.defaultCardinality) + if c.agg != nil { + return c.agg.count(name, value, tags, cardinality) + } + return c.send(metric{metricType: count, name: name, ivalue: value, tags: tags, rate: rate, globalTags: c.tags, namespace: c.namespace, originDetection: c.originDetection, cardinality: cardinality}) +} + +// CountWithTimestamp tracks how many times something happened at the given second. +// BETA - Please contact our support team for more information to use this feature: https://www.datadoghq.com/support/ +// The value will bypass any aggregation on the client side and agent side, this is +// useful when sending points in the past. +// +// Minimum Datadog Agent version: 7.40.0 +func (c *ClientEx) CountWithTimestamp(name string, value int64, tags []string, rate float64, timestamp time.Time, parameters ...Parameter) error { + if c == nil { + return ErrNoClient + } + + if timestamp.IsZero() || timestamp.Unix() <= noTimestamp { + return InvalidTimestamp + } + + atomic.AddUint64(&c.telemetry.totalMetricsCount, 1) + cardinality := parameterCardinality(parameters, c.defaultCardinality) + return c.send(metric{metricType: count, name: name, ivalue: value, tags: tags, rate: rate, globalTags: c.tags, namespace: c.namespace, timestamp: timestamp.Unix(), originDetection: c.originDetection, cardinality: cardinality}) +} + +// Histogram tracks the statistical distribution of a set of values on each host. +func (c *ClientEx) Histogram(name string, value float64, tags []string, rate float64, parameters ...Parameter) error { + if c == nil { + return ErrNoClient + } + atomic.AddUint64(&c.telemetry.totalMetricsHistogram, 1) + cardinality := parameterCardinality(parameters, c.defaultCardinality) + if c.aggExtended != nil { + return c.sendToAggregator(histogram, name, value, tags, rate, c.aggExtended.histogram, cardinality) + } + return c.send(metric{metricType: histogram, name: name, fvalue: value, tags: tags, rate: rate, globalTags: c.tags, namespace: c.namespace, originDetection: c.originDetection, cardinality: cardinality}) +} + +// Distribution tracks the statistical distribution of a set of values across your infrastructure. +func (c *ClientEx) Distribution(name string, value float64, tags []string, rate float64, parameters ...Parameter) error { + if c == nil { + return ErrNoClient + } + atomic.AddUint64(&c.telemetry.totalMetricsDistribution, 1) + cardinality := parameterCardinality(parameters, c.defaultCardinality) + if c.aggExtended != nil { + return c.sendToAggregator(distribution, name, value, tags, rate, c.aggExtended.distribution, cardinality) + } + return c.send(metric{metricType: distribution, name: name, fvalue: value, tags: tags, rate: rate, globalTags: c.tags, namespace: c.namespace, originDetection: c.originDetection, cardinality: cardinality}) +} + +// Decr is just Count of -1 +func (c *ClientEx) Decr(name string, tags []string, rate float64, parameters ...Parameter) error { + return c.Count(name, -1, tags, rate, parameters...) +} + +// Incr is just Count of 1 +func (c *ClientEx) Incr(name string, tags []string, rate float64, parameters ...Parameter) error { + return c.Count(name, 1, tags, rate, parameters...) +} + +// Set counts the number of unique elements in a group. +func (c *ClientEx) Set(name string, value string, tags []string, rate float64, parameters ...Parameter) error { + if c == nil { + return ErrNoClient + } + atomic.AddUint64(&c.telemetry.totalMetricsSet, 1) + cardinality := parameterCardinality(parameters, c.defaultCardinality) + if c.agg != nil { + return c.agg.set(name, value, tags, cardinality) + } + return c.send(metric{metricType: set, name: name, svalue: value, tags: tags, rate: rate, globalTags: c.tags, namespace: c.namespace, originDetection: c.originDetection, cardinality: cardinality}) +} + +// Timing sends timing information, it is an alias for TimeInMilliseconds +func (c *ClientEx) Timing(name string, value time.Duration, tags []string, rate float64, parameters ...Parameter) error { + return c.TimeInMilliseconds(name, value.Seconds()*1000, tags, rate, parameters...) +} + +// TimeInMilliseconds sends timing information in milliseconds. +// It is flushed by statsd with percentiles, mean and other info (https://github.com/etsy/statsd/blob/master/docs/metric_types.md#timing) +func (c *ClientEx) TimeInMilliseconds(name string, value float64, tags []string, rate float64, parameters ...Parameter) error { + if c == nil { + return ErrNoClient + } + atomic.AddUint64(&c.telemetry.totalMetricsTiming, 1) + cardinality := parameterCardinality(parameters, c.defaultCardinality) + if c.aggExtended != nil { + return c.sendToAggregator(timing, name, value, tags, rate, c.aggExtended.timing, cardinality) + } + return c.send(metric{metricType: timing, name: name, fvalue: value, tags: tags, rate: rate, globalTags: c.tags, namespace: c.namespace, originDetection: c.originDetection, cardinality: cardinality}) +} + +// Event sends the provided Event. +func (c *ClientEx) Event(e *Event, parameters ...Parameter) error { + if c == nil { + return ErrNoClient + } + atomic.AddUint64(&c.telemetry.totalEvents, 1) + cardinality := parameterCardinality(parameters, c.defaultCardinality) + return c.send(metric{metricType: event, evalue: e, rate: 1, globalTags: c.tags, namespace: c.namespace, originDetection: c.originDetection, cardinality: cardinality}) +} + +// SimpleEvent sends an event with the provided title and text. +func (c *ClientEx) SimpleEvent(title, text string, parameters ...Parameter) error { + e := NewEvent(title, text) + return c.Event(e, parameters...) +} + +// ServiceCheck sends the provided ServiceCheck. +func (c *ClientEx) ServiceCheck(sc *ServiceCheck, parameters ...Parameter) error { + if c == nil { + return ErrNoClient + } + atomic.AddUint64(&c.telemetry.totalServiceChecks, 1) + cardinality := parameterCardinality(parameters, c.defaultCardinality) + return c.send(metric{metricType: serviceCheck, scvalue: sc, rate: 1, globalTags: c.tags, namespace: c.namespace, originDetection: c.originDetection, cardinality: cardinality}) +} + +// SimpleServiceCheck sends an serviceCheck with the provided name and status. +func (c *ClientEx) SimpleServiceCheck(name string, status ServiceCheckStatus, parameters ...Parameter) error { + sc := NewServiceCheck(name, status) + return c.ServiceCheck(sc, parameters...) +} + +// Close the client connection. +func (c *ClientEx) Close() error { + if c == nil { + return ErrNoClient + } + + // Acquire closer lock to ensure only one thread can close the stop channel + c.closerLock.Lock() + defer c.closerLock.Unlock() + + if c.isClosed { + return nil + } + + // Notify all other threads that they should stop + select { + case <-c.stop: + return nil + default: + } + close(c.stop) + + if c.workersMode == channelMode { + for _, w := range c.workers { + w.stopReceivingMetric() + } + } + + // flush the aggregator first + if c.agg != nil { + if c.aggExtended != nil && c.aggregatorMode == channelMode { + c.agg.stopReceivingMetric() + } + c.agg.stop() + } + + // Wait for the threads to stop + c.wg.Wait() + + c.Flush() + + c.isClosed = true + return c.sender.close() +} + +func (*ClientEx) private() { +} + +// isOriginDetectionEnabled returns whether origin detection is enabled. +// +// Disable origin detection only in one of the following cases: +// - DD_ORIGIN_DETECTION_ENABLED is explicitly set to false +// - o.originDetection is explicitly set to false, which is true by default +func isOriginDetectionEnabled(o *Options) bool { + if !o.originDetection { + return false + } + + envVarValue := os.Getenv(originDetectionEnabled) + if envVarValue == "" { + // DD_ORIGIN_DETECTION_ENABLED is not set + // default to true + return true + } + + enabled, err := strconv.ParseBool(envVarValue) + if err != nil { + // Error due to an unsupported DD_ORIGIN_DETECTION_ENABLED value + // default to true + return true + } + + return enabled +} + +// fillInContainerID returns whether the clients should fill the container field. +func fillInContainerID(o *Options) bool { + if o.containerID != "" { + return false + } + return isOriginDetectionEnabled(o) +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/tag_cardinality.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/tag_cardinality.go new file mode 100644 index 000000000..23d73d530 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/tag_cardinality.go @@ -0,0 +1,78 @@ +package statsd + +import ( + "os" + "strings" +) + +type Parameter interface{} + +type Cardinality int + +const ( + CardinalityNotSet Cardinality = iota + CardinalityNone + CardinalityLow + CardinalityOrchestrator + CardinalityHigh +) + +func (c Cardinality) isValid() bool { + return c >= CardinalityNotSet && c <= CardinalityHigh +} + +func (c Cardinality) String() string { + switch c { + case CardinalityNone: + return "none" + case CardinalityLow: + return "low" + case CardinalityOrchestrator: + return "orchestrator" + case CardinalityHigh: + return "high" + } + return "" +} + +// validateCardinality converts a string to Cardinality +func validateCardinality(card string) (Cardinality, bool) { + card = strings.ToLower(card) + switch card { + case "none": + return CardinalityNone, true + case "low": + return CardinalityLow, true + case "orchestrator": + return CardinalityOrchestrator, true + case "high": + return CardinalityHigh, true + default: + return CardinalityNotSet, false + } +} + +// envTagCardinality returns the tag cardinality value from the DD_CARDINALITY/DATADOG_CARDINALITY environment variable. +func envTagCardinality() (Cardinality, bool) { + // If the user has not provided a value, read the value from the DD_CARDINALITY environment variable. + if card, ok := validateCardinality(os.Getenv("DD_CARDINALITY")); ok { + return card, true + } + + // If DD_CARDINALITY is not set or valid, read the value from the DATADOG_CARDINALITY environment variable. + if card, ok := validateCardinality(os.Getenv("DATADOG_CARDINALITY")); ok { + return card, true + } + + return CardinalityNotSet, false +} + +func parameterCardinality(parameters []Parameter, defaultCardinality Cardinality) Cardinality { + for _, o := range parameters { + c, ok := o.(Cardinality) + if ok && c.isValid() { + return c + } + } + return defaultCardinality +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/telemetry.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/telemetry.go new file mode 100644 index 000000000..bfec2d72d --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/telemetry.go @@ -0,0 +1,307 @@ +package statsd + +import ( + "fmt" + "sync" + "time" +) + +/* +telemetryInterval is the interval at which telemetry will be sent by the client. +*/ +const telemetryInterval = 10 * time.Second + +/* +clientTelemetryTag is a tag identifying this specific client. +*/ +var clientTelemetryTag = "client:go" + +/* +clientVersionTelemetryTag is a tag identifying this specific client version. +*/ +var clientVersionTelemetryTag = "client_version:5.8.3" + +// Telemetry represents internal metrics about the client behavior since it started. +type Telemetry struct { + // + // Those are produced by the 'Client' + // + + // TotalMetrics is the total number of metrics sent by the client before aggregation and sampling. + TotalMetrics uint64 + // TotalMetricsGauge is the total number of gauges sent by the client before aggregation and sampling. + TotalMetricsGauge uint64 + // TotalMetricsCount is the total number of counts sent by the client before aggregation and sampling. + TotalMetricsCount uint64 + // TotalMetricsHistogram is the total number of histograms sent by the client before aggregation and sampling. + TotalMetricsHistogram uint64 + // TotalMetricsDistribution is the total number of distributions sent by the client before aggregation and + // sampling. + TotalMetricsDistribution uint64 + // TotalMetricsSet is the total number of sets sent by the client before aggregation and sampling. + TotalMetricsSet uint64 + // TotalMetricsTiming is the total number of timings sent by the client before aggregation and sampling. + TotalMetricsTiming uint64 + // TotalEvents is the total number of events sent by the client before aggregation and sampling. + TotalEvents uint64 + // TotalServiceChecks is the total number of service_checks sent by the client before aggregation and sampling. + TotalServiceChecks uint64 + + // TotalDroppedOnReceive is the total number metrics/event/service_checks dropped when using ChannelMode (see + // WithChannelMode option). + TotalDroppedOnReceive uint64 + + // + // Those are produced by the 'sender' + // + + // TotalPayloadsSent is the total number of payload (packet on the network) succesfully sent by the client. When + // using UDP we don't know if packet dropped or not, so all packet are considered as succesfully sent. + TotalPayloadsSent uint64 + // TotalPayloadsDropped is the total number of payload dropped by the client. This includes all cause of dropped + // (TotalPayloadsDroppedQueueFull and TotalPayloadsDroppedWriter). When using UDP This won't includes the + // network dropped. + TotalPayloadsDropped uint64 + // TotalPayloadsDroppedWriter is the total number of payload dropped by the writer (when using UDS or named + // pipe) due to network timeout or error. + TotalPayloadsDroppedWriter uint64 + // TotalPayloadsDroppedQueueFull is the total number of payload dropped internally because the queue of payloads + // waiting to be sent on the wire is full. This means the client is generating more metrics than can be sent on + // the wire. If your app sends metrics in batch look at WithSenderQueueSize option to increase the queue size. + TotalPayloadsDroppedQueueFull uint64 + + // TotalBytesSent is the total number of bytes succesfully sent by the client. When using UDP we don't know if + // packet dropped or not, so all packet are considered as succesfully sent. + TotalBytesSent uint64 + // TotalBytesDropped is the total number of bytes dropped by the client. This includes all cause of dropped + // (TotalBytesDroppedQueueFull and TotalBytesDroppedWriter). When using UDP This + // won't includes the network dropped. + TotalBytesDropped uint64 + // TotalBytesDroppedWriter is the total number of bytes dropped by the writer (when using UDS or named pipe) due + // to network timeout or error. + TotalBytesDroppedWriter uint64 + // TotalBytesDroppedQueueFull is the total number of bytes dropped internally because the queue of payloads + // waiting to be sent on the wire is full. This means the client is generating more metrics than can be sent on + // the wire. If your app sends metrics in batch look at WithSenderQueueSize option to increase the queue size. + TotalBytesDroppedQueueFull uint64 + + // + // Those are produced by the 'aggregator' + // + + // AggregationNbContext is the total number of contexts flushed by the aggregator when either + // WithClientSideAggregation or WithExtendedClientSideAggregation options are enabled. + AggregationNbContext uint64 + // AggregationNbContextGauge is the total number of contexts for gauges flushed by the aggregator when either + // WithClientSideAggregation or WithExtendedClientSideAggregation options are enabled. + AggregationNbContextGauge uint64 + // AggregationNbContextCount is the total number of contexts for counts flushed by the aggregator when either + // WithClientSideAggregation or WithExtendedClientSideAggregation options are enabled. + AggregationNbContextCount uint64 + // AggregationNbContextSet is the total number of contexts for sets flushed by the aggregator when either + // WithClientSideAggregation or WithExtendedClientSideAggregation options are enabled. + AggregationNbContextSet uint64 + // AggregationNbContextHistogram is the total number of contexts for histograms flushed by the aggregator when either + // WithClientSideAggregation or WithExtendedClientSideAggregation options are enabled. + AggregationNbContextHistogram uint64 + // AggregationNbContextDistribution is the total number of contexts for distributions flushed by the aggregator when either + // WithClientSideAggregation or WithExtendedClientSideAggregation options are enabled. + AggregationNbContextDistribution uint64 + // AggregationNbContextTiming is the total number of contexts for timings flushed by the aggregator when either + // WithClientSideAggregation or WithExtendedClientSideAggregation options are enabled. + AggregationNbContextTiming uint64 +} + +type telemetryClient struct { + sync.RWMutex // used mostly to change the transport tag. + + c *ClientEx + aggEnabled bool // is aggregation enabled and should we sent aggregation telemetry. + transport string + tags []string + tagsByType map[metricType][]string + transportTagKnown bool + sender *sender + worker *worker + lastSample Telemetry // The previous sample of telemetry sent +} + +func newTelemetryClient(c *ClientEx, aggregationEnabled bool) *telemetryClient { + t := &telemetryClient{ + c: c, + aggEnabled: aggregationEnabled, + tags: []string{}, + tagsByType: map[metricType][]string{}, + } + + t.setTags() + return t +} + +func newTelemetryClientWithCustomAddr(c *ClientEx, telemetryAddr string, aggregationEnabled bool, pool *bufferPool, + writeTimeout time.Duration, connectTimeout time.Duration, +) (*telemetryClient, error) { + telemetryAddr = resolveAddr(telemetryAddr) + telemetryWriter, _, err := createWriter(telemetryAddr, writeTimeout, connectTimeout) + if err != nil { + return nil, fmt.Errorf("Could not resolve telemetry address: %v", err) + } + + t := newTelemetryClient(c, aggregationEnabled) + + // Creating a custom sender/worker with 1 worker in mutex mode for the + // telemetry that share the same bufferPool. + // FIXME due to performance pitfall, we're always using UDP defaults + // even for UDS. + t.sender = newSender(telemetryWriter, DefaultUDPBufferPoolSize, pool, c.errorHandler) + t.worker = newWorker(pool, t.sender) + return t, nil +} + +func (t *telemetryClient) run(wg *sync.WaitGroup, stop chan struct{}) { + wg.Add(1) + go func() { + defer wg.Done() + ticker := time.NewTicker(telemetryInterval) + for { + select { + case <-ticker.C: + t.sendTelemetry() + case <-stop: + ticker.Stop() + if t.sender != nil { + t.sender.close() + } + return + } + } + }() +} + +func (t *telemetryClient) sendTelemetry() { + for _, m := range t.flush() { + if t.worker != nil { + t.worker.processMetric(m) + } else { + t.c.send(m) + } + } + + if t.worker != nil { + t.worker.flush() + } +} + +func (t *telemetryClient) getTelemetry() Telemetry { + if t == nil { + // telemetry was disabled through the WithoutTelemetry option + return Telemetry{} + } + + tlm := Telemetry{} + t.c.flushTelemetryMetrics(&tlm) + t.c.sender.flushTelemetryMetrics(&tlm) + t.c.agg.flushTelemetryMetrics(&tlm) + + tlm.TotalMetrics = tlm.TotalMetricsGauge + + tlm.TotalMetricsCount + + tlm.TotalMetricsSet + + tlm.TotalMetricsHistogram + + tlm.TotalMetricsDistribution + + tlm.TotalMetricsTiming + + tlm.TotalPayloadsDropped = tlm.TotalPayloadsDroppedQueueFull + tlm.TotalPayloadsDroppedWriter + tlm.TotalBytesDropped = tlm.TotalBytesDroppedQueueFull + tlm.TotalBytesDroppedWriter + + if t.aggEnabled { + tlm.AggregationNbContext = tlm.AggregationNbContextGauge + + tlm.AggregationNbContextCount + + tlm.AggregationNbContextSet + + tlm.AggregationNbContextHistogram + + tlm.AggregationNbContextDistribution + + tlm.AggregationNbContextTiming + } + return tlm +} + +// setTransportTag if it was never set and is now known. +func (t *telemetryClient) setTags() { + transport := t.c.GetTransport() + t.RLock() + // We need to refresh if we never set the tags or if the transport changed. + // For example when `unix://` is used we might return `uds` until we actually connect and detect that + // this is a UDS Stream socket and then return `uds-stream`. + needsRefresh := len(t.tags) == len(t.c.tags) || t.transport != transport + t.RUnlock() + + if !needsRefresh { + return + } + + t.Lock() + defer t.Unlock() + + t.transport = transport + t.tags = append(t.c.tags, clientTelemetryTag, clientVersionTelemetryTag) + if transport != "" { + t.tags = append(t.tags, "client_transport:"+transport) + } + t.tagsByType[gauge] = append(append([]string{}, t.tags...), "metrics_type:gauge") + t.tagsByType[count] = append(append([]string{}, t.tags...), "metrics_type:count") + t.tagsByType[set] = append(append([]string{}, t.tags...), "metrics_type:set") + t.tagsByType[timing] = append(append([]string{}, t.tags...), "metrics_type:timing") + t.tagsByType[histogram] = append(append([]string{}, t.tags...), "metrics_type:histogram") + t.tagsByType[distribution] = append(append([]string{}, t.tags...), "metrics_type:distribution") +} + +// flushTelemetry returns Telemetry metrics to be flushed. It's its own function to ease testing. +func (t *telemetryClient) flush() []metric { + m := []metric{} + + // same as Count but without global namespace + telemetryCount := func(name string, value int64, tags []string) { + m = append(m, metric{metricType: count, name: name, ivalue: value, tags: tags, rate: 1}) + } + + tlm := t.getTelemetry() + t.setTags() + + // We send the diff between now and the previous telemetry flush. This keep the same telemetry behavior from V4 + // so users dashboard's aren't broken when upgrading to V5. It also allow to graph on the same dashboard a mix + // of V4 and V5 apps. + telemetryCount("datadog.dogstatsd.client.metrics", int64(tlm.TotalMetrics-t.lastSample.TotalMetrics), t.tags) + telemetryCount("datadog.dogstatsd.client.metrics_by_type", int64(tlm.TotalMetricsGauge-t.lastSample.TotalMetricsGauge), t.tagsByType[gauge]) + telemetryCount("datadog.dogstatsd.client.metrics_by_type", int64(tlm.TotalMetricsCount-t.lastSample.TotalMetricsCount), t.tagsByType[count]) + telemetryCount("datadog.dogstatsd.client.metrics_by_type", int64(tlm.TotalMetricsHistogram-t.lastSample.TotalMetricsHistogram), t.tagsByType[histogram]) + telemetryCount("datadog.dogstatsd.client.metrics_by_type", int64(tlm.TotalMetricsDistribution-t.lastSample.TotalMetricsDistribution), t.tagsByType[distribution]) + telemetryCount("datadog.dogstatsd.client.metrics_by_type", int64(tlm.TotalMetricsSet-t.lastSample.TotalMetricsSet), t.tagsByType[set]) + telemetryCount("datadog.dogstatsd.client.metrics_by_type", int64(tlm.TotalMetricsTiming-t.lastSample.TotalMetricsTiming), t.tagsByType[timing]) + telemetryCount("datadog.dogstatsd.client.events", int64(tlm.TotalEvents-t.lastSample.TotalEvents), t.tags) + telemetryCount("datadog.dogstatsd.client.service_checks", int64(tlm.TotalServiceChecks-t.lastSample.TotalServiceChecks), t.tags) + + telemetryCount("datadog.dogstatsd.client.metric_dropped_on_receive", int64(tlm.TotalDroppedOnReceive-t.lastSample.TotalDroppedOnReceive), t.tags) + + telemetryCount("datadog.dogstatsd.client.packets_sent", int64(tlm.TotalPayloadsSent-t.lastSample.TotalPayloadsSent), t.tags) + telemetryCount("datadog.dogstatsd.client.packets_dropped", int64(tlm.TotalPayloadsDropped-t.lastSample.TotalPayloadsDropped), t.tags) + telemetryCount("datadog.dogstatsd.client.packets_dropped_queue", int64(tlm.TotalPayloadsDroppedQueueFull-t.lastSample.TotalPayloadsDroppedQueueFull), t.tags) + telemetryCount("datadog.dogstatsd.client.packets_dropped_writer", int64(tlm.TotalPayloadsDroppedWriter-t.lastSample.TotalPayloadsDroppedWriter), t.tags) + + telemetryCount("datadog.dogstatsd.client.bytes_dropped", int64(tlm.TotalBytesDropped-t.lastSample.TotalBytesDropped), t.tags) + telemetryCount("datadog.dogstatsd.client.bytes_sent", int64(tlm.TotalBytesSent-t.lastSample.TotalBytesSent), t.tags) + telemetryCount("datadog.dogstatsd.client.bytes_dropped_queue", int64(tlm.TotalBytesDroppedQueueFull-t.lastSample.TotalBytesDroppedQueueFull), t.tags) + telemetryCount("datadog.dogstatsd.client.bytes_dropped_writer", int64(tlm.TotalBytesDroppedWriter-t.lastSample.TotalBytesDroppedWriter), t.tags) + + if t.aggEnabled { + telemetryCount("datadog.dogstatsd.client.aggregated_context", int64(tlm.AggregationNbContext-t.lastSample.AggregationNbContext), t.tags) + telemetryCount("datadog.dogstatsd.client.aggregated_context_by_type", int64(tlm.AggregationNbContextGauge-t.lastSample.AggregationNbContextGauge), t.tagsByType[gauge]) + telemetryCount("datadog.dogstatsd.client.aggregated_context_by_type", int64(tlm.AggregationNbContextSet-t.lastSample.AggregationNbContextSet), t.tagsByType[set]) + telemetryCount("datadog.dogstatsd.client.aggregated_context_by_type", int64(tlm.AggregationNbContextCount-t.lastSample.AggregationNbContextCount), t.tagsByType[count]) + telemetryCount("datadog.dogstatsd.client.aggregated_context_by_type", int64(tlm.AggregationNbContextHistogram-t.lastSample.AggregationNbContextHistogram), t.tagsByType[histogram]) + telemetryCount("datadog.dogstatsd.client.aggregated_context_by_type", int64(tlm.AggregationNbContextDistribution-t.lastSample.AggregationNbContextDistribution), t.tagsByType[distribution]) + telemetryCount("datadog.dogstatsd.client.aggregated_context_by_type", int64(tlm.AggregationNbContextTiming-t.lastSample.AggregationNbContextTiming), t.tagsByType[timing]) + } + + t.lastSample = tlm + + return m +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/udp.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/udp.go new file mode 100644 index 000000000..b90f75279 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/udp.go @@ -0,0 +1,39 @@ +package statsd + +import ( + "net" + "time" +) + +// udpWriter is an internal class wrapping around management of UDP connection +type udpWriter struct { + conn net.Conn +} + +// New returns a pointer to a new udpWriter given an addr in the format "hostname:port". +func newUDPWriter(addr string, _ time.Duration) (*udpWriter, error) { + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + conn, err := net.DialUDP("udp", nil, udpAddr) + if err != nil { + return nil, err + } + writer := &udpWriter{conn: conn} + return writer, nil +} + +// Write data to the UDP connection with no error handling +func (w *udpWriter) Write(data []byte) (int, error) { + return w.conn.Write(data) +} + +func (w *udpWriter) Close() error { + return w.conn.Close() +} + +// GetTransportName returns the transport used by the sender +func (w *udpWriter) GetTransportName() string { + return writerNameUDP +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/uds.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/uds.go new file mode 100644 index 000000000..ed26f3ea2 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/uds.go @@ -0,0 +1,190 @@ +//go:build !windows +// +build !windows + +package statsd + +import ( + "encoding/binary" + "net" + "strings" + "sync" + "time" +) + +// udsWriter is an internal class wrapping around management of UDS connection +type udsWriter struct { + // Address to send metrics to, needed to allow reconnection on error + addr string + // Transport used + transport string + // Established connection object, or nil if not connected yet + conn net.Conn + // write timeout + writeTimeout time.Duration + // connect timeout + connectTimeout time.Duration + sync.RWMutex // used to lock conn / writer can replace it +} + +// newUDSWriter returns a pointer to a new udsWriter given a socket file path as addr. +func newUDSWriter(addr string, writeTimeout time.Duration, connectTimeout time.Duration, transport string) (*udsWriter, error) { + // Defer connection to first Write + writer := &udsWriter{addr: addr, transport: transport, conn: nil, writeTimeout: writeTimeout, connectTimeout: connectTimeout} + return writer, nil +} + +// GetTransportName returns the transport used by the writer +func (w *udsWriter) GetTransportName() string { + w.RLock() + defer w.RUnlock() + + if w.transport == "unix" { + return writerNameUDSStream + } else { + return writerNameUDS + } +} + +func (w *udsWriter) shouldCloseConnection(err error, partialWrite bool) bool { + if err != nil && partialWrite { + // We can't recover from a partial write + return true + } + if err, isNetworkErr := err.(net.Error); err != nil && (!isNetworkErr || !err.Timeout()) { + // Statsd server disconnected, retry connecting at next packet + return true + } + return false +} + +// Write data to the UDS connection with write timeout and minimal error handling: +// create the connection if nil, and destroy it if the statsd server has disconnected +func (w *udsWriter) Write(data []byte) (int, error) { + var n int + partialWrite := false + conn, err := w.ensureConnection() + if err != nil { + return 0, err + } + stream := conn.LocalAddr().Network() == "unix" + + // When using streams the deadline will only make us drop the packet if we can't write it at all, + // once we've started writing we need to finish. + conn.SetWriteDeadline(time.Now().Add(w.writeTimeout)) + + // When using streams, we append the length of the packet to the data + if stream { + bs := []byte{0, 0, 0, 0} + binary.LittleEndian.PutUint32(bs, uint32(len(data))) + _, err = w.conn.Write(bs) + + partialWrite = true + + // W need to be able to finish to write partially written packets once we have started. + // But we will reset the connection if we can't write anything at all for a long time. + w.conn.SetWriteDeadline(time.Now().Add(w.connectTimeout)) + + // Continue writing only if we've written the length of the packet + if err == nil { + n, err = w.conn.Write(data) + if err == nil { + partialWrite = false + } + } + } else { + n, err = w.conn.Write(data) + } + + if w.shouldCloseConnection(err, partialWrite) { + w.unsetConnection() + } + return n, err +} + +func (w *udsWriter) Close() error { + if w.conn != nil { + return w.conn.Close() + } + return nil +} + +func (w *udsWriter) tryToDial(network string) (net.Conn, error) { + udsAddr, err := net.ResolveUnixAddr(network, w.addr) + if err != nil { + return nil, err + } + + // Try to gracefully reconnect to the socket when we encounter "connection refused", as it's likely that the Agent + // is restarting and the socket is not yet available. + connectAttemptsLeft := 3 + connectDeadline := time.Now().Add(w.connectTimeout) + + // Calculate the backoff time for connection refused errors, but don't exceed one second: this means we won't waste + // longer than 1 seconds worth of time if the socket becomes available immediately after our last connect attempt + connRefusedBackoff := w.connectTimeout / time.Duration(connectAttemptsLeft+1) + if connRefusedBackoff > time.Second { + connRefusedBackoff = time.Second + } + + for { + connectAttemptsLeft-- + + perCallTimeout := time.Until(connectDeadline) + newConn, err := net.DialTimeout(udsAddr.Network(), udsAddr.String(), perCallTimeout) + if err != nil { + if strings.HasSuffix(err.Error(), "connection refused") && connectAttemptsLeft > 0 { + // If we get a connection refused error, we need to wait a bit before trying again. + time.Sleep(connRefusedBackoff) + continue + } + return nil, err + } + return newConn, nil + } +} + +func (w *udsWriter) ensureConnection() (net.Conn, error) { + // Check if we've already got a socket we can use + w.RLock() + currentConn := w.conn + w.RUnlock() + + if currentConn != nil { + return currentConn, nil + } + + // Looks like we might need to connect - try again with write locking. + w.Lock() + defer w.Unlock() + if w.conn != nil { + return w.conn, nil + } + + var newConn net.Conn + var err error + + // Try to guess the transport if not specified. + if w.transport == "" { + newConn, err = w.tryToDial("unixgram") + // try to connect with unixgram failed, try again with unix streams. + if err != nil && strings.Contains(err.Error(), "protocol wrong type for socket") { + newConn, err = w.tryToDial("unix") + } + } else { + newConn, err = w.tryToDial(w.transport) + } + + if err != nil { + return nil, err + } + w.conn = newConn + w.transport = newConn.RemoteAddr().Network() + return newConn, nil +} + +func (w *udsWriter) unsetConnection() { + w.Lock() + defer w.Unlock() + _ = w.conn.Close() + w.conn = nil +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/uds_windows.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/uds_windows.go new file mode 100644 index 000000000..909f5a0a0 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/uds_windows.go @@ -0,0 +1,15 @@ +//go:build windows +// +build windows + +package statsd + +import ( + "fmt" + "time" +) + +// newUDSWriter is disabled on Windows, SOCK_DGRAM are still unavailable but +// SOCK_STREAM should work once implemented in the agent (https://devblogs.microsoft.com/commandline/af_unix-comes-to-windows/) +func newUDSWriter(_ string, _ time.Duration, _ time.Duration, _ string) (Transport, error) { + return nil, fmt.Errorf("Unix socket is not available on Windows") +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/utils.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/utils.go new file mode 100644 index 000000000..8c3ac8426 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/utils.go @@ -0,0 +1,32 @@ +package statsd + +import ( + "math/rand" + "sync" +) + +func shouldSample(rate float64, r *rand.Rand, lock *sync.Mutex) bool { + if rate >= 1 { + return true + } + // sources created by rand.NewSource() (ie. w.random) are not thread safe. + // TODO: use defer once the lowest Go version we support is 1.14 (defer + // has an overhead before that). + lock.Lock() + if r.Float64() > rate { + lock.Unlock() + return false + } + lock.Unlock() + return true +} + +func copySlice(src []string) []string { + if src == nil { + return nil + } + + c := make([]string, len(src)) + copy(c, src) + return c +} diff --git a/vendor/github.com/DataDog/datadog-go/v5/statsd/worker.go b/vendor/github.com/DataDog/datadog-go/v5/statsd/worker.go new file mode 100644 index 000000000..056282627 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/v5/statsd/worker.go @@ -0,0 +1,158 @@ +package statsd + +import ( + "math/rand" + "sync" + "time" +) + +type worker struct { + pool *bufferPool + buffer *statsdBuffer + sender *sender + random *rand.Rand + randomLock sync.Mutex + sync.Mutex + + inputMetrics chan metric + stop chan struct{} +} + +func newWorker(pool *bufferPool, sender *sender) *worker { + // Each worker uses its own random source and random lock to prevent + // workers in separate goroutines from contending for the lock on the + // "math/rand" package-global random source (e.g. calls like + // "rand.Float64()" must acquire a shared lock to get the next + // pseudorandom number). + // Note that calling "time.Now().UnixNano()" repeatedly quickly may return + // very similar values. That's fine for seeding the worker-specific random + // source because we just need an evenly distributed stream of float values. + // Do not use this random source for cryptographic randomness. + random := rand.New(rand.NewSource(time.Now().UnixNano())) + return &worker{ + pool: pool, + sender: sender, + buffer: pool.borrowBuffer(), + random: random, + stop: make(chan struct{}), + } +} + +func (w *worker) startReceivingMetric(bufferSize int) { + w.inputMetrics = make(chan metric, bufferSize) + go w.pullMetric() +} + +func (w *worker) stopReceivingMetric() { + w.stop <- struct{}{} +} + +func (w *worker) pullMetric() { + for { + select { + case m := <-w.inputMetrics: + w.processMetric(m) + case <-w.stop: + return + } + } +} + +func (w *worker) processMetric(m metric) error { + // Aggregated metrics are already sampled. + if m.metricType != distributionAggregated && m.metricType != histogramAggregated && m.metricType != timingAggregated { + if !shouldSample(m.rate, w.random, &w.randomLock) { + return nil + } + } + w.Lock() + var err error + if err = w.writeMetricUnsafe(m); err == errBufferFull { + w.flushUnsafe() + err = w.writeMetricUnsafe(m) + } + w.Unlock() + return err +} + +func (w *worker) writeAggregatedMetricUnsafe(m metric, metricSymbol []byte, precision int, rate float64) error { + globalPos := 0 + + // first check how much data we can write to the buffer: + // +3 + len(metricSymbol) because the message will include '||#' before the tags + // +1 for the potential line break at the start of the metric + extraSize := len(m.stags) + 4 + len(metricSymbol) + if m.rate < 1 { + // +2 for "|@" + // + the maximum size of a rate (https://en.wikipedia.org/wiki/IEEE_754-1985) + extraSize += 2 + 18 + } + for _, t := range m.globalTags { + extraSize += len(t) + 1 + } + + for { + pos, err := w.buffer.writeAggregated(metricSymbol, m.namespace, m.globalTags, m.name, m.fvalues[globalPos:], m.stags, extraSize, precision, rate, m.originDetection, m.cardinality) + if err == errPartialWrite { + // We successfully wrote part of the histogram metrics. + // We flush the current buffer and finish the histogram + // in a new one. + w.flushUnsafe() + globalPos += pos + } else { + return err + } + } +} + +func (w *worker) writeMetricUnsafe(m metric) error { + switch m.metricType { + case gauge: + return w.buffer.writeGauge(m.namespace, m.globalTags, m.name, m.fvalue, m.tags, m.rate, m.timestamp, m.originDetection, m.cardinality) + case count: + return w.buffer.writeCount(m.namespace, m.globalTags, m.name, m.ivalue, m.tags, m.rate, m.timestamp, m.originDetection, m.cardinality) + case histogram: + return w.buffer.writeHistogram(m.namespace, m.globalTags, m.name, m.fvalue, m.tags, m.rate, m.originDetection, m.cardinality) + case distribution: + return w.buffer.writeDistribution(m.namespace, m.globalTags, m.name, m.fvalue, m.tags, m.rate, m.originDetection, m.cardinality) + case set: + return w.buffer.writeSet(m.namespace, m.globalTags, m.name, m.svalue, m.tags, m.rate, m.originDetection, m.cardinality) + case timing: + return w.buffer.writeTiming(m.namespace, m.globalTags, m.name, m.fvalue, m.tags, m.rate, m.originDetection, m.cardinality) + case event: + return w.buffer.writeEvent(m.evalue, m.globalTags, m.originDetection, m.cardinality) + case serviceCheck: + return w.buffer.writeServiceCheck(m.scvalue, m.globalTags, m.originDetection, m.cardinality) + case histogramAggregated: + return w.writeAggregatedMetricUnsafe(m, histogramSymbol, -1, m.rate) + case distributionAggregated: + return w.writeAggregatedMetricUnsafe(m, distributionSymbol, -1, m.rate) + case timingAggregated: + return w.writeAggregatedMetricUnsafe(m, timingSymbol, 6, m.rate) + default: + return nil + } +} + +func (w *worker) flush() { + w.Lock() + w.flushUnsafe() + w.Unlock() +} + +func (w *worker) pause() { + w.Lock() +} + +func (w *worker) unpause() { + w.Unlock() +} + +// flush the current buffer. Lock must be held by caller. +// flushed buffer written to the network asynchronously. +func (w *worker) flushUnsafe() { + if len(w.buffer.bytes()) > 0 { + w.sender.send(w.buffer) + w.buffer = w.pool.borrowBuffer() + } +} diff --git a/vendor/modules.txt b/vendor/modules.txt index c13e36ca0..62ac1e341 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -9,6 +9,9 @@ filippo.io/edwards25519/field ## explicit; go 1.16 github.com/Azure/go-ansiterm github.com/Azure/go-ansiterm/winterm +# github.com/DataDog/datadog-go/v5 v5.8.3 +## explicit; go 1.13 +github.com/DataDog/datadog-go/v5/statsd # github.com/Masterminds/semver v1.5.0 ## explicit github.com/Masterminds/semver