diff --git a/.idea/python-libraries.iml b/.idea/python-libraries.iml index ca971af..e576e58 100644 --- a/.idea/python-libraries.iml +++ b/.idea/python-libraries.iml @@ -17,6 +17,7 @@ + diff --git a/database/CHANGELOG.md b/database/CHANGELOG.md index 1979cde..cb72f06 100644 --- a/database/CHANGELOG.md +++ b/database/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## [4.2.0] - 2026-06-17 + +- Add `template_database` context manager to create a database with an existing + template. +- Rename `temp_database` -> `temporary_database` to align names with + `template_database`. +- Reorganize imports for database query utilities + ## [4.1.1] - 2026-06-12 - Fix a bug in with escape sequences. diff --git a/database/macrostrat/database/__init__.py b/database/macrostrat/database/__init__.py index 6ea055a..af068da 100644 --- a/database/macrostrat/database/__init__.py +++ b/database/macrostrat/database/__init__.py @@ -12,9 +12,9 @@ from sqlalchemy.sql.expression import Insert from macrostrat.utils import get_logger - from .mapper import DatabaseMapper from .postgresql import on_conflict, prefix_inserts # noqa +from .query import run_fixtures, run_query, run_sql, execute # noqa from .utils import ( # noqa create_database, create_engine, @@ -23,9 +23,6 @@ get_dataframe, get_or_create, reflect_table, - run_fixtures, - run_query, - run_sql, ) metadata = MetaData() diff --git a/database/macrostrat/database/query.py b/database/macrostrat/database/query.py index 8c7a02b..dc1c7b9 100644 --- a/database/macrostrat/database/query.py +++ b/database/macrostrat/database/query.py @@ -595,3 +595,27 @@ def run_sql(*args, **kwargs): if kwargs.pop("yield_results", False): return res return list(res) + + +def execute(connectable, sql, params=None, stop_on_error=False, **kwargs): + output_file = kwargs.pop("output_file", None) + output_mode = kwargs.pop("output_mode", None) + sql = format(sql, strip_comments=True).strip() + if sql == "": + return + try: + connectable.begin() + res = connectable.execute(text(sql), params=params) + if hasattr(connectable, "commit"): + connectable.commit() + pretty_print(sql, dim=True, file=output_file, mode=output_mode) + return res + except (ProgrammingError, IntegrityError) as err: + if hasattr(connectable, "rollback"): + connectable.rollback() + _print_error(sql, dim=True, file=output_file, mode=output_mode) + if stop_on_error: + return + finally: + if hasattr(connectable, "close"): + connectable.close() diff --git a/database/macrostrat/database/utils.py b/database/macrostrat/database/utils.py index 854c204..5512c1d 100644 --- a/database/macrostrat/database/utils.py +++ b/database/macrostrat/database/utils.py @@ -1,5 +1,7 @@ +import warnings from contextlib import contextmanager from time import sleep +from uuid import uuid4 from click import echo from sqlalchemy import MetaData @@ -8,19 +10,16 @@ from sqlalchemy.engine import Engine from sqlalchemy.engine.url import make_url from sqlalchemy.exc import ( - IntegrityError, OperationalError, - ProgrammingError, ) from sqlalchemy.orm import sessionmaker from sqlalchemy.schema import Table from sqlalchemy.sql.elements import ClauseElement from sqlalchemy_utils import create_database as _create_database from sqlalchemy_utils import database_exists, drop_database -from sqlparse import format from macrostrat.utils import cmd, get_logger -from .query import get_sql_text +from .query import get_sql_text, execute # noqa log = get_logger(__name__) @@ -46,30 +45,6 @@ def db_session(engine): return factory() -def execute(connectable, sql, params=None, stop_on_error=False, **kwargs): - output_file = kwargs.pop("output_file", None) - output_mode = kwargs.pop("output_mode", None) - sql = format(sql, strip_comments=True).strip() - if sql == "": - return - try: - connectable.begin() - res = connectable.execute(text(sql), params=params) - if hasattr(connectable, "commit"): - connectable.commit() - pretty_print(sql, dim=True, file=output_file, mode=output_mode) - return res - except (ProgrammingError, IntegrityError) as err: - if hasattr(connectable, "rollback"): - connectable.rollback() - _print_error(sql, dim=True, file=output_file, mode=output_mode) - if stop_on_error: - return - finally: - if hasattr(connectable, "close"): - connectable.close() - - def get_or_create(session, model, defaults=None, **kwargs): """ Get an instance of a model, or create it if it doesn't @@ -97,9 +72,13 @@ def get_db_model(db, model_name: str): @contextmanager -def temp_database(conn_string, drop=True, ensure_empty=False): +def temporary_database( + conn_string, *, drop=True, ensure_empty=False, exists_ok=True, template=None +): """Create a temporary database and tear it down after tests.""" - create_database(conn_string, exists_ok=True, replace=ensure_empty) + create_database( + conn_string, exists_ok=exists_ok, replace=ensure_empty, template=template + ) try: engine = create_engine(conn_string) yield engine @@ -109,6 +88,33 @@ def temp_database(conn_string, drop=True, ensure_empty=False): drop_database(conn_string) +@contextmanager +def temp_database(*args, **kwargs): + warnings.warn( + "temp_database is deprecated, use temporary_database instead", + DeprecationWarning, + ) + with temporary_database(*args, **kwargs) as engine: + yield engine + + +@contextmanager +def template_database(engine: Engine, *, name: str = None): + """Create a temporary template database using an existing database as a template.""" + db_name = engine.url.database + template_db_name = name + if name is None: + uid = str(uuid4())[:8] + template_db_name = db_name + "_template_" + uid + # Close connection to the database so we can create a new one based on the template + new_db_url = engine.url.set(database=template_db_name) + engine.dispose() + with temporary_database( + new_db_url, drop=True, exists_ok=False, template=db_name + ) as engine: + yield engine + + def create_database(url, **kwargs): """Create a database if it doesn't exist. diff --git a/database/pyproject.toml b/database/pyproject.toml index f71cbc4..ccceb88 100644 --- a/database/pyproject.toml +++ b/database/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "macrostrat.database" -version = "4.1.1" +version = "4.2.0" description = "A SQLAlchemy-based database toolkit." authors = [{ name = "Daven Quinn", email = "dev@davenquinn.com" }] requires-python = ">=3.10,<4" diff --git a/database/tests/test_database.py b/database/tests/test_database.py index 47d5508..7017a12 100644 --- a/database/tests/test_database.py +++ b/database/tests/test_database.py @@ -27,7 +27,7 @@ infer_is_sql_text, run_fixtures, ) -from macrostrat.database.utils import temp_database +from macrostrat.database.utils import temp_database, template_database from macrostrat.utils import get_logger, relative_path load_dotenv() @@ -76,6 +76,15 @@ def test_database(db): assert "geology_formation" in db.model +def test_template_database(db): + with template_database(db.engine) as engine: + assert engine.url.database != db.engine.url.database + db1 = Database(engine) + db.automap(schemas=["public", "geology"]) + assert "sample" in db.model + assert "geology_formation" in db.model + + def test_database_mapper(db): Sample = db.model.sample Formation = db.model.geology_formation diff --git a/database/uv.lock b/database/uv.lock index 444d2cc..a377355 100644 --- a/database/uv.lock +++ b/database/uv.lock @@ -115,7 +115,7 @@ wheels = [ [[package]] name = "macrostrat-database" -version = "4.1.1" +version = "4.2.0" source = { editable = "." } dependencies = [ { name = "aiofiles" }, diff --git a/uv.lock b/uv.lock index 8dc3bee..a8bf920 100644 --- a/uv.lock +++ b/uv.lock @@ -647,7 +647,7 @@ dev = [ [[package]] name = "macrostrat-database" -version = "4.1.1" +version = "4.2.0" source = { editable = "database" } dependencies = [ { name = "aiofiles" },