Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/docs/usage-guide/changing_a_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,21 @@ key = "..." # your Codestral api key

(you can obtain a Codestral key from [here](https://console.mistral.ai/codestral))

### Databricks

To use a model hosted on Databricks (e.g. an Azure Databricks serving endpoint), set:

```toml
[config] # in configuration.toml
model = "databricks/databricks-claude-sonnet-4"
fallback_models = ["databricks/databricks-claude-sonnet-4"]
[databricks] # in .secrets.toml
api_key = "..." # your Databricks personal access token (PAT)
api_base = "https://adb-xxxx.azuredatabricks.net/serving-endpoints" # your workspace serving-endpoints URL
```

The model name after the `databricks/` prefix is the name of your serving endpoint. See LiteLLM's [Databricks provider docs](https://docs.litellm.ai/docs/providers/databricks) for details.

### Openrouter

To use model from Openrouter, for example, set:
Expand Down
19 changes: 17 additions & 2 deletions pr_agent/algo/ai_handlers/litellm_ai_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,14 @@ def __init__(self):
if get_settings().get("CODESTRAL.KEY", None):
os.environ["CODESTRAL_API_KEY"] = get_settings().get("CODESTRAL.KEY")

# Support Databricks-hosted models (e.g. Azure Databricks serving endpoints).
# Uses PAT/key authentication via LiteLLM's env vars.
# SEE https://docs.litellm.ai/docs/providers/databricks
if get_settings().get("DATABRICKS.API_KEY", None):
os.environ["DATABRICKS_API_KEY"] = get_settings().get("DATABRICKS.API_KEY")
if get_settings().get("DATABRICKS.API_BASE", None):
os.environ["DATABRICKS_API_BASE"] = get_settings().get("DATABRICKS.API_BASE")
Comment thread
qodo-free-for-open-source-projects[bot] marked this conversation as resolved.

# Check for Azure AD configuration
if get_settings().get("AZURE_AD.CLIENT_ID", None):
self.azure = True
Expand Down Expand Up @@ -461,12 +469,16 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
messages = [{"role": "user", "content": user}]

# Build request kwargs after normalizing messages for the target model.
# Databricks selects its endpoint via the DATABRICKS_API_BASE env var; don't let an
# api_base configured by another provider (OpenRouter/Ollama/Azure AD/OpenAI) during
# __init__ override it in multi-provider configs. None lets LiteLLM read the env var.
api_base = os.environ.get("DATABRICKS_API_BASE") if model.startswith("databricks/") else self.api_base
Comment thread
qodo-free-for-open-source-projects[bot] marked this conversation as resolved.
Outdated
kwargs = {
"model": model,
"deployment_id": deployment_id,
"messages": messages,
"timeout": get_settings().config.ai_timeout,
"api_base": self.api_base,
"api_base": api_base,
}

# Add temperature only if model supports it
Expand Down Expand Up @@ -540,7 +552,10 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:

# Inject api_key to the call. This key is populated during init by providers
# like Groq, SambaNova, XAI, Azure AD, and OpenRouter. Skip if None or placeholder.
if litellm.api_key and litellm.api_key != DUMMY_LITELLM_API_KEY:
# Databricks authenticates via the DATABRICKS_API_KEY/DATABRICKS_API_BASE env vars,
# so don't override it with another provider's key in multi-provider configs.
if (litellm.api_key and litellm.api_key != DUMMY_LITELLM_API_KEY
and not model.startswith("databricks/")):
kwargs["api_key"] = litellm.api_key
Comment thread
qodo-free-for-open-source-projects[bot] marked this conversation as resolved.

# Get completion with automatic streaming detection
Expand Down
4 changes: 4 additions & 0 deletions pr_agent/settings/.secrets_template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ key = ""
[deepinfra]
key = ""

[databricks]
api_key = "" # Databricks personal access token (PAT)
api_base = "" # e.g. https://adb-xxxx.azuredatabricks.net/serving-endpoints

[azure_ad]
# Azure AD authentication for OpenAI services
client_id = "" # Your Azure AD application client ID
Expand Down
54 changes: 54 additions & 0 deletions tests/unittest/test_litellm_api_key_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,60 @@ async def test_sambanova_key_forwarded_for_non_ollama_model(self, monkeypatch):

assert mock_call.call_args[1].get("api_key") == sambanova_key

@pytest.mark.asyncio
async def test_databricks_model_does_not_forward_foreign_key(self, monkeypatch):
"""Databricks models authenticate via DATABRICKS_API_KEY/DATABRICKS_API_BASE env vars.

In a multi-provider config another provider (e.g. Groq/OpenRouter) may have stored
its key in litellm.api_key during __init__. That key must NOT be forwarded for
databricks/* calls, otherwise it would override the intended env-var auth and break
Databricks authentication.
"""
foreign_key = "test-groq-key-shadowing-databricks"

with patch("pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion",
new_callable=AsyncMock) as mock_call:
mock_call.return_value = _mock_response()
handler = LiteLLMAIHandler()
# Simulate another provider having populated litellm.api_key during init
monkeypatch.setattr(litellm, "api_key", foreign_key)
await handler.chat_completion(
model="databricks/databricks-claude-sonnet-4", system="sys", user="usr"
)

assert "api_key" not in mock_call.call_args[1], (
f"Foreign provider key must not be forwarded for databricks/* models. "
f"kwargs had: {mock_call.call_args[1]}"
)

@pytest.mark.asyncio
async def test_databricks_model_does_not_forward_foreign_api_base(self, monkeypatch):
"""Databricks models select their endpoint via the DATABRICKS_API_BASE env var.

In a multi-provider config another provider (OpenRouter/Ollama/Azure AD/OpenAI) may
have set self.api_base during __init__. That base URL must NOT be forwarded for
databricks/* calls, otherwise it would route the request to the wrong host and
override the intended DATABRICKS_API_BASE endpoint. The Databricks base (or None,
which lets LiteLLM read the env var) must be used instead.
"""
databricks_base = "https://adb-1234.azuredatabricks.net/serving-endpoints"
monkeypatch.setenv("DATABRICKS_API_BASE", databricks_base)

with patch("pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion",
new_callable=AsyncMock) as mock_call:
mock_call.return_value = _mock_response()
handler = LiteLLMAIHandler()
# Simulate another provider having set api_base during init
handler.api_base = "https://openrouter.ai/api/v1"
await handler.chat_completion(
model="databricks/databricks-claude-sonnet-4", system="sys", user="usr"
)

assert mock_call.call_args[1]["api_base"] == databricks_base, (
f"Databricks endpoint must come from DATABRICKS_API_BASE, not a foreign provider's "
f"api_base. kwargs had: {mock_call.call_args[1]}"
)

@pytest.mark.asyncio
async def test_ollama_and_groq_coexist(self, monkeypatch):
"""Verify both Ollama and Groq keys can coexist and be forwarded correctly.
Expand Down
88 changes: 88 additions & 0 deletions tests/unittest/test_litellm_databricks_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""
Tests for Databricks provider wiring in LiteLLMAIHandler.__init__.

Verifies that DATABRICKS.API_KEY / DATABRICKS.API_BASE settings are exported to
the env vars LiteLLM's Databricks provider reads (DATABRICKS_API_KEY /
DATABRICKS_API_BASE), and that nothing is exported when they are unset.
"""
import os

import pytest

import pr_agent.algo.ai_handlers.litellm_ai_handler as litellm_handler
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed

# Env vars LiteLLMAIHandler.__init__ branches on — clear them so the handler
# under test isn't influenced by (or leaking into) the runner environment.
_ISOLATED_ENV = (
"DATABRICKS_API_KEY",
"DATABRICKS_API_BASE",
"OPENAI_API_KEY",
"AWS_USE_IMDS",
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_SESSION_TOKEN",
"AWS_REGION_NAME",
)


def _make_settings(overrides):
"""Minimal settings whose top-level .get() returns the provided overrides."""
return type("Settings", (), {
"config": type("Config", (), {
"reasoning_effort": None,
"ai_timeout": 30,
"custom_reasoning_model": False,
"max_model_tokens": 32000,
"verbosity_level": 0,
"seed": -1,
"get": lambda self, key, default=None: default,
})(),
"litellm": type("LiteLLM", (), {
"get": lambda self, key, default=None: default,
})(),
"get": lambda self, key, default=None: overrides.get(key, default),
})()


@pytest.fixture(autouse=True)
def _isolate_env(monkeypatch):
for var in _ISOLATED_ENV:
monkeypatch.delenv(var, raising=False)
yield
# Drop anything the handler wrote so it can't leak into other tests;
# monkeypatch then restores any pre-existing originals.
for var in ("DATABRICKS_API_KEY", "DATABRICKS_API_BASE"):
os.environ.pop(var, None)


def test_databricks_env_vars_exported_from_settings(monkeypatch):
overrides = {
"DATABRICKS.API_KEY": "dapi-test-123",
"DATABRICKS.API_BASE": "https://adb-1234.azuredatabricks.net/serving-endpoints",
}
monkeypatch.setattr(litellm_handler, "get_settings", lambda: _make_settings(overrides))

litellm_handler.LiteLLMAIHandler()

assert os.environ["DATABRICKS_API_KEY"] == "dapi-test-123"
assert os.environ["DATABRICKS_API_BASE"] == "https://adb-1234.azuredatabricks.net/serving-endpoints"


def test_databricks_env_vars_absent_when_unset(monkeypatch):
monkeypatch.setattr(litellm_handler, "get_settings", lambda: _make_settings({}))

litellm_handler.LiteLLMAIHandler()

assert "DATABRICKS_API_KEY" not in os.environ
assert "DATABRICKS_API_BASE" not in os.environ


def test_databricks_api_base_optional(monkeypatch):
"""API base is optional (a workspace default may be configured elsewhere)."""
overrides = {"DATABRICKS.API_KEY": "dapi-only-key"}
monkeypatch.setattr(litellm_handler, "get_settings", lambda: _make_settings(overrides))

litellm_handler.LiteLLMAIHandler()

assert os.environ["DATABRICKS_API_KEY"] == "dapi-only-key"
assert "DATABRICKS_API_BASE" not in os.environ
Loading