Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/docs/usage-guide/changing_a_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,21 @@ key = "..." # your Codestral api key

(you can obtain a Codestral key from [here](https://console.mistral.ai/codestral))

### Databricks

To use a model hosted on Databricks (e.g. an Azure Databricks serving endpoint), set:

```toml
[config] # in configuration.toml
model = "databricks/databricks-claude-sonnet-4"
fallback_models = ["databricks/databricks-claude-sonnet-4"]
[databricks] # in .secrets.toml
api_key = "..." # your Databricks personal access token (PAT)
api_base = "https://adb-xxxx.azuredatabricks.net/serving-endpoints" # your workspace serving-endpoints URL
```

The model name after the `databricks/` prefix is the name of your serving endpoint. See LiteLLM's [Databricks provider docs](https://docs.litellm.ai/docs/providers/databricks) for details.

### Openrouter

To use model from Openrouter, for example, set:
Expand Down
13 changes: 12 additions & 1 deletion pr_agent/algo/ai_handlers/litellm_ai_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,14 @@ def __init__(self):
if get_settings().get("CODESTRAL.KEY", None):
os.environ["CODESTRAL_API_KEY"] = get_settings().get("CODESTRAL.KEY")

# Support Databricks-hosted models (e.g. Azure Databricks serving endpoints).
# Uses PAT/key authentication via LiteLLM's env vars.
# SEE https://docs.litellm.ai/docs/providers/databricks
if get_settings().get("DATABRICKS.API_KEY", None):
os.environ["DATABRICKS_API_KEY"] = get_settings().get("DATABRICKS.API_KEY")
if get_settings().get("DATABRICKS.API_BASE", None):
os.environ["DATABRICKS_API_BASE"] = get_settings().get("DATABRICKS.API_BASE")
Comment thread
qodo-free-for-open-source-projects[bot] marked this conversation as resolved.

# Check for Azure AD configuration
if get_settings().get("AZURE_AD.CLIENT_ID", None):
self.azure = True
Expand Down Expand Up @@ -540,7 +548,10 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:

# Inject api_key to the call. This key is populated during init by providers
# like Groq, SambaNova, XAI, Azure AD, and OpenRouter. Skip if None or placeholder.
if litellm.api_key and litellm.api_key != DUMMY_LITELLM_API_KEY:
# Databricks authenticates via the DATABRICKS_API_KEY/DATABRICKS_API_BASE env vars,
# so don't override it with another provider's key in multi-provider configs.
if (litellm.api_key and litellm.api_key != DUMMY_LITELLM_API_KEY
and not model.startswith("databricks/")):
kwargs["api_key"] = litellm.api_key
Comment thread
qodo-free-for-open-source-projects[bot] marked this conversation as resolved.

# Get completion with automatic streaming detection
Expand Down
4 changes: 4 additions & 0 deletions pr_agent/settings/.secrets_template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ key = ""
[deepinfra]
key = ""

[databricks]
api_key = "" # Databricks personal access token (PAT)
api_base = "" # e.g. https://adb-xxxx.azuredatabricks.net/serving-endpoints

[azure_ad]
# Azure AD authentication for OpenAI services
client_id = "" # Your Azure AD application client ID
Expand Down
26 changes: 26 additions & 0 deletions tests/unittest/test_litellm_api_key_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,32 @@ async def test_sambanova_key_forwarded_for_non_ollama_model(self, monkeypatch):

assert mock_call.call_args[1].get("api_key") == sambanova_key

@pytest.mark.asyncio
async def test_databricks_model_does_not_forward_foreign_key(self, monkeypatch):
"""Databricks models authenticate via DATABRICKS_API_KEY/DATABRICKS_API_BASE env vars.

In a multi-provider config another provider (e.g. Groq/OpenRouter) may have stored
its key in litellm.api_key during __init__. That key must NOT be forwarded for
databricks/* calls, otherwise it would override the intended env-var auth and break
Databricks authentication.
"""
foreign_key = "test-groq-key-shadowing-databricks"

with patch("pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion",
new_callable=AsyncMock) as mock_call:
mock_call.return_value = _mock_response()
handler = LiteLLMAIHandler()
# Simulate another provider having populated litellm.api_key during init
monkeypatch.setattr(litellm, "api_key", foreign_key)
await handler.chat_completion(
model="databricks/databricks-claude-sonnet-4", system="sys", user="usr"
)

assert "api_key" not in mock_call.call_args[1], (
f"Foreign provider key must not be forwarded for databricks/* models. "
f"kwargs had: {mock_call.call_args[1]}"
)

@pytest.mark.asyncio
async def test_ollama_and_groq_coexist(self, monkeypatch):
"""Verify both Ollama and Groq keys can coexist and be forwarded correctly.
Expand Down
88 changes: 88 additions & 0 deletions tests/unittest/test_litellm_databricks_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""
Tests for Databricks provider wiring in LiteLLMAIHandler.__init__.

Verifies that DATABRICKS.API_KEY / DATABRICKS.API_BASE settings are exported to
the env vars LiteLLM's Databricks provider reads (DATABRICKS_API_KEY /
DATABRICKS_API_BASE), and that nothing is exported when they are unset.
"""
import os

import pytest

import pr_agent.algo.ai_handlers.litellm_ai_handler as litellm_handler
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed

# Env vars LiteLLMAIHandler.__init__ branches on — clear them so the handler
# under test isn't influenced by (or leaking into) the runner environment.
_ISOLATED_ENV = (
"DATABRICKS_API_KEY",
"DATABRICKS_API_BASE",
"OPENAI_API_KEY",
"AWS_USE_IMDS",
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_SESSION_TOKEN",
"AWS_REGION_NAME",
)


def _make_settings(overrides):
"""Minimal settings whose top-level .get() returns the provided overrides."""
return type("Settings", (), {
"config": type("Config", (), {
"reasoning_effort": None,
"ai_timeout": 30,
"custom_reasoning_model": False,
"max_model_tokens": 32000,
"verbosity_level": 0,
"seed": -1,
"get": lambda self, key, default=None: default,
})(),
"litellm": type("LiteLLM", (), {
"get": lambda self, key, default=None: default,
})(),
"get": lambda self, key, default=None: overrides.get(key, default),
})()


@pytest.fixture(autouse=True)
def _isolate_env(monkeypatch):
for var in _ISOLATED_ENV:
monkeypatch.delenv(var, raising=False)
yield
# Drop anything the handler wrote so it can't leak into other tests;
# monkeypatch then restores any pre-existing originals.
for var in ("DATABRICKS_API_KEY", "DATABRICKS_API_BASE"):
os.environ.pop(var, None)


def test_databricks_env_vars_exported_from_settings(monkeypatch):
overrides = {
"DATABRICKS.API_KEY": "dapi-test-123",
"DATABRICKS.API_BASE": "https://adb-1234.azuredatabricks.net/serving-endpoints",
}
monkeypatch.setattr(litellm_handler, "get_settings", lambda: _make_settings(overrides))

litellm_handler.LiteLLMAIHandler()

assert os.environ["DATABRICKS_API_KEY"] == "dapi-test-123"
assert os.environ["DATABRICKS_API_BASE"] == "https://adb-1234.azuredatabricks.net/serving-endpoints"


def test_databricks_env_vars_absent_when_unset(monkeypatch):
monkeypatch.setattr(litellm_handler, "get_settings", lambda: _make_settings({}))

litellm_handler.LiteLLMAIHandler()

assert "DATABRICKS_API_KEY" not in os.environ
assert "DATABRICKS_API_BASE" not in os.environ


def test_databricks_api_base_optional(monkeypatch):
"""API base is optional (a workspace default may be configured elsewhere)."""
overrides = {"DATABRICKS.API_KEY": "dapi-only-key"}
monkeypatch.setattr(litellm_handler, "get_settings", lambda: _make_settings(overrides))

litellm_handler.LiteLLMAIHandler()

assert os.environ["DATABRICKS_API_KEY"] == "dapi-only-key"
assert "DATABRICKS_API_BASE" not in os.environ
Loading