diff --git a/.idea/python-libraries.iml b/.idea/python-libraries.iml
index ca971af..e576e58 100644
--- a/.idea/python-libraries.iml
+++ b/.idea/python-libraries.iml
@@ -17,6 +17,7 @@
+
diff --git a/database/CHANGELOG.md b/database/CHANGELOG.md
index 1979cde..cb72f06 100644
--- a/database/CHANGELOG.md
+++ b/database/CHANGELOG.md
@@ -1,5 +1,13 @@
# Changelog
+## [4.2.0] - 2026-06-17
+
+- Add `template_database` context manager to create a database with an existing
+ template.
+- Rename `temp_database` -> `temporary_database` to align names with
+ `template_database`.
+- Reorganize imports for database query utilities
+
## [4.1.1] - 2026-06-12
- Fix a bug in with escape sequences.
diff --git a/database/macrostrat/database/__init__.py b/database/macrostrat/database/__init__.py
index 6ea055a..af068da 100644
--- a/database/macrostrat/database/__init__.py
+++ b/database/macrostrat/database/__init__.py
@@ -12,9 +12,9 @@
from sqlalchemy.sql.expression import Insert
from macrostrat.utils import get_logger
-
from .mapper import DatabaseMapper
from .postgresql import on_conflict, prefix_inserts # noqa
+from .query import run_fixtures, run_query, run_sql, execute # noqa
from .utils import ( # noqa
create_database,
create_engine,
@@ -23,9 +23,6 @@
get_dataframe,
get_or_create,
reflect_table,
- run_fixtures,
- run_query,
- run_sql,
)
metadata = MetaData()
diff --git a/database/macrostrat/database/query.py b/database/macrostrat/database/query.py
index 8c7a02b..dc1c7b9 100644
--- a/database/macrostrat/database/query.py
+++ b/database/macrostrat/database/query.py
@@ -595,3 +595,27 @@ def run_sql(*args, **kwargs):
if kwargs.pop("yield_results", False):
return res
return list(res)
+
+
+def execute(connectable, sql, params=None, stop_on_error=False, **kwargs):
+ output_file = kwargs.pop("output_file", None)
+ output_mode = kwargs.pop("output_mode", None)
+ sql = format(sql, strip_comments=True).strip()
+ if sql == "":
+ return
+ try:
+ connectable.begin()
+ res = connectable.execute(text(sql), params=params)
+ if hasattr(connectable, "commit"):
+ connectable.commit()
+ pretty_print(sql, dim=True, file=output_file, mode=output_mode)
+ return res
+ except (ProgrammingError, IntegrityError) as err:
+ if hasattr(connectable, "rollback"):
+ connectable.rollback()
+ _print_error(sql, dim=True, file=output_file, mode=output_mode)
+ if stop_on_error:
+ return
+ finally:
+ if hasattr(connectable, "close"):
+ connectable.close()
diff --git a/database/macrostrat/database/utils.py b/database/macrostrat/database/utils.py
index 854c204..5512c1d 100644
--- a/database/macrostrat/database/utils.py
+++ b/database/macrostrat/database/utils.py
@@ -1,5 +1,7 @@
+import warnings
from contextlib import contextmanager
from time import sleep
+from uuid import uuid4
from click import echo
from sqlalchemy import MetaData
@@ -8,19 +10,16 @@
from sqlalchemy.engine import Engine
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import (
- IntegrityError,
OperationalError,
- ProgrammingError,
)
from sqlalchemy.orm import sessionmaker
from sqlalchemy.schema import Table
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy_utils import create_database as _create_database
from sqlalchemy_utils import database_exists, drop_database
-from sqlparse import format
from macrostrat.utils import cmd, get_logger
-from .query import get_sql_text
+from .query import get_sql_text, execute # noqa
log = get_logger(__name__)
@@ -46,30 +45,6 @@ def db_session(engine):
return factory()
-def execute(connectable, sql, params=None, stop_on_error=False, **kwargs):
- output_file = kwargs.pop("output_file", None)
- output_mode = kwargs.pop("output_mode", None)
- sql = format(sql, strip_comments=True).strip()
- if sql == "":
- return
- try:
- connectable.begin()
- res = connectable.execute(text(sql), params=params)
- if hasattr(connectable, "commit"):
- connectable.commit()
- pretty_print(sql, dim=True, file=output_file, mode=output_mode)
- return res
- except (ProgrammingError, IntegrityError) as err:
- if hasattr(connectable, "rollback"):
- connectable.rollback()
- _print_error(sql, dim=True, file=output_file, mode=output_mode)
- if stop_on_error:
- return
- finally:
- if hasattr(connectable, "close"):
- connectable.close()
-
-
def get_or_create(session, model, defaults=None, **kwargs):
"""
Get an instance of a model, or create it if it doesn't
@@ -97,9 +72,13 @@ def get_db_model(db, model_name: str):
@contextmanager
-def temp_database(conn_string, drop=True, ensure_empty=False):
+def temporary_database(
+ conn_string, *, drop=True, ensure_empty=False, exists_ok=True, template=None
+):
"""Create a temporary database and tear it down after tests."""
- create_database(conn_string, exists_ok=True, replace=ensure_empty)
+ create_database(
+ conn_string, exists_ok=exists_ok, replace=ensure_empty, template=template
+ )
try:
engine = create_engine(conn_string)
yield engine
@@ -109,6 +88,33 @@ def temp_database(conn_string, drop=True, ensure_empty=False):
drop_database(conn_string)
+@contextmanager
+def temp_database(*args, **kwargs):
+ warnings.warn(
+ "temp_database is deprecated, use temporary_database instead",
+ DeprecationWarning,
+ )
+ with temporary_database(*args, **kwargs) as engine:
+ yield engine
+
+
+@contextmanager
+def template_database(engine: Engine, *, name: str = None):
+ """Create a temporary template database using an existing database as a template."""
+ db_name = engine.url.database
+ template_db_name = name
+ if name is None:
+ uid = str(uuid4())[:8]
+ template_db_name = db_name + "_template_" + uid
+ # Close connection to the database so we can create a new one based on the template
+ new_db_url = engine.url.set(database=template_db_name)
+ engine.dispose()
+ with temporary_database(
+ new_db_url, drop=True, exists_ok=False, template=db_name
+ ) as engine:
+ yield engine
+
+
def create_database(url, **kwargs):
"""Create a database if it doesn't exist.
diff --git a/database/pyproject.toml b/database/pyproject.toml
index f71cbc4..ccceb88 100644
--- a/database/pyproject.toml
+++ b/database/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "macrostrat.database"
-version = "4.1.1"
+version = "4.2.0"
description = "A SQLAlchemy-based database toolkit."
authors = [{ name = "Daven Quinn", email = "dev@davenquinn.com" }]
requires-python = ">=3.10,<4"
diff --git a/database/tests/test_database.py b/database/tests/test_database.py
index 47d5508..7017a12 100644
--- a/database/tests/test_database.py
+++ b/database/tests/test_database.py
@@ -27,7 +27,7 @@
infer_is_sql_text,
run_fixtures,
)
-from macrostrat.database.utils import temp_database
+from macrostrat.database.utils import temp_database, template_database
from macrostrat.utils import get_logger, relative_path
load_dotenv()
@@ -76,6 +76,15 @@ def test_database(db):
assert "geology_formation" in db.model
+def test_template_database(db):
+ with template_database(db.engine) as engine:
+ assert engine.url.database != db.engine.url.database
+ db1 = Database(engine)
+ db.automap(schemas=["public", "geology"])
+ assert "sample" in db.model
+ assert "geology_formation" in db.model
+
+
def test_database_mapper(db):
Sample = db.model.sample
Formation = db.model.geology_formation
diff --git a/database/uv.lock b/database/uv.lock
index 444d2cc..a377355 100644
--- a/database/uv.lock
+++ b/database/uv.lock
@@ -115,7 +115,7 @@ wheels = [
[[package]]
name = "macrostrat-database"
-version = "4.1.1"
+version = "4.2.0"
source = { editable = "." }
dependencies = [
{ name = "aiofiles" },
diff --git a/uv.lock b/uv.lock
index 8dc3bee..a8bf920 100644
--- a/uv.lock
+++ b/uv.lock
@@ -647,7 +647,7 @@ dev = [
[[package]]
name = "macrostrat-database"
-version = "4.1.1"
+version = "4.2.0"
source = { editable = "database" }
dependencies = [
{ name = "aiofiles" },