Skip to content
83 changes: 73 additions & 10 deletions spec/v2/providers/https.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -677,24 +677,32 @@ describe("onCall", () => {

describe("onCallGenkit", () => {
it("calls with JSON requests", async () => {
let gotContext: any;
const flow = {
__action: {
name: "test",
},
run: sinon.stub(),
stream: sinon.stub(),
};
flow.run.withArgs("answer").returns({ result: 42 });
flow.run.callsFake((data, opts) => {
gotContext = opts.context;
expect(data).to.equal("answer");
return { result: 42 };
});
flow.stream.throws("Unexpected stream");

const f = https.onCallGenkit(flow);

const req = request({ data: "answer" });
const res = await runHandler(f, req);
expect(JSON.parse(res.body)).to.deep.equal({ result: 42 });
expect(gotContext[https.CALLABLE_RAW_REQUEST]).to.equal(req);
expect(gotContext[https.CALLABLE_RESPONSE_SIGNAL]).to.be.instanceOf(AbortSignal);
});

it("Streams with SSE requests", async () => {
it("forwards rawRequest and response signal into streaming Genkit action context", async () => {
let gotContext: any;
const flow = {
__action: {
name: "test",
Expand All @@ -703,25 +711,80 @@ describe("onCallGenkit", () => {
stream: sinon.stub(),
};
flow.run.onFirstCall().throws();
flow.stream.withArgs("answer").returns({
stream: (async function* () {
await Promise.resolve();
yield 1;
await Promise.resolve();
yield 2;
})(),
output: Promise.resolve(42),
flow.stream.callsFake((data, opts) => {
gotContext = opts.context;
expect(data).to.equal("answer");
return {
stream: (async function* () {
await Promise.resolve();
yield 1;
await Promise.resolve();
yield 2;
})(),
output: Promise.resolve(42),
};
});

const f = https.onCallGenkit(flow);

const req = request({ data: "answer", headers: { accept: "text/event-stream" } });
const res = await runHandler(f, req);
expect(gotContext[https.CALLABLE_RAW_REQUEST]).to.equal(req);
expect(gotContext[https.CALLABLE_RESPONSE_SIGNAL]).to.be.instanceOf(AbortSignal);
expect(res.body).to.equal(
['data: {"message":1}', 'data: {"message":2}', 'data: {"result":42}', ""].join("\n\n")
);
});

it("aborts the forwarded response signal when the client disconnects", async () => {
let capturedSignal: AbortSignal;
let resolveOutput: (value: number) => void;
const output = new Promise<number>((resolve) => {
resolveOutput = resolve;
});
let notifyStreamStarted: () => void;
const streamStarted = new Promise<void>((resolve) => {
notifyStreamStarted = resolve;
});
const flow = {
__action: {
name: "test",
},
run: sinon.stub(),
stream: sinon.stub(),
};
flow.run.onFirstCall().throws();
flow.stream.callsFake((_data, opts) => {
capturedSignal = opts.context[https.CALLABLE_RESPONSE_SIGNAL] as AbortSignal;
notifyStreamStarted();
return {
stream: (async function* () {
await output;
// This test intentionally doesn't emit any stream messages, but ESLint's `require-yield`
// rule requires at least one `yield` in a generator function.
return;
// eslint-disable-next-line no-unreachable
yield undefined;
})(),
output,
};
});

const f = https.onCallGenkit(flow);
const req = request({ data: "answer", headers: { accept: "text/event-stream" } });
const resPromise = runHandler(f, req);

await streamStarted;
expect(capturedSignal.aborted).to.equal(false);

req.emit("close");
expect(capturedSignal.aborted).to.equal(true);
resolveOutput(42);

const res = await resPromise;
expect(res.body).to.be.undefined;
});

it("Exports types that are compatible with the genkit library (compilation is success)", () => {
const ai = genkit({});
const flow = ai.defineFlow("test", () => 42);
Expand Down
10 changes: 10 additions & 0 deletions src/common/providers/https.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ export const CALLABLE_AUTH_HEADER = "x-callable-context-auth";
export const ORIGINAL_AUTH_HEADER = "x-original-auth";
/** @internal */
export const DEFAULT_HEARTBEAT_SECONDS = 30;
/**
* Symbol key used by {@link https.onCallGenkit} to expose the underlying callable raw request
* inside the Genkit action context.
*/
export const CALLABLE_RAW_REQUEST = Symbol.for("firebase.callable.rawRequest");
/**
* Symbol key used by {@link https.onCallGenkit} to expose the callable disconnect signal
* inside the Genkit action context.
*/
export const CALLABLE_RESPONSE_SIGNAL = Symbol.for("firebase.callable.responseSignal");

/** An express request with the wire format representation of the request body. */
export interface Request extends express.Request {
Expand Down
9 changes: 8 additions & 1 deletion src/v2/providers/https.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import { wrapTraceContext } from "../trace";
import { isDebugFeatureEnabled } from "../../common/debug";
import { ResetValue } from "../../common/options";
import {
CALLABLE_RAW_REQUEST,
CALLABLE_RESPONSE_SIGNAL,
type CallableRequest,
type CallableResponse,
type FunctionsErrorCode,
Expand All @@ -50,6 +52,7 @@ import { withInit } from "../../common/onInit";
import * as logger from "../../logger";

export type { Request, CallableRequest, CallableResponse, FunctionsErrorCode };
export { CALLABLE_RAW_REQUEST, CALLABLE_RESPONSE_SIGNAL };
export { HttpsError };

/**
Expand Down Expand Up @@ -604,8 +607,12 @@ export function onCallGenkit<A extends GenkitAction>(
const cloudFunction = onCall<ActionInput<A>, Promise<ActionOutput<A>>, ActionStream<A>>(
opts,
async (req, res) => {
const context: Omit<CallableRequest, "data" | "rawRequest" | "acceptsStreaming"> = {};
const context: Omit<CallableRequest, "data" | "rawRequest" | "acceptsStreaming"> & {
[key: symbol]: unknown;
} = {};
copyIfPresent(context, req, "auth", "app", "instanceIdToken");
context[CALLABLE_RAW_REQUEST] = req.rawRequest;
context[CALLABLE_RESPONSE_SIGNAL] = res.signal;
Comment thread
IzaakGough marked this conversation as resolved.
Outdated

if (!req.acceptsStreaming) {
const { result } = await action.run(req.data, { context });
Expand Down
Loading