Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .idea/python-libraries.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions database/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Changelog

## [4.2.0] - 2026-06-17

- Add `template_database` context manager to create a database with an existing
template.
- Rename `temp_database` -> `temporary_database` to align names with
`template_database`.
- Reorganize imports for database query utilities

## [4.1.1] - 2026-06-12

- Fix a bug in with escape sequences.
Expand Down
5 changes: 1 addition & 4 deletions database/macrostrat/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from sqlalchemy.sql.expression import Insert

from macrostrat.utils import get_logger

from .mapper import DatabaseMapper
from .postgresql import on_conflict, prefix_inserts # noqa
from .query import run_fixtures, run_query, run_sql, execute # noqa
from .utils import ( # noqa
create_database,
create_engine,
Expand All @@ -23,9 +23,6 @@
get_dataframe,
get_or_create,
reflect_table,
run_fixtures,
run_query,
run_sql,
)

metadata = MetaData()
Expand Down
24 changes: 24 additions & 0 deletions database/macrostrat/database/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,3 +595,27 @@ def run_sql(*args, **kwargs):
if kwargs.pop("yield_results", False):
return res
return list(res)


def execute(connectable, sql, params=None, stop_on_error=False, **kwargs):
output_file = kwargs.pop("output_file", None)
output_mode = kwargs.pop("output_mode", None)
sql = format(sql, strip_comments=True).strip()
if sql == "":
return
try:
connectable.begin()
res = connectable.execute(text(sql), params=params)
if hasattr(connectable, "commit"):
connectable.commit()
pretty_print(sql, dim=True, file=output_file, mode=output_mode)
return res
except (ProgrammingError, IntegrityError) as err:
if hasattr(connectable, "rollback"):
connectable.rollback()
_print_error(sql, dim=True, file=output_file, mode=output_mode)
if stop_on_error:
return
finally:
if hasattr(connectable, "close"):
connectable.close()
66 changes: 36 additions & 30 deletions database/macrostrat/database/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import warnings
from contextlib import contextmanager
from time import sleep
from uuid import uuid4

from click import echo
from sqlalchemy import MetaData
Expand All @@ -8,19 +10,16 @@
from sqlalchemy.engine import Engine
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import (
IntegrityError,
OperationalError,
ProgrammingError,
)
from sqlalchemy.orm import sessionmaker
from sqlalchemy.schema import Table
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy_utils import create_database as _create_database
from sqlalchemy_utils import database_exists, drop_database
from sqlparse import format

from macrostrat.utils import cmd, get_logger
from .query import get_sql_text
from .query import get_sql_text, execute # noqa

log = get_logger(__name__)

Expand All @@ -46,30 +45,6 @@ def db_session(engine):
return factory()


def execute(connectable, sql, params=None, stop_on_error=False, **kwargs):
output_file = kwargs.pop("output_file", None)
output_mode = kwargs.pop("output_mode", None)
sql = format(sql, strip_comments=True).strip()
if sql == "":
return
try:
connectable.begin()
res = connectable.execute(text(sql), params=params)
if hasattr(connectable, "commit"):
connectable.commit()
pretty_print(sql, dim=True, file=output_file, mode=output_mode)
return res
except (ProgrammingError, IntegrityError) as err:
if hasattr(connectable, "rollback"):
connectable.rollback()
_print_error(sql, dim=True, file=output_file, mode=output_mode)
if stop_on_error:
return
finally:
if hasattr(connectable, "close"):
connectable.close()


def get_or_create(session, model, defaults=None, **kwargs):
"""
Get an instance of a model, or create it if it doesn't
Expand Down Expand Up @@ -97,9 +72,13 @@ def get_db_model(db, model_name: str):


@contextmanager
def temp_database(conn_string, drop=True, ensure_empty=False):
def temporary_database(
conn_string, *, drop=True, ensure_empty=False, exists_ok=True, template=None
):
"""Create a temporary database and tear it down after tests."""
create_database(conn_string, exists_ok=True, replace=ensure_empty)
create_database(
conn_string, exists_ok=exists_ok, replace=ensure_empty, template=template
)
try:
engine = create_engine(conn_string)
yield engine
Expand All @@ -109,6 +88,33 @@ def temp_database(conn_string, drop=True, ensure_empty=False):
drop_database(conn_string)


@contextmanager
def temp_database(*args, **kwargs):
warnings.warn(
"temp_database is deprecated, use temporary_database instead",
DeprecationWarning,
)
with temporary_database(*args, **kwargs) as engine:
yield engine


@contextmanager
def template_database(engine: Engine, *, name: str = None):
"""Create a temporary template database using an existing database as a template."""
db_name = engine.url.database
template_db_name = name
if name is None:
uid = str(uuid4())[:8]
template_db_name = db_name + "_template_" + uid
# Close connection to the database so we can create a new one based on the template
new_db_url = engine.url.set(database=template_db_name)
engine.dispose()
with temporary_database(
new_db_url, drop=True, exists_ok=False, template=db_name
) as engine:
yield engine


def create_database(url, **kwargs):
"""Create a database if it doesn't exist.

Expand Down
2 changes: 1 addition & 1 deletion database/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "macrostrat.database"
version = "4.1.1"
version = "4.2.0"
description = "A SQLAlchemy-based database toolkit."
authors = [{ name = "Daven Quinn", email = "dev@davenquinn.com" }]
requires-python = ">=3.10,<4"
Expand Down
11 changes: 10 additions & 1 deletion database/tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
infer_is_sql_text,
run_fixtures,
)
from macrostrat.database.utils import temp_database
from macrostrat.database.utils import temp_database, template_database
from macrostrat.utils import get_logger, relative_path

load_dotenv()
Expand Down Expand Up @@ -76,6 +76,15 @@ def test_database(db):
assert "geology_formation" in db.model


def test_template_database(db):
with template_database(db.engine) as engine:
assert engine.url.database != db.engine.url.database
db1 = Database(engine)
db.automap(schemas=["public", "geology"])
assert "sample" in db.model
assert "geology_formation" in db.model


def test_database_mapper(db):
Sample = db.model.sample
Formation = db.model.geology_formation
Expand Down
2 changes: 1 addition & 1 deletion database/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading