diff --git a/backend/src/ee/services/pam-account/pam-account-schemas.test.ts b/backend/src/ee/services/pam-account/pam-account-schemas.test.ts index e02eac70609..f4bae64c814 100644 --- a/backend/src/ee/services/pam-account/pam-account-schemas.test.ts +++ b/backend/src/ee/services/pam-account/pam-account-schemas.test.ts @@ -13,13 +13,15 @@ import { // These assertions exercise the Zod-introspection path (buildPamAccountTypeMetadata reads schema internals to derive field descriptors) describe("buildPamAccountTypeMetadata", () => { - const metadata = buildPamAccountTypeMetadata(new Set([PamAccountType.Postgres, PamAccountType.SSH])); + const metadata = buildPamAccountTypeMetadata( + new Set([PamAccountType.Postgres, PamAccountType.MySQL, PamAccountType.SSH]) + ); const byType = new Map(metadata.map((m) => [m.type, m])); test("flags web-access support from the provided supported-types set", () => { expect(byType.get(PamAccountType.Postgres)?.supportsWebAccess).toBe(true); expect(byType.get(PamAccountType.SSH)?.supportsWebAccess).toBe(true); - expect(byType.get(PamAccountType.MySQL)?.supportsWebAccess).toBe(false); + expect(byType.get(PamAccountType.MySQL)?.supportsWebAccess).toBe(true); expect(byType.get(PamAccountType.Kubernetes)?.supportsWebAccess).toBe(false); }); diff --git a/backend/src/ee/services/pam-web-access/mysql/pam-mysql-connection-controller.ts b/backend/src/ee/services/pam-web-access/mysql/pam-mysql-connection-controller.ts new file mode 100644 index 00000000000..1c1fd3c66b0 --- /dev/null +++ b/backend/src/ee/services/pam-web-access/mysql/pam-mysql-connection-controller.ts @@ -0,0 +1,283 @@ +import mysql from "mysql2/promise"; + +import { logger } from "@app/lib/logger"; + +import { type ControllerParams } from "../pam-data-explorer-session-handler"; +import { + DataExplorerClientMessageType, + DataExplorerServerMessageType, + type TConnectionController, + type TTabScopedMessage +} from "../pam-data-explorer-ws-types"; +import { extractCommand, splitMysqlStatements } from "./pam-mysql-data-explorer-fns"; +import { getTableDetailQuery } from "./pam-mysql-data-explorer-metadata"; + +const MAX_ROWS = 1000; + +export const createMysqlConnectionController = async (params: ControllerParams): Promise => { + const { relayPort, username, database, sessionId, connectionId, sendResponse, onUnexpectedTermination } = params; + + const conn = await mysql.createConnection({ + host: "localhost", + port: relayPort, + user: username, + database: database || undefined, + password: "", + connectTimeout: 30_000, + multipleStatements: false, + supportBigNumbers: true, + bigNumberStrings: true, + dateStrings: true, + typeCast: (field, next) => { + if (field.type === "JSON") { + return field.string(); + } + if (field.type === "BIT" && field.length === 1) { + const buf = field.buffer(); + if (!buf) return null; + return buf[0] === 1 ? "true" : "false"; + } + if (field.type === "TINY" && field.length === 1) { + return field.string(); + } + return next(); + } + }); + + const [pidRows] = await conn.execute("SELECT CONNECTION_ID() AS pid"); + const nativeConnectionId = (pidRows[0]?.pid as number) ?? null; + + await conn.query(`SET SESSION max_execution_time = 30000, sql_select_limit = ${MAX_ROWS + 1}`); + + let isInTransaction = false; + let disposing = false; + + const queryTransactionState = async () => { + try { + const [result] = await conn.query("DO 0"); + // eslint-disable-next-line no-bitwise + isInTransaction = (result.serverStatus & 1) === 1; + } catch { + isInTransaction = false; + } + }; + + const sendQueryError = async (id: string, err: unknown) => { + const mysqlErr = err as { message?: string; sqlMessage?: string; code?: string }; + + try { + await conn.execute("ROLLBACK"); + } catch { + // ROLLBACK fails if there was no active transaction + } + + await queryTransactionState(); + + sendResponse({ + type: DataExplorerServerMessageType.Error, + id, + connectionId, + transactionOpen: isInTransaction, + error: mysqlErr.sqlMessage ?? mysqlErr.message ?? "Query execution failed", + detail: mysqlErr.code + }); + }; + + const cancelRunningQuery = async () => { + if (!nativeConnectionId) return; + let cancelConn: mysql.Connection | null = null; + try { + cancelConn = await mysql.createConnection({ + host: "localhost", + port: relayPort, + user: username, + database: database || undefined, + password: "", + connectTimeout: 5_000 + }); + cancelConn.on("error" as never, () => {}); + await cancelConn.execute("KILL QUERY ?", [nativeConnectionId]); + } catch (err) { + logger.debug(err, `Failed to cancel MySQL query [sessionId=${sessionId}] [connectionId=${connectionId}]`); + } finally { + if (cancelConn) await cancelConn.end().catch(() => {}); + } + }; + + // max_execution_time only covers SELECTs; this guards DML/DDL with KILL QUERY on a timer + const queryWithTimeout = async (fn: () => Promise, timeoutMs = 30_000): Promise => { + const timer = setTimeout(() => { + void cancelRunningQuery(); + }, timeoutMs); + try { + return await fn(); + } finally { + clearTimeout(timer); + } + }; + + let processingPromise: Promise = Promise.resolve(); + + const handleMessage = (message: TTabScopedMessage) => { + if (message.type === DataExplorerClientMessageType.Cancel) { + if (disposing) return; + void cancelRunningQuery(); + return; + } + + processingPromise = processingPromise + .then(async () => { + if (disposing) return; + + switch (message.type) { + case DataExplorerClientMessageType.GetTableDetail: { + try { + const query = getTableDetailQuery(message.schema, message.table); + const [rows] = await conn.execute(query.sql, query.values); + // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment + const rawDetail = rows[0]?.result; + if (!rawDetail) { + sendResponse({ + type: DataExplorerServerMessageType.Error, + id: message.id, + connectionId, + transactionOpen: isInTransaction, + error: "Table not found or no metadata available" + }); + break; + } + const detail = + typeof rawDetail === "string" + ? (JSON.parse(rawDetail) as Record) + : (rawDetail as unknown as Record); + sendResponse({ + type: DataExplorerServerMessageType.TableDetail, + id: message.id, + connectionId, + transactionOpen: isInTransaction, + data: detail as { + columns: { + name: string; + type: string; + nullable: boolean; + identityGeneration: string | null; + }[]; + primaryKeys: string[]; + foreignKeys: { + constraintName: string; + columns: string[]; + targetSchema: string; + targetTable: string; + targetColumns: string[]; + }[]; + } + }); + } catch (err) { + await sendQueryError(message.id, err); + } + break; + } + + case DataExplorerClientMessageType.Query: { + try { + const startTime = performance.now(); + + const stmts = splitMysqlStatements(message.sql); + + let lastRows: Record[] = []; + let lastFields: { name: string }[] = []; + let lastRowCount: number | null = null; + let lastCommand = ""; + let lastIsTruncated = false; + + for (const stmtSql of stmts) { + // eslint-disable-next-line no-await-in-loop + await conn.query(`SET SESSION sql_select_limit = ${MAX_ROWS + 1}`); + // eslint-disable-next-line no-await-in-loop + const [result, fields] = await queryWithTimeout(() => conn.query(stmtSql)); + + if (Array.isArray(result)) { + const rows = result as mysql.RowDataPacket[]; + lastIsTruncated = rows.length > MAX_ROWS; + lastRows = lastIsTruncated ? rows.slice(0, MAX_ROWS) : rows; + lastFields = (fields ?? []).map((f) => ({ name: f.name })); + lastRowCount = rows.length; + lastCommand = "SELECT"; + } else { + const header = result as mysql.ResultSetHeader; + lastRowCount = header.affectedRows; + lastCommand = extractCommand(stmtSql); + lastRows = []; + lastFields = []; + lastIsTruncated = false; + } + } + + // eslint-disable-next-line no-await-in-loop + await queryTransactionState(); + + const safeRows = lastRows.map((row) => { + const out: Record = {}; + for (const [k, v] of Object.entries(row)) { + out[k] = Buffer.isBuffer(v) ? `\\x${v.toString("hex")}` : v; + } + return out; + }); + + const executionTimeMs = Math.round(performance.now() - startTime); + sendResponse({ + type: DataExplorerServerMessageType.QueryResult, + id: message.id, + connectionId, + rows: safeRows, + fields: lastFields, + rowCount: lastRowCount, + isTruncated: lastIsTruncated, + transactionOpen: isInTransaction, + command: lastCommand, + executionTimeMs + }); + } catch (err) { + await sendQueryError(message.id, err); + } + break; + } + + default: + break; + } + }) + .catch((err) => { + logger.error(err, `Error processing MySQL message [sessionId=${sessionId}] [connectionId=${connectionId}]`); + }); + }; + + conn.on("error" as never, (err: Error) => { + if (disposing) return; + logger.error(err, `MySQL tab connection error [sessionId=${sessionId}] [connectionId=${connectionId}]`); + disposing = true; + onUnexpectedTermination(err.message || "Database connection error"); + }); + + conn.on("end" as never, () => { + if (disposing) return; + disposing = true; + onUnexpectedTermination("Database connection ended"); + }); + + const dispose = () => { + if (disposing) return; + disposing = true; + void conn.end().catch((err) => { + logger.debug(err, `Error closing MySQL connection [sessionId=${sessionId}] [connectionId=${connectionId}]`); + }); + }; + + return { + connectionId, + nativeConnectionId, + handleMessage, + dispose, + isDisposing: () => disposing + }; +}; diff --git a/backend/src/ee/services/pam-web-access/mysql/pam-mysql-data-explorer-fns.test.ts b/backend/src/ee/services/pam-web-access/mysql/pam-mysql-data-explorer-fns.test.ts new file mode 100644 index 00000000000..44aacaf69ec --- /dev/null +++ b/backend/src/ee/services/pam-web-access/mysql/pam-mysql-data-explorer-fns.test.ts @@ -0,0 +1,191 @@ +import { describe, expect, test } from "vitest"; + +import { extractCommand, splitMysqlStatements } from "./pam-mysql-data-explorer-fns"; + +describe("splitMysqlStatements", () => { + test("single statement without semicolon", () => { + expect(splitMysqlStatements("SELECT 1")).toEqual(["SELECT 1"]); + }); + + test("single statement with trailing semicolon", () => { + expect(splitMysqlStatements("SELECT 1;")).toEqual(["SELECT 1"]); + }); + + test("multiple statements", () => { + expect(splitMysqlStatements("SELECT 1; SELECT 2; SELECT 3")).toEqual(["SELECT 1", "SELECT 2", "SELECT 3"]); + }); + + test("ignores empty statements", () => { + expect(splitMysqlStatements("SELECT 1;; ;SELECT 2")).toEqual(["SELECT 1", "SELECT 2"]); + }); + + test("empty input", () => { + expect(splitMysqlStatements("")).toEqual([]); + }); + + test("whitespace only", () => { + expect(splitMysqlStatements(" \n\t ")).toEqual([]); + }); + + test("semicolon inside single-quoted string", () => { + expect(splitMysqlStatements("SELECT 'a;b'; SELECT 2")).toEqual(["SELECT 'a;b'", "SELECT 2"]); + }); + + test("semicolon inside double-quoted string", () => { + expect(splitMysqlStatements('SELECT "a;b"; SELECT 2')).toEqual(['SELECT "a;b"', "SELECT 2"]); + }); + + test("semicolon inside backtick-quoted identifier", () => { + expect(splitMysqlStatements("SELECT `col;name` FROM t; SELECT 2")).toEqual([ + "SELECT `col;name` FROM t", + "SELECT 2" + ]); + }); + + test("escaped quote inside string", () => { + expect(splitMysqlStatements("SELECT 'it\\'s'; SELECT 2")).toEqual(["SELECT 'it\\'s'", "SELECT 2"]); + }); + + test("semicolon inside line comment", () => { + expect(splitMysqlStatements("SELECT 1 -- ; not a split\n; SELECT 2")).toEqual([ + "SELECT 1 -- ; not a split", + "SELECT 2" + ]); + }); + + test("semicolon inside hash comment", () => { + expect(splitMysqlStatements("SELECT 1 # ; not a split\n; SELECT 2")).toEqual([ + "SELECT 1 # ; not a split", + "SELECT 2" + ]); + }); + + test("semicolon inside block comment", () => { + expect(splitMysqlStatements("SELECT 1 /* ; still one */ ; SELECT 2")).toEqual([ + "SELECT 1 /* ; still one */", + "SELECT 2" + ]); + }); + + test("multi-line block comment", () => { + expect( + splitMysqlStatements(`SELECT 1 /* +; not a split +; still not +*/; SELECT 2`) + ).toEqual([ + `SELECT 1 /* +; not a split +; still not +*/`, + "SELECT 2" + ]); + }); + + test("transaction statements", () => { + expect(splitMysqlStatements("BEGIN; INSERT INTO t VALUES (1); COMMIT")).toEqual([ + "BEGIN", + "INSERT INTO t VALUES (1)", + "COMMIT" + ]); + }); + + test("mixed quoting styles", () => { + expect(splitMysqlStatements(`SELECT "a;b", 'c;d', \`e;f\`; SELECT 2`)).toEqual([ + "SELECT \"a;b\", 'c;d', `e;f`", + "SELECT 2" + ]); + }); + + test("doubled-quote escape in single-quoted string", () => { + expect(splitMysqlStatements("SELECT 'it''s here; still one'; SELECT 2")).toEqual([ + "SELECT 'it''s here; still one'", + "SELECT 2" + ]); + }); + + test("doubled-quote escape in double-quoted string", () => { + expect(splitMysqlStatements('SELECT "a""b;c"; SELECT 2')).toEqual(['SELECT "a""b;c"', "SELECT 2"]); + }); + + test("doubled backtick in identifier", () => { + expect(splitMysqlStatements("SELECT `col``; name` FROM t; SELECT 2")).toEqual([ + "SELECT `col``; name` FROM t", + "SELECT 2" + ]); + }); + + test("double-dash without trailing space is not a line comment", () => { + expect(splitMysqlStatements("SELECT 1--1; SELECT 2")).toEqual(["SELECT 1--1", "SELECT 2"]); + }); + + test("double-dash with trailing space is a line comment", () => { + expect(splitMysqlStatements("SELECT 1 -- comment\n; SELECT 2")).toEqual(["SELECT 1 -- comment", "SELECT 2"]); + }); + + test("double-dash at end of input is a line comment", () => { + expect(splitMysqlStatements("SELECT 1 --")).toEqual(["SELECT 1 --"]); + }); + + test("double-dash with tab is a line comment", () => { + expect(splitMysqlStatements("SELECT 1 --\tcomment\n; SELECT 2")).toEqual(["SELECT 1 --\tcomment", "SELECT 2"]); + }); + + test("backtick identifier with backslash is not escaped", () => { + expect(splitMysqlStatements("SELECT `col\\`; SELECT 2")).toEqual(["SELECT `col\\`", "SELECT 2"]); + }); + + test("backslash in single-quoted string still escapes", () => { + expect(splitMysqlStatements("SELECT 'a\\';b'; SELECT 2")).toEqual(["SELECT 'a\\';b'", "SELECT 2"]); + }); + + test("unterminated single-quoted string", () => { + expect(splitMysqlStatements("SELECT 'abc")).toEqual(["SELECT 'abc"]); + }); + + test("unterminated block comment", () => { + expect(splitMysqlStatements("SELECT 1 /* oops")).toEqual(["SELECT 1 /* oops"]); + }); +}); + +describe("extractCommand", () => { + test("simple SELECT", () => { + expect(extractCommand("SELECT 1")).toBe("SELECT"); + }); + + test("leading whitespace", () => { + expect(extractCommand(" \n INSERT INTO t VALUES (1)")).toBe("INSERT"); + }); + + test("leading line comment", () => { + expect(extractCommand("-- comment\nBEGIN")).toBe("BEGIN"); + }); + + test("leading hash comment", () => { + expect(extractCommand("# comment\nCOMMIT")).toBe("COMMIT"); + }); + + test("leading block comment", () => { + expect(extractCommand("/* note */ ROLLBACK")).toBe("ROLLBACK"); + }); + + test("multiple leading comments", () => { + expect(extractCommand("-- first\n/* second */ # third\nDELETE FROM t")).toBe("DELETE"); + }); + + test("double-dash without space is not a comment in extractCommand", () => { + expect(extractCommand("--nospc")).toBe("--NOSPC"); + }); + + test("START TRANSACTION", () => { + expect(extractCommand("START TRANSACTION")).toBe("START"); + }); + + test("empty input", () => { + expect(extractCommand("")).toBe(""); + }); + + test("only comments", () => { + expect(extractCommand("-- just a comment\n")).toBe(""); + }); +}); diff --git a/backend/src/ee/services/pam-web-access/mysql/pam-mysql-data-explorer-fns.ts b/backend/src/ee/services/pam-web-access/mysql/pam-mysql-data-explorer-fns.ts new file mode 100644 index 00000000000..115e779dfc2 --- /dev/null +++ b/backend/src/ee/services/pam-web-access/mysql/pam-mysql-data-explorer-fns.ts @@ -0,0 +1,109 @@ +const skipQuoted = (sql: string, pos: number, quote: string): number => { + let i = pos + 1; + const backslashEscape = quote !== "`"; + while (i < sql.length) { + if (backslashEscape && sql[i] === "\\") { + i += 2; + } else if (sql[i] === quote) { + if (i + 1 < sql.length && sql[i + 1] === quote) { + i += 2; + } else { + return i + 1; + } + } else { + i += 1; + } + } + return sql.length; +}; + +const skipLineComment = (sql: string, pos: number): number => { + let i = pos + 2; + while (i < sql.length && sql[i] !== "\n") i += 1; + return i; +}; + +const skipHashComment = (sql: string, pos: number): number => { + let i = pos + 1; + while (i < sql.length && sql[i] !== "\n") i += 1; + return i; +}; + +const skipBlockComment = (sql: string, pos: number): number => { + let i = pos + 2; + while (i + 1 < sql.length && !(sql[i] === "*" && sql[i + 1] === "/")) i += 1; + return Math.min(i + 2, sql.length); +}; + +const isLineComment = (sql: string, pos: number) => + sql[pos] === "-" && + sql[pos + 1] === "-" && + (pos + 2 >= sql.length || sql[pos + 2] === " " || sql[pos + 2] === "\t" || sql[pos + 2] === "\n"); + +const isBlockComment = (sql: string, pos: number) => sql[pos] === "/" && sql[pos + 1] === "*"; + +// mysql2 can't run multiple statements per query() call, so we split on ';' ourselves, skipping strings/comments. +export const splitMysqlStatements = (sql: string): string[] => { + const stmts: string[] = []; + let pos = 0; + let stmtStart = 0; + + while (pos < sql.length) { + const ch = sql[pos]; + + if (ch === "'" || ch === '"' || ch === "`") { + pos = skipQuoted(sql, pos, ch); + } else if (isLineComment(sql, pos)) { + pos = skipLineComment(sql, pos); + } else if (ch === "#") { + pos = skipHashComment(sql, pos); + } else if (isBlockComment(sql, pos)) { + pos = skipBlockComment(sql, pos); + } else if (ch === ";") { + const stmt = sql.slice(stmtStart, pos).trim(); + if (stmt.length > 0) stmts.push(stmt); + pos += 1; + stmtStart = pos; + } else { + pos += 1; + } + } + + const tail = sql.slice(stmtStart).trim(); + if (tail.length > 0) stmts.push(tail); + + return stmts; +}; + +export const extractCommand = (sql: string): string => { + let pos = 0; + const len = sql.length; + + // Skip leading whitespace and comments + while (pos < len) { + // Skip whitespace + if (sql[pos] === " " || sql[pos] === "\t" || sql[pos] === "\n" || sql[pos] === "\r") { + pos += 1; + } else if (isLineComment(sql, pos)) { + pos += 2; + while (pos < len && sql[pos] !== "\n") pos += 1; + if (pos < len) pos += 1; + } else if (sql[pos] === "#") { + pos += 1; + while (pos < len && sql[pos] !== "\n") pos += 1; + if (pos < len) pos += 1; + } else if (isBlockComment(sql, pos)) { + pos += 2; + while (pos + 1 < len && !(sql[pos] === "*" && sql[pos + 1] === "/")) pos += 1; + pos = Math.min(pos + 2, len); + } else { + break; + } + } + + const start = pos; + while (pos < len && sql[pos] !== " " && sql[pos] !== "\t" && sql[pos] !== "\n" && sql[pos] !== "\r") { + pos += 1; + } + return sql.slice(start, pos).toUpperCase(); +}; diff --git a/backend/src/ee/services/pam-web-access/mysql/pam-mysql-data-explorer-metadata.ts b/backend/src/ee/services/pam-web-access/mysql/pam-mysql-data-explorer-metadata.ts new file mode 100644 index 00000000000..e73b65074af --- /dev/null +++ b/backend/src/ee/services/pam-web-access/mysql/pam-mysql-data-explorer-metadata.ts @@ -0,0 +1,94 @@ +export const getSchemasQuery = () => ({ + sql: ` + SELECT SCHEMA_NAME AS name + FROM information_schema.SCHEMATA + WHERE SCHEMA_NAME NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys') + ORDER BY SCHEMA_NAME + `, + values: [] as string[] +}); + +export const getTablesQuery = (schema: string) => ({ + sql: ` + SELECT + TABLE_NAME AS name, + CASE TABLE_TYPE + WHEN 'BASE TABLE' THEN 'table' + WHEN 'VIEW' THEN 'view' + WHEN 'SYSTEM VIEW' THEN 'view' + ELSE 'other' + END AS tableType + FROM information_schema.TABLES + WHERE TABLE_SCHEMA = ? + ORDER BY TABLE_NAME + `, + values: [schema] +}); + +export const getTableDetailQuery = (schema: string, table: string) => ({ + sql: ` + SELECT JSON_OBJECT( + 'columns', COALESCE(( + SELECT JSON_ARRAYAGG( + JSON_OBJECT( + 'name', sub.COLUMN_NAME, + 'type', sub.COLUMN_TYPE, + 'nullable', CAST(IF(sub.IS_NULLABLE = 'YES', TRUE, FALSE) AS JSON), + 'identityGeneration', CASE + WHEN sub.EXTRA LIKE '%auto_increment%' THEN 'AUTO_INCREMENT' + ELSE NULL + END + ) + ) + FROM ( + SELECT c.COLUMN_NAME, c.COLUMN_TYPE, c.IS_NULLABLE, c.EXTRA + FROM information_schema.COLUMNS c + WHERE c.TABLE_SCHEMA = ? AND c.TABLE_NAME = ? + ORDER BY c.ORDINAL_POSITION + ) sub + ), JSON_ARRAY()), + 'primaryKeys', COALESCE(( + SELECT JSON_ARRAYAGG(sub.COLUMN_NAME) + FROM ( + SELECT kcu.COLUMN_NAME + FROM information_schema.KEY_COLUMN_USAGE kcu + JOIN information_schema.TABLE_CONSTRAINTS tc + ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME + AND tc.TABLE_SCHEMA = kcu.TABLE_SCHEMA + AND tc.TABLE_NAME = kcu.TABLE_NAME + WHERE kcu.TABLE_SCHEMA = ? AND kcu.TABLE_NAME = ? + AND tc.CONSTRAINT_TYPE = 'PRIMARY KEY' + ORDER BY kcu.ORDINAL_POSITION + ) sub + ), JSON_ARRAY()), + 'foreignKeys', COALESCE(( + SELECT JSON_ARRAYAGG( + JSON_OBJECT( + 'constraintName', fk_info.CONSTRAINT_NAME, + 'columns', fk_info.fk_columns, + 'targetSchema', fk_info.REFERENCED_TABLE_SCHEMA, + 'targetTable', fk_info.REFERENCED_TABLE_NAME, + 'targetColumns', fk_info.ref_columns + ) + ) + FROM ( + SELECT + kcu.CONSTRAINT_NAME, + kcu.REFERENCED_TABLE_SCHEMA, + kcu.REFERENCED_TABLE_NAME, + JSON_ARRAYAGG(kcu.COLUMN_NAME) AS fk_columns, + JSON_ARRAYAGG(kcu.REFERENCED_COLUMN_NAME) AS ref_columns + FROM information_schema.KEY_COLUMN_USAGE kcu + JOIN information_schema.TABLE_CONSTRAINTS tc + ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME + AND tc.TABLE_SCHEMA = kcu.TABLE_SCHEMA + AND tc.TABLE_NAME = kcu.TABLE_NAME + WHERE kcu.TABLE_SCHEMA = ? AND kcu.TABLE_NAME = ? + AND tc.CONSTRAINT_TYPE = 'FOREIGN KEY' + GROUP BY kcu.CONSTRAINT_NAME, kcu.REFERENCED_TABLE_SCHEMA, kcu.REFERENCED_TABLE_NAME + ) fk_info + ), JSON_ARRAY()) + ) AS result + `, + values: [schema, table, schema, table, schema, table] +}); diff --git a/backend/src/ee/services/pam-web-access/mysql/pam-mysql-metadata.ts b/backend/src/ee/services/pam-web-access/mysql/pam-mysql-metadata.ts new file mode 100644 index 00000000000..9325f7aad62 --- /dev/null +++ b/backend/src/ee/services/pam-web-access/mysql/pam-mysql-metadata.ts @@ -0,0 +1,56 @@ +import mysql from "mysql2/promise"; + +import { logger } from "@app/lib/logger"; + +import { type OneShotOptions } from "../pam-data-explorer-session-handler"; +import { getSchemasQuery, getTablesQuery } from "./pam-mysql-data-explorer-metadata"; + +const buildConnection = async ({ relayPort, username, database }: OneShotOptions): Promise => { + const conn = await mysql.createConnection({ + host: "localhost", + port: relayPort, + user: username, + database: database || undefined, + password: "", + connectTimeout: 10_000, + multipleStatements: false + }); + conn.on("error" as never, (err: Error) => { + logger.debug(err, "one-shot mysql connection error"); + }); + return conn; +}; + +const withConnection = async (opts: OneShotOptions, fn: (conn: mysql.Connection) => Promise): Promise => { + const conn = await buildConnection(opts); + try { + await conn.query("SET SESSION max_execution_time = 30000"); + return await fn(conn); + } finally { + await conn.end().catch((err) => { + logger.debug(err, "one-shot mysql connection end error"); + }); + } +}; + +export const fetchSchemasOneShot = (opts: OneShotOptions): Promise<{ name: string }[]> => + withConnection(opts, async (conn) => { + const query = getSchemasQuery(); + const [rows] = await conn.execute(query.sql, query.values); + return rows as { name: string }[]; + }); + +export const fetchTablesOneShot = ( + opts: OneShotOptions, + schema: string +): Promise<{ name: string; tableType: string }[]> => + withConnection(opts, async (conn) => { + const query = getTablesQuery(schema); + const [rows] = await conn.execute(query.sql, query.values); + return rows as { name: string; tableType: string }[]; + }); + +export const verifyReachabilityOneShot = (opts: OneShotOptions): Promise => + withConnection(opts, async (conn) => { + await conn.execute("SELECT 1"); + }); diff --git a/backend/src/ee/services/pam-web-access/mysql/pam-mysql-session-handler.ts b/backend/src/ee/services/pam-web-access/mysql/pam-mysql-session-handler.ts new file mode 100644 index 00000000000..6bbd6dcaef0 --- /dev/null +++ b/backend/src/ee/services/pam-web-access/mysql/pam-mysql-session-handler.ts @@ -0,0 +1,15 @@ +import { createDataExplorerSessionHandler } from "../pam-data-explorer-session-handler"; +import { createMysqlConnectionController } from "./pam-mysql-connection-controller"; +import { fetchSchemasOneShot, fetchTablesOneShot, verifyReachabilityOneShot } from "./pam-mysql-metadata"; + +export const handleMysqlSession = createDataExplorerSessionHandler({ + dialectName: "MySQL", + createController: createMysqlConnectionController, + fetchSchemas: fetchSchemasOneShot, + fetchTables: fetchTablesOneShot, + verifyReachability: verifyReachabilityOneShot, + extractErrorFields: (err: unknown) => { + const mysqlErr = err as { message?: string; sqlMessage?: string; code?: string }; + return { message: mysqlErr.sqlMessage ?? mysqlErr.message, detail: mysqlErr.code }; + } +}); diff --git a/backend/src/ee/services/pam-web-access/pam-data-explorer-session-handler.ts b/backend/src/ee/services/pam-web-access/pam-data-explorer-session-handler.ts new file mode 100644 index 00000000000..23a0be3363c --- /dev/null +++ b/backend/src/ee/services/pam-web-access/pam-data-explorer-session-handler.ts @@ -0,0 +1,265 @@ +import crypto from "crypto"; + +import { logger } from "@app/lib/logger"; + +import { + DataExplorerClientMessageSchema, + DataExplorerClientMessageType, + DataExplorerServerMessageType, + type TConnectionController, + type TDataExplorerServerMessage +} from "./pam-data-explorer-ws-types"; +import { parseClientMessage } from "./pam-web-access-fns"; +import { + SessionEndReason, + TerminalServerMessageType, + TSessionContext, + TSessionHandlerResult +} from "./pam-web-access-types"; + +const MAX_CONNECTIONS_PER_WS = 20; + +export type OneShotOptions = { + relayPort: number; + username: string; + database?: string; +}; + +export type ControllerParams = { + relayPort: number; + username: string; + database?: string; + sessionId: string; + connectionId: string; + sendResponse: (msg: TDataExplorerServerMessage) => void; + onUnexpectedTermination: (reason: string) => void; +}; + +export type TDataExplorerDialectConfig = { + dialectName: string; + createController: (params: ControllerParams) => Promise; + fetchSchemas: (opts: OneShotOptions) => Promise<{ name: string }[]>; + fetchTables: (opts: OneShotOptions, schema: string) => Promise<{ name: string; tableType: string }[]>; + verifyReachability: (opts: OneShotOptions) => Promise; + extractErrorFields: (err: unknown) => { message?: string; detail?: string; hint?: string }; +}; + +export const createDataExplorerSessionHandler = (config: TDataExplorerDialectConfig) => { + const { dialectName, createController, fetchSchemas, fetchTables, verifyReachability, extractErrorFields } = config; + + return async ( + ctx: TSessionContext, + params: { connectionDetails: Record; credentials: Record } + ): Promise => { + const { socket, relayPort, resourceName, sessionId, sendMessage, sendSessionEnd, onCleanup } = ctx; + const connectionDetails = params.connectionDetails as { host: string; port: number; database?: string }; + const credentials = params.credentials as { username: string; password: string }; + + const oneShotOpts: OneShotOptions = { + relayPort, + username: credentials.username, + database: connectionDetails.database + }; + + await verifyReachability(oneShotOpts); + + const dbLabel = connectionDetails.database ? ` (${connectionDetails.database})` : ""; + sendMessage({ + type: TerminalServerMessageType.Ready, + data: `Connected to ${resourceName}${dbLabel} as ${credentials.username}\n\n` + }); + + logger.info(`${dialectName} web access session established [sessionId=${sessionId}]`); + + const sendResponse = (msg: TDataExplorerServerMessage) => { + try { + if (socket.readyState === socket.OPEN) { + socket.send(JSON.stringify(msg)); + } + } catch (err) { + logger.error(err, `Failed to send WebSocket message [sessionId=${sessionId}]`); + } + }; + + const controllers = new Map(); + let metadataPromise: Promise = Promise.resolve(); + + const openTabConnection = async (requestId: string) => { + if (controllers.size >= MAX_CONNECTIONS_PER_WS) { + sendResponse({ + type: DataExplorerServerMessageType.ConnectionOpenFailed, + id: requestId, + error: `Maximum ${MAX_CONNECTIONS_PER_WS} connections per session reached` + }); + return; + } + + const connectionId = crypto.randomUUID(); + controllers.set(connectionId, null); + try { + const controller = await createController({ + relayPort, + username: credentials.username, + database: connectionDetails.database, + sessionId, + connectionId, + sendResponse, + onUnexpectedTermination: (reason) => { + if (!controllers.has(connectionId)) return; + controllers.delete(connectionId); + sendResponse({ + type: DataExplorerServerMessageType.ConnectionClosed, + connectionId, + reason + }); + } + }); + if (!controllers.has(connectionId)) { + controller.dispose(); + return; + } + controllers.set(connectionId, controller); + sendResponse({ + type: DataExplorerServerMessageType.ConnectionOpened, + id: requestId, + connectionId, + nativeConnectionId: controller.nativeConnectionId + }); + } catch (err) { + controllers.delete(connectionId); + const msg = err instanceof Error ? err.message : "Failed to open connection"; + logger.error(err, `Failed to open ${dialectName} tab connection [sessionId=${sessionId}]`); + sendResponse({ + type: DataExplorerServerMessageType.ConnectionOpenFailed, + id: requestId, + error: msg + }); + } + }; + + const queueMetadata = ( + requestId: string, + fetcher: () => Promise, + onSuccess: (rows: T) => TDataExplorerServerMessage, + fallbackError: string + ) => { + metadataPromise = metadataPromise + .then(async () => { + try { + const rows = await fetcher(); + sendResponse(onSuccess(rows)); + } catch (err) { + const { message: errMsg, detail, hint } = extractErrorFields(err); + sendResponse({ + type: DataExplorerServerMessageType.Error, + id: requestId, + error: errMsg ?? fallbackError, + detail, + hint + }); + } + }) + .catch(() => {}); + }; + + socket.on("message", (rawData: Buffer | ArrayBuffer | Buffer[]) => { + const message = parseClientMessage(rawData, DataExplorerClientMessageSchema); + if (!message) return; + + switch (message.type) { + case DataExplorerClientMessageType.Control: { + if (message.data === "quit") { + sendSessionEnd(SessionEndReason.UserQuit); + onCleanup(); + socket.close(); + } + break; + } + + case DataExplorerClientMessageType.OpenConnection: { + void openTabConnection(message.id); + break; + } + + case DataExplorerClientMessageType.CloseConnection: { + const controller = controllers.get(message.connectionId); + if (!controller) return; + controllers.delete(message.connectionId); + controller.dispose(); + break; + } + + case DataExplorerClientMessageType.Cancel: { + const controller = controllers.get(message.connectionId); + if (!controller || controller.isDisposing()) { + logger.debug( + `Cancel on missing/disposing connection [sessionId=${sessionId}] [connectionId=${message.connectionId}]` + ); + return; + } + controller.handleMessage(message); + break; + } + + case DataExplorerClientMessageType.GetSchemas: { + queueMetadata( + message.id, + () => fetchSchemas(oneShotOpts), + (rows) => ({ type: DataExplorerServerMessageType.Schemas, id: message.id, data: rows }), + "Failed to fetch schemas" + ); + break; + } + + case DataExplorerClientMessageType.GetTables: { + queueMetadata( + message.id, + () => fetchTables(oneShotOpts, message.schema), + (rows) => ({ type: DataExplorerServerMessageType.Tables, id: message.id, data: rows }), + "Failed to fetch tables" + ); + break; + } + + case DataExplorerClientMessageType.GetTableDetail: + case DataExplorerClientMessageType.Query: { + const controller = controllers.get(message.connectionId); + if (!controller || controller.isDisposing()) { + sendResponse({ + type: DataExplorerServerMessageType.Error, + id: message.id, + connectionId: message.connectionId, + error: "Connection not found" + }); + return; + } + controller.handleMessage(message); + break; + } + + case DataExplorerClientMessageType.Activity: { + break; + } + + default: + break; + } + }); + + return { + cleanup: async () => { + const snapshot = Array.from(controllers.values()); + controllers.clear(); + for (const controller of snapshot) { + if (controller) { + try { + controller.dispose(); + } catch (err) { + logger.debug(err, `Error disposing ${dialectName} controller [sessionId=${sessionId}]`); + } + } + } + } + }; + }; +}; diff --git a/backend/src/ee/services/pam-web-access/postgres/pam-postgres-ws-types.ts b/backend/src/ee/services/pam-web-access/pam-data-explorer-ws-types.ts similarity index 65% rename from backend/src/ee/services/pam-web-access/postgres/pam-postgres-ws-types.ts rename to backend/src/ee/services/pam-web-access/pam-data-explorer-ws-types.ts index 878ddddb2d3..0486669c310 100644 --- a/backend/src/ee/services/pam-web-access/postgres/pam-postgres-ws-types.ts +++ b/backend/src/ee/services/pam-web-access/pam-data-explorer-ws-types.ts @@ -1,6 +1,6 @@ import { z } from "zod"; -export enum PostgresClientMessageType { +export enum DataExplorerClientMessageType { Control = "control", GetSchemas = "get-schemas", GetTables = "get-tables", @@ -12,7 +12,7 @@ export enum PostgresClientMessageType { Activity = "activity" } -export enum PostgresServerMessageType { +export enum DataExplorerServerMessageType { Schemas = "schemas", Tables = "tables", TableDetail = "table-detail", @@ -26,47 +26,47 @@ export enum PostgresServerMessageType { const CorrelatedBaseSchema = z.object({ id: z.string().uuid() }); const TabScopedBaseSchema = CorrelatedBaseSchema.extend({ connectionId: z.string().uuid() }); -const ControlSchema = z.object({ type: z.literal(PostgresClientMessageType.Control), data: z.string() }); +const ControlSchema = z.object({ type: z.literal(DataExplorerClientMessageType.Control), data: z.string() }); const GetSchemasRequestSchema = CorrelatedBaseSchema.extend({ - type: z.literal(PostgresClientMessageType.GetSchemas) + type: z.literal(DataExplorerClientMessageType.GetSchemas) }); const GetTablesRequestSchema = CorrelatedBaseSchema.extend({ - type: z.literal(PostgresClientMessageType.GetTables), + type: z.literal(DataExplorerClientMessageType.GetTables), schema: z.string() }); const GetTableDetailRequestSchema = TabScopedBaseSchema.extend({ - type: z.literal(PostgresClientMessageType.GetTableDetail), + type: z.literal(DataExplorerClientMessageType.GetTableDetail), schema: z.string(), table: z.string() }); const QueryRequestSchema = TabScopedBaseSchema.extend({ - type: z.literal(PostgresClientMessageType.Query), + type: z.literal(DataExplorerClientMessageType.Query), sql: z.string().max(50 * 1024) }); const CancelSchema = z.object({ - type: z.literal(PostgresClientMessageType.Cancel), + type: z.literal(DataExplorerClientMessageType.Cancel), connectionId: z.string().uuid() }); const OpenConnectionSchema = CorrelatedBaseSchema.extend({ - type: z.literal(PostgresClientMessageType.OpenConnection) + type: z.literal(DataExplorerClientMessageType.OpenConnection) }); const CloseConnectionSchema = z.object({ - type: z.literal(PostgresClientMessageType.CloseConnection), + type: z.literal(DataExplorerClientMessageType.CloseConnection), connectionId: z.string().uuid() }); const ActivitySchema = z.object({ - type: z.literal(PostgresClientMessageType.Activity) + type: z.literal(DataExplorerClientMessageType.Activity) }); -export const PostgresClientMessageSchema = z.discriminatedUnion("type", [ +export const DataExplorerClientMessageSchema = z.discriminatedUnion("type", [ ControlSchema, GetSchemasRequestSchema, GetTablesRequestSchema, @@ -78,20 +78,20 @@ export const PostgresClientMessageSchema = z.discriminatedUnion("type", [ ActivitySchema ]); -export type TPostgresClientMessage = z.infer; +export type TDataExplorerClientMessage = z.infer; const SchemasResponseSchema = CorrelatedBaseSchema.extend({ - type: z.literal(PostgresServerMessageType.Schemas), + type: z.literal(DataExplorerServerMessageType.Schemas), data: z.array(z.object({ name: z.string() })) }); const TablesResponseSchema = CorrelatedBaseSchema.extend({ - type: z.literal(PostgresServerMessageType.Tables), + type: z.literal(DataExplorerServerMessageType.Tables), data: z.array(z.object({ name: z.string(), tableType: z.string() })) }); const TableDetailResponseSchema = TabScopedBaseSchema.extend({ - type: z.literal(PostgresServerMessageType.TableDetail), + type: z.literal(DataExplorerServerMessageType.TableDetail), transactionOpen: z.boolean(), data: z.object({ columns: z.array( @@ -116,7 +116,7 @@ const TableDetailResponseSchema = TabScopedBaseSchema.extend({ }); const QueryResultResponseSchema = TabScopedBaseSchema.extend({ - type: z.literal(PostgresServerMessageType.QueryResult), + type: z.literal(DataExplorerServerMessageType.QueryResult), rows: z.array(z.record(z.string(), z.unknown())), fields: z.array(z.object({ name: z.string() })), rowCount: z.number().nullable(), @@ -127,7 +127,7 @@ const QueryResultResponseSchema = TabScopedBaseSchema.extend({ }); const ErrorResponseSchema = CorrelatedBaseSchema.extend({ - type: z.literal(PostgresServerMessageType.Error), + type: z.literal(DataExplorerServerMessageType.Error), connectionId: z.string().uuid().optional(), transactionOpen: z.boolean().optional(), error: z.string(), @@ -136,23 +136,23 @@ const ErrorResponseSchema = CorrelatedBaseSchema.extend({ }); const ConnectionOpenedResponseSchema = CorrelatedBaseSchema.extend({ - type: z.literal(PostgresServerMessageType.ConnectionOpened), + type: z.literal(DataExplorerServerMessageType.ConnectionOpened), connectionId: z.string().uuid(), - backendPid: z.number().nullable() + nativeConnectionId: z.number().nullable() }); const ConnectionOpenFailedResponseSchema = CorrelatedBaseSchema.extend({ - type: z.literal(PostgresServerMessageType.ConnectionOpenFailed), + type: z.literal(DataExplorerServerMessageType.ConnectionOpenFailed), error: z.string() }); const ConnectionClosedResponseSchema = z.object({ - type: z.literal(PostgresServerMessageType.ConnectionClosed), + type: z.literal(DataExplorerServerMessageType.ConnectionClosed), connectionId: z.string().uuid(), reason: z.string() }); -export type TPostgresCorrelatedServerMessage = z.infer< +export type TDataExplorerServerMessage = z.infer< | typeof SchemasResponseSchema | typeof TablesResponseSchema | typeof TableDetailResponseSchema @@ -162,3 +162,21 @@ export type TPostgresCorrelatedServerMessage = z.infer< | typeof ConnectionOpenFailedResponseSchema | typeof ConnectionClosedResponseSchema >; + +export type TTabScopedMessage = Extract< + TDataExplorerClientMessage, + { + type: + | DataExplorerClientMessageType.GetTableDetail + | DataExplorerClientMessageType.Query + | DataExplorerClientMessageType.Cancel; + } +>; + +export type TConnectionController = { + connectionId: string; + nativeConnectionId: number | null; + handleMessage: (msg: TTabScopedMessage) => void; + dispose: () => void; + isDisposing: () => boolean; +}; diff --git a/backend/src/ee/services/pam-web-access/pam-session-handlers.ts b/backend/src/ee/services/pam-web-access/pam-session-handlers.ts index dfdcdece47c..5ac04010c3c 100644 --- a/backend/src/ee/services/pam-web-access/pam-session-handlers.ts +++ b/backend/src/ee/services/pam-web-access/pam-session-handlers.ts @@ -1,5 +1,6 @@ import { PamAccountType } from "@app/ee/services/pam/pam-enums"; +import { handleMysqlSession } from "./mysql/pam-mysql-session-handler"; import { TSessionContext, TSessionHandlerResult } from "./pam-web-access-types"; import { handlePostgresSession } from "./postgres/pam-postgres-session-handler"; import { handleRdpSession } from "./rdp/pam-rdp-session-handler"; @@ -20,6 +21,10 @@ export const SESSION_HANDLERS: Partial void; - onUnexpectedTermination: (reason: string) => void; -}; - -type TTabScopedMessage = Extract< - TPostgresClientMessage, - { - type: PostgresClientMessageType.GetTableDetail | PostgresClientMessageType.Query | PostgresClientMessageType.Cancel; - } ->; - -export type TPostgresConnectionController = { - connectionId: string; - backendPid: number | null; - handleMessage: (msg: TTabScopedMessage) => void; - dispose: () => void; - isDisposing: () => boolean; -}; + DataExplorerClientMessageType, + DataExplorerServerMessageType, + type TConnectionController, + type TTabScopedMessage +} from "../pam-data-explorer-ws-types"; +import { getTableDetailQuery } from "./pam-postgres-data-explorer-metadata"; const pgTypes = { getTypeParser: (oid: number) => { @@ -48,9 +24,7 @@ const pgTypes = { } }; -export const createPostgresConnectionController = async ( - params: ControllerParams -): Promise => { +export const createPostgresConnectionController = async (params: ControllerParams): Promise => { const { relayPort, username, database, sessionId, connectionId, sendResponse, onUnexpectedTermination } = params; const pgClient = new pg.Client({ @@ -68,7 +42,7 @@ export const createPostgresConnectionController = async ( await pgClient.connect(); const { rows: pidRows } = await pgClient.query<{ pid: number }>("SELECT pg_backend_pid() AS pid"); - const backendPid = pidRows[0]?.pid ?? null; + const nativeConnectionId = pidRows[0]?.pid ?? null; let isInTransaction = false; let disposing = false; @@ -85,7 +59,7 @@ export const createPostgresConnectionController = async ( isInTransaction = false; sendResponse({ - type: PostgresServerMessageType.Error, + type: DataExplorerServerMessageType.Error, id, connectionId, transactionOpen: false, @@ -96,7 +70,7 @@ export const createPostgresConnectionController = async ( }; const cancelRunningQuery = async () => { - if (!backendPid) return; + if (!nativeConnectionId) return; const cancelClient = new pg.Client({ host: "localhost", port: relayPort, @@ -111,7 +85,7 @@ export const createPostgresConnectionController = async ( }); try { await cancelClient.connect(); - await cancelClient.query("SELECT pg_cancel_backend($1)", [backendPid]); + await cancelClient.query("SELECT pg_cancel_backend($1)", [nativeConnectionId]); } catch (err) { logger.debug(err, `Failed to cancel backend query [sessionId=${sessionId}] [connectionId=${connectionId}]`); } finally { @@ -122,7 +96,7 @@ export const createPostgresConnectionController = async ( let processingPromise: Promise = Promise.resolve(); const handleMessage = (message: TTabScopedMessage) => { - if (message.type === PostgresClientMessageType.Cancel) { + if (message.type === DataExplorerClientMessageType.Cancel) { if (disposing) return; void cancelRunningQuery(); return; @@ -133,14 +107,14 @@ export const createPostgresConnectionController = async ( if (disposing) return; switch (message.type) { - case PostgresClientMessageType.GetTableDetail: { + case DataExplorerClientMessageType.GetTableDetail: { try { const query = getTableDetailQuery(message.schema, message.table); const result = await pgClient.query<{ result: string }>(query.text, query.values); const rawDetail = result.rows[0]?.result; if (!rawDetail) { sendResponse({ - type: PostgresServerMessageType.Error, + type: DataExplorerServerMessageType.Error, id: message.id, connectionId, transactionOpen: isInTransaction, @@ -153,7 +127,7 @@ export const createPostgresConnectionController = async ( ? (JSON.parse(rawDetail) as Record) : (rawDetail as unknown as Record); sendResponse({ - type: PostgresServerMessageType.TableDetail, + type: DataExplorerServerMessageType.TableDetail, id: message.id, connectionId, transactionOpen: isInTransaction, @@ -180,7 +154,7 @@ export const createPostgresConnectionController = async ( break; } - case PostgresClientMessageType.Query: { + case DataExplorerClientMessageType.Query: { try { const startTime = performance.now(); const MAX_ROWS = 1000; @@ -225,7 +199,7 @@ export const createPostgresConnectionController = async ( const executionTimeMs = Math.round(performance.now() - startTime); sendResponse({ - type: PostgresServerMessageType.QueryResult, + type: DataExplorerServerMessageType.QueryResult, id: message.id, connectionId, rows: lastRows, @@ -274,7 +248,7 @@ export const createPostgresConnectionController = async ( return { connectionId, - backendPid, + nativeConnectionId, handleMessage, dispose, isDisposing: () => disposing diff --git a/backend/src/ee/services/pam-web-access/postgres/pam-postgres-metadata.ts b/backend/src/ee/services/pam-web-access/postgres/pam-postgres-metadata.ts index fd5e6032506..0ecb07bb005 100644 --- a/backend/src/ee/services/pam-web-access/postgres/pam-postgres-metadata.ts +++ b/backend/src/ee/services/pam-web-access/postgres/pam-postgres-metadata.ts @@ -2,14 +2,9 @@ import pg from "pg"; import { logger } from "@app/lib/logger"; +import { type OneShotOptions } from "../pam-data-explorer-session-handler"; import { getSchemasQuery, getTablesQuery } from "./pam-postgres-data-explorer-metadata"; -type OneShotOptions = { - relayPort: number; - username: string; - database: string; -}; - const buildClient = ({ relayPort, username, database }: OneShotOptions): pg.Client => { const client = new pg.Client({ host: "localhost", diff --git a/backend/src/ee/services/pam-web-access/postgres/pam-postgres-session-handler.ts b/backend/src/ee/services/pam-web-access/postgres/pam-postgres-session-handler.ts index 62b481bf5bc..8e90bd32d41 100644 --- a/backend/src/ee/services/pam-web-access/postgres/pam-postgres-session-handler.ts +++ b/backend/src/ee/services/pam-web-access/postgres/pam-postgres-session-handler.ts @@ -1,244 +1,15 @@ -import crypto from "crypto"; - -import { logger } from "@app/lib/logger"; - -import { parseClientMessage } from "../pam-web-access-fns"; -import { - SessionEndReason, - TerminalServerMessageType, - TSessionContext, - TSessionHandlerResult -} from "../pam-web-access-types"; -import { - createPostgresConnectionController, - type TPostgresConnectionController -} from "./pam-postgres-connection-controller"; +import { createDataExplorerSessionHandler } from "../pam-data-explorer-session-handler"; +import { createPostgresConnectionController } from "./pam-postgres-connection-controller"; import { fetchSchemasOneShot, fetchTablesOneShot, verifyReachabilityOneShot } from "./pam-postgres-metadata"; -import { - PostgresClientMessageSchema, - PostgresClientMessageType, - PostgresServerMessageType, - type TPostgresCorrelatedServerMessage -} from "./pam-postgres-ws-types"; - -const MAX_CONNECTIONS_PER_WS = 20; - -const toPgErrorFields = (err: unknown) => { - const pgErr = err as { message?: string; detail?: string; hint?: string }; - return { message: pgErr.message, detail: pgErr.detail, hint: pgErr.hint }; -}; - -export const handlePostgresSession = async ( - ctx: TSessionContext, - params: { connectionDetails: Record; credentials: Record } -): Promise => { - const { socket, relayPort, resourceName, sessionId, sendMessage, sendSessionEnd, onCleanup } = ctx; - const connectionDetails = params.connectionDetails as { host: string; port: number; database: string }; - const credentials = params.credentials as { username: string; password: string }; - - const oneShotOpts = { - relayPort, - username: credentials.username, - database: connectionDetails.database - }; - - await verifyReachabilityOneShot(oneShotOpts); - - sendMessage({ - type: TerminalServerMessageType.Ready, - data: `Connected to ${resourceName} (${connectionDetails.database}) as ${credentials.username}\n\n` - }); - - logger.info(`Postgres web access session established [sessionId=${sessionId}]`); - - const sendResponse = (msg: TPostgresCorrelatedServerMessage) => { - try { - if (socket.readyState === socket.OPEN) { - socket.send(JSON.stringify(msg)); - } - } catch (err) { - logger.error(err, `Failed to send WebSocket message [sessionId=${sessionId}]`); - } - }; - - const controllers = new Map(); - let metadataPromise: Promise = Promise.resolve(); - - const openTabConnection = async (requestId: string) => { - if (controllers.size >= MAX_CONNECTIONS_PER_WS) { - sendResponse({ - type: PostgresServerMessageType.ConnectionOpenFailed, - id: requestId, - error: `Maximum ${MAX_CONNECTIONS_PER_WS} connections per session reached` - }); - return; - } - - const connectionId = crypto.randomUUID(); - controllers.set(connectionId, null); - try { - const controller = await createPostgresConnectionController({ - relayPort, - username: credentials.username, - database: connectionDetails.database, - sessionId, - connectionId, - sendResponse, - onUnexpectedTermination: (reason) => { - if (!controllers.has(connectionId)) return; - controllers.delete(connectionId); - sendResponse({ - type: PostgresServerMessageType.ConnectionClosed, - connectionId, - reason - }); - } - }); - if (!controllers.has(connectionId)) { - controller.dispose(); - return; - } - controllers.set(connectionId, controller); - sendResponse({ - type: PostgresServerMessageType.ConnectionOpened, - id: requestId, - connectionId, - backendPid: controller.backendPid - }); - } catch (err) { - controllers.delete(connectionId); - const msg = err instanceof Error ? err.message : "Failed to open connection"; - logger.error(err, `Failed to open tab connection [sessionId=${sessionId}]`); - sendResponse({ - type: PostgresServerMessageType.ConnectionOpenFailed, - id: requestId, - error: msg - }); - } - }; - - const queueMetadata = ( - requestId: string, - fetcher: () => Promise, - onSuccess: (rows: T) => TPostgresCorrelatedServerMessage, - fallbackError: string - ) => { - metadataPromise = metadataPromise - .then(async () => { - try { - const rows = await fetcher(); - sendResponse(onSuccess(rows)); - } catch (err) { - const { message: errMsg, detail, hint } = toPgErrorFields(err); - sendResponse({ - type: PostgresServerMessageType.Error, - id: requestId, - error: errMsg ?? fallbackError, - detail, - hint - }); - } - }) - .catch(() => {}); - }; - - socket.on("message", (rawData: Buffer | ArrayBuffer | Buffer[]) => { - const message = parseClientMessage(rawData, PostgresClientMessageSchema); - if (!message) return; - - switch (message.type) { - case PostgresClientMessageType.Control: { - if (message.data === "quit") { - sendSessionEnd(SessionEndReason.UserQuit); - onCleanup(); - socket.close(); - } - break; - } - - case PostgresClientMessageType.OpenConnection: { - void openTabConnection(message.id); - break; - } - - case PostgresClientMessageType.CloseConnection: { - const controller = controllers.get(message.connectionId); - if (!controller) return; - controllers.delete(message.connectionId); - controller.dispose(); - break; - } - - case PostgresClientMessageType.Cancel: { - const controller = controllers.get(message.connectionId); - if (!controller || controller.isDisposing()) { - logger.debug( - `Cancel on missing/disposing connection [sessionId=${sessionId}] [connectionId=${message.connectionId}]` - ); - return; - } - controller.handleMessage(message); - break; - } - - case PostgresClientMessageType.GetSchemas: { - queueMetadata( - message.id, - () => fetchSchemasOneShot(oneShotOpts), - (rows) => ({ type: PostgresServerMessageType.Schemas, id: message.id, data: rows }), - "Failed to fetch schemas" - ); - break; - } - - case PostgresClientMessageType.GetTables: { - queueMetadata( - message.id, - () => fetchTablesOneShot(oneShotOpts, message.schema), - (rows) => ({ type: PostgresServerMessageType.Tables, id: message.id, data: rows }), - "Failed to fetch tables" - ); - break; - } - - case PostgresClientMessageType.GetTableDetail: - case PostgresClientMessageType.Query: { - const controller = controllers.get(message.connectionId); - if (!controller || controller.isDisposing()) { - sendResponse({ - type: PostgresServerMessageType.Error, - id: message.id, - connectionId: message.connectionId, - error: "Connection not found" - }); - return; - } - controller.handleMessage(message); - break; - } - - case PostgresClientMessageType.Activity: { - break; - } - - default: - break; - } - }); - return { - cleanup: async () => { - const snapshot = Array.from(controllers.values()); - controllers.clear(); - for (const controller of snapshot) { - if (controller) { - try { - controller.dispose(); - } catch (err) { - logger.debug(err, `Error disposing controller [sessionId=${sessionId}]`); - } - } - } - } - }; -}; +export const handlePostgresSession = createDataExplorerSessionHandler({ + dialectName: "Postgres", + createController: createPostgresConnectionController, + fetchSchemas: fetchSchemasOneShot, + fetchTables: fetchTablesOneShot, + verifyReachability: verifyReachabilityOneShot, + extractErrorFields: (err: unknown) => { + const pgErr = err as { message?: string; detail?: string; hint?: string }; + return { message: pgErr.message, detail: pgErr.detail, hint: pgErr.hint }; + } +}); diff --git a/frontend/src/pages/pam/PamAccountAccessPage/PamAccountAccessPage.tsx b/frontend/src/pages/pam/PamAccountAccessPage/PamAccountAccessPage.tsx index bf05c51696a..7bfa05ca190 100644 --- a/frontend/src/pages/pam/PamAccountAccessPage/PamAccountAccessPage.tsx +++ b/frontend/src/pages/pam/PamAccountAccessPage/PamAccountAccessPage.tsx @@ -117,7 +117,10 @@ const PageContent = () => { return ( {({ reason, mfaSessionId }) => { - if (account.accountType === PamAccountType.Postgres) { + if ( + account.accountType === PamAccountType.Postgres || + account.accountType === PamAccountType.MySQL + ) { return ; } if ( diff --git a/frontend/src/pages/pam/PamDataExplorerPage/PamDataExplorerPage.tsx b/frontend/src/pages/pam/PamDataExplorerPage/PamDataExplorerPage.tsx index 51133141b72..80f8252af0a 100644 --- a/frontend/src/pages/pam/PamDataExplorerPage/PamDataExplorerPage.tsx +++ b/frontend/src/pages/pam/PamDataExplorerPage/PamDataExplorerPage.tsx @@ -14,13 +14,14 @@ import { import { Spinner } from "@app/components/v2"; import { Button } from "@app/components/v3/generic/Button"; import { cn } from "@app/components/v3/utils"; -import { useGetPamAccountById } from "@app/hooks/api/pam"; +import { PamAccountType, useGetPamAccountById } from "@app/hooks/api/pam"; import { WebAccessStatusCard } from "../PamAccountAccessPage/WebAccessStatusCard"; import { DataExplorerGrid } from "./components/DataExplorerGrid"; import { DataExplorerSidebar } from "./components/DataExplorerSidebar"; import { QueryPanel } from "./components/QueryPanel"; import type { SchemaInfo, TableInfo } from "./data-explorer-types"; +import type { SqlDialect } from "./sql-generation"; import { useDataExplorerSession } from "./use-data-explorer-session"; import { useQueryTabs } from "./use-query-tabs"; @@ -39,10 +40,16 @@ export const PamDataExplorerPage = ({ reason, mfaSessionId }: Props = {}) => { const { data: account } = useGetPamAccountById(accountId); + const dialect: SqlDialect = account?.accountType === PamAccountType.MySQL ? "mysql" : "postgres"; + const defaultSchema = + dialect === "mysql" + ? ((account?.connectionDetails as { database?: string })?.database ?? "") + : "public"; + // Sidebar-only view state. Switching schemas in the sidebar does not alter // open tabs — tabs are bound to their own (schema, table) at open time. const [schemas, setSchemas] = useState([]); - const [selectedSchema, setSelectedSchema] = useState("public"); + const [selectedSchema, setSelectedSchema] = useState(defaultSchema); const [tables, setTables] = useState([]); const [isLoadingSchemas, setIsLoadingSchemas] = useState(false); const [isLoadingTables, setIsLoadingTables] = useState(false); @@ -173,7 +180,7 @@ export const PamDataExplorerPage = ({ reason, mfaSessionId }: Props = {}) => { const result = await fetchSchemas(); setSchemas(result); const hasSelected = result.find((s) => s.name === selectedSchema); - const activeSchema = hasSelected ? selectedSchema : (result[0]?.name ?? "public"); + const activeSchema = hasSelected ? selectedSchema : (result[0]?.name ?? defaultSchema); if (!hasSelected && result.length > 0 && !keepSelected) { setSelectedSchema(activeSchema); } @@ -436,6 +443,7 @@ export const PamDataExplorerPage = ({ reason, mfaSessionId }: Props = {}) => { executeQuery={executeQuery} isLoading={tab.isLoadingDetail} onRefresh={() => handleTabRefresh(tab.id)} + dialect={dialect} /> ); } else { @@ -447,6 +455,7 @@ export const PamDataExplorerPage = ({ reason, mfaSessionId }: Props = {}) => { cancelQuery={cancelQuery} onSqlChange={(sql) => updateTabSql(tab.id, sql)} onTransactionStateChange={(open) => setTabTransactionOpen(tab.id, open)} + dialect={dialect} /> ); } diff --git a/frontend/src/pages/pam/PamDataExplorerPage/components/DataExplorerGrid.tsx b/frontend/src/pages/pam/PamDataExplorerPage/components/DataExplorerGrid.tsx index e641e4f8243..2febc2dfc51 100644 --- a/frontend/src/pages/pam/PamDataExplorerPage/components/DataExplorerGrid.tsx +++ b/frontend/src/pages/pam/PamDataExplorerPage/components/DataExplorerGrid.tsx @@ -13,7 +13,7 @@ import { Skeleton } from "@app/components/v3/generic/Skeleton"; import type { ColumnInfo, FieldInfo, ForeignKeyInfo, TableDetail } from "../data-explorer-types"; import { getColumnIndicator } from "../data-explorer-utils"; import { copyData, exportData } from "../data-export"; -import type { FilterCondition, SortCondition } from "../sql-generation"; +import type { FilterCondition, SortCondition, SqlDialect } from "../sql-generation"; import { buildCountQuery, buildDeleteQuery, @@ -42,6 +42,7 @@ type DataExplorerGridProps = { }>; isLoading: boolean; onRefresh?: () => Promise; + dialect: SqlDialect; }; const ROW_KEY_PREFIX = "__new_"; @@ -182,8 +183,16 @@ export const DataExplorerGrid = ({ connectionId, executeQuery, isLoading, - onRefresh + onRefresh, + dialect }: DataExplorerGridProps) => { + const executeStatements = useCallback( + async (statements: string[]) => { + await executeQuery(connectionId, wrapInTransaction(statements)); + }, + [executeQuery, connectionId] + ); + const [originalData, setOriginalData] = useState[]>([]); const [currentData, setCurrentData] = useState[]>([]); const [totalCount, setTotalCount] = useState(0); @@ -234,9 +243,10 @@ export const DataExplorerGrid = ({ sorts: s, limit: ps, offset: o, - primaryKeys + primaryKeys, + dialect }); - const countSql = buildCountQuery({ schema, table, filters: f }); + const countSql = buildCountQuery({ schema, table, filters: f, dialect }); // These two queries don't share a database snapshot (the backend processes them // sequentially, not in a single transaction), so the count could be off by 1 if @@ -271,7 +281,7 @@ export const DataExplorerGrid = ({ setIsDataLoading(false); } }, - [tableDetail, schema, table, primaryKeys, executeQuery, connectionId] + [tableDetail, schema, table, primaryKeys, executeQuery, connectionId, dialect] ); // Fetch data when filters/sorts/pagination change. @@ -379,7 +389,12 @@ export const DataExplorerGrid = ({ tempIdsToRemove.push(tempId); } else { deleteStatements.push( - buildDeleteQuery({ schema, table, primaryKeyMatch: getPkMatch(row, primaryKeys) }) + buildDeleteQuery({ + schema, + table, + primaryKeyMatch: getPkMatch(row, primaryKeys), + dialect + }) ); } }); @@ -400,8 +415,7 @@ export const DataExplorerGrid = ({ // Execute DELETE SQL for persisted rows immediately if (deleteStatements.length > 0) { try { - const sql = wrapInTransaction(deleteStatements); - await executeQuery(connectionId, sql); + await executeStatements(deleteStatements); createNotification({ text: `Deleted ${deleteStatements.length} row${deleteStatements.length !== 1 ? "s" : ""}`, type: "success" @@ -424,8 +438,8 @@ export const DataExplorerGrid = ({ primaryKeys, schema, table, - connectionId, - executeQuery, + dialect, + executeStatements, fetchData, offset, pageSize, @@ -458,7 +472,7 @@ export const DataExplorerGrid = ({ values[col.name] = row[col.name]; } }); - statements.push(buildInsertQuery({ schema, table, row: values })); + statements.push(buildInsertQuery({ schema, table, row: values, dialect })); }); // Updates (changed rows) — use PK-based lookup so prepends don't misalign @@ -480,7 +494,8 @@ export const DataExplorerGrid = ({ schema, table, changes, - primaryKeyMatch: getPkMatch(original, primaryKeys) + primaryKeyMatch: getPkMatch(original, primaryKeys), + dialect }) ); } @@ -492,12 +507,9 @@ export const DataExplorerGrid = ({ return; } - const sql = wrapInTransaction(statements); - - // The WebSocket maxPayload is 64 KB. Guard against sending a query - // that would exceed the limit and silently kill the connection. + const combined = statements.join(";\n"); const MAX_QUERY_SIZE = 50 * 1024; - if (sql.length > MAX_QUERY_SIZE) { + if (combined.length > MAX_QUERY_SIZE) { createNotification({ title: "Save failed", text: "Changes are too large to save at once. Try saving fewer or smaller changes.", @@ -507,7 +519,7 @@ export const DataExplorerGrid = ({ return; } - await executeQuery(connectionId, sql); + await executeStatements(statements); createNotification({ text: `Saved ${statements.length} change${statements.length !== 1 ? "s" : ""}`, type: "success" @@ -531,8 +543,8 @@ export const DataExplorerGrid = ({ primaryKeys, schema, table, - connectionId, - executeQuery, + dialect, + executeStatements, fetchData, offset, pageSize, @@ -622,7 +634,12 @@ export const DataExplorerGrid = ({ tempIdsToRemove.push(tempId); } else { deleteStatements.push( - buildDeleteQuery({ schema, table, primaryKeyMatch: getPkMatch(row, primaryKeys) }) + buildDeleteQuery({ + schema, + table, + primaryKeyMatch: getPkMatch(row, primaryKeys), + dialect + }) ); } }); @@ -643,8 +660,7 @@ export const DataExplorerGrid = ({ // Execute DELETE SQL for persisted rows if (deleteStatements.length > 0) { try { - const sql = wrapInTransaction(deleteStatements); - await executeQuery(connectionId, sql); + await executeStatements(deleteStatements); createNotification({ text: `Deleted ${deleteStatements.length} row${deleteStatements.length !== 1 ? "s" : ""}`, type: "success" @@ -669,8 +685,8 @@ export const DataExplorerGrid = ({ schema, table, primaryKeys, - connectionId, - executeQuery, + dialect, + executeStatements, fetchData, offset, pageSize, @@ -729,6 +745,7 @@ export const DataExplorerGrid = ({ ) } hasData={currentData.length > 0} + dialect={dialect} /> {!hasPrimaryKey && ( diff --git a/frontend/src/pages/pam/PamDataExplorerPage/components/DataExplorerToolbar.tsx b/frontend/src/pages/pam/PamDataExplorerPage/components/DataExplorerToolbar.tsx index 13018d85ec8..c7898719e1b 100644 --- a/frontend/src/pages/pam/PamDataExplorerPage/components/DataExplorerToolbar.tsx +++ b/frontend/src/pages/pam/PamDataExplorerPage/components/DataExplorerToolbar.tsx @@ -15,7 +15,7 @@ import { Popover, PopoverContent, PopoverTrigger } from "@app/components/v3/gene import type { ColumnInfo } from "../data-explorer-types"; import type { ExportFormat } from "../data-export"; -import type { FilterCondition, SortCondition } from "../sql-generation"; +import type { FilterCondition, SortCondition, SqlDialect } from "../sql-generation"; import { ExportDropdown } from "./ExportDropdown"; import { FilterPopover } from "./FilterPopover"; import { SortPopover } from "./SortPopover"; @@ -46,6 +46,7 @@ type DataExplorerToolbarProps = { onExport: (format: ExportFormat) => void; onCopy: (format: ExportFormat) => void; hasData: boolean; + dialect: SqlDialect; }; export const DataExplorerToolbar = ({ @@ -73,7 +74,8 @@ export const DataExplorerToolbar = ({ isRefreshing = false, onExport, onCopy, - hasData + hasData, + dialect }: DataExplorerToolbarProps) => { const rangeStart = totalCount === 0 ? 0 : offset + 1; const rangeEnd = Math.min(offset + pageSize, totalCount); @@ -86,7 +88,12 @@ export const DataExplorerToolbar = ({ onMouseDown={(e) => e.stopPropagation()} >
- + Promise; + dialect: SqlDialect; }; -export const FilterPopover = ({ columns, filters, onFiltersChange }: FilterPopoverProps) => { +export const FilterPopover = ({ + columns, + filters, + onFiltersChange, + dialect +}: FilterPopoverProps) => { const [open, setOpen] = useState(false); const [draft, setDraft] = useState(filters); const [isApplying, setIsApplying] = useState(false); @@ -212,11 +232,13 @@ export const FilterPopover = ({ columns, filters, onFiltersChange }: FilterPopov - {OPERATORS.map((op) => ( - - {op.label} - - ))} + {OPERATORS.filter((op) => dialect !== "mysql" || op.value !== "ILIKE").map( + (op) => ( + + {op.label} + + ) + )} diff --git a/frontend/src/pages/pam/PamDataExplorerPage/components/QueryPanel.tsx b/frontend/src/pages/pam/PamDataExplorerPage/components/QueryPanel.tsx index 1436612e2c1..71824bd3153 100644 --- a/frontend/src/pages/pam/PamDataExplorerPage/components/QueryPanel.tsx +++ b/frontend/src/pages/pam/PamDataExplorerPage/components/QueryPanel.tsx @@ -5,6 +5,7 @@ import { cn } from "@app/components/v3/utils"; import type { FieldInfo } from "../data-explorer-types"; import { copyData, exportData } from "../data-export"; +import type { SqlDialect } from "../sql-generation"; import type { QueryTab } from "../use-query-tabs"; import { ExportDropdown } from "./ExportDropdown"; import { QueryResultsTable } from "./QueryResultsTable"; @@ -32,6 +33,7 @@ type Props = { cancelQuery: (connectionId: string) => void; onSqlChange: (sql: string) => void; onTransactionStateChange: (open: boolean) => void; + dialect: SqlDialect; }; export function QueryPanel({ @@ -39,7 +41,8 @@ export function QueryPanel({ executeQuery, cancelQuery, onSqlChange, - onTransactionStateChange + onTransactionStateChange, + dialect }: Props) { const [isRunning, setIsRunning] = useState(false); const [result, setResult] = useState(null); @@ -142,6 +145,7 @@ export function QueryPanel({ onSqlToRunChange={(s) => { sqlToRunRef.current = s; }} + sqlDialect={dialect} />
diff --git a/frontend/src/pages/pam/PamDataExplorerPage/components/SqlEditor.tsx b/frontend/src/pages/pam/PamDataExplorerPage/components/SqlEditor.tsx index 87b36729cd5..ba80ec3dadb 100644 --- a/frontend/src/pages/pam/PamDataExplorerPage/components/SqlEditor.tsx +++ b/frontend/src/pages/pam/PamDataExplorerPage/components/SqlEditor.tsx @@ -1,11 +1,13 @@ import { useEffect, useRef } from "react"; import { defaultKeymap, history, historyKeymap } from "@codemirror/commands"; -import { PostgreSQL, sql } from "@codemirror/lang-sql"; +import { MySQL, PostgreSQL, sql } from "@codemirror/lang-sql"; import { HighlightStyle, syntaxHighlighting } from "@codemirror/language"; import { EditorState, type Transaction } from "@codemirror/state"; import { EditorView, keymap, placeholder, type ViewUpdate } from "@codemirror/view"; import { tags } from "@lezer/highlight"; +import type { SqlDialect } from "../sql-generation"; + const infisicalTheme = EditorView.theme({ "&": { height: "100%", fontSize: "13px", backgroundColor: "#16181a" }, "&.cm-editor": { backgroundColor: "#16181a" }, @@ -69,6 +71,7 @@ type Props = { onExecute: (sql: string) => void; onSelectionChange: (hasSelection: boolean) => void; onSqlToRunChange: (sql: string) => void; + sqlDialect: SqlDialect; }; export function SqlEditor({ @@ -76,7 +79,8 @@ export function SqlEditor({ onChange, onExecute, onSelectionChange, - onSqlToRunChange + onSqlToRunChange, + sqlDialect }: Props) { const containerRef = useRef(null); const viewRef = useRef(null); @@ -114,7 +118,7 @@ export function SqlEditor({ ]), maxSqlLength, placeholder("Start writing SQL..."), - sql({ dialect: PostgreSQL }), + sql({ dialect: sqlDialect === "mysql" ? MySQL : PostgreSQL }), infisicalTheme, syntaxHighlighting(infisicalHighlight), EditorView.updateListener.of((update: ViewUpdate) => { diff --git a/frontend/src/pages/pam/PamDataExplorerPage/data-explorer-types.ts b/frontend/src/pages/pam/PamDataExplorerPage/data-explorer-types.ts index 0c423f56a51..00e36d63f97 100644 --- a/frontend/src/pages/pam/PamDataExplorerPage/data-explorer-types.ts +++ b/frontend/src/pages/pam/PamDataExplorerPage/data-explorer-types.ts @@ -1,7 +1,7 @@ -// WebSocket message types for the Postgres Data Explorer +// WebSocket message types for the Data Explorer export type DataExplorerClientMessage = - // Metadata (no connectionId) — served by short-lived BE pg.Clients + // Metadata (no connectionId) -- served by short-lived BE database connections | { type: "get-schemas"; id: string } | { type: "get-tables"; id: string; schema: string } // Tab-scoped — routed to a specific BE controller @@ -46,7 +46,12 @@ export type DataExplorerServerMessage = detail?: string; hint?: string; } - | { type: "connection-opened"; id: string; connectionId: string; backendPid: number | null } + | { + type: "connection-opened"; + id: string; + connectionId: string; + nativeConnectionId: number | null; + } | { type: "connection-open-failed"; id: string; error: string } | { type: "connection-closed"; connectionId: string; reason: string } | { type: "ready" } diff --git a/frontend/src/pages/pam/PamDataExplorerPage/sql-generation.ts b/frontend/src/pages/pam/PamDataExplorerPage/sql-generation.ts index 8790880a931..6b8f123950c 100644 --- a/frontend/src/pages/pam/PamDataExplorerPage/sql-generation.ts +++ b/frontend/src/pages/pam/PamDataExplorerPage/sql-generation.ts @@ -1,11 +1,14 @@ -// Client-side SQL generation for the Postgres Data Explorer. +// Client-side SQL generation for the Data Explorer. // All identifiers are properly quoted to prevent SQL injection. -export function quoteIdent(name: string): string { +export type SqlDialect = "postgres" | "mysql"; + +function quoteIdent(name: string, dialect: SqlDialect = "postgres"): string { + if (dialect === "mysql") return `\`${name.replace(/`/g, "``")}\``; return `"${name.replace(/"/g, '""')}"`; } -export function quoteLiteral(value: unknown): string { +function quoteLiteral(value: unknown, dialect: SqlDialect = "postgres"): string { if (value === null || value === undefined) return "NULL"; if (typeof value === "number") { if (!Number.isFinite(value)) return "NULL"; @@ -13,11 +16,12 @@ export function quoteLiteral(value: unknown): string { } if (typeof value === "boolean") return value ? "TRUE" : "FALSE"; const str = String(value); - // Use dollar-quoting if the string contains single quotes and no dollar signs - if (str.includes("'") && !str.includes("$")) { + if (dialect === "postgres" && str.includes("'") && !str.includes("$")) { return `$$${str}$$`; } - return `'${str.replace(/'/g, "''")}'`; + const escaped = + dialect === "mysql" ? str.replace(/\\/g, "\\\\").replace(/'/g, "''") : str.replace(/'/g, "''"); + return `'${escaped}'`; } export type FilterCondition = { @@ -44,10 +48,10 @@ export type SortCondition = { direction: "ASC" | "DESC"; }; -function buildWhereClause(filters: FilterCondition[]): string { +function buildWhereClause(filters: FilterCondition[], dialect: SqlDialect): string { if (filters.length === 0) return ""; const conditions = filters.map((f) => { - const col = quoteIdent(f.column); + const col = quoteIdent(f.column, dialect); switch (f.operator) { case "IS NULL": return `${col} IS NULL`; @@ -56,23 +60,25 @@ function buildWhereClause(filters: FilterCondition[]): string { case "IN": { const values = f.value .split(",") - .map((v) => quoteLiteral(v.trim())) + .map((v) => quoteLiteral(v.trim(), dialect)) .join(", "); return `${col} IN (${values})`; } - case "LIKE": case "ILIKE": - return `${col} ${f.operator} ${quoteLiteral(f.value)}`; + if (dialect === "mysql") return `${col} LIKE ${quoteLiteral(f.value, dialect)}`; + return `${col} ILIKE ${quoteLiteral(f.value, dialect)}`; + case "LIKE": + return `${col} LIKE ${quoteLiteral(f.value, dialect)}`; default: - return `${col} ${f.operator} ${quoteLiteral(f.value)}`; + return `${col} ${f.operator} ${quoteLiteral(f.value, dialect)}`; } }); return ` WHERE ${conditions.join(" AND ")}`; } -function buildOrderByClause(sorts: SortCondition[]): string { +function buildOrderByClause(sorts: SortCondition[], dialect: SqlDialect): string { if (sorts.length === 0) return ""; - const parts = sorts.map((s) => `${quoteIdent(s.column)} ${s.direction}`); + const parts = sorts.map((s) => `${quoteIdent(s.column, dialect)} ${s.direction}`); return ` ORDER BY ${parts.join(", ")}`; } @@ -84,18 +90,28 @@ export function buildSelectQuery(params: { limit: number; offset: number; primaryKeys?: string[]; + dialect?: SqlDialect; }): string { - const { schema, table, filters, sorts, limit, offset, primaryKeys } = params; - const tableName = `${quoteIdent(schema)}.${quoteIdent(table)}`; - const where = buildWhereClause(filters); + const { + schema, + table, + filters, + sorts, + limit, + offset, + primaryKeys, + dialect = "postgres" + } = params; + const tableName = `${quoteIdent(schema, dialect)}.${quoteIdent(table, dialect)}`; + const where = buildWhereClause(filters, dialect); - // Default sort by PK for stable pagination let orderBy: string; if (sorts.length > 0) { - orderBy = buildOrderByClause(sorts); + orderBy = buildOrderByClause(sorts, dialect); } else if (primaryKeys && primaryKeys.length > 0) { orderBy = buildOrderByClause( - primaryKeys.map((pk) => ({ column: pk, direction: "ASC" as const })) + primaryKeys.map((pk) => ({ column: pk, direction: "ASC" as const })), + dialect ); } else { orderBy = ""; @@ -108,10 +124,11 @@ export function buildCountQuery(params: { schema: string; table: string; filters: FilterCondition[]; + dialect?: SqlDialect; }): string { - const { schema, table, filters } = params; - const tableName = `${quoteIdent(schema)}.${quoteIdent(table)}`; - const where = buildWhereClause(filters); + const { schema, table, filters, dialect = "postgres" } = params; + const tableName = `${quoteIdent(schema, dialect)}.${quoteIdent(table, dialect)}`; + const where = buildWhereClause(filters, dialect); return `SELECT COUNT(*) AS count FROM ${tableName}${where}`; } @@ -119,15 +136,18 @@ export function buildInsertQuery(params: { schema: string; table: string; row: Record; + dialect?: SqlDialect; }): string { - const { schema, table, row } = params; - const tableName = `${quoteIdent(schema)}.${quoteIdent(table)}`; + const { schema, table, row, dialect = "postgres" } = params; + const tableName = `${quoteIdent(schema, dialect)}.${quoteIdent(table, dialect)}`; const entries = Object.entries(row).filter(([, v]) => v !== undefined && v !== ""); if (entries.length === 0) { + if (dialect === "mysql") return `INSERT INTO ${tableName} () VALUES ()`; return `INSERT INTO ${tableName} DEFAULT VALUES RETURNING *`; } - const columns = entries.map(([k]) => quoteIdent(k)).join(", "); - const values = entries.map(([, v]) => quoteLiteral(v)).join(", "); + const columns = entries.map(([k]) => quoteIdent(k, dialect)).join(", "); + const values = entries.map(([, v]) => quoteLiteral(v, dialect)).join(", "); + if (dialect === "mysql") return `INSERT INTO ${tableName} (${columns}) VALUES (${values})`; return `INSERT INTO ${tableName} (${columns}) VALUES (${values}) RETURNING *`; } @@ -136,18 +156,20 @@ export function buildUpdateQuery(params: { table: string; changes: Record; primaryKeyMatch: Record; + dialect?: SqlDialect; }): string { - const { schema, table, changes, primaryKeyMatch } = params; + const { schema, table, changes, primaryKeyMatch, dialect = "postgres" } = params; if (Object.keys(primaryKeyMatch).length === 0) { throw new Error("UPDATE requires at least one primary key condition"); } - const tableName = `${quoteIdent(schema)}.${quoteIdent(table)}`; + const tableName = `${quoteIdent(schema, dialect)}.${quoteIdent(table, dialect)}`; const setClauses = Object.entries(changes) - .map(([col, val]) => `${quoteIdent(col)} = ${quoteLiteral(val)}`) + .map(([col, val]) => `${quoteIdent(col, dialect)} = ${quoteLiteral(val, dialect)}`) .join(", "); const whereClauses = Object.entries(primaryKeyMatch) - .map(([col, val]) => `${quoteIdent(col)} = ${quoteLiteral(val)}`) + .map(([col, val]) => `${quoteIdent(col, dialect)} = ${quoteLiteral(val, dialect)}`) .join(" AND "); + if (dialect === "mysql") return `UPDATE ${tableName} SET ${setClauses} WHERE ${whereClauses}`; return `UPDATE ${tableName} SET ${setClauses} WHERE ${whereClauses} RETURNING *`; } @@ -155,20 +177,20 @@ export function buildDeleteQuery(params: { schema: string; table: string; primaryKeyMatch: Record; + dialect?: SqlDialect; }): string { - const { schema, table, primaryKeyMatch } = params; + const { schema, table, primaryKeyMatch, dialect = "postgres" } = params; if (Object.keys(primaryKeyMatch).length === 0) { throw new Error("DELETE requires at least one primary key condition"); } - const tableName = `${quoteIdent(schema)}.${quoteIdent(table)}`; + const tableName = `${quoteIdent(schema, dialect)}.${quoteIdent(table, dialect)}`; const whereClauses = Object.entries(primaryKeyMatch) - .map(([col, val]) => `${quoteIdent(col)} = ${quoteLiteral(val)}`) + .map(([col, val]) => `${quoteIdent(col, dialect)} = ${quoteLiteral(val, dialect)}`) .join(" AND "); return `DELETE FROM ${tableName} WHERE ${whereClauses}`; } -// Note: RETURNING * in individual INSERT/UPDATE statements is harmless but unused when -// wrapped here — pgClient.query() returns only the last statement's result (COMMIT). +// RETURNING * in individual INSERT/UPDATE is harmless but unused when wrapped here. // The frontend re-fetches data after save anyway. export function wrapInTransaction(statements: string[]): string { return `BEGIN;\n${statements.join(";\n")};\nCOMMIT;`; diff --git a/frontend/src/pages/pam/PamDataExplorerPage/use-data-explorer-session.ts b/frontend/src/pages/pam/PamDataExplorerPage/use-data-explorer-session.ts index f54bc655778..029242e9624 100644 --- a/frontend/src/pages/pam/PamDataExplorerPage/use-data-explorer-session.ts +++ b/frontend/src/pages/pam/PamDataExplorerPage/use-data-explorer-session.ts @@ -340,14 +340,14 @@ export const useDataExplorerSession = ({ const openConnection = useCallback(async (): Promise<{ connectionId: string; - backendPid: number | null; + nativeConnectionId: number | null; }> => { const resp = await sendRequest< Extract >({ type: "open-connection" }); - return { connectionId: resp.connectionId, backendPid: resp.backendPid }; + return { connectionId: resp.connectionId, nativeConnectionId: resp.nativeConnectionId }; }, [sendRequest]); const closeConnection = useCallback((connectionId: string): void => { diff --git a/frontend/src/pages/pam/PamDataExplorerPage/use-query-tabs.ts b/frontend/src/pages/pam/PamDataExplorerPage/use-query-tabs.ts index 5d488a827a1..3c9804f8eec 100644 --- a/frontend/src/pages/pam/PamDataExplorerPage/use-query-tabs.ts +++ b/frontend/src/pages/pam/PamDataExplorerPage/use-query-tabs.ts @@ -7,7 +7,7 @@ import type { TableDetail } from "./data-explorer-types"; type TabBase = { id: string; connectionId: string; - backendPid: number | null; + nativeConnectionId: number | null; isInTransaction: boolean; lastFocusedAt: number; isDead?: boolean; @@ -33,7 +33,7 @@ type Tab = BrowseTab | QueryTab; const MAX_TABS = 20; type UseQueryTabsOptions = { - openConnection: () => Promise<{ connectionId: string; backendPid: number | null }>; + openConnection: () => Promise<{ connectionId: string; nativeConnectionId: number | null }>; closeConnection: (connectionId: string) => void; fetchTableDetail: ( connectionId: string, @@ -71,7 +71,7 @@ export const useQueryTabs = ({ // failure as a toast. Returns null if the caller should bail out. const acquireTabConnection = useCallback(async (): Promise<{ connectionId: string; - backendPid: number | null; + nativeConnectionId: number | null; } | null> => { if (!guardLimit()) return null; setIsOpeningTab(true); @@ -101,7 +101,7 @@ export const useQueryTabs = ({ kind: "query", id, connectionId: conn.connectionId, - backendPid: conn.backendPid, + nativeConnectionId: conn.nativeConnectionId, isInTransaction: false, lastFocusedAt: Date.now(), title, @@ -145,7 +145,7 @@ export const useQueryTabs = ({ kind: "browse", id, connectionId: conn.connectionId, - backendPid: conn.backendPid, + nativeConnectionId: conn.nativeConnectionId, isInTransaction: false, lastFocusedAt: Date.now(), title: `${schema}.${table}`,