diff --git a/oomagent/server.go b/oomagent/server.go index 334d20b8d..a3c784e6a 100644 --- a/oomagent/server.go +++ b/oomagent/server.go @@ -2,9 +2,11 @@ package main import ( "context" + "encoding/csv" "fmt" "io" "log" + "os" "time" "google.golang.org/grpc/codes" @@ -14,6 +16,7 @@ import ( "github.com/oom-ai/oomstore/pkg/errdefs" "github.com/oom-ai/oomstore/pkg/oomstore" "github.com/oom-ai/oomstore/pkg/oomstore/types" + "github.com/spf13/cast" ) type server struct { @@ -292,18 +295,55 @@ func (s *server) ChannelJoin(stream codegen.OomAgent_ChannelJoinServer) error { } func (s *server) Join(ctx context.Context, req *codegen.JoinRequest) (*codegen.JoinResponse, error) { - err := s.oomstore.Join(ctx, types.JoinOpt{ - FeatureNames: req.Features, - InputFilePath: req.InputFile, - OutputFilePath: req.OutputFile, + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + joinResult, err := s.oomstore.Join(ctx, types.JoinOpt{ + FeatureNames: req.Features, + InputFilePath: req.InputFile, }) if err != nil { return nil, internalError(err.Error()) } + if err := writeJoinResultToFile(req.OutputFile, joinResult); err != nil { + return nil, wrapErr(err) + } + return &codegen.JoinResponse{}, nil } +func writeJoinResultToFile(outputFilePath string, joinResult *types.JoinResult) error { + file, err := os.Create(outputFilePath) + if err != nil { + return err + } + defer file.Close() + w := csv.NewWriter(file) + defer w.Flush() + + if err := w.Write(joinResult.Header); err != nil { + return err + } + for row := range joinResult.Data { + if row.Error != nil { + return row.Error + } + if err := w.Write(joinRecord(row.Record)); err != nil { + return err + } + } + return nil +} + +func joinRecord(row []interface{}) []string { + record := make([]string, 0, len(row)) + for _, value := range row { + record = append(record, cast.ToString(value)) + } + return record +} + func (s *server) ChannelExport(req *codegen.ChannelExportRequest, stream codegen.OomAgent_ChannelExportServer) error { if len(req.Features) == 0 { return nil diff --git a/oomcli/cmd/join_helper.go b/oomcli/cmd/join_helper.go index 236c987e7..0391d9e6b 100644 --- a/oomcli/cmd/join_helper.go +++ b/oomcli/cmd/join_helper.go @@ -22,15 +22,9 @@ func join(ctx context.Context, store *oomstore.OomStore, opt JoinOpt, output str ctx, cancel := context.WithCancel(ctx) defer cancel() - entityRows, header, err := oomstore.GetEntityRowsFromInputFile(ctx, opt.InputFilePath) - if err != nil { - return err - } - - joinResult, err := store.ChannelJoin(ctx, types.ChannelJoinOpt{ - JoinFeatureNames: opt.FeatureNames, - EntityRows: entityRows, - ExistedFeatureNames: header[2:], + joinResult, err := store.Join(ctx, types.JoinOpt{ + FeatureNames: opt.FeatureNames, + InputFilePath: opt.InputFilePath, }) if err != nil { return err diff --git a/pkg/oomstore/join.go b/pkg/oomstore/join.go index bb6adc5de..a925f6330 100644 --- a/pkg/oomstore/join.go +++ b/pkg/oomstore/join.go @@ -9,8 +9,6 @@ import ( "sort" "strconv" - "github.com/spf13/cast" - "github.com/oom-ai/oomstore/internal/database/offline" "github.com/oom-ai/oomstore/pkg/errdefs" "github.com/oom-ai/oomstore/pkg/oomstore/types" @@ -73,31 +71,24 @@ func (s *OomStore) ChannelJoin(ctx context.Context, opt types.ChannelJoinOpt) (* } // Join gets point-in-time correct feature values for each entity row. -// The method is similar to Join, except that both input and output are files on disk. +// The method is similar to ChannelJoin, except a input files on disk. // Input File should contain header, the first two columns of Input File should be // entity_key, unix_milli, then followed by other real-time feature values. -func (s *OomStore) Join(ctx context.Context, opt types.JoinOpt) error { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - +func (s *OomStore) Join(ctx context.Context, opt types.JoinOpt) (*types.JoinResult, error) { if err := util.ValidateFullFeatureNames(opt.FeatureNames...); err != nil { - return err + return nil, err } - entityRows, header, err := GetEntityRowsFromInputFile(ctx, opt.InputFilePath) + entityRows, header, err := getEntityRowsFromInputFile(ctx, opt.InputFilePath) if err != nil { - return err + return nil, err } - joinResult, err := s.ChannelJoin(ctx, types.ChannelJoinOpt{ + return s.ChannelJoin(ctx, types.ChannelJoinOpt{ JoinFeatureNames: opt.FeatureNames, EntityRows: entityRows, ExistedFeatureNames: header[2:], }) - if err != nil { - return err - } - return writeJoinResultToFile(opt.OutputFilePath, joinResult) } func (s *OomStore) buildRevisionRanges(ctx context.Context, group *types.Group) ([]*offline.RevisionRange, error) { @@ -142,7 +133,7 @@ func (s *OomStore) buildRevisionRanges(ctx context.Context, group *types.Group) return ranges, nil } -func GetEntityRowsFromInputFile(ctx context.Context, inputFilePath string) (<-chan types.EntityRow, []string, error) { +func getEntityRowsFromInputFile(ctx context.Context, inputFilePath string) (<-chan types.EntityRow, []string, error) { input, err := os.Open(inputFilePath) if err != nil { return nil, nil, errdefs.WithStack(err) @@ -207,34 +198,3 @@ func GetEntityRowsFromInputFile(ctx context.Context, inputFilePath string) (<-ch }() return entityRows, header, nil } - -func writeJoinResultToFile(outputFilePath string, joinResult *types.JoinResult) error { - file, err := os.Create(outputFilePath) - if err != nil { - return err - } - defer file.Close() - w := csv.NewWriter(file) - defer w.Flush() - - if err := w.Write(joinResult.Header); err != nil { - return err - } - for row := range joinResult.Data { - if row.Error != nil { - return row.Error - } - if err := w.Write(joinRecord(row.Record)); err != nil { - return err - } - } - return nil -} - -func joinRecord(row []interface{}) []string { - record := make([]string, 0, len(row)) - for _, value := range row { - record = append(record, cast.ToString(value)) - } - return record -} diff --git a/pkg/oomstore/types/options.go b/pkg/oomstore/types/options.go index 59d0206c9..0492f2b05 100644 --- a/pkg/oomstore/types/options.go +++ b/pkg/oomstore/types/options.go @@ -70,9 +70,8 @@ type ChannelJoinOpt struct { } type JoinOpt struct { - FeatureNames []string - InputFilePath string - OutputFilePath string + FeatureNames []string + InputFilePath string } type UpdateEntityOpt struct {