From f82ec9a48b2e83cde8165db18f0b3eb23cb3f962 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 24 Jun 2026 15:11:28 +0000 Subject: [PATCH 1/6] Add discovery-based Firebolt connections Co-authored-by: Ivan Koptiev --- .github/workflows/integration-tests-core.yml | 63 +++-- src/firebolt/async_db/connection.py | 103 +++++++- src/firebolt/common/discovery.py | 245 ++++++++++++++++++ src/firebolt/db/connection.py | 103 +++++++- .../dbapi/sync/V2/test_discovery.py | 8 + tests/unit/async_db/test_connection.py | 77 ++++++ tests/unit/db/test_connection.py | 69 +++++ 7 files changed, 644 insertions(+), 24 deletions(-) create mode 100644 src/firebolt/common/discovery.py create mode 100644 tests/integration/dbapi/sync/V2/test_discovery.py diff --git a/.github/workflows/integration-tests-core.yml b/.github/workflows/integration-tests-core.yml index 890270780bd..3c745a0a4b3 100644 --- a/.github/workflows/integration-tests-core.yml +++ b/.github/workflows/integration-tests-core.yml @@ -53,11 +53,44 @@ jobs: with: repository: 'firebolt-db/firebolt-python-sdk' - - name: Setup Firebolt Core - id: setup-core - uses: firebolt-db/action-setup-core@eabcd701de0be41793fda0655d29d46c70c847c2 # main - with: - tag_version: ${{ inputs.tag_version || vars.DEFAULT_CORE_IMAGE_TAG }} + - name: Install Firebolt + env: + ENGINE_TAG: ${{ inputs.tag_version || vars.DEFAULT_CORE_IMAGE_TAG || 'dev' }} + run: | + bash <(curl -fsSL https://get.firebolt.io/) < /dev/null + + - name: Start Firebolt + env: + ENGINE_REPO: ghcr.io/firebolt-db/engine + ENGINE_TAG: ${{ inputs.tag_version || vars.DEFAULT_CORE_IMAGE_TAG || 'dev' }} + run: | + mkdir -p -m 777 firebolt-data + docker run \ + --detach \ + --user firebolt \ + --name firebolt \ + --rm \ + --ulimit memlock=8589934592:8589934592 \ + --security-opt seccomp=unconfined \ + -v "${PWD}/firebolt-data:/var/lib/firebolt" \ + -p 3473:3473 \ + "${ENGINE_REPO}:${ENGINE_TAG}" + + timeout=60 + until [ "$timeout" -eq 0 ]; do + response="$( + curl -s 'http://localhost:3473/?output_format=TabSeparatedWithNamesAndTypes' \ + --data-binary 'SELECT 42;' || true + )" + if [ "$response" = $'?column?\nint\n42' ]; then + exit 0 + fi + sleep 1 + timeout=$((timeout - 1)) + done + + docker logs firebolt + exit 1 - name: Set up Python uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 @@ -69,7 +102,7 @@ jobs: python -m pip install --upgrade pip pip install ".[dev]" - - name: Run integration tests HTTP + - name: Run integration tests env: SERVICE_ID: ${{ secrets.FIREBOLT_CLIENT_ID_STG_NEW_IDN }} SERVICE_SECRET: ${{ secrets.FIREBOLT_CLIENT_SECRET_STG_NEW_IDN }} @@ -78,23 +111,10 @@ jobs: STOPPED_ENGINE_NAME: "" API_ENDPOINT: "" ACCOUNT_NAME: "" - CORE_URL: ${{ steps.setup-core.outputs.service_url }} + CORE_URL: http://localhost:3473 run: | pytest -o log_cli=true -o log_cli_level=WARNING tests/integration -k "core" --alluredir=allure-results/ - - name: Run integration tests HTTPS - env: - SERVICE_ID: ${{ secrets.FIREBOLT_CLIENT_ID_STG_NEW_IDN }} - SERVICE_SECRET: ${{ secrets.FIREBOLT_CLIENT_SECRET_STG_NEW_IDN }} - DATABASE_NAME: "firebolt" - ENGINE_NAME: "" - STOPPED_ENGINE_NAME: "" - API_ENDPOINT: "" - ACCOUNT_NAME: "" - CORE_URL: ${{ steps.setup-core.outputs.service_https_url }} - run: | - pytest -o log_cli=true -o log_cli_level=WARNING tests/integration -k "core" --alluredir=allure-results-https/ - - name: Allure Reports uses: firebolt-db/action-allure-report@8cdc116f65f6eca845a992e347e72b75ca8ccf5f # v2.1.1 if: always() @@ -103,6 +123,5 @@ jobs: pages-branch: gh-pages mapping-json: | { - "allure-results": "core", - "allure-results-https": "core_https" + "allure-results": "core" } diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 9d7c8c36c9f..ce4201e3870 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -11,7 +11,7 @@ from firebolt.async_db.cursor import Cursor, CursorV1, CursorV2 from firebolt.client import DEFAULT_API_URL -from firebolt.client.auth import Auth +from firebolt.client.auth import Auth, FireboltCore from firebolt.client.auth.base import FireboltAuthVersion from firebolt.client.client import AsyncClient, AsyncClientV1, AsyncClientV2 from firebolt.common.base_connection import ( @@ -27,6 +27,7 @@ set_cached_system_engine_info, ) from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS +from firebolt.common.discovery import async_discover from firebolt.utils.cache import EngineInfo from firebolt.utils.exception import ( AccountNotFoundOrNoAccessError, @@ -285,14 +286,41 @@ async def connect( auth: Optional[Auth] = None, account_name: Optional[str] = None, database: Optional[str] = None, + engine: Optional[str] = None, engine_name: Optional[str] = None, engine_url: Optional[str] = None, api_endpoint: str = DEFAULT_API_URL, disable_cache: bool = False, url: Optional[str] = None, + host: Optional[str] = None, + ssl_mode: str = "strict", + settings: Optional[Dict[str, Any]] = None, autocommit: bool = True, additional_parameters: Dict[str, Any] = {}, ) -> Connection: + if host: + return await connect_discovery( + host=host, + database=database, + engine=engine, + engine_name=engine_name, + engine_url=engine_url, + account_name=account_name, + api_endpoint=api_endpoint, + url=url, + auth=auth, + ssl_mode=ssl_mode, + settings=settings, + autocommit=autocommit, + additional_parameters=additional_parameters, + ) + + if engine and engine_name and engine != engine_name: + raise ConfigurationError( + "Both engine and engine_name are provided. Provide only one to connect." + ) + engine_name = engine_name or engine + # auth parameter is optional in function signature # but is required to connect. # PEP 249 recommends making it kwargs. @@ -348,6 +376,79 @@ async def connect( raise ConfigurationError(f"Unsupported auth type: {type(auth)}") +async def connect_discovery( + host: str, + database: Optional[str] = None, + engine: Optional[str] = None, + engine_name: Optional[str] = None, + engine_url: Optional[str] = None, + account_name: Optional[str] = None, + api_endpoint: str = DEFAULT_API_URL, + url: Optional[str] = None, + auth: Optional[Auth] = None, + ssl_mode: str = "strict", + settings: Optional[Dict[str, Any]] = None, + autocommit: bool = True, + additional_parameters: Dict[str, Any] = {}, +) -> Connection: + """Connect using the discovery-based Firebolt session model.""" + if account_name: + raise ConfigurationError( + "account_name is not compatible with discovery-based connections." + ) + if api_endpoint != DEFAULT_API_URL: + raise ConfigurationError( + "api_endpoint is not compatible with discovery-based connections." + ) + if engine_url: + raise ConfigurationError( + "engine_url is not compatible with discovery-based connections." + ) + if url: + raise ConfigurationError( + "url is not compatible with discovery-based connections. Use host instead." + ) + if auth and auth.get_firebolt_version() != FireboltAuthVersion.CORE: + raise ConfigurationError( + "auth is not compatible with discovery-based connections." + ) + + connection_id = uuid4().hex + discovery_info = await async_discover( + host=host, + ssl_mode=ssl_mode, + database=database, + engine=engine, + engine_name=engine_name, + settings=settings, + ) + core_auth = auth or FireboltCore() + user_agent_header = get_user_agent_for_connection( + core_auth, connection_id, None, additional_parameters, True + ) + + client = AsyncClientV2( + auth=core_auth, + account_name="", + base_url=discovery_info.engine_url, + api_endpoint=discovery_info.api_endpoint, + timeout=Timeout(DEFAULT_TIMEOUT_SECONDS, read=None), + headers={"User-Agent": user_agent_header}, + verify=discovery_info.verify, + ) + + return Connection( + engine_url=discovery_info.engine_url, + database=None, + client=client, + cursor_type=CursorV2, + api_endpoint=discovery_info.api_endpoint, + init_parameters=discovery_info.parameters, + id=connection_id, + autocommit=autocommit, + ) + + async def connect_v2( auth: Auth, user_agent_header: str, diff --git a/src/firebolt/common/discovery.py b/src/firebolt/common/discovery.py new file mode 100644 index 00000000000..a54df37c6b8 --- /dev/null +++ b/src/firebolt/common/discovery.py @@ -0,0 +1,245 @@ +from __future__ import annotations + +from dataclasses import dataclass +from json import JSONDecodeError +from ssl import SSLContext +from typing import Any, Dict, Mapping, Optional, Union +from urllib.parse import urljoin, urlparse + +from httpx import AsyncClient as HttpxAsyncClient +from httpx import Client as HttpxClient +from httpx import Timeout, codes + +from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS +from firebolt.utils.exception import ConfigurationError, InterfaceError +from firebolt.utils.firebolt_core import get_core_certificate_context +from firebolt.utils.util import parse_url_and_params + +DISCOVERY_PATH = "/.well-known/firebolt" +SSL_MODE_STRICT = "strict" +SSL_MODE_NONE = "none" +SSL_MODES = {SSL_MODE_STRICT, SSL_MODE_NONE} + + +@dataclass(frozen=True) +class DiscoveryConnectionInfo: + engine_url: str + api_endpoint: str + parameters: Dict[str, Any] + verify: Union[SSLContext, bool] + + +def normalize_ssl_mode(ssl_mode: str) -> str: + mode = ssl_mode.lower() + if mode not in SSL_MODES: + allowed = ", ".join(sorted(SSL_MODES)) + raise ConfigurationError( + f"Invalid ssl_mode: {ssl_mode}. Expected one of: {allowed}." + ) + return mode + + +def normalize_host(host: str, ssl_mode: str) -> str: + """Normalize a discovery host into an HTTP(S) base URL.""" + if not host: + raise ConfigurationError("host is required for discovery-based connections.") + + mode = normalize_ssl_mode(ssl_mode) + default_scheme = "http" if mode == SSL_MODE_NONE else "https" + raw_url = host if "://" in host else f"{default_scheme}://{host}" + parsed = urlparse(raw_url) + + if parsed.scheme not in {"http", "https"}: + raise ConfigurationError( + f"Invalid host scheme: {parsed.scheme}. Expected 'http' or 'https'." + ) + if not parsed.netloc: + raise ConfigurationError( + f"Invalid host: {host}. Expected a hostname, optionally with scheme and port." + ) + if parsed.query or parsed.fragment: + raise ConfigurationError( + "host must not include query parameters or a fragment. " + "Pass connection parameters as connect() arguments instead." + ) + + return f"{parsed.scheme}://{parsed.netloc}{parsed.path.rstrip('/')}" + + +def get_tls_verify(base_url: str, ssl_mode: str) -> Union[SSLContext, bool]: + mode = normalize_ssl_mode(ssl_mode) + if mode == SSL_MODE_NONE: + return False + if urlparse(base_url).scheme == "https": + return get_core_certificate_context() + return True + + +def build_discovery_url(base_url: str) -> str: + return urljoin(base_url.rstrip("/") + "/", DISCOVERY_PATH.lstrip("/")) + + +def _string_value(data: Mapping[str, Any], *keys: str) -> Optional[str]: + for key in keys: + value = data.get(key) + if isinstance(value, str) and value: + return value + return None + + +def _endpoint_from_mapping(data: Mapping[str, Any]) -> Optional[str]: + return _string_value( + data, + "engineUrl", + "engine_url", + "engineEndpoint", + "engine_endpoint", + "queryUrl", + "query_url", + "url", + "endpoint", + ) + + +def _extract_engine_url(discovery: Mapping[str, Any], base_url: str) -> str: + endpoint = _endpoint_from_mapping(discovery) + if endpoint: + return urljoin(base_url.rstrip("/") + "/", endpoint) + + endpoints = discovery.get("endpoints") + if isinstance(endpoints, Mapping): + for key in ("query", "sql", "engine", "http"): + value = endpoints.get(key) + if isinstance(value, str) and value: + return urljoin(base_url.rstrip("/") + "/", value) + if isinstance(value, Mapping): + endpoint = _endpoint_from_mapping(value) + if endpoint: + return urljoin(base_url.rstrip("/") + "/", endpoint) + + query = discovery.get("query") + if isinstance(query, Mapping): + endpoint = _endpoint_from_mapping(query) + if endpoint: + return urljoin(base_url.rstrip("/") + "/", endpoint) + + return base_url + + +def make_discovery_connection_info( + host: str, + ssl_mode: str, + discovery: Mapping[str, Any], + database: Optional[str] = None, + engine: Optional[str] = None, + engine_name: Optional[str] = None, + settings: Optional[Dict[str, Any]] = None, +) -> DiscoveryConnectionInfo: + base_url = normalize_host(host, ssl_mode) + verify = get_tls_verify(base_url, ssl_mode) + + if engine and engine_name and engine != engine_name: + raise ConfigurationError( + "Both engine and engine_name are provided. Provide only one to connect." + ) + engine_parameter = engine or engine_name + + endpoint = _extract_engine_url(discovery, base_url) + endpoint_url, endpoint_params = parse_url_and_params(endpoint) + + parameters: Dict[str, Any] = dict(endpoint_params) + if settings: + parameters.update(settings) + if database: + parameters["database"] = database + if engine_parameter: + parameters["engine"] = engine_parameter + + return DiscoveryConnectionInfo( + engine_url=endpoint_url, + api_endpoint=base_url, + parameters=parameters, + verify=verify, + ) + + +def _decode_discovery_response(response_text: str) -> Mapping[str, Any]: + try: + import json + + decoded = json.loads(response_text) + except JSONDecodeError as e: + raise InterfaceError("Unable to decode Firebolt discovery response.") from e + if not isinstance(decoded, Mapping): + raise InterfaceError("Firebolt discovery response must be a JSON object.") + return decoded + + +def discover( + host: str, + ssl_mode: str, + database: Optional[str] = None, + engine: Optional[str] = None, + engine_name: Optional[str] = None, + settings: Optional[Dict[str, Any]] = None, +) -> DiscoveryConnectionInfo: + base_url = normalize_host(host, ssl_mode) + verify = get_tls_verify(base_url, ssl_mode) + discovery_url = build_discovery_url(base_url) + + with HttpxClient( + verify=verify, + timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), + ) as client: + response = client.get(discovery_url) + + if response.status_code != codes.OK: + raise InterfaceError( + f"Unable to retrieve Firebolt discovery document {discovery_url}: " + f"{response.status_code} {response.text}" + ) + + return make_discovery_connection_info( + host=host, + ssl_mode=ssl_mode, + discovery=_decode_discovery_response(response.text), + database=database, + engine=engine, + engine_name=engine_name, + settings=settings, + ) + + +async def async_discover( + host: str, + ssl_mode: str, + database: Optional[str] = None, + engine: Optional[str] = None, + engine_name: Optional[str] = None, + settings: Optional[Dict[str, Any]] = None, +) -> DiscoveryConnectionInfo: + base_url = normalize_host(host, ssl_mode) + verify = get_tls_verify(base_url, ssl_mode) + discovery_url = build_discovery_url(base_url) + + async with HttpxAsyncClient( + verify=verify, + timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), + ) as client: + response = await client.get(discovery_url) + + if response.status_code != codes.OK: + raise InterfaceError( + f"Unable to retrieve Firebolt discovery document {discovery_url}: " + f"{response.status_code} {response.text}" + ) + + return make_discovery_connection_info( + host=host, + ssl_mode=ssl_mode, + discovery=_decode_discovery_response(response.text), + database=database, + engine=engine, + engine_name=engine_name, + settings=settings, + ) diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index a0cc7a3c5fb..909390d5a64 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -11,7 +11,7 @@ from httpx import Request, Response, Timeout, codes from firebolt.client import DEFAULT_API_URL, Client, ClientV1, ClientV2 -from firebolt.client.auth import Auth +from firebolt.client.auth import Auth, FireboltCore from firebolt.client.auth.base import FireboltAuthVersion from firebolt.common.base_connection import ( ASYNC_QUERY_CANCEL, @@ -26,6 +26,7 @@ set_cached_system_engine_info, ) from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS +from firebolt.common.discovery import discover from firebolt.db.cursor import Cursor, CursorV1, CursorV2 from firebolt.utils.cache import EngineInfo from firebolt.utils.exception import ( @@ -54,14 +55,41 @@ def connect( auth: Optional[Auth] = None, account_name: Optional[str] = None, database: Optional[str] = None, + engine: Optional[str] = None, engine_name: Optional[str] = None, engine_url: Optional[str] = None, api_endpoint: str = DEFAULT_API_URL, disable_cache: bool = False, url: Optional[str] = None, + host: Optional[str] = None, + ssl_mode: str = "strict", + settings: Optional[Dict[str, Any]] = None, autocommit: bool = True, additional_parameters: Dict[str, Any] = {}, ) -> Connection: + if host: + return connect_discovery( + host=host, + database=database, + engine=engine, + engine_name=engine_name, + engine_url=engine_url, + account_name=account_name, + api_endpoint=api_endpoint, + url=url, + auth=auth, + ssl_mode=ssl_mode, + settings=settings, + autocommit=autocommit, + additional_parameters=additional_parameters, + ) + + if engine and engine_name and engine != engine_name: + raise ConfigurationError( + "Both engine and engine_name are provided. Provide only one to connect." + ) + engine_name = engine_name or engine + # auth parameter is optional in function signature # but is required to connect. # PEP 249 recommends making it kwargs. @@ -118,6 +146,79 @@ def connect( raise ConfigurationError(f"Unsupported auth type: {type(auth)}") +def connect_discovery( + host: str, + database: Optional[str] = None, + engine: Optional[str] = None, + engine_name: Optional[str] = None, + engine_url: Optional[str] = None, + account_name: Optional[str] = None, + api_endpoint: str = DEFAULT_API_URL, + url: Optional[str] = None, + auth: Optional[Auth] = None, + ssl_mode: str = "strict", + settings: Optional[Dict[str, Any]] = None, + autocommit: bool = True, + additional_parameters: Dict[str, Any] = {}, +) -> Connection: + """Connect using the discovery-based Firebolt session model.""" + if account_name: + raise ConfigurationError( + "account_name is not compatible with discovery-based connections." + ) + if api_endpoint != DEFAULT_API_URL: + raise ConfigurationError( + "api_endpoint is not compatible with discovery-based connections." + ) + if engine_url: + raise ConfigurationError( + "engine_url is not compatible with discovery-based connections." + ) + if url: + raise ConfigurationError( + "url is not compatible with discovery-based connections. Use host instead." + ) + if auth and auth.get_firebolt_version() != FireboltAuthVersion.CORE: + raise ConfigurationError( + "auth is not compatible with discovery-based connections." + ) + + connection_id = uuid4().hex + discovery_info = discover( + host=host, + ssl_mode=ssl_mode, + database=database, + engine=engine, + engine_name=engine_name, + settings=settings, + ) + core_auth = auth or FireboltCore() + user_agent_header = get_user_agent_for_connection( + core_auth, connection_id, None, additional_parameters, True + ) + + client = ClientV2( + auth=core_auth, + account_name="", + base_url=discovery_info.engine_url, + api_endpoint=discovery_info.api_endpoint, + timeout=Timeout(DEFAULT_TIMEOUT_SECONDS, read=None), + headers={"User-Agent": user_agent_header}, + verify=discovery_info.verify, + ) + + return Connection( + engine_url=discovery_info.engine_url, + database=None, + client=client, + cursor_type=CursorV2, + api_endpoint=discovery_info.api_endpoint, + init_parameters=discovery_info.parameters, + id=connection_id, + autocommit=autocommit, + ) + + def connect_v2( auth: Auth, user_agent_header: str, diff --git a/tests/integration/dbapi/sync/V2/test_discovery.py b/tests/integration/dbapi/sync/V2/test_discovery.py new file mode 100644 index 00000000000..ded2db56694 --- /dev/null +++ b/tests/integration/dbapi/sync/V2/test_discovery.py @@ -0,0 +1,8 @@ +from firebolt.db import connect + + +def test_core_discovery_connection(core_url: str): + with connect(host=core_url, ssl_mode="none", database="firebolt") as connection: + cursor = connection.cursor() + cursor.execute("SELECT 42") + assert cursor.fetchone()[0] == 42 diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index f7a01db6883..3522b591495 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -2,6 +2,7 @@ from unittest.mock import ANY as AnyValue from unittest.mock import MagicMock, patch +from httpx import Request, codes from pyfakefs.fake_filesystem_unittest import Patcher from pytest import mark, raises from pytest_httpx import HTTPXMock @@ -15,8 +16,10 @@ ConfigurationError, ConnectionClosedError, FireboltError, + InterfaceError, ) from firebolt.utils.token_storage import TokenSecureStorage +from tests.unit.response import Response @mark.skip("__slots__ is broken on Connection class") @@ -108,6 +111,80 @@ async def test_connect( assert await connection.cursor().execute("select *") == len(python_query_data) +async def test_connect_discovery( + db_name: str, + engine_name: str, + httpx_mock: HTTPXMock, + query_callback: Callable, + python_query_data: List[List[ColType]], +): + """Discovery connections pass database, engine and settings as query params.""" + httpx_mock.add_response( + method="GET", + url="http://localhost:3473/.well-known/firebolt", + json={"engineUrl": "http://localhost:3473/?discovered_param=value"}, + ) + + def query_with_discovery_params(request: Request, **kwargs) -> Response: + params = dict(request.url.params) + assert "authorization" not in request.headers + assert params["database"] == db_name + assert params["engine"] == engine_name + assert params["custom_setting"] == "custom_value" + assert params["discovered_param"] == "value" + return query_callback(request, **kwargs) + + httpx_mock.add_callback(query_with_discovery_params, method="POST") + + async with await connect( + host="localhost:3473", + ssl_mode="none", + database=db_name, + engine=engine_name, + settings={"custom_setting": "custom_value"}, + ) as connection: + assert await connection.cursor().execute("select *") == len(python_query_data) + + +async def test_connect_discovery_rejects_legacy_parameters(auth: Auth): + with raises(ConfigurationError, match="account_name"): + await connect(host="localhost:3473", ssl_mode="none", account_name="account") + with raises(ConfigurationError, match="api_endpoint"): + await connect( + host="localhost:3473", + ssl_mode="none", + api_endpoint="api.example.com", + ) + with raises(ConfigurationError, match="engine_url"): + await connect( + host="localhost:3473", + ssl_mode="none", + engine_url="engine.example.com", + ) + with raises(ConfigurationError, match="url"): + await connect( + host="localhost:3473", + ssl_mode="none", + url="http://localhost:3473", + ) + with raises(ConfigurationError, match="auth"): + await connect(host="localhost:3473", ssl_mode="none", auth=auth) + + +async def test_connect_discovery_validation(httpx_mock: HTTPXMock): + with raises(ConfigurationError, match="ssl_mode"): + await connect(host="localhost:3473", ssl_mode="invalid") + + httpx_mock.add_response( + method="GET", + url="http://localhost:3473/.well-known/firebolt", + status_code=codes.NOT_FOUND, + text="not found", + ) + with raises(InterfaceError, match="Unable to retrieve Firebolt discovery"): + await connect(host="localhost:3473", ssl_mode="none") + + async def test_connect_database_failed( db_name: str, account_name: str, diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index 3602ebf5e59..a1b7c371006 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -4,6 +4,7 @@ from unittest.mock import ANY as AnyValue from unittest.mock import MagicMock, patch +from httpx import Request, codes from pyfakefs.fake_filesystem_unittest import Patcher from pytest import mark, raises, warns from pytest_httpx import HTTPXMock @@ -19,8 +20,10 @@ ConfigurationError, ConnectionClosedError, FireboltError, + InterfaceError, ) from firebolt.utils.token_storage import TokenSecureStorage +from tests.unit.response import Response def test_connection_attributes(connection: Connection) -> None: @@ -112,6 +115,72 @@ def test_connect( assert connection.cursor().execute("select *") == len(python_query_data) +def test_connect_discovery( + db_name: str, + engine_name: str, + httpx_mock: HTTPXMock, + query_callback: Callable, + python_query_data: List[List[ColType]], +): + """Discovery connections pass database, engine and settings as query params.""" + httpx_mock.add_response( + method="GET", + url="http://localhost:3473/.well-known/firebolt", + json={"engineUrl": "http://localhost:3473/?discovered_param=value"}, + ) + + def query_with_discovery_params(request: Request, **kwargs) -> Response: + params = dict(request.url.params) + assert "authorization" not in request.headers + assert params["database"] == db_name + assert params["engine"] == engine_name + assert params["custom_setting"] == "custom_value" + assert params["discovered_param"] == "value" + return query_callback(request, **kwargs) + + httpx_mock.add_callback(query_with_discovery_params, method="POST") + + with connect( + host="localhost:3473", + ssl_mode="none", + database=db_name, + engine=engine_name, + settings={"custom_setting": "custom_value"}, + ) as connection: + assert connection.cursor().execute("select *") == len(python_query_data) + + +def test_connect_discovery_rejects_legacy_parameters(auth: Auth): + with raises(ConfigurationError, match="account_name"): + connect(host="localhost:3473", ssl_mode="none", account_name="account") + with raises(ConfigurationError, match="api_endpoint"): + connect( + host="localhost:3473", + ssl_mode="none", + api_endpoint="api.example.com", + ) + with raises(ConfigurationError, match="engine_url"): + connect(host="localhost:3473", ssl_mode="none", engine_url="engine.example.com") + with raises(ConfigurationError, match="url"): + connect(host="localhost:3473", ssl_mode="none", url="http://localhost:3473") + with raises(ConfigurationError, match="auth"): + connect(host="localhost:3473", ssl_mode="none", auth=auth) + + +def test_connect_discovery_validation(httpx_mock: HTTPXMock): + with raises(ConfigurationError, match="ssl_mode"): + connect(host="localhost:3473", ssl_mode="invalid") + + httpx_mock.add_response( + method="GET", + url="http://localhost:3473/.well-known/firebolt", + status_code=codes.NOT_FOUND, + text="not found", + ) + with raises(InterfaceError, match="Unable to retrieve Firebolt discovery"): + connect(host="localhost:3473", ssl_mode="none") + + def test_connect_database_failed( db_name: str, account_name: str, From 7758cd8488aa802e8ae8e704504a517b00575123 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 24 Jun 2026 15:14:26 +0000 Subject: [PATCH 2/6] Tighten discovery endpoint normalization Co-authored-by: Ivan Koptiev --- .github/workflows/integration-tests-core.yml | 2 +- src/firebolt/common/discovery.py | 17 +++++++++++++---- tests/unit/async_db/test_connection.py | 2 +- tests/unit/db/test_connection.py | 2 +- 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/.github/workflows/integration-tests-core.yml b/.github/workflows/integration-tests-core.yml index 3c745a0a4b3..f9b6b0f9fb7 100644 --- a/.github/workflows/integration-tests-core.yml +++ b/.github/workflows/integration-tests-core.yml @@ -57,7 +57,7 @@ jobs: env: ENGINE_TAG: ${{ inputs.tag_version || vars.DEFAULT_CORE_IMAGE_TAG || 'dev' }} run: | - bash <(curl -fsSL https://get.firebolt.io/) < /dev/null + printf 'n\n' | bash <(curl -fsSL https://get.firebolt.io/) - name: Start Firebolt env: diff --git a/src/firebolt/common/discovery.py b/src/firebolt/common/discovery.py index a54df37c6b8..53a67c167e5 100644 --- a/src/firebolt/common/discovery.py +++ b/src/firebolt/common/discovery.py @@ -101,27 +101,36 @@ def _endpoint_from_mapping(data: Mapping[str, Any]) -> Optional[str]: ) +def _resolve_endpoint(base_url: str, endpoint: str) -> str: + if "://" in endpoint or endpoint.startswith("/"): + return urljoin(base_url.rstrip("/") + "/", endpoint) + parsed_base = urlparse(base_url) + if ":" in endpoint or "." in endpoint: + return f"{parsed_base.scheme}://{endpoint}" + return urljoin(base_url.rstrip("/") + "/", endpoint) + + def _extract_engine_url(discovery: Mapping[str, Any], base_url: str) -> str: endpoint = _endpoint_from_mapping(discovery) if endpoint: - return urljoin(base_url.rstrip("/") + "/", endpoint) + return _resolve_endpoint(base_url, endpoint) endpoints = discovery.get("endpoints") if isinstance(endpoints, Mapping): for key in ("query", "sql", "engine", "http"): value = endpoints.get(key) if isinstance(value, str) and value: - return urljoin(base_url.rstrip("/") + "/", value) + return _resolve_endpoint(base_url, value) if isinstance(value, Mapping): endpoint = _endpoint_from_mapping(value) if endpoint: - return urljoin(base_url.rstrip("/") + "/", endpoint) + return _resolve_endpoint(base_url, endpoint) query = discovery.get("query") if isinstance(query, Mapping): endpoint = _endpoint_from_mapping(query) if endpoint: - return urljoin(base_url.rstrip("/") + "/", endpoint) + return _resolve_endpoint(base_url, endpoint) return base_url diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index 3522b591495..2b5fd4b01cd 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -122,7 +122,7 @@ async def test_connect_discovery( httpx_mock.add_response( method="GET", url="http://localhost:3473/.well-known/firebolt", - json={"engineUrl": "http://localhost:3473/?discovered_param=value"}, + json={"engineUrl": "localhost:3473/?discovered_param=value"}, ) def query_with_discovery_params(request: Request, **kwargs) -> Response: diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index a1b7c371006..9b54fcfdddd 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -126,7 +126,7 @@ def test_connect_discovery( httpx_mock.add_response( method="GET", url="http://localhost:3473/.well-known/firebolt", - json={"engineUrl": "http://localhost:3473/?discovered_param=value"}, + json={"engineUrl": "localhost:3473/?discovered_param=value"}, ) def query_with_discovery_params(request: Request, **kwargs) -> Response: From 4aab128181310fb60930b62acc13b9d7c892bf8f Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 24 Jun 2026 15:18:19 +0000 Subject: [PATCH 3/6] Refactor discovery connection helpers Co-authored-by: Ivan Koptiev --- src/firebolt/async_db/connection.py | 58 +++------ src/firebolt/common/discovery.py | 144 ++++++++++++++++++--- src/firebolt/db/connection.py | 58 +++------ tests/unit/async_db/test_connection.py | 55 +++----- tests/unit/db/test_connection.py | 62 +++------ tests/unit/discovery_connection_helpers.py | 69 ++++++++++ 6 files changed, 268 insertions(+), 178 deletions(-) create mode 100644 tests/unit/discovery_connection_helpers.py diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index ce4201e3870..7141e66092c 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -11,7 +11,7 @@ from firebolt.async_db.cursor import Cursor, CursorV1, CursorV2 from firebolt.client import DEFAULT_API_URL -from firebolt.client.auth import Auth, FireboltCore +from firebolt.client.auth import Auth from firebolt.client.auth.base import FireboltAuthVersion from firebolt.client.client import AsyncClient, AsyncClientV1, AsyncClientV2 from firebolt.common.base_connection import ( @@ -27,7 +27,13 @@ set_cached_system_engine_info, ) from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS -from firebolt.common.discovery import async_discover +from firebolt.common.discovery import ( + async_discover, + make_discovery_client_kwargs, + prepare_discovery_connection, + resolve_engine_name, + validate_discovery_connection_parameters, +) from firebolt.utils.cache import EngineInfo from firebolt.utils.exception import ( AccountNotFoundOrNoAccessError, @@ -315,11 +321,7 @@ async def connect( additional_parameters=additional_parameters, ) - if engine and engine_name and engine != engine_name: - raise ConfigurationError( - "Both engine and engine_name are provided. Provide only one to connect." - ) - engine_name = engine_name or engine + engine_name = resolve_engine_name(engine, engine_name) # auth parameter is optional in function signature # but is required to connect. @@ -392,27 +394,13 @@ async def connect_discovery( additional_parameters: Dict[str, Any] = {}, ) -> Connection: """Connect using the discovery-based Firebolt session model.""" - if account_name: - raise ConfigurationError( - "account_name is not compatible with discovery-based connections." - ) - if api_endpoint != DEFAULT_API_URL: - raise ConfigurationError( - "api_endpoint is not compatible with discovery-based connections." - ) - if engine_url: - raise ConfigurationError( - "engine_url is not compatible with discovery-based connections." - ) - if url: - raise ConfigurationError( - "url is not compatible with discovery-based connections. Use host instead." - ) - if auth and auth.get_firebolt_version() != FireboltAuthVersion.CORE: - raise ConfigurationError( - "auth is not compatible with discovery-based connections." - ) - + validate_discovery_connection_parameters( + account_name=account_name, + api_endpoint=api_endpoint, + engine_url=engine_url, + url=url, + auth=auth, + ) connection_id = uuid4().hex discovery_info = await async_discover( host=host, @@ -422,19 +410,11 @@ async def connect_discovery( engine_name=engine_name, settings=settings, ) - core_auth = auth or FireboltCore() - user_agent_header = get_user_agent_for_connection( - core_auth, connection_id, None, additional_parameters, True + prepared_connection = prepare_discovery_connection( + auth, connection_id, additional_parameters ) - client = AsyncClientV2( - auth=core_auth, - account_name="", - base_url=discovery_info.engine_url, - api_endpoint=discovery_info.api_endpoint, - timeout=Timeout(DEFAULT_TIMEOUT_SECONDS, read=None), - headers={"User-Agent": user_agent_header}, - verify=discovery_info.verify, + **make_discovery_client_kwargs(discovery_info, prepared_connection) ) return Connection( diff --git a/src/firebolt/common/discovery.py b/src/firebolt/common/discovery.py index 53a67c167e5..110180ef73f 100644 --- a/src/firebolt/common/discovery.py +++ b/src/firebolt/common/discovery.py @@ -10,6 +10,10 @@ from httpx import Client as HttpxClient from httpx import Timeout, codes +from firebolt.client.auth import Auth, FireboltCore +from firebolt.client.auth.base import FireboltAuthVersion +from firebolt.client.constants import DEFAULT_API_URL +from firebolt.common.base_connection import get_user_agent_for_connection from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS from firebolt.utils.exception import ConfigurationError, InterfaceError from firebolt.utils.firebolt_core import get_core_certificate_context @@ -29,6 +33,12 @@ class DiscoveryConnectionInfo: verify: Union[SSLContext, bool] +@dataclass(frozen=True) +class PreparedDiscoveryConnection: + auth: Auth + user_agent_header: str + + def normalize_ssl_mode(ssl_mode: str) -> str: mode = ssl_mode.lower() if mode not in SSL_MODES: @@ -79,6 +89,79 @@ def build_discovery_url(base_url: str) -> str: return urljoin(base_url.rstrip("/") + "/", DISCOVERY_PATH.lstrip("/")) +def resolve_engine_name( + engine: Optional[str], + engine_name: Optional[str], +) -> Optional[str]: + if engine and engine_name and engine != engine_name: + raise ConfigurationError( + "Both engine and engine_name are provided. Provide only one to connect." + ) + return engine_name or engine + + +def validate_discovery_connection_parameters( + account_name: Optional[str], + api_endpoint: str, + engine_url: Optional[str], + url: Optional[str], + auth: Optional[Auth], +) -> None: + if account_name: + raise ConfigurationError( + "account_name is not compatible with discovery-based connections." + ) + if api_endpoint != DEFAULT_API_URL: + raise ConfigurationError( + "api_endpoint is not compatible with discovery-based connections." + ) + if engine_url: + raise ConfigurationError( + "engine_url is not compatible with discovery-based connections." + ) + if url: + raise ConfigurationError( + "url is not compatible with discovery-based connections. Use host instead." + ) + if auth and auth.get_firebolt_version() != FireboltAuthVersion.CORE: + raise ConfigurationError( + "auth is not compatible with discovery-based connections." + ) + + +def prepare_discovery_connection( + auth: Optional[Auth], + connection_id: str, + additional_parameters: Dict[str, Any], +) -> PreparedDiscoveryConnection: + core_auth = auth or FireboltCore() + return PreparedDiscoveryConnection( + auth=core_auth, + user_agent_header=get_user_agent_for_connection( + core_auth, + connection_id, + None, + additional_parameters, + True, + ), + ) + + +def make_discovery_client_kwargs( + discovery_info: DiscoveryConnectionInfo, + prepared_connection: PreparedDiscoveryConnection, +) -> Dict[str, Any]: + return { + "auth": prepared_connection.auth, + "account_name": "", + "base_url": discovery_info.engine_url, + "api_endpoint": discovery_info.api_endpoint, + "timeout": Timeout(DEFAULT_TIMEOUT_SECONDS, read=None), + "headers": {"User-Agent": prepared_connection.user_agent_header}, + "verify": discovery_info.verify, + } + + def _string_value(data: Mapping[str, Any], *keys: str) -> Optional[str]: for key in keys: value = data.get(key) @@ -147,11 +230,7 @@ def make_discovery_connection_info( base_url = normalize_host(host, ssl_mode) verify = get_tls_verify(base_url, ssl_mode) - if engine and engine_name and engine != engine_name: - raise ConfigurationError( - "Both engine and engine_name are provided. Provide only one to connect." - ) - engine_parameter = engine or engine_name + engine_parameter = resolve_engine_name(engine, engine_name) endpoint = _extract_engine_url(discovery, base_url) endpoint_url, endpoint_params = parse_url_and_params(endpoint) @@ -184,6 +263,37 @@ def _decode_discovery_response(response_text: str) -> Mapping[str, Any]: return decoded +def _raise_if_discovery_failed(status_code: int, text: str, discovery_url: str) -> None: + if status_code != codes.OK: + raise InterfaceError( + f"Unable to retrieve Firebolt discovery document {discovery_url}: " + f"{status_code} {text}" + ) + + +def _make_info_from_response( + status_code: int, + text: str, + discovery_url: str, + host: str, + ssl_mode: str, + database: Optional[str], + engine: Optional[str], + engine_name: Optional[str], + settings: Optional[Dict[str, Any]], +) -> DiscoveryConnectionInfo: + _raise_if_discovery_failed(status_code, text, discovery_url) + return make_discovery_connection_info( + host=host, + ssl_mode=ssl_mode, + discovery=_decode_discovery_response(text), + database=database, + engine=engine, + engine_name=engine_name, + settings=settings, + ) + + def discover( host: str, ssl_mode: str, @@ -202,16 +312,12 @@ def discover( ) as client: response = client.get(discovery_url) - if response.status_code != codes.OK: - raise InterfaceError( - f"Unable to retrieve Firebolt discovery document {discovery_url}: " - f"{response.status_code} {response.text}" - ) - - return make_discovery_connection_info( + return _make_info_from_response( + status_code=response.status_code, + text=response.text, + discovery_url=discovery_url, host=host, ssl_mode=ssl_mode, - discovery=_decode_discovery_response(response.text), database=database, engine=engine, engine_name=engine_name, @@ -237,16 +343,12 @@ async def async_discover( ) as client: response = await client.get(discovery_url) - if response.status_code != codes.OK: - raise InterfaceError( - f"Unable to retrieve Firebolt discovery document {discovery_url}: " - f"{response.status_code} {response.text}" - ) - - return make_discovery_connection_info( + return _make_info_from_response( + status_code=response.status_code, + text=response.text, + discovery_url=discovery_url, host=host, ssl_mode=ssl_mode, - discovery=_decode_discovery_response(response.text), database=database, engine=engine, engine_name=engine_name, diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index 909390d5a64..e0bb58ae0e9 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -11,7 +11,7 @@ from httpx import Request, Response, Timeout, codes from firebolt.client import DEFAULT_API_URL, Client, ClientV1, ClientV2 -from firebolt.client.auth import Auth, FireboltCore +from firebolt.client.auth import Auth from firebolt.client.auth.base import FireboltAuthVersion from firebolt.common.base_connection import ( ASYNC_QUERY_CANCEL, @@ -26,7 +26,13 @@ set_cached_system_engine_info, ) from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS -from firebolt.common.discovery import discover +from firebolt.common.discovery import ( + discover, + make_discovery_client_kwargs, + prepare_discovery_connection, + resolve_engine_name, + validate_discovery_connection_parameters, +) from firebolt.db.cursor import Cursor, CursorV1, CursorV2 from firebolt.utils.cache import EngineInfo from firebolt.utils.exception import ( @@ -84,11 +90,7 @@ def connect( additional_parameters=additional_parameters, ) - if engine and engine_name and engine != engine_name: - raise ConfigurationError( - "Both engine and engine_name are provided. Provide only one to connect." - ) - engine_name = engine_name or engine + engine_name = resolve_engine_name(engine, engine_name) # auth parameter is optional in function signature # but is required to connect. @@ -162,27 +164,13 @@ def connect_discovery( additional_parameters: Dict[str, Any] = {}, ) -> Connection: """Connect using the discovery-based Firebolt session model.""" - if account_name: - raise ConfigurationError( - "account_name is not compatible with discovery-based connections." - ) - if api_endpoint != DEFAULT_API_URL: - raise ConfigurationError( - "api_endpoint is not compatible with discovery-based connections." - ) - if engine_url: - raise ConfigurationError( - "engine_url is not compatible with discovery-based connections." - ) - if url: - raise ConfigurationError( - "url is not compatible with discovery-based connections. Use host instead." - ) - if auth and auth.get_firebolt_version() != FireboltAuthVersion.CORE: - raise ConfigurationError( - "auth is not compatible with discovery-based connections." - ) - + validate_discovery_connection_parameters( + account_name=account_name, + api_endpoint=api_endpoint, + engine_url=engine_url, + url=url, + auth=auth, + ) connection_id = uuid4().hex discovery_info = discover( host=host, @@ -192,19 +180,11 @@ def connect_discovery( engine_name=engine_name, settings=settings, ) - core_auth = auth or FireboltCore() - user_agent_header = get_user_agent_for_connection( - core_auth, connection_id, None, additional_parameters, True + prepared_connection = prepare_discovery_connection( + auth, connection_id, additional_parameters ) - client = ClientV2( - auth=core_auth, - account_name="", - base_url=discovery_info.engine_url, - api_endpoint=discovery_info.api_endpoint, - timeout=Timeout(DEFAULT_TIMEOUT_SECONDS, read=None), - headers={"User-Agent": user_agent_header}, - verify=discovery_info.verify, + **make_discovery_client_kwargs(discovery_info, prepared_connection) ) return Connection( diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index 2b5fd4b01cd..60a17e124db 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -2,7 +2,6 @@ from unittest.mock import ANY as AnyValue from unittest.mock import MagicMock, patch -from httpx import Request, codes from pyfakefs.fake_filesystem_unittest import Patcher from pytest import mark, raises from pytest_httpx import HTTPXMock @@ -16,10 +15,15 @@ ConfigurationError, ConnectionClosedError, FireboltError, - InterfaceError, ) from firebolt.utils.token_storage import TokenSecureStorage -from tests.unit.response import Response +from tests.unit.discovery_connection_helpers import ( + DISCOVERY_HOST, + DISCOVERY_SETTINGS, + assert_async_discovery_lookup_error, + mock_discovery_connection_flow, + mock_discovery_not_found, +) @mark.skip("__slots__ is broken on Connection class") @@ -119,70 +123,51 @@ async def test_connect_discovery( python_query_data: List[List[ColType]], ): """Discovery connections pass database, engine and settings as query params.""" - httpx_mock.add_response( - method="GET", - url="http://localhost:3473/.well-known/firebolt", - json={"engineUrl": "localhost:3473/?discovered_param=value"}, - ) - - def query_with_discovery_params(request: Request, **kwargs) -> Response: - params = dict(request.url.params) - assert "authorization" not in request.headers - assert params["database"] == db_name - assert params["engine"] == engine_name - assert params["custom_setting"] == "custom_value" - assert params["discovered_param"] == "value" - return query_callback(request, **kwargs) - - httpx_mock.add_callback(query_with_discovery_params, method="POST") + mock_discovery_connection_flow(httpx_mock, db_name, engine_name, query_callback) async with await connect( - host="localhost:3473", + host=DISCOVERY_HOST, ssl_mode="none", database=db_name, engine=engine_name, - settings={"custom_setting": "custom_value"}, + settings=DISCOVERY_SETTINGS, ) as connection: assert await connection.cursor().execute("select *") == len(python_query_data) async def test_connect_discovery_rejects_legacy_parameters(auth: Auth): with raises(ConfigurationError, match="account_name"): - await connect(host="localhost:3473", ssl_mode="none", account_name="account") + await connect(host=DISCOVERY_HOST, ssl_mode="none", account_name="account") with raises(ConfigurationError, match="api_endpoint"): await connect( - host="localhost:3473", + host=DISCOVERY_HOST, ssl_mode="none", api_endpoint="api.example.com", ) with raises(ConfigurationError, match="engine_url"): await connect( - host="localhost:3473", + host=DISCOVERY_HOST, ssl_mode="none", engine_url="engine.example.com", ) with raises(ConfigurationError, match="url"): await connect( - host="localhost:3473", + host=DISCOVERY_HOST, ssl_mode="none", - url="http://localhost:3473", + url=f"http://{DISCOVERY_HOST}", ) with raises(ConfigurationError, match="auth"): - await connect(host="localhost:3473", ssl_mode="none", auth=auth) + await connect(host=DISCOVERY_HOST, ssl_mode="none", auth=auth) async def test_connect_discovery_validation(httpx_mock: HTTPXMock): with raises(ConfigurationError, match="ssl_mode"): - await connect(host="localhost:3473", ssl_mode="invalid") + await connect(host=DISCOVERY_HOST, ssl_mode="invalid") - httpx_mock.add_response( - method="GET", - url="http://localhost:3473/.well-known/firebolt", - status_code=codes.NOT_FOUND, - text="not found", + mock_discovery_not_found(httpx_mock) + await assert_async_discovery_lookup_error( + lambda: connect(host=DISCOVERY_HOST, ssl_mode="none") ) - with raises(InterfaceError, match="Unable to retrieve Firebolt discovery"): - await connect(host="localhost:3473", ssl_mode="none") async def test_connect_database_failed( diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index 9b54fcfdddd..64812c50d9f 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -4,7 +4,6 @@ from unittest.mock import ANY as AnyValue from unittest.mock import MagicMock, patch -from httpx import Request, codes from pyfakefs.fake_filesystem_unittest import Patcher from pytest import mark, raises, warns from pytest_httpx import HTTPXMock @@ -20,10 +19,16 @@ ConfigurationError, ConnectionClosedError, FireboltError, - InterfaceError, ) from firebolt.utils.token_storage import TokenSecureStorage -from tests.unit.response import Response +from tests.unit.discovery_connection_helpers import ( + DISCOVERY_HOST, + DISCOVERY_SETTINGS, + assert_discovery_lookup_error, + assert_discovery_validation_errors, + mock_discovery_connection_flow, + mock_discovery_not_found, +) def test_connection_attributes(connection: Connection) -> None: @@ -123,62 +128,31 @@ def test_connect_discovery( python_query_data: List[List[ColType]], ): """Discovery connections pass database, engine and settings as query params.""" - httpx_mock.add_response( - method="GET", - url="http://localhost:3473/.well-known/firebolt", - json={"engineUrl": "localhost:3473/?discovered_param=value"}, - ) - - def query_with_discovery_params(request: Request, **kwargs) -> Response: - params = dict(request.url.params) - assert "authorization" not in request.headers - assert params["database"] == db_name - assert params["engine"] == engine_name - assert params["custom_setting"] == "custom_value" - assert params["discovered_param"] == "value" - return query_callback(request, **kwargs) - - httpx_mock.add_callback(query_with_discovery_params, method="POST") + mock_discovery_connection_flow(httpx_mock, db_name, engine_name, query_callback) with connect( - host="localhost:3473", + host=DISCOVERY_HOST, ssl_mode="none", database=db_name, engine=engine_name, - settings={"custom_setting": "custom_value"}, + settings=DISCOVERY_SETTINGS, ) as connection: assert connection.cursor().execute("select *") == len(python_query_data) def test_connect_discovery_rejects_legacy_parameters(auth: Auth): - with raises(ConfigurationError, match="account_name"): - connect(host="localhost:3473", ssl_mode="none", account_name="account") - with raises(ConfigurationError, match="api_endpoint"): - connect( - host="localhost:3473", - ssl_mode="none", - api_endpoint="api.example.com", - ) - with raises(ConfigurationError, match="engine_url"): - connect(host="localhost:3473", ssl_mode="none", engine_url="engine.example.com") - with raises(ConfigurationError, match="url"): - connect(host="localhost:3473", ssl_mode="none", url="http://localhost:3473") - with raises(ConfigurationError, match="auth"): - connect(host="localhost:3473", ssl_mode="none", auth=auth) + def connect_discovery(**kwargs): + return connect(host=DISCOVERY_HOST, ssl_mode="none", **kwargs) + + assert_discovery_validation_errors(connect_discovery, auth) def test_connect_discovery_validation(httpx_mock: HTTPXMock): with raises(ConfigurationError, match="ssl_mode"): - connect(host="localhost:3473", ssl_mode="invalid") + connect(host=DISCOVERY_HOST, ssl_mode="invalid") - httpx_mock.add_response( - method="GET", - url="http://localhost:3473/.well-known/firebolt", - status_code=codes.NOT_FOUND, - text="not found", - ) - with raises(InterfaceError, match="Unable to retrieve Firebolt discovery"): - connect(host="localhost:3473", ssl_mode="none") + mock_discovery_not_found(httpx_mock) + assert_discovery_lookup_error(lambda: connect(host=DISCOVERY_HOST, ssl_mode="none")) def test_connect_database_failed( diff --git a/tests/unit/discovery_connection_helpers.py b/tests/unit/discovery_connection_helpers.py new file mode 100644 index 00000000000..0162e233fe8 --- /dev/null +++ b/tests/unit/discovery_connection_helpers.py @@ -0,0 +1,69 @@ +from typing import Callable + +from httpx import Request, codes +from pytest import raises +from pytest_httpx import HTTPXMock + +from firebolt.client.auth import Auth +from firebolt.utils.exception import ConfigurationError, InterfaceError +from tests.unit.response import Response + +DISCOVERY_HOST = "localhost:3473" +DISCOVERY_URL = f"http://{DISCOVERY_HOST}/.well-known/firebolt" +DISCOVERY_SETTINGS = {"custom_setting": "custom_value"} + + +def mock_discovery_connection_flow( + httpx_mock: HTTPXMock, + db_name: str, + engine_name: str, + query_callback: Callable, +) -> None: + httpx_mock.add_response( + method="GET", + url=DISCOVERY_URL, + json={"engineUrl": f"{DISCOVERY_HOST}/?discovered_param=value"}, + ) + + def query_with_discovery_params(request: Request, **kwargs) -> Response: + params = dict(request.url.params) + assert "authorization" not in request.headers + assert params["database"] == db_name + assert params["engine"] == engine_name + assert params["custom_setting"] == DISCOVERY_SETTINGS["custom_setting"] + assert params["discovered_param"] == "value" + return query_callback(request, **kwargs) + + httpx_mock.add_callback(query_with_discovery_params, method="POST") + + +def assert_discovery_validation_errors(connect_call: Callable, auth: Auth) -> None: + with raises(ConfigurationError, match="account_name"): + connect_call(account_name="account") + with raises(ConfigurationError, match="api_endpoint"): + connect_call(api_endpoint="api.example.com") + with raises(ConfigurationError, match="engine_url"): + connect_call(engine_url="engine.example.com") + with raises(ConfigurationError, match="url"): + connect_call(url=f"http://{DISCOVERY_HOST}") + with raises(ConfigurationError, match="auth"): + connect_call(auth=auth) + + +def mock_discovery_not_found(httpx_mock: HTTPXMock) -> None: + httpx_mock.add_response( + method="GET", + url=DISCOVERY_URL, + status_code=codes.NOT_FOUND, + text="not found", + ) + + +def assert_discovery_lookup_error(connect_call: Callable) -> None: + with raises(InterfaceError, match="Unable to retrieve Firebolt discovery"): + connect_call() + + +async def assert_async_discovery_lookup_error(connect_call: Callable) -> None: + with raises(InterfaceError, match="Unable to retrieve Firebolt discovery"): + await connect_call() From 84765448ff431a2b560540899266005b0a0d21ec Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 24 Jun 2026 15:22:09 +0000 Subject: [PATCH 4/6] Fix discovery flake8 line length Co-authored-by: Ivan Koptiev --- src/firebolt/common/discovery.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/firebolt/common/discovery.py b/src/firebolt/common/discovery.py index 110180ef73f..8b2e394c0b1 100644 --- a/src/firebolt/common/discovery.py +++ b/src/firebolt/common/discovery.py @@ -65,7 +65,8 @@ def normalize_host(host: str, ssl_mode: str) -> str: ) if not parsed.netloc: raise ConfigurationError( - f"Invalid host: {host}. Expected a hostname, optionally with scheme and port." + f"Invalid host: {host}. Expected a hostname, optionally with scheme " + "and port." ) if parsed.query or parsed.fragment: raise ConfigurationError( From a3cc55f481fc976e6a4c5b5ded24b92d76fac23a Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 24 Jun 2026 15:23:29 +0000 Subject: [PATCH 5/6] Reduce discovery test duplication Co-authored-by: Ivan Koptiev --- tests/unit/async_db/test_connection.py | 37 ---------------------- tests/unit/discovery_connection_helpers.py | 5 --- 2 files changed, 42 deletions(-) diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index 60a17e124db..1d871204440 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -20,9 +20,7 @@ from tests.unit.discovery_connection_helpers import ( DISCOVERY_HOST, DISCOVERY_SETTINGS, - assert_async_discovery_lookup_error, mock_discovery_connection_flow, - mock_discovery_not_found, ) @@ -135,41 +133,6 @@ async def test_connect_discovery( assert await connection.cursor().execute("select *") == len(python_query_data) -async def test_connect_discovery_rejects_legacy_parameters(auth: Auth): - with raises(ConfigurationError, match="account_name"): - await connect(host=DISCOVERY_HOST, ssl_mode="none", account_name="account") - with raises(ConfigurationError, match="api_endpoint"): - await connect( - host=DISCOVERY_HOST, - ssl_mode="none", - api_endpoint="api.example.com", - ) - with raises(ConfigurationError, match="engine_url"): - await connect( - host=DISCOVERY_HOST, - ssl_mode="none", - engine_url="engine.example.com", - ) - with raises(ConfigurationError, match="url"): - await connect( - host=DISCOVERY_HOST, - ssl_mode="none", - url=f"http://{DISCOVERY_HOST}", - ) - with raises(ConfigurationError, match="auth"): - await connect(host=DISCOVERY_HOST, ssl_mode="none", auth=auth) - - -async def test_connect_discovery_validation(httpx_mock: HTTPXMock): - with raises(ConfigurationError, match="ssl_mode"): - await connect(host=DISCOVERY_HOST, ssl_mode="invalid") - - mock_discovery_not_found(httpx_mock) - await assert_async_discovery_lookup_error( - lambda: connect(host=DISCOVERY_HOST, ssl_mode="none") - ) - - async def test_connect_database_failed( db_name: str, account_name: str, diff --git a/tests/unit/discovery_connection_helpers.py b/tests/unit/discovery_connection_helpers.py index 0162e233fe8..1cf1bcf3e7b 100644 --- a/tests/unit/discovery_connection_helpers.py +++ b/tests/unit/discovery_connection_helpers.py @@ -62,8 +62,3 @@ def mock_discovery_not_found(httpx_mock: HTTPXMock) -> None: def assert_discovery_lookup_error(connect_call: Callable) -> None: with raises(InterfaceError, match="Unable to retrieve Firebolt discovery"): connect_call() - - -async def assert_async_discovery_lookup_error(connect_call: Callable) -> None: - with raises(InterfaceError, match="Unable to retrieve Firebolt discovery"): - await connect_call() From 25f3acc535f9d80f31cdfd9009d90638f5efe6ff Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 24 Jun 2026 15:29:00 +0000 Subject: [PATCH 6/6] Share discovery connection assembly Co-authored-by: Ivan Koptiev --- src/firebolt/async_db/connection.py | 74 +++++------------------- src/firebolt/common/discovery.py | 90 ++++++++++++++++++++++++++++- src/firebolt/db/connection.py | 76 +++++------------------- 3 files changed, 114 insertions(+), 126 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 7141e66092c..a6db1185250 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -28,11 +28,11 @@ ) from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS from firebolt.common.discovery import ( + DiscoveryConnectConfig, async_discover, - make_discovery_client_kwargs, - prepare_discovery_connection, + make_connection_from_discovery, resolve_engine_name, - validate_discovery_connection_parameters, + validate_discovery_connect_config, ) from firebolt.utils.cache import EngineInfo from firebolt.utils.exception import ( @@ -306,19 +306,7 @@ async def connect( ) -> Connection: if host: return await connect_discovery( - host=host, - database=database, - engine=engine, - engine_name=engine_name, - engine_url=engine_url, - account_name=account_name, - api_endpoint=api_endpoint, - url=url, - auth=auth, - ssl_mode=ssl_mode, - settings=settings, - autocommit=autocommit, - additional_parameters=additional_parameters, + DiscoveryConnectConfig.from_connect_kwargs(locals()) ) engine_name = resolve_engine_name(engine, engine_name) @@ -378,54 +366,18 @@ async def connect( raise ConfigurationError(f"Unsupported auth type: {type(auth)}") -async def connect_discovery( - host: str, - database: Optional[str] = None, - engine: Optional[str] = None, - engine_name: Optional[str] = None, - engine_url: Optional[str] = None, - account_name: Optional[str] = None, - api_endpoint: str = DEFAULT_API_URL, - url: Optional[str] = None, - auth: Optional[Auth] = None, - ssl_mode: str = "strict", - settings: Optional[Dict[str, Any]] = None, - autocommit: bool = True, - additional_parameters: Dict[str, Any] = {}, -) -> Connection: +async def connect_discovery(config: DiscoveryConnectConfig) -> Connection: """Connect using the discovery-based Firebolt session model.""" - validate_discovery_connection_parameters( - account_name=account_name, - api_endpoint=api_endpoint, - engine_url=engine_url, - url=url, - auth=auth, - ) + validate_discovery_connect_config(config) connection_id = uuid4().hex - discovery_info = await async_discover( - host=host, - ssl_mode=ssl_mode, - database=database, - engine=engine, - engine_name=engine_name, - settings=settings, - ) - prepared_connection = prepare_discovery_connection( - auth, connection_id, additional_parameters - ) - client = AsyncClientV2( - **make_discovery_client_kwargs(discovery_info, prepared_connection) - ) - - return Connection( - engine_url=discovery_info.engine_url, - database=None, - client=client, + discovery_info = await async_discover(**config.discovery_kwargs()) + return make_connection_from_discovery( + discovery_info=discovery_info, + config=config, + connection_id=connection_id, + client_type=AsyncClientV2, cursor_type=CursorV2, - api_endpoint=discovery_info.api_endpoint, - init_parameters=discovery_info.parameters, - id=connection_id, - autocommit=autocommit, + connection_type=Connection, ) diff --git a/src/firebolt/common/discovery.py b/src/firebolt/common/discovery.py index 8b2e394c0b1..844777a1cdb 100644 --- a/src/firebolt/common/discovery.py +++ b/src/firebolt/common/discovery.py @@ -1,9 +1,9 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from json import JSONDecodeError from ssl import SSLContext -from typing import Any, Dict, Mapping, Optional, Union +from typing import Any, Dict, Mapping, Optional, Type, Union from urllib.parse import urljoin, urlparse from httpx import AsyncClient as HttpxAsyncClient @@ -39,6 +39,54 @@ class PreparedDiscoveryConnection: user_agent_header: str +@dataclass(frozen=True) +class DiscoveryConnectConfig: + host: str + database: Optional[str] = None + engine: Optional[str] = None + engine_name: Optional[str] = None + engine_url: Optional[str] = None + account_name: Optional[str] = None + api_endpoint: str = DEFAULT_API_URL + url: Optional[str] = None + auth: Optional[Auth] = None + ssl_mode: str = SSL_MODE_STRICT + settings: Optional[Dict[str, Any]] = None + autocommit: bool = True + additional_parameters: Dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_connect_kwargs( + cls, + kwargs: Mapping[str, Any], + ) -> "DiscoveryConnectConfig": + return cls( + host=kwargs["host"], + database=kwargs.get("database"), + engine=kwargs.get("engine"), + engine_name=kwargs.get("engine_name"), + engine_url=kwargs.get("engine_url"), + account_name=kwargs.get("account_name"), + api_endpoint=kwargs.get("api_endpoint", DEFAULT_API_URL), + url=kwargs.get("url"), + auth=kwargs.get("auth"), + ssl_mode=kwargs.get("ssl_mode", SSL_MODE_STRICT), + settings=kwargs.get("settings"), + autocommit=kwargs.get("autocommit", True), + additional_parameters=kwargs.get("additional_parameters") or {}, + ) + + def discovery_kwargs(self) -> Dict[str, Any]: + return { + "host": self.host, + "ssl_mode": self.ssl_mode, + "database": self.database, + "engine": self.engine, + "engine_name": self.engine_name, + "settings": self.settings, + } + + def normalize_ssl_mode(ssl_mode: str) -> str: mode = ssl_mode.lower() if mode not in SSL_MODES: @@ -130,6 +178,16 @@ def validate_discovery_connection_parameters( ) +def validate_discovery_connect_config(config: DiscoveryConnectConfig) -> None: + validate_discovery_connection_parameters( + account_name=config.account_name, + api_endpoint=config.api_endpoint, + engine_url=config.engine_url, + url=config.url, + auth=config.auth, + ) + + def prepare_discovery_connection( auth: Optional[Auth], connection_id: str, @@ -163,6 +221,34 @@ def make_discovery_client_kwargs( } +def make_connection_from_discovery( + discovery_info: DiscoveryConnectionInfo, + config: DiscoveryConnectConfig, + connection_id: str, + client_type: Type, + cursor_type: Type, + connection_type: Type, +) -> Any: + prepared_connection = prepare_discovery_connection( + config.auth, + connection_id, + config.additional_parameters, + ) + client = client_type( + **make_discovery_client_kwargs(discovery_info, prepared_connection) + ) + return connection_type( + engine_url=discovery_info.engine_url, + database=None, + client=client, + cursor_type=cursor_type, + api_endpoint=discovery_info.api_endpoint, + init_parameters=discovery_info.parameters, + id=connection_id, + autocommit=config.autocommit, + ) + + def _string_value(data: Mapping[str, Any], *keys: str) -> Optional[str]: for key in keys: value = data.get(key) diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index e0bb58ae0e9..657138b8079 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -27,11 +27,11 @@ ) from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS from firebolt.common.discovery import ( + DiscoveryConnectConfig, discover, - make_discovery_client_kwargs, - prepare_discovery_connection, + make_connection_from_discovery, resolve_engine_name, - validate_discovery_connection_parameters, + validate_discovery_connect_config, ) from firebolt.db.cursor import Cursor, CursorV1, CursorV2 from firebolt.utils.cache import EngineInfo @@ -74,21 +74,7 @@ def connect( additional_parameters: Dict[str, Any] = {}, ) -> Connection: if host: - return connect_discovery( - host=host, - database=database, - engine=engine, - engine_name=engine_name, - engine_url=engine_url, - account_name=account_name, - api_endpoint=api_endpoint, - url=url, - auth=auth, - ssl_mode=ssl_mode, - settings=settings, - autocommit=autocommit, - additional_parameters=additional_parameters, - ) + return connect_discovery(DiscoveryConnectConfig.from_connect_kwargs(locals())) engine_name = resolve_engine_name(engine, engine_name) @@ -148,54 +134,18 @@ def connect( raise ConfigurationError(f"Unsupported auth type: {type(auth)}") -def connect_discovery( - host: str, - database: Optional[str] = None, - engine: Optional[str] = None, - engine_name: Optional[str] = None, - engine_url: Optional[str] = None, - account_name: Optional[str] = None, - api_endpoint: str = DEFAULT_API_URL, - url: Optional[str] = None, - auth: Optional[Auth] = None, - ssl_mode: str = "strict", - settings: Optional[Dict[str, Any]] = None, - autocommit: bool = True, - additional_parameters: Dict[str, Any] = {}, -) -> Connection: +def connect_discovery(config: DiscoveryConnectConfig) -> Connection: """Connect using the discovery-based Firebolt session model.""" - validate_discovery_connection_parameters( - account_name=account_name, - api_endpoint=api_endpoint, - engine_url=engine_url, - url=url, - auth=auth, - ) + validate_discovery_connect_config(config) connection_id = uuid4().hex - discovery_info = discover( - host=host, - ssl_mode=ssl_mode, - database=database, - engine=engine, - engine_name=engine_name, - settings=settings, - ) - prepared_connection = prepare_discovery_connection( - auth, connection_id, additional_parameters - ) - client = ClientV2( - **make_discovery_client_kwargs(discovery_info, prepared_connection) - ) - - return Connection( - engine_url=discovery_info.engine_url, - database=None, - client=client, + discovery_info = discover(**config.discovery_kwargs()) + return make_connection_from_discovery( + discovery_info=discovery_info, + config=config, + connection_id=connection_id, + client_type=ClientV2, cursor_type=CursorV2, - api_endpoint=discovery_info.api_endpoint, - init_parameters=discovery_info.parameters, - id=connection_id, - autocommit=autocommit, + connection_type=Connection, )