Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 46 additions & 63 deletions runtime/drivers/athena/information_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,43 +115,67 @@ func (c *Connection) ListTables(ctx context.Context, database, databaseSchema st
return res, next, nil
}

func (c *Connection) GetTable(ctx context.Context, database, databaseSchema, table string) (*drivers.TableMetadata, error) {
func (c *Connection) Lookup(ctx context.Context, database, databaseSchema, name string) (*drivers.OlapTable, error) {
q := fmt.Sprintf(`
SELECT
CASE t.table_type WHEN 'VIEW' THEN true ELSE false END AS view,
column_name,
data_type
FROM %s.information_schema.columns c
JOIN %s.information_schema.tables t
ON t.table_schema = c.table_schema AND t.table_name = c.table_name
WHERE c.table_schema = ? AND c.table_name = ?
ORDER BY c.ordinal_position
`, sqlSafeName(database), sqlSafeName(database))
SELECT
CASE t.table_type WHEN 'VIEW' THEN true ELSE false END AS view,
column_name,
data_type
FROM %s.information_schema.columns c
JOIN %s.information_schema.tables t
ON t.table_schema = c.table_schema AND t.table_name = c.table_name
WHERE c.table_schema = ? AND c.table_name = ?
ORDER BY c.ordinal_position
`, sqlSafeName(database), sqlSafeName(database))

rows, err := c.Query(ctx, &drivers.Statement{
Query: q,
Args: []any{databaseSchema, table},
Args: []any{databaseSchema, name},
})
if err != nil {
return nil, err
}
defer rows.Close()

res := &drivers.TableMetadata{
Schema: make(map[string]string),
}
var view bool
var col, typ string
fields := make([]*runtimev1.StructType_Field, 0)
for rows.Next() {
err = rows.Scan(&res.View, &col, &typ)
if err != nil {
if err := rows.Scan(&view, &col, &typ); err != nil {
return nil, err
}
res.Schema[col] = typ
}
if err := rows.Err(); err != nil {
return nil, err
fields = append(fields, &runtimev1.StructType_Field{
Name: col,
Type: athenaTypeToRuntimeType(typ),
})
}
return res, nil

return &drivers.OlapTable{
Database: database,
DatabaseSchema: databaseSchema,
Name: name,
View: view,
Schema: &runtimev1.StructType{
Fields: fields,
},
UnsupportedCols: nil,
PhysicalSizeBytes: 0,
}, rows.Err()
}

// All implements drivers.OLAPInformationSchema.
func (c *Connection) All(ctx context.Context, like string, pageSize uint32, pageToken string) ([]*drivers.OlapTable, string, error) {
return drivers.AllFromInformationSchema(ctx, like, pageSize, pageToken, c)
}

// LoadPhysicalSize implements drivers.OLAPInformationSchema.
func (c *Connection) LoadPhysicalSize(ctx context.Context, tables []*drivers.OlapTable) error {
return nil
}

// LoadDDL implements drivers.OLAPInformationSchema.
func (c *Connection) LoadDDL(ctx context.Context, table *drivers.OlapTable) error {
return nil // Not implemented
}

func (c *Connection) listCatalogs(ctx context.Context, client *athena.Client) ([]string, error) {
Expand Down Expand Up @@ -181,47 +205,6 @@ func (c *Connection) listCatalogs(ctx context.Context, client *athena.Client) ([
return catalogs, nil
}

// All implements drivers.OLAPInformationSchema.
func (c *Connection) All(ctx context.Context, like string, pageSize uint32, pageToken string) ([]*drivers.OlapTable, string, error) {
return drivers.AllFromInformationSchema(ctx, like, pageSize, pageToken, c)
}

// LoadPhysicalSize implements drivers.OLAPInformationSchema.
func (c *Connection) LoadPhysicalSize(ctx context.Context, tables []*drivers.OlapTable) error {
return nil
}

// LoadDDL implements drivers.OLAPInformationSchema.
func (c *Connection) LoadDDL(ctx context.Context, table *drivers.OlapTable) error {
return nil // Not implemented
}

// Lookup implements drivers.OLAPInformationSchema.
func (c *Connection) Lookup(ctx context.Context, db, schema, name string) (*drivers.OlapTable, error) {
meta, err := c.GetTable(ctx, db, schema, name)
if err != nil {
return nil, err
}
runtimeSchema := &runtimev1.StructType{
Fields: make([]*runtimev1.StructType_Field, 0, len(meta.Schema)),
}
for name, typ := range meta.Schema {
runtimeSchema.Fields = append(runtimeSchema.Fields, &runtimev1.StructType_Field{
Name: name,
Type: athenaTypeToRuntimeType(typ),
})
}
return &drivers.OlapTable{
Database: db,
DatabaseSchema: schema,
Name: name,
View: meta.View,
Schema: runtimeSchema,
UnsupportedCols: nil,
PhysicalSizeBytes: 0,
}, nil
}

func (c *Connection) listSchemasForCatalog(ctx context.Context, client *athena.Client, catalog string) ([]*drivers.DatabaseSchemaInfo, error) {
// Use catalog if specified
var q string
Expand Down
14 changes: 7 additions & 7 deletions runtime/drivers/athena/information_schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,20 @@ func TestGetTable(t *testing.T) {
require.True(t, ok)

// Test getting metadata for the all_datatypes table
metadata, err := infoSchema.GetTable(ctx, "awsdatacatalog", "integration_test", "all_datatypes")
metadata, err := infoSchema.Lookup(ctx, "awsdatacatalog", "integration_test", "all_datatypes")
require.NoError(t, err)
require.NotNil(t, metadata)
require.False(t, metadata.View)
require.NotEmpty(t, metadata.Schema)

// Verify some expected columns exist
_, hasID := metadata.Schema["id"]
hasID := metadata.Schema.Fields[0].Name == "id"
require.True(t, hasID, "Expected 'id' column in table schema")

_, hasInt32 := metadata.Schema["int32_col"]
hasInt32 := metadata.Schema.Fields[2].Name == "int32_col"
require.True(t, hasInt32, "Expected 'int32_col' column in table schema")

_, hasFloat := metadata.Schema["float_col"]
hasFloat := metadata.Schema.Fields[4].Name == "float_col"
require.True(t, hasFloat, "Expected 'float_col' column in table schema")
})

Expand All @@ -53,17 +53,17 @@ func TestGetTable(t *testing.T) {
})

// Get metadata for the view
metadata, err := infoSchema.GetTable(ctx, "awsdatacatalog", "integration_test", "test_view")
metadata, err := infoSchema.Lookup(ctx, "awsdatacatalog", "integration_test", "test_view")
require.NoError(t, err)
require.NotNil(t, metadata)
require.True(t, metadata.View, "Expected test_view to be identified as a view")
require.NotEmpty(t, metadata.Schema)

// Verify columns from the view
_, hasID := metadata.Schema["id"]
hasID := metadata.Schema.Fields[0].Name == "id"
require.True(t, hasID, "Expected 'id' column in view schema")

_, hasInt32 := metadata.Schema["int32_col"]
hasInt32 := metadata.Schema.Fields[1].Name == "int32_col"
require.True(t, hasInt32, "Expected 'int32_col' column in view schema")
})
}
9 changes: 8 additions & 1 deletion runtime/drivers/athena/olap.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,14 @@ func (r *rows) runtimeSchema() *runtimev1.StructType {

func athenaTypeToRuntimeType(colType string) *runtimev1.Type {
t := &runtimev1.Type{RawType: colType}
switch strings.ToLower(colType) {

typeLower := strings.ToLower(colType)
baseType := typeLower
if idx := strings.Index(typeLower, "("); idx != -1 {
baseType = typeLower[:idx]
}

switch baseType {
case "tinyint":
t.Code = runtimev1.Type_CODE_INT8
case "smallint":
Expand Down
93 changes: 21 additions & 72 deletions runtime/drivers/bigquery/information_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,54 +124,37 @@ func (c *Connection) ListTables(ctx context.Context, database, databaseSchema st
return res, next, nil
}

func (c *Connection) GetTable(ctx context.Context, database, databaseSchema, table string) (*drivers.TableMetadata, error) {
q := fmt.Sprintf(`
SELECT
CASE t.table_type WHEN 'VIEW' THEN true else false END AS is_view,
c.column_name,
c.data_type
FROM `+"`%s.%s.INFORMATION_SCHEMA.TABLES`"+` AS t
JOIN `+"`%s.%s.INFORMATION_SCHEMA.COLUMNS`"+` AS c
ON t.table_name = c.table_name
WHERE c.table_name = @table
ORDER BY c.ordinal_position
`, database, databaseSchema, database, databaseSchema)

func (c *Connection) Lookup(ctx context.Context, database, databaseSchema, name string) (*drivers.OlapTable, error) {
client, err := c.getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery client: %w", err)
}
cq := client.Query(q)
cq.Parameters = []bigquery.QueryParameter{
{Name: "table", Value: table},
}

it, err := cq.Read(ctx)
if err != nil {
return nil, fmt.Errorf("failed to run INFORMATION_SCHEMA query: %w", err)
var table *bigquery.Table
if database != "" {
table = client.DatasetInProject(database, databaseSchema).Table(name)
} else {
table = client.Dataset(databaseSchema).Table(name)
}

r := &drivers.TableMetadata{
Schema: make(map[string]string),
meta, err := table.Metadata(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get table metadata: %w", err)
}
var row struct {
IsView bool `bigquery:"is_view"`
ColumnName string `bigquery:"column_name"`
DataType string `bigquery:"data_type"`
runtimeSchema, err := fromBQSchema(meta.Schema)
if err != nil {
return nil, err
}
for {
err := it.Next(&row)
if errors.Is(err, iterator.Done) {
break
}
if err != nil {
return nil, fmt.Errorf("failed to iterate over schema rows: %w", err)
}
r.Schema[row.ColumnName] = row.DataType
r.View = row.IsView
tbl := &drivers.OlapTable{
Database: database,
DatabaseSchema: databaseSchema,
Name: name,
View: meta.Type == bigquery.ViewTable,
Schema: runtimeSchema,
UnsupportedCols: nil, // all columns are currently being mapped though may not be as specific as in BigQuery
PhysicalSizeBytes: 0,
}

return r, nil
return tbl, nil
}

// All implements drivers.OLAPInformationSchema.
Expand Down Expand Up @@ -212,37 +195,3 @@ func (c *Connection) LoadDDL(ctx context.Context, table *drivers.OlapTable) erro
table.DDL = row.DDL
return nil
}

// Lookup implements drivers.OLAPInformationSchema.
func (c *Connection) Lookup(ctx context.Context, db, schema, name string) (*drivers.OlapTable, error) {
client, err := c.getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery client: %w", err)
}

var table *bigquery.Table
if db != "" {
table = client.DatasetInProject(db, schema).Table(name)
} else {
table = client.Dataset(schema).Table(name)
}

meta, err := table.Metadata(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get table metadata: %w", err)
}
runtimeSchema, err := fromBQSchema(meta.Schema)
if err != nil {
return nil, err
}
tbl := &drivers.OlapTable{
Database: db,
DatabaseSchema: schema,
Name: name,
View: meta.Type == bigquery.ViewTable,
Schema: runtimeSchema,
UnsupportedCols: nil, // all columns are currently being mapped though may not be as specific as in BigQuery
PhysicalSizeBytes: 0,
}
return tbl, nil
}
Loading
Loading