diff --git a/.github/workflows/integration-tests-core.yml b/.github/workflows/integration-tests-core.yml index 890270780bd..f9b6b0f9fb7 100644 --- a/.github/workflows/integration-tests-core.yml +++ b/.github/workflows/integration-tests-core.yml @@ -53,11 +53,44 @@ jobs: with: repository: 'firebolt-db/firebolt-python-sdk' - - name: Setup Firebolt Core - id: setup-core - uses: firebolt-db/action-setup-core@eabcd701de0be41793fda0655d29d46c70c847c2 # main - with: - tag_version: ${{ inputs.tag_version || vars.DEFAULT_CORE_IMAGE_TAG }} + - name: Install Firebolt + env: + ENGINE_TAG: ${{ inputs.tag_version || vars.DEFAULT_CORE_IMAGE_TAG || 'dev' }} + run: | + printf 'n\n' | bash <(curl -fsSL https://get.firebolt.io/) + + - name: Start Firebolt + env: + ENGINE_REPO: ghcr.io/firebolt-db/engine + ENGINE_TAG: ${{ inputs.tag_version || vars.DEFAULT_CORE_IMAGE_TAG || 'dev' }} + run: | + mkdir -p -m 777 firebolt-data + docker run \ + --detach \ + --user firebolt \ + --name firebolt \ + --rm \ + --ulimit memlock=8589934592:8589934592 \ + --security-opt seccomp=unconfined \ + -v "${PWD}/firebolt-data:/var/lib/firebolt" \ + -p 3473:3473 \ + "${ENGINE_REPO}:${ENGINE_TAG}" + + timeout=60 + until [ "$timeout" -eq 0 ]; do + response="$( + curl -s 'http://localhost:3473/?output_format=TabSeparatedWithNamesAndTypes' \ + --data-binary 'SELECT 42;' || true + )" + if [ "$response" = $'?column?\nint\n42' ]; then + exit 0 + fi + sleep 1 + timeout=$((timeout - 1)) + done + + docker logs firebolt + exit 1 - name: Set up Python uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 @@ -69,7 +102,7 @@ jobs: python -m pip install --upgrade pip pip install ".[dev]" - - name: Run integration tests HTTP + - name: Run integration tests env: SERVICE_ID: ${{ secrets.FIREBOLT_CLIENT_ID_STG_NEW_IDN }} SERVICE_SECRET: ${{ secrets.FIREBOLT_CLIENT_SECRET_STG_NEW_IDN }} @@ -78,23 +111,10 @@ jobs: STOPPED_ENGINE_NAME: "" API_ENDPOINT: "" ACCOUNT_NAME: "" - CORE_URL: ${{ steps.setup-core.outputs.service_url }} + CORE_URL: http://localhost:3473 run: | pytest -o log_cli=true -o log_cli_level=WARNING tests/integration -k "core" --alluredir=allure-results/ - - name: Run integration tests HTTPS - env: - SERVICE_ID: ${{ secrets.FIREBOLT_CLIENT_ID_STG_NEW_IDN }} - SERVICE_SECRET: ${{ secrets.FIREBOLT_CLIENT_SECRET_STG_NEW_IDN }} - DATABASE_NAME: "firebolt" - ENGINE_NAME: "" - STOPPED_ENGINE_NAME: "" - API_ENDPOINT: "" - ACCOUNT_NAME: "" - CORE_URL: ${{ steps.setup-core.outputs.service_https_url }} - run: | - pytest -o log_cli=true -o log_cli_level=WARNING tests/integration -k "core" --alluredir=allure-results-https/ - - name: Allure Reports uses: firebolt-db/action-allure-report@8cdc116f65f6eca845a992e347e72b75ca8ccf5f # v2.1.1 if: always() @@ -103,6 +123,5 @@ jobs: pages-branch: gh-pages mapping-json: | { - "allure-results": "core", - "allure-results-https": "core_https" + "allure-results": "core" } diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 9d7c8c36c9f..a6db1185250 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -27,6 +27,13 @@ set_cached_system_engine_info, ) from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS +from firebolt.common.discovery import ( + DiscoveryConnectConfig, + async_discover, + make_connection_from_discovery, + resolve_engine_name, + validate_discovery_connect_config, +) from firebolt.utils.cache import EngineInfo from firebolt.utils.exception import ( AccountNotFoundOrNoAccessError, @@ -285,14 +292,25 @@ async def connect( auth: Optional[Auth] = None, account_name: Optional[str] = None, database: Optional[str] = None, + engine: Optional[str] = None, engine_name: Optional[str] = None, engine_url: Optional[str] = None, api_endpoint: str = DEFAULT_API_URL, disable_cache: bool = False, url: Optional[str] = None, + host: Optional[str] = None, + ssl_mode: str = "strict", + settings: Optional[Dict[str, Any]] = None, autocommit: bool = True, additional_parameters: Dict[str, Any] = {}, ) -> Connection: + if host: + return await connect_discovery( + DiscoveryConnectConfig.from_connect_kwargs(locals()) + ) + + engine_name = resolve_engine_name(engine, engine_name) + # auth parameter is optional in function signature # but is required to connect. # PEP 249 recommends making it kwargs. @@ -348,6 +366,21 @@ async def connect( raise ConfigurationError(f"Unsupported auth type: {type(auth)}") +async def connect_discovery(config: DiscoveryConnectConfig) -> Connection: + """Connect using the discovery-based Firebolt session model.""" + validate_discovery_connect_config(config) + connection_id = uuid4().hex + discovery_info = await async_discover(**config.discovery_kwargs()) + return make_connection_from_discovery( + discovery_info=discovery_info, + config=config, + connection_id=connection_id, + client_type=AsyncClientV2, + cursor_type=CursorV2, + connection_type=Connection, + ) + + async def connect_v2( auth: Auth, user_agent_header: str, diff --git a/src/firebolt/common/discovery.py b/src/firebolt/common/discovery.py new file mode 100644 index 00000000000..844777a1cdb --- /dev/null +++ b/src/firebolt/common/discovery.py @@ -0,0 +1,443 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from json import JSONDecodeError +from ssl import SSLContext +from typing import Any, Dict, Mapping, Optional, Type, Union +from urllib.parse import urljoin, urlparse + +from httpx import AsyncClient as HttpxAsyncClient +from httpx import Client as HttpxClient +from httpx import Timeout, codes + +from firebolt.client.auth import Auth, FireboltCore +from firebolt.client.auth.base import FireboltAuthVersion +from firebolt.client.constants import DEFAULT_API_URL +from firebolt.common.base_connection import get_user_agent_for_connection +from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS +from firebolt.utils.exception import ConfigurationError, InterfaceError +from firebolt.utils.firebolt_core import get_core_certificate_context +from firebolt.utils.util import parse_url_and_params + +DISCOVERY_PATH = "/.well-known/firebolt" +SSL_MODE_STRICT = "strict" +SSL_MODE_NONE = "none" +SSL_MODES = {SSL_MODE_STRICT, SSL_MODE_NONE} + + +@dataclass(frozen=True) +class DiscoveryConnectionInfo: + engine_url: str + api_endpoint: str + parameters: Dict[str, Any] + verify: Union[SSLContext, bool] + + +@dataclass(frozen=True) +class PreparedDiscoveryConnection: + auth: Auth + user_agent_header: str + + +@dataclass(frozen=True) +class DiscoveryConnectConfig: + host: str + database: Optional[str] = None + engine: Optional[str] = None + engine_name: Optional[str] = None + engine_url: Optional[str] = None + account_name: Optional[str] = None + api_endpoint: str = DEFAULT_API_URL + url: Optional[str] = None + auth: Optional[Auth] = None + ssl_mode: str = SSL_MODE_STRICT + settings: Optional[Dict[str, Any]] = None + autocommit: bool = True + additional_parameters: Dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_connect_kwargs( + cls, + kwargs: Mapping[str, Any], + ) -> "DiscoveryConnectConfig": + return cls( + host=kwargs["host"], + database=kwargs.get("database"), + engine=kwargs.get("engine"), + engine_name=kwargs.get("engine_name"), + engine_url=kwargs.get("engine_url"), + account_name=kwargs.get("account_name"), + api_endpoint=kwargs.get("api_endpoint", DEFAULT_API_URL), + url=kwargs.get("url"), + auth=kwargs.get("auth"), + ssl_mode=kwargs.get("ssl_mode", SSL_MODE_STRICT), + settings=kwargs.get("settings"), + autocommit=kwargs.get("autocommit", True), + additional_parameters=kwargs.get("additional_parameters") or {}, + ) + + def discovery_kwargs(self) -> Dict[str, Any]: + return { + "host": self.host, + "ssl_mode": self.ssl_mode, + "database": self.database, + "engine": self.engine, + "engine_name": self.engine_name, + "settings": self.settings, + } + + +def normalize_ssl_mode(ssl_mode: str) -> str: + mode = ssl_mode.lower() + if mode not in SSL_MODES: + allowed = ", ".join(sorted(SSL_MODES)) + raise ConfigurationError( + f"Invalid ssl_mode: {ssl_mode}. Expected one of: {allowed}." + ) + return mode + + +def normalize_host(host: str, ssl_mode: str) -> str: + """Normalize a discovery host into an HTTP(S) base URL.""" + if not host: + raise ConfigurationError("host is required for discovery-based connections.") + + mode = normalize_ssl_mode(ssl_mode) + default_scheme = "http" if mode == SSL_MODE_NONE else "https" + raw_url = host if "://" in host else f"{default_scheme}://{host}" + parsed = urlparse(raw_url) + + if parsed.scheme not in {"http", "https"}: + raise ConfigurationError( + f"Invalid host scheme: {parsed.scheme}. Expected 'http' or 'https'." + ) + if not parsed.netloc: + raise ConfigurationError( + f"Invalid host: {host}. Expected a hostname, optionally with scheme " + "and port." + ) + if parsed.query or parsed.fragment: + raise ConfigurationError( + "host must not include query parameters or a fragment. " + "Pass connection parameters as connect() arguments instead." + ) + + return f"{parsed.scheme}://{parsed.netloc}{parsed.path.rstrip('/')}" + + +def get_tls_verify(base_url: str, ssl_mode: str) -> Union[SSLContext, bool]: + mode = normalize_ssl_mode(ssl_mode) + if mode == SSL_MODE_NONE: + return False + if urlparse(base_url).scheme == "https": + return get_core_certificate_context() + return True + + +def build_discovery_url(base_url: str) -> str: + return urljoin(base_url.rstrip("/") + "/", DISCOVERY_PATH.lstrip("/")) + + +def resolve_engine_name( + engine: Optional[str], + engine_name: Optional[str], +) -> Optional[str]: + if engine and engine_name and engine != engine_name: + raise ConfigurationError( + "Both engine and engine_name are provided. Provide only one to connect." + ) + return engine_name or engine + + +def validate_discovery_connection_parameters( + account_name: Optional[str], + api_endpoint: str, + engine_url: Optional[str], + url: Optional[str], + auth: Optional[Auth], +) -> None: + if account_name: + raise ConfigurationError( + "account_name is not compatible with discovery-based connections." + ) + if api_endpoint != DEFAULT_API_URL: + raise ConfigurationError( + "api_endpoint is not compatible with discovery-based connections." + ) + if engine_url: + raise ConfigurationError( + "engine_url is not compatible with discovery-based connections." + ) + if url: + raise ConfigurationError( + "url is not compatible with discovery-based connections. Use host instead." + ) + if auth and auth.get_firebolt_version() != FireboltAuthVersion.CORE: + raise ConfigurationError( + "auth is not compatible with discovery-based connections." + ) + + +def validate_discovery_connect_config(config: DiscoveryConnectConfig) -> None: + validate_discovery_connection_parameters( + account_name=config.account_name, + api_endpoint=config.api_endpoint, + engine_url=config.engine_url, + url=config.url, + auth=config.auth, + ) + + +def prepare_discovery_connection( + auth: Optional[Auth], + connection_id: str, + additional_parameters: Dict[str, Any], +) -> PreparedDiscoveryConnection: + core_auth = auth or FireboltCore() + return PreparedDiscoveryConnection( + auth=core_auth, + user_agent_header=get_user_agent_for_connection( + core_auth, + connection_id, + None, + additional_parameters, + True, + ), + ) + + +def make_discovery_client_kwargs( + discovery_info: DiscoveryConnectionInfo, + prepared_connection: PreparedDiscoveryConnection, +) -> Dict[str, Any]: + return { + "auth": prepared_connection.auth, + "account_name": "", + "base_url": discovery_info.engine_url, + "api_endpoint": discovery_info.api_endpoint, + "timeout": Timeout(DEFAULT_TIMEOUT_SECONDS, read=None), + "headers": {"User-Agent": prepared_connection.user_agent_header}, + "verify": discovery_info.verify, + } + + +def make_connection_from_discovery( + discovery_info: DiscoveryConnectionInfo, + config: DiscoveryConnectConfig, + connection_id: str, + client_type: Type, + cursor_type: Type, + connection_type: Type, +) -> Any: + prepared_connection = prepare_discovery_connection( + config.auth, + connection_id, + config.additional_parameters, + ) + client = client_type( + **make_discovery_client_kwargs(discovery_info, prepared_connection) + ) + return connection_type( + engine_url=discovery_info.engine_url, + database=None, + client=client, + cursor_type=cursor_type, + api_endpoint=discovery_info.api_endpoint, + init_parameters=discovery_info.parameters, + id=connection_id, + autocommit=config.autocommit, + ) + + +def _string_value(data: Mapping[str, Any], *keys: str) -> Optional[str]: + for key in keys: + value = data.get(key) + if isinstance(value, str) and value: + return value + return None + + +def _endpoint_from_mapping(data: Mapping[str, Any]) -> Optional[str]: + return _string_value( + data, + "engineUrl", + "engine_url", + "engineEndpoint", + "engine_endpoint", + "queryUrl", + "query_url", + "url", + "endpoint", + ) + + +def _resolve_endpoint(base_url: str, endpoint: str) -> str: + if "://" in endpoint or endpoint.startswith("/"): + return urljoin(base_url.rstrip("/") + "/", endpoint) + parsed_base = urlparse(base_url) + if ":" in endpoint or "." in endpoint: + return f"{parsed_base.scheme}://{endpoint}" + return urljoin(base_url.rstrip("/") + "/", endpoint) + + +def _extract_engine_url(discovery: Mapping[str, Any], base_url: str) -> str: + endpoint = _endpoint_from_mapping(discovery) + if endpoint: + return _resolve_endpoint(base_url, endpoint) + + endpoints = discovery.get("endpoints") + if isinstance(endpoints, Mapping): + for key in ("query", "sql", "engine", "http"): + value = endpoints.get(key) + if isinstance(value, str) and value: + return _resolve_endpoint(base_url, value) + if isinstance(value, Mapping): + endpoint = _endpoint_from_mapping(value) + if endpoint: + return _resolve_endpoint(base_url, endpoint) + + query = discovery.get("query") + if isinstance(query, Mapping): + endpoint = _endpoint_from_mapping(query) + if endpoint: + return _resolve_endpoint(base_url, endpoint) + + return base_url + + +def make_discovery_connection_info( + host: str, + ssl_mode: str, + discovery: Mapping[str, Any], + database: Optional[str] = None, + engine: Optional[str] = None, + engine_name: Optional[str] = None, + settings: Optional[Dict[str, Any]] = None, +) -> DiscoveryConnectionInfo: + base_url = normalize_host(host, ssl_mode) + verify = get_tls_verify(base_url, ssl_mode) + + engine_parameter = resolve_engine_name(engine, engine_name) + + endpoint = _extract_engine_url(discovery, base_url) + endpoint_url, endpoint_params = parse_url_and_params(endpoint) + + parameters: Dict[str, Any] = dict(endpoint_params) + if settings: + parameters.update(settings) + if database: + parameters["database"] = database + if engine_parameter: + parameters["engine"] = engine_parameter + + return DiscoveryConnectionInfo( + engine_url=endpoint_url, + api_endpoint=base_url, + parameters=parameters, + verify=verify, + ) + + +def _decode_discovery_response(response_text: str) -> Mapping[str, Any]: + try: + import json + + decoded = json.loads(response_text) + except JSONDecodeError as e: + raise InterfaceError("Unable to decode Firebolt discovery response.") from e + if not isinstance(decoded, Mapping): + raise InterfaceError("Firebolt discovery response must be a JSON object.") + return decoded + + +def _raise_if_discovery_failed(status_code: int, text: str, discovery_url: str) -> None: + if status_code != codes.OK: + raise InterfaceError( + f"Unable to retrieve Firebolt discovery document {discovery_url}: " + f"{status_code} {text}" + ) + + +def _make_info_from_response( + status_code: int, + text: str, + discovery_url: str, + host: str, + ssl_mode: str, + database: Optional[str], + engine: Optional[str], + engine_name: Optional[str], + settings: Optional[Dict[str, Any]], +) -> DiscoveryConnectionInfo: + _raise_if_discovery_failed(status_code, text, discovery_url) + return make_discovery_connection_info( + host=host, + ssl_mode=ssl_mode, + discovery=_decode_discovery_response(text), + database=database, + engine=engine, + engine_name=engine_name, + settings=settings, + ) + + +def discover( + host: str, + ssl_mode: str, + database: Optional[str] = None, + engine: Optional[str] = None, + engine_name: Optional[str] = None, + settings: Optional[Dict[str, Any]] = None, +) -> DiscoveryConnectionInfo: + base_url = normalize_host(host, ssl_mode) + verify = get_tls_verify(base_url, ssl_mode) + discovery_url = build_discovery_url(base_url) + + with HttpxClient( + verify=verify, + timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), + ) as client: + response = client.get(discovery_url) + + return _make_info_from_response( + status_code=response.status_code, + text=response.text, + discovery_url=discovery_url, + host=host, + ssl_mode=ssl_mode, + database=database, + engine=engine, + engine_name=engine_name, + settings=settings, + ) + + +async def async_discover( + host: str, + ssl_mode: str, + database: Optional[str] = None, + engine: Optional[str] = None, + engine_name: Optional[str] = None, + settings: Optional[Dict[str, Any]] = None, +) -> DiscoveryConnectionInfo: + base_url = normalize_host(host, ssl_mode) + verify = get_tls_verify(base_url, ssl_mode) + discovery_url = build_discovery_url(base_url) + + async with HttpxAsyncClient( + verify=verify, + timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), + ) as client: + response = await client.get(discovery_url) + + return _make_info_from_response( + status_code=response.status_code, + text=response.text, + discovery_url=discovery_url, + host=host, + ssl_mode=ssl_mode, + database=database, + engine=engine, + engine_name=engine_name, + settings=settings, + ) diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index a0cc7a3c5fb..657138b8079 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -26,6 +26,13 @@ set_cached_system_engine_info, ) from firebolt.common.constants import DEFAULT_TIMEOUT_SECONDS +from firebolt.common.discovery import ( + DiscoveryConnectConfig, + discover, + make_connection_from_discovery, + resolve_engine_name, + validate_discovery_connect_config, +) from firebolt.db.cursor import Cursor, CursorV1, CursorV2 from firebolt.utils.cache import EngineInfo from firebolt.utils.exception import ( @@ -54,14 +61,23 @@ def connect( auth: Optional[Auth] = None, account_name: Optional[str] = None, database: Optional[str] = None, + engine: Optional[str] = None, engine_name: Optional[str] = None, engine_url: Optional[str] = None, api_endpoint: str = DEFAULT_API_URL, disable_cache: bool = False, url: Optional[str] = None, + host: Optional[str] = None, + ssl_mode: str = "strict", + settings: Optional[Dict[str, Any]] = None, autocommit: bool = True, additional_parameters: Dict[str, Any] = {}, ) -> Connection: + if host: + return connect_discovery(DiscoveryConnectConfig.from_connect_kwargs(locals())) + + engine_name = resolve_engine_name(engine, engine_name) + # auth parameter is optional in function signature # but is required to connect. # PEP 249 recommends making it kwargs. @@ -118,6 +134,21 @@ def connect( raise ConfigurationError(f"Unsupported auth type: {type(auth)}") +def connect_discovery(config: DiscoveryConnectConfig) -> Connection: + """Connect using the discovery-based Firebolt session model.""" + validate_discovery_connect_config(config) + connection_id = uuid4().hex + discovery_info = discover(**config.discovery_kwargs()) + return make_connection_from_discovery( + discovery_info=discovery_info, + config=config, + connection_id=connection_id, + client_type=ClientV2, + cursor_type=CursorV2, + connection_type=Connection, + ) + + def connect_v2( auth: Auth, user_agent_header: str, diff --git a/tests/integration/dbapi/sync/V2/test_discovery.py b/tests/integration/dbapi/sync/V2/test_discovery.py new file mode 100644 index 00000000000..ded2db56694 --- /dev/null +++ b/tests/integration/dbapi/sync/V2/test_discovery.py @@ -0,0 +1,8 @@ +from firebolt.db import connect + + +def test_core_discovery_connection(core_url: str): + with connect(host=core_url, ssl_mode="none", database="firebolt") as connection: + cursor = connection.cursor() + cursor.execute("SELECT 42") + assert cursor.fetchone()[0] == 42 diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index f7a01db6883..1d871204440 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -17,6 +17,11 @@ FireboltError, ) from firebolt.utils.token_storage import TokenSecureStorage +from tests.unit.discovery_connection_helpers import ( + DISCOVERY_HOST, + DISCOVERY_SETTINGS, + mock_discovery_connection_flow, +) @mark.skip("__slots__ is broken on Connection class") @@ -108,6 +113,26 @@ async def test_connect( assert await connection.cursor().execute("select *") == len(python_query_data) +async def test_connect_discovery( + db_name: str, + engine_name: str, + httpx_mock: HTTPXMock, + query_callback: Callable, + python_query_data: List[List[ColType]], +): + """Discovery connections pass database, engine and settings as query params.""" + mock_discovery_connection_flow(httpx_mock, db_name, engine_name, query_callback) + + async with await connect( + host=DISCOVERY_HOST, + ssl_mode="none", + database=db_name, + engine=engine_name, + settings=DISCOVERY_SETTINGS, + ) as connection: + assert await connection.cursor().execute("select *") == len(python_query_data) + + async def test_connect_database_failed( db_name: str, account_name: str, diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index 3602ebf5e59..64812c50d9f 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -21,6 +21,14 @@ FireboltError, ) from firebolt.utils.token_storage import TokenSecureStorage +from tests.unit.discovery_connection_helpers import ( + DISCOVERY_HOST, + DISCOVERY_SETTINGS, + assert_discovery_lookup_error, + assert_discovery_validation_errors, + mock_discovery_connection_flow, + mock_discovery_not_found, +) def test_connection_attributes(connection: Connection) -> None: @@ -112,6 +120,41 @@ def test_connect( assert connection.cursor().execute("select *") == len(python_query_data) +def test_connect_discovery( + db_name: str, + engine_name: str, + httpx_mock: HTTPXMock, + query_callback: Callable, + python_query_data: List[List[ColType]], +): + """Discovery connections pass database, engine and settings as query params.""" + mock_discovery_connection_flow(httpx_mock, db_name, engine_name, query_callback) + + with connect( + host=DISCOVERY_HOST, + ssl_mode="none", + database=db_name, + engine=engine_name, + settings=DISCOVERY_SETTINGS, + ) as connection: + assert connection.cursor().execute("select *") == len(python_query_data) + + +def test_connect_discovery_rejects_legacy_parameters(auth: Auth): + def connect_discovery(**kwargs): + return connect(host=DISCOVERY_HOST, ssl_mode="none", **kwargs) + + assert_discovery_validation_errors(connect_discovery, auth) + + +def test_connect_discovery_validation(httpx_mock: HTTPXMock): + with raises(ConfigurationError, match="ssl_mode"): + connect(host=DISCOVERY_HOST, ssl_mode="invalid") + + mock_discovery_not_found(httpx_mock) + assert_discovery_lookup_error(lambda: connect(host=DISCOVERY_HOST, ssl_mode="none")) + + def test_connect_database_failed( db_name: str, account_name: str, diff --git a/tests/unit/discovery_connection_helpers.py b/tests/unit/discovery_connection_helpers.py new file mode 100644 index 00000000000..1cf1bcf3e7b --- /dev/null +++ b/tests/unit/discovery_connection_helpers.py @@ -0,0 +1,64 @@ +from typing import Callable + +from httpx import Request, codes +from pytest import raises +from pytest_httpx import HTTPXMock + +from firebolt.client.auth import Auth +from firebolt.utils.exception import ConfigurationError, InterfaceError +from tests.unit.response import Response + +DISCOVERY_HOST = "localhost:3473" +DISCOVERY_URL = f"http://{DISCOVERY_HOST}/.well-known/firebolt" +DISCOVERY_SETTINGS = {"custom_setting": "custom_value"} + + +def mock_discovery_connection_flow( + httpx_mock: HTTPXMock, + db_name: str, + engine_name: str, + query_callback: Callable, +) -> None: + httpx_mock.add_response( + method="GET", + url=DISCOVERY_URL, + json={"engineUrl": f"{DISCOVERY_HOST}/?discovered_param=value"}, + ) + + def query_with_discovery_params(request: Request, **kwargs) -> Response: + params = dict(request.url.params) + assert "authorization" not in request.headers + assert params["database"] == db_name + assert params["engine"] == engine_name + assert params["custom_setting"] == DISCOVERY_SETTINGS["custom_setting"] + assert params["discovered_param"] == "value" + return query_callback(request, **kwargs) + + httpx_mock.add_callback(query_with_discovery_params, method="POST") + + +def assert_discovery_validation_errors(connect_call: Callable, auth: Auth) -> None: + with raises(ConfigurationError, match="account_name"): + connect_call(account_name="account") + with raises(ConfigurationError, match="api_endpoint"): + connect_call(api_endpoint="api.example.com") + with raises(ConfigurationError, match="engine_url"): + connect_call(engine_url="engine.example.com") + with raises(ConfigurationError, match="url"): + connect_call(url=f"http://{DISCOVERY_HOST}") + with raises(ConfigurationError, match="auth"): + connect_call(auth=auth) + + +def mock_discovery_not_found(httpx_mock: HTTPXMock) -> None: + httpx_mock.add_response( + method="GET", + url=DISCOVERY_URL, + status_code=codes.NOT_FOUND, + text="not found", + ) + + +def assert_discovery_lookup_error(connect_call: Callable) -> None: + with raises(InterfaceError, match="Unable to retrieve Firebolt discovery"): + connect_call()