diff --git a/docs/docs/usage-guide/changing_a_model.md b/docs/docs/usage-guide/changing_a_model.md index b56164255d..a5d8ff23d2 100644 --- a/docs/docs/usage-guide/changing_a_model.md +++ b/docs/docs/usage-guide/changing_a_model.md @@ -386,6 +386,21 @@ key = "..." # your Codestral api key (you can obtain a Codestral key from [here](https://console.mistral.ai/codestral)) +### Databricks + +To use a model hosted on Databricks (e.g. an Azure Databricks serving endpoint), set: + +```toml +[config] # in configuration.toml +model = "databricks/databricks-claude-sonnet-4" +fallback_models = ["databricks/databricks-claude-sonnet-4"] +[databricks] # in .secrets.toml +api_key = "..." # your Databricks personal access token (PAT) +api_base = "https://adb-xxxx.azuredatabricks.net/serving-endpoints" # your workspace serving-endpoints URL +``` + +The model name after the `databricks/` prefix is the name of your serving endpoint. See LiteLLM's [Databricks provider docs](https://docs.litellm.ai/docs/providers/databricks) for details. + ### Openrouter To use model from Openrouter, for example, set: diff --git a/pr_agent/algo/ai_handlers/litellm_ai_handler.py b/pr_agent/algo/ai_handlers/litellm_ai_handler.py index 6acf6a878b..36003607a7 100644 --- a/pr_agent/algo/ai_handlers/litellm_ai_handler.py +++ b/pr_agent/algo/ai_handlers/litellm_ai_handler.py @@ -200,6 +200,14 @@ def __init__(self): if get_settings().get("CODESTRAL.KEY", None): os.environ["CODESTRAL_API_KEY"] = get_settings().get("CODESTRAL.KEY") + # Support Databricks-hosted models (e.g. Azure Databricks serving endpoints). + # Uses PAT/key authentication via LiteLLM's env vars. + # SEE https://docs.litellm.ai/docs/providers/databricks + if get_settings().get("DATABRICKS.API_KEY", None): + os.environ["DATABRICKS_API_KEY"] = get_settings().get("DATABRICKS.API_KEY") + if get_settings().get("DATABRICKS.API_BASE", None): + os.environ["DATABRICKS_API_BASE"] = get_settings().get("DATABRICKS.API_BASE") + # Check for Azure AD configuration if get_settings().get("AZURE_AD.CLIENT_ID", None): self.azure = True @@ -408,7 +416,12 @@ async def chat_completion(self, model: str, system: str, user: str, temperature: try: resp, finish_reason = None, None deployment_id = self.deployment_id - if self.azure: + # Capture the provider prefix before any rewriting below. Databricks auth/endpoint + # selection keys off this; rewriting (e.g. 'azure/' + model when Azure is enabled in + # a multi-provider config) would otherwise hide the 'databricks/' prefix and bypass + # the guards that keep Databricks on its own DATABRICKS_API_KEY/DATABRICKS_API_BASE. + is_databricks = model.startswith("databricks/") + if self.azure and not is_databricks: model = 'azure/' + model if 'claude' in model and not system: system = "No system prompt provided" @@ -461,12 +474,16 @@ async def chat_completion(self, model: str, system: str, user: str, temperature: messages = [{"role": "user", "content": user}] # Build request kwargs after normalizing messages for the target model. + # Databricks selects its endpoint via the DATABRICKS_API_BASE env var; don't let an + # api_base configured by another provider (OpenRouter/Ollama/Azure AD/OpenAI) during + # __init__ override it in multi-provider configs. None lets LiteLLM read the env var. + api_base = os.environ.get("DATABRICKS_API_BASE") if is_databricks else self.api_base kwargs = { "model": model, "deployment_id": deployment_id, "messages": messages, "timeout": get_settings().config.ai_timeout, - "api_base": self.api_base, + "api_base": api_base, } # Add temperature only if model supports it @@ -540,7 +557,10 @@ async def chat_completion(self, model: str, system: str, user: str, temperature: # Inject api_key to the call. This key is populated during init by providers # like Groq, SambaNova, XAI, Azure AD, and OpenRouter. Skip if None or placeholder. - if litellm.api_key and litellm.api_key != DUMMY_LITELLM_API_KEY: + # Databricks authenticates via the DATABRICKS_API_KEY/DATABRICKS_API_BASE env vars, + # so don't override it with another provider's key in multi-provider configs. + if (litellm.api_key and litellm.api_key != DUMMY_LITELLM_API_KEY + and not is_databricks): kwargs["api_key"] = litellm.api_key # Get completion with automatic streaming detection diff --git a/pr_agent/settings/.secrets_template.toml b/pr_agent/settings/.secrets_template.toml index e0d8e11d10..6c9c9f734a 100644 --- a/pr_agent/settings/.secrets_template.toml +++ b/pr_agent/settings/.secrets_template.toml @@ -121,6 +121,10 @@ key = "" [deepinfra] key = "" +[databricks] +api_key = "" # Databricks personal access token (PAT) +api_base = "" # e.g. https://adb-xxxx.azuredatabricks.net/serving-endpoints + [azure_ad] # Azure AD authentication for OpenAI services client_id = "" # Your Azure AD application client ID diff --git a/tests/unittest/test_litellm_api_key_guard.py b/tests/unittest/test_litellm_api_key_guard.py index dce655fc87..350ed874a9 100644 --- a/tests/unittest/test_litellm_api_key_guard.py +++ b/tests/unittest/test_litellm_api_key_guard.py @@ -320,6 +320,100 @@ async def test_sambanova_key_forwarded_for_non_ollama_model(self, monkeypatch): assert mock_call.call_args[1].get("api_key") == sambanova_key + @pytest.mark.asyncio + async def test_databricks_model_does_not_forward_foreign_key(self, monkeypatch): + """Databricks models authenticate via DATABRICKS_API_KEY/DATABRICKS_API_BASE env vars. + + In a multi-provider config another provider (e.g. Groq/OpenRouter) may have stored + its key in litellm.api_key during __init__. That key must NOT be forwarded for + databricks/* calls, otherwise it would override the intended env-var auth and break + Databricks authentication. + """ + foreign_key = "test-groq-key-shadowing-databricks" + + with patch("pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", + new_callable=AsyncMock) as mock_call: + mock_call.return_value = _mock_response() + handler = LiteLLMAIHandler() + # Simulate another provider having populated litellm.api_key during init + monkeypatch.setattr(litellm, "api_key", foreign_key) + await handler.chat_completion( + model="databricks/databricks-claude-sonnet-4", system="sys", user="usr" + ) + + assert "api_key" not in mock_call.call_args[1], ( + f"Foreign provider key must not be forwarded for databricks/* models. " + f"kwargs had: {mock_call.call_args[1]}" + ) + + @pytest.mark.asyncio + async def test_databricks_model_does_not_forward_foreign_api_base(self, monkeypatch): + """Databricks models select their endpoint via the DATABRICKS_API_BASE env var. + + In a multi-provider config another provider (OpenRouter/Ollama/Azure AD/OpenAI) may + have set self.api_base during __init__. That base URL must NOT be forwarded for + databricks/* calls, otherwise it would route the request to the wrong host and + override the intended DATABRICKS_API_BASE endpoint. The Databricks base (or None, + which lets LiteLLM read the env var) must be used instead. + """ + databricks_base = "https://adb-1234.azuredatabricks.net/serving-endpoints" + monkeypatch.setenv("DATABRICKS_API_BASE", databricks_base) + + with patch("pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", + new_callable=AsyncMock) as mock_call: + mock_call.return_value = _mock_response() + handler = LiteLLMAIHandler() + # Simulate another provider having set api_base during init + handler.api_base = "https://openrouter.ai/api/v1" + await handler.chat_completion( + model="databricks/databricks-claude-sonnet-4", system="sys", user="usr" + ) + + assert mock_call.call_args[1]["api_base"] == databricks_base, ( + f"Databricks endpoint must come from DATABRICKS_API_BASE, not a foreign provider's " + f"api_base. kwargs had: {mock_call.call_args[1]}" + ) + + @pytest.mark.asyncio + async def test_databricks_guards_survive_azure_mode(self, monkeypatch): + """Azure mode must not rewrite databricks/* models and bypass the Databricks guards. + + When Azure is enabled in a multi-provider config (OPENAI.API_TYPE=azure or AZURE_AD), + chat_completion() prepends 'azure/' to the model. If that rewrite happened for a + databricks/* model the prefix-based guards would never trigger, routing the call to + Azure with a foreign key/base. The model must keep its 'databricks/' prefix, the foreign + key must not be forwarded, and api_base must come from DATABRICKS_API_BASE. + """ + foreign_key = "test-azure-key-shadowing-databricks" + databricks_base = "https://adb-1234.azuredatabricks.net/serving-endpoints" + monkeypatch.setenv("DATABRICKS_API_BASE", databricks_base) + + with patch("pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", + new_callable=AsyncMock) as mock_call: + mock_call.return_value = _mock_response() + handler = LiteLLMAIHandler() + # Simulate Azure mode + a foreign key/base set by another provider during init + handler.azure = True + handler.api_base = "https://my-azure.openai.azure.com" + monkeypatch.setattr(litellm, "api_key", foreign_key) + await handler.chat_completion( + model="databricks/databricks-claude-sonnet-4", system="sys", user="usr" + ) + + forwarded = mock_call.call_args[1] + assert forwarded["model"] == "databricks/databricks-claude-sonnet-4", ( + f"databricks/* model must not be rewritten with an 'azure/' prefix. " + f"kwargs had: {forwarded}" + ) + assert "api_key" not in forwarded, ( + f"Foreign provider key must not be forwarded for databricks/* models even in Azure " + f"mode. kwargs had: {forwarded}" + ) + assert forwarded["api_base"] == databricks_base, ( + f"Databricks endpoint must come from DATABRICKS_API_BASE even in Azure mode. " + f"kwargs had: {forwarded}" + ) + @pytest.mark.asyncio async def test_ollama_and_groq_coexist(self, monkeypatch): """Verify both Ollama and Groq keys can coexist and be forwarded correctly. diff --git a/tests/unittest/test_litellm_databricks_provider.py b/tests/unittest/test_litellm_databricks_provider.py new file mode 100644 index 0000000000..7c15ee7687 --- /dev/null +++ b/tests/unittest/test_litellm_databricks_provider.py @@ -0,0 +1,94 @@ +""" +Tests for Databricks provider wiring in LiteLLMAIHandler.__init__. + +Verifies that DATABRICKS.API_KEY / DATABRICKS.API_BASE settings are exported to +the env vars LiteLLM's Databricks provider reads (DATABRICKS_API_KEY / +DATABRICKS_API_BASE), and that nothing is exported when they are unset. +""" +import os + +import litellm +import pytest + +import pr_agent.algo.ai_handlers.litellm_ai_handler as litellm_handler + +# Env vars LiteLLMAIHandler.__init__ branches on — clear them so the handler +# under test isn't influenced by (or leaking into) the runner environment. +_ISOLATED_ENV = ( + "DATABRICKS_API_KEY", + "DATABRICKS_API_BASE", + "OPENAI_API_KEY", + "AWS_USE_IMDS", + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", + "AWS_REGION_NAME", +) + + +def _make_settings(overrides): + """Minimal settings whose top-level .get() returns the provided overrides.""" + return type("Settings", (), { + "config": type("Config", (), { + "reasoning_effort": None, + "ai_timeout": 30, + "custom_reasoning_model": False, + "max_model_tokens": 32000, + "verbosity_level": 0, + "seed": -1, + "get": lambda self, key, default=None: default, + })(), + "litellm": type("LiteLLM", (), { + "get": lambda self, key, default=None: default, + })(), + "get": lambda self, key, default=None: overrides.get(key, default), + })() + + +@pytest.fixture(autouse=True) +def _isolate_env(monkeypatch): + for var in _ISOLATED_ENV: + monkeypatch.delenv(var, raising=False) + # LiteLLMAIHandler.__init__ mutates the global litellm.api_key (sets the dummy + # fallback when OPENAI_API_KEY is absent, as it is here). Snapshot it so this + # file can't leak that global into order-dependent later tests. + saved_api_key = litellm.api_key + yield + litellm.api_key = saved_api_key + # Drop anything the handler wrote so it can't leak into other tests; + # monkeypatch then restores any pre-existing originals. + for var in ("DATABRICKS_API_KEY", "DATABRICKS_API_BASE"): + os.environ.pop(var, None) + + +def test_databricks_env_vars_exported_from_settings(monkeypatch): + overrides = { + "DATABRICKS.API_KEY": "dapi-test-123", + "DATABRICKS.API_BASE": "https://adb-1234.azuredatabricks.net/serving-endpoints", + } + monkeypatch.setattr(litellm_handler, "get_settings", lambda: _make_settings(overrides)) + + litellm_handler.LiteLLMAIHandler() + + assert os.environ["DATABRICKS_API_KEY"] == "dapi-test-123" + assert os.environ["DATABRICKS_API_BASE"] == "https://adb-1234.azuredatabricks.net/serving-endpoints" + + +def test_databricks_env_vars_absent_when_unset(monkeypatch): + monkeypatch.setattr(litellm_handler, "get_settings", lambda: _make_settings({})) + + litellm_handler.LiteLLMAIHandler() + + assert "DATABRICKS_API_KEY" not in os.environ + assert "DATABRICKS_API_BASE" not in os.environ + + +def test_databricks_api_base_optional(monkeypatch): + """API base is optional (a workspace default may be configured elsewhere).""" + overrides = {"DATABRICKS.API_KEY": "dapi-only-key"} + monkeypatch.setattr(litellm_handler, "get_settings", lambda: _make_settings(overrides)) + + litellm_handler.LiteLLMAIHandler() + + assert os.environ["DATABRICKS_API_KEY"] == "dapi-only-key" + assert "DATABRICKS_API_BASE" not in os.environ