diff --git a/CHANGELOG.md b/CHANGELOG.md index aae265a60..e63f1bca8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +- feat: forward RawRequest and disconnect AbortSignal into onCallGenkit (#1891) - feat: Add requiresRole developer API for declarative security support and automatic Manifest extraction (#1908) - Validate literal `timeoutSeconds` values per v2 trigger type (0-540s for events, 0-3600s for HTTPS/callable, 0-1800s for task queues, 0-7s for identity functions) so misconfigured values fail at function-definition or manifest-extraction time instead of at deploy time. (#1877) - chore: drop support for Node 18 and below (minimum supported version is now Node 20) diff --git a/spec/v2/providers/https.spec.ts b/spec/v2/providers/https.spec.ts index 5e778c594..f8546cd4a 100644 --- a/spec/v2/providers/https.spec.ts +++ b/spec/v2/providers/https.spec.ts @@ -719,6 +719,7 @@ describe("onCall", () => { describe("onCallGenkit", () => { it("calls with JSON requests", async () => { + let gotContext: any; const flow = { __action: { name: "test", @@ -726,7 +727,38 @@ describe("onCallGenkit", () => { run: sinon.stub(), stream: sinon.stub(), }; - flow.run.withArgs("answer").returns({ result: 42 }); + flow.run.callsFake((data, opts) => { + gotContext = opts.context; + expect(data).to.equal("answer"); + return { result: 42 }; + }); + flow.stream.throws("Unexpected stream"); + + const f = https.onCallGenkit(flow); + + const req = request({ data: "answer" }); + const res = await runHandler(f, req); + expect(JSON.parse(res.body)).to.deep.equal({ result: 42 }); + expect(gotContext[https.CALLABLE_RAW_REQUEST]).to.equal(req); + expect(gotContext[https.CALLABLE_RESPONSE_SIGNAL]).to.be.instanceOf(AbortSignal); + }); + + it("exposes rawRequest and response signal through Symbol.for lookups", async () => { + let gotRawRequest: unknown; + let gotResponseSignal: unknown; + const flow = { + __action: { + name: "test", + }, + run: sinon.stub(), + stream: sinon.stub(), + }; + flow.run.callsFake((data, opts) => { + expect(data).to.equal("answer"); + gotRawRequest = opts.context[Symbol.for("firebase.callable.rawRequest")]; + gotResponseSignal = opts.context[Symbol.for("firebase.callable.responseSignal")]; + return { result: 42 }; + }); flow.stream.throws("Unexpected stream"); const f = https.onCallGenkit(flow); @@ -734,9 +766,12 @@ describe("onCallGenkit", () => { const req = request({ data: "answer" }); const res = await runHandler(f, req); expect(JSON.parse(res.body)).to.deep.equal({ result: 42 }); + expect(gotRawRequest).to.equal(req); + expect(gotResponseSignal).to.be.instanceOf(AbortSignal); }); - it("Streams with SSE requests", async () => { + it("forwards rawRequest and response signal into streaming Genkit action context", async () => { + let gotContext: any; const flow = { __action: { name: "test", @@ -745,25 +780,119 @@ describe("onCallGenkit", () => { stream: sinon.stub(), }; flow.run.onFirstCall().throws(); - flow.stream.withArgs("answer").returns({ - stream: (async function* () { - await Promise.resolve(); - yield 1; - await Promise.resolve(); - yield 2; - })(), - output: Promise.resolve(42), + flow.stream.callsFake((data, opts) => { + gotContext = opts.context; + expect(data).to.equal("answer"); + return { + stream: (async function* () { + await Promise.resolve(); + yield 1; + await Promise.resolve(); + yield 2; + })(), + output: Promise.resolve(42), + }; }); const f = https.onCallGenkit(flow); const req = request({ data: "answer", headers: { accept: "text/event-stream" } }); const res = await runHandler(f, req); + expect(gotContext[https.CALLABLE_RAW_REQUEST]).to.equal(req); + expect(gotContext[https.CALLABLE_RESPONSE_SIGNAL]).to.be.instanceOf(AbortSignal); expect(res.body).to.equal( ['data: {"message":1}', 'data: {"message":2}', 'data: {"result":42}', ""].join("\n\n") ); }); + it("aborts the forwarded response signal when the client disconnects", async () => { + let capturedSignal: AbortSignal; + let resolveOutput: (value: number) => void; + const output = new Promise((resolve) => { + resolveOutput = resolve; + }); + let notifyStreamStarted: () => void; + const streamStarted = new Promise((resolve) => { + notifyStreamStarted = resolve; + }); + const flow = { + __action: { + name: "test", + }, + run: sinon.stub(), + stream: sinon.stub(), + }; + flow.run.onFirstCall().throws(); + flow.stream.callsFake((_data, opts) => { + capturedSignal = opts.context[https.CALLABLE_RESPONSE_SIGNAL] as AbortSignal; + notifyStreamStarted(); + return { + stream: (async function* () { + await output; + // This test intentionally doesn't emit any stream messages, but ESLint's `require-yield` + // rule requires at least one `yield` in a generator function. + return; + // eslint-disable-next-line no-unreachable + yield undefined; + })(), + output, + }; + }); + + const f = https.onCallGenkit(flow); + const req = request({ data: "answer", headers: { accept: "text/event-stream" } }); + const resPromise = runHandler(f, req); + + await streamStarted; + expect(capturedSignal.aborted).to.equal(false); + + req.emit("close"); + expect(capturedSignal.aborted).to.equal(true); + resolveOutput(42); + + const res = await resPromise; + expect(res.body).to.be.undefined; + }); + + it("aborts the forwarded response signal for non-streaming Genkit actions", async () => { + let capturedSignal: AbortSignal; + let resolveResult: (value: { result: number }) => void; + const result = new Promise<{ result: number }>((resolve) => { + resolveResult = resolve; + }); + let notifyRunStarted: () => void; + const runStarted = new Promise((resolve) => { + notifyRunStarted = resolve; + }); + const flow = { + __action: { + name: "test", + }, + run: sinon.stub(), + stream: sinon.stub(), + }; + flow.run.callsFake((_data, opts) => { + capturedSignal = opts.context[Symbol.for("firebase.callable.responseSignal")] as AbortSignal; + notifyRunStarted(); + return result; + }); + flow.stream.throws("Unexpected stream"); + + const f = https.onCallGenkit(flow); + const req = request({ data: "answer" }); + const resPromise = runHandler(f, req); + + await runStarted; + expect(capturedSignal.aborted).to.equal(false); + + req.emit("close"); + expect(capturedSignal.aborted).to.equal(true); + resolveResult({ result: 42 }); + + const res = await resPromise; + expect(res.body).to.be.undefined; + }); + it("Exports types that are compatible with the genkit library (compilation is success)", () => { const ai = genkit({}); const flow = ai.defineFlow("test", () => 42); diff --git a/src/common/providers/https.ts b/src/common/providers/https.ts index cb8f8c9fe..b3f81fecc 100644 --- a/src/common/providers/https.ts +++ b/src/common/providers/https.ts @@ -43,6 +43,16 @@ export const CALLABLE_AUTH_HEADER = "x-callable-context-auth"; export const ORIGINAL_AUTH_HEADER = "x-original-auth"; /** @internal */ export const DEFAULT_HEARTBEAT_SECONDS = 30; +/** + * Symbol key used by {@link https.onCallGenkit} to expose the underlying callable raw request + * inside the Genkit action context. + */ +export const CALLABLE_RAW_REQUEST = Symbol.for("firebase.callable.rawRequest"); +/** + * Symbol key used by {@link https.onCallGenkit} to expose the callable disconnect signal + * inside the Genkit action context. + */ +export const CALLABLE_RESPONSE_SIGNAL = Symbol.for("firebase.callable.responseSignal"); /** An express request with the wire format representation of the request body. */ export interface Request extends express.Request { diff --git a/src/v2/providers/https.ts b/src/v2/providers/https.ts index 8aa9438ec..0ca45d482 100644 --- a/src/v2/providers/https.ts +++ b/src/v2/providers/https.ts @@ -32,6 +32,8 @@ import { wrapTraceContext } from "../trace"; import { isDebugFeatureEnabled } from "../../common/debug"; import { ResetValue } from "../../common/options"; import { + CALLABLE_RAW_REQUEST, + CALLABLE_RESPONSE_SIGNAL, type CallableRequest, type CallableResponse, type FunctionsErrorCode, @@ -52,6 +54,7 @@ import { withInit } from "../../common/onInit"; import * as logger from "../../logger"; export type { Request, CallableRequest, CallableResponse, FunctionsErrorCode }; +export { CALLABLE_RAW_REQUEST, CALLABLE_RESPONSE_SIGNAL }; export { HttpsError }; /** @@ -585,8 +588,12 @@ export function onCallGenkit( const cloudFunction = onCall, Promise>, ActionStream>( opts, async (req, res) => { - const context: Omit = {}; + const context: Omit & { + [key: symbol]: unknown; + } = {}; copyIfPresent(context, req, "auth", "app", "instanceIdToken"); + context[CALLABLE_RAW_REQUEST] = req.rawRequest; + context[CALLABLE_RESPONSE_SIGNAL] = res?.signal; if (!req.acceptsStreaming) { const { result } = await action.run(req.data, { context });