Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/docs/usage-guide/changing_a_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,21 @@ key = "..." # your Codestral api key

(you can obtain a Codestral key from [here](https://console.mistral.ai/codestral))

### Databricks

To use a model hosted on Databricks (e.g. an Azure Databricks serving endpoint), set:

```toml
[config] # in configuration.toml
model = "databricks/databricks-claude-sonnet-4"
fallback_models = ["databricks/databricks-claude-sonnet-4"]
[databricks] # in .secrets.toml
api_key = "..." # your Databricks personal access token (PAT)
api_base = "https://adb-xxxx.azuredatabricks.net/serving-endpoints" # your workspace serving-endpoints URL
```

The model name after the `databricks/` prefix is the name of your serving endpoint. See LiteLLM's [Databricks provider docs](https://docs.litellm.ai/docs/providers/databricks) for details.

### Openrouter

To use model from Openrouter, for example, set:
Expand Down
26 changes: 23 additions & 3 deletions pr_agent/algo/ai_handlers/litellm_ai_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,14 @@ def __init__(self):
if get_settings().get("CODESTRAL.KEY", None):
os.environ["CODESTRAL_API_KEY"] = get_settings().get("CODESTRAL.KEY")

# Support Databricks-hosted models (e.g. Azure Databricks serving endpoints).
# Uses PAT/key authentication via LiteLLM's env vars.
# SEE https://docs.litellm.ai/docs/providers/databricks
if get_settings().get("DATABRICKS.API_KEY", None):
os.environ["DATABRICKS_API_KEY"] = get_settings().get("DATABRICKS.API_KEY")
if get_settings().get("DATABRICKS.API_BASE", None):
os.environ["DATABRICKS_API_BASE"] = get_settings().get("DATABRICKS.API_BASE")
Comment thread
qodo-free-for-open-source-projects[bot] marked this conversation as resolved.

# Check for Azure AD configuration
if get_settings().get("AZURE_AD.CLIENT_ID", None):
self.azure = True
Expand Down Expand Up @@ -408,7 +416,12 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
try:
resp, finish_reason = None, None
deployment_id = self.deployment_id
if self.azure:
# Capture the provider prefix before any rewriting below. Databricks auth/endpoint
# selection keys off this; rewriting (e.g. 'azure/' + model when Azure is enabled in
# a multi-provider config) would otherwise hide the 'databricks/' prefix and bypass
# the guards that keep Databricks on its own DATABRICKS_API_KEY/DATABRICKS_API_BASE.
is_databricks = model.startswith("databricks/")
if self.azure and not is_databricks:
model = 'azure/' + model
if 'claude' in model and not system:
system = "No system prompt provided"
Expand Down Expand Up @@ -461,12 +474,16 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
messages = [{"role": "user", "content": user}]

# Build request kwargs after normalizing messages for the target model.
# Databricks selects its endpoint via the DATABRICKS_API_BASE env var; don't let an
# api_base configured by another provider (OpenRouter/Ollama/Azure AD/OpenAI) during
# __init__ override it in multi-provider configs. None lets LiteLLM read the env var.
api_base = os.environ.get("DATABRICKS_API_BASE") if is_databricks else self.api_base
kwargs = {
"model": model,
"deployment_id": deployment_id,
"messages": messages,
"timeout": get_settings().config.ai_timeout,
"api_base": self.api_base,
"api_base": api_base,
}

# Add temperature only if model supports it
Expand Down Expand Up @@ -540,7 +557,10 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:

# Inject api_key to the call. This key is populated during init by providers
# like Groq, SambaNova, XAI, Azure AD, and OpenRouter. Skip if None or placeholder.
if litellm.api_key and litellm.api_key != DUMMY_LITELLM_API_KEY:
# Databricks authenticates via the DATABRICKS_API_KEY/DATABRICKS_API_BASE env vars,
# so don't override it with another provider's key in multi-provider configs.
if (litellm.api_key and litellm.api_key != DUMMY_LITELLM_API_KEY
and not is_databricks):
kwargs["api_key"] = litellm.api_key
Comment thread
qodo-free-for-open-source-projects[bot] marked this conversation as resolved.

# Get completion with automatic streaming detection
Expand Down
4 changes: 4 additions & 0 deletions pr_agent/settings/.secrets_template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ key = ""
[deepinfra]
key = ""

[databricks]
api_key = "" # Databricks personal access token (PAT)
api_base = "" # e.g. https://adb-xxxx.azuredatabricks.net/serving-endpoints

[azure_ad]
# Azure AD authentication for OpenAI services
client_id = "" # Your Azure AD application client ID
Expand Down
94 changes: 94 additions & 0 deletions tests/unittest/test_litellm_api_key_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,100 @@ async def test_sambanova_key_forwarded_for_non_ollama_model(self, monkeypatch):

assert mock_call.call_args[1].get("api_key") == sambanova_key

@pytest.mark.asyncio
async def test_databricks_model_does_not_forward_foreign_key(self, monkeypatch):
"""Databricks models authenticate via DATABRICKS_API_KEY/DATABRICKS_API_BASE env vars.

In a multi-provider config another provider (e.g. Groq/OpenRouter) may have stored
its key in litellm.api_key during __init__. That key must NOT be forwarded for
databricks/* calls, otherwise it would override the intended env-var auth and break
Databricks authentication.
"""
foreign_key = "test-groq-key-shadowing-databricks"

with patch("pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion",
new_callable=AsyncMock) as mock_call:
mock_call.return_value = _mock_response()
handler = LiteLLMAIHandler()
# Simulate another provider having populated litellm.api_key during init
monkeypatch.setattr(litellm, "api_key", foreign_key)
await handler.chat_completion(
model="databricks/databricks-claude-sonnet-4", system="sys", user="usr"
)

assert "api_key" not in mock_call.call_args[1], (
f"Foreign provider key must not be forwarded for databricks/* models. "
f"kwargs had: {mock_call.call_args[1]}"
)

@pytest.mark.asyncio
async def test_databricks_model_does_not_forward_foreign_api_base(self, monkeypatch):
"""Databricks models select their endpoint via the DATABRICKS_API_BASE env var.

In a multi-provider config another provider (OpenRouter/Ollama/Azure AD/OpenAI) may
have set self.api_base during __init__. That base URL must NOT be forwarded for
databricks/* calls, otherwise it would route the request to the wrong host and
override the intended DATABRICKS_API_BASE endpoint. The Databricks base (or None,
which lets LiteLLM read the env var) must be used instead.
"""
databricks_base = "https://adb-1234.azuredatabricks.net/serving-endpoints"
monkeypatch.setenv("DATABRICKS_API_BASE", databricks_base)

with patch("pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion",
new_callable=AsyncMock) as mock_call:
mock_call.return_value = _mock_response()
handler = LiteLLMAIHandler()
# Simulate another provider having set api_base during init
handler.api_base = "https://openrouter.ai/api/v1"
await handler.chat_completion(
model="databricks/databricks-claude-sonnet-4", system="sys", user="usr"
)

assert mock_call.call_args[1]["api_base"] == databricks_base, (
f"Databricks endpoint must come from DATABRICKS_API_BASE, not a foreign provider's "
f"api_base. kwargs had: {mock_call.call_args[1]}"
)

@pytest.mark.asyncio
async def test_databricks_guards_survive_azure_mode(self, monkeypatch):
"""Azure mode must not rewrite databricks/* models and bypass the Databricks guards.

When Azure is enabled in a multi-provider config (OPENAI.API_TYPE=azure or AZURE_AD),
chat_completion() prepends 'azure/' to the model. If that rewrite happened for a
databricks/* model the prefix-based guards would never trigger, routing the call to
Azure with a foreign key/base. The model must keep its 'databricks/' prefix, the foreign
key must not be forwarded, and api_base must come from DATABRICKS_API_BASE.
"""
foreign_key = "test-azure-key-shadowing-databricks"
databricks_base = "https://adb-1234.azuredatabricks.net/serving-endpoints"
monkeypatch.setenv("DATABRICKS_API_BASE", databricks_base)

with patch("pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion",
new_callable=AsyncMock) as mock_call:
mock_call.return_value = _mock_response()
handler = LiteLLMAIHandler()
# Simulate Azure mode + a foreign key/base set by another provider during init
handler.azure = True
handler.api_base = "https://my-azure.openai.azure.com"
monkeypatch.setattr(litellm, "api_key", foreign_key)
await handler.chat_completion(
model="databricks/databricks-claude-sonnet-4", system="sys", user="usr"
)

forwarded = mock_call.call_args[1]
assert forwarded["model"] == "databricks/databricks-claude-sonnet-4", (
f"databricks/* model must not be rewritten with an 'azure/' prefix. "
f"kwargs had: {forwarded}"
)
assert "api_key" not in forwarded, (
f"Foreign provider key must not be forwarded for databricks/* models even in Azure "
f"mode. kwargs had: {forwarded}"
)
assert forwarded["api_base"] == databricks_base, (
f"Databricks endpoint must come from DATABRICKS_API_BASE even in Azure mode. "
f"kwargs had: {forwarded}"
)

@pytest.mark.asyncio
async def test_ollama_and_groq_coexist(self, monkeypatch):
"""Verify both Ollama and Groq keys can coexist and be forwarded correctly.
Expand Down
94 changes: 94 additions & 0 deletions tests/unittest/test_litellm_databricks_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
Tests for Databricks provider wiring in LiteLLMAIHandler.__init__.

Verifies that DATABRICKS.API_KEY / DATABRICKS.API_BASE settings are exported to
the env vars LiteLLM's Databricks provider reads (DATABRICKS_API_KEY /
DATABRICKS_API_BASE), and that nothing is exported when they are unset.
"""
import os

import litellm
import pytest

import pr_agent.algo.ai_handlers.litellm_ai_handler as litellm_handler
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed

# Env vars LiteLLMAIHandler.__init__ branches on — clear them so the handler
# under test isn't influenced by (or leaking into) the runner environment.
_ISOLATED_ENV = (
"DATABRICKS_API_KEY",
"DATABRICKS_API_BASE",
"OPENAI_API_KEY",
"AWS_USE_IMDS",
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_SESSION_TOKEN",
"AWS_REGION_NAME",
)


def _make_settings(overrides):
"""Minimal settings whose top-level .get() returns the provided overrides."""
return type("Settings", (), {
"config": type("Config", (), {
"reasoning_effort": None,
"ai_timeout": 30,
"custom_reasoning_model": False,
"max_model_tokens": 32000,
"verbosity_level": 0,
"seed": -1,
"get": lambda self, key, default=None: default,
})(),
"litellm": type("LiteLLM", (), {
"get": lambda self, key, default=None: default,
})(),
"get": lambda self, key, default=None: overrides.get(key, default),
})()


@pytest.fixture(autouse=True)
def _isolate_env(monkeypatch):
for var in _ISOLATED_ENV:
monkeypatch.delenv(var, raising=False)
# LiteLLMAIHandler.__init__ mutates the global litellm.api_key (sets the dummy
# fallback when OPENAI_API_KEY is absent, as it is here). Snapshot it so this
# file can't leak that global into order-dependent later tests.
saved_api_key = litellm.api_key
yield
litellm.api_key = saved_api_key
# Drop anything the handler wrote so it can't leak into other tests;
# monkeypatch then restores any pre-existing originals.
for var in ("DATABRICKS_API_KEY", "DATABRICKS_API_BASE"):
os.environ.pop(var, None)


def test_databricks_env_vars_exported_from_settings(monkeypatch):
overrides = {
"DATABRICKS.API_KEY": "dapi-test-123",
"DATABRICKS.API_BASE": "https://adb-1234.azuredatabricks.net/serving-endpoints",
}
monkeypatch.setattr(litellm_handler, "get_settings", lambda: _make_settings(overrides))

litellm_handler.LiteLLMAIHandler()

assert os.environ["DATABRICKS_API_KEY"] == "dapi-test-123"
assert os.environ["DATABRICKS_API_BASE"] == "https://adb-1234.azuredatabricks.net/serving-endpoints"


def test_databricks_env_vars_absent_when_unset(monkeypatch):
monkeypatch.setattr(litellm_handler, "get_settings", lambda: _make_settings({}))

litellm_handler.LiteLLMAIHandler()

assert "DATABRICKS_API_KEY" not in os.environ
assert "DATABRICKS_API_BASE" not in os.environ


def test_databricks_api_base_optional(monkeypatch):
"""API base is optional (a workspace default may be configured elsewhere)."""
overrides = {"DATABRICKS.API_KEY": "dapi-only-key"}
monkeypatch.setattr(litellm_handler, "get_settings", lambda: _make_settings(overrides))

litellm_handler.LiteLLMAIHandler()

assert os.environ["DATABRICKS_API_KEY"] == "dapi-only-key"
assert "DATABRICKS_API_BASE" not in os.environ
Loading