69 lines
2.0 KiB
Python
69 lines
2.0 KiB
Python
import os
|
|
import sqlite3 as sql
|
|
from pathlib import Path
|
|
from typing import List
|
|
|
|
from src import DATABASE_DIR, settings
|
|
from src.shared.logging import log
|
|
|
|
MIGRATIONS_DIR = Path(__file__).parent / "migrations"
|
|
|
|
|
|
def _ensure_migrations_table(conn: sql.Connection) -> None:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS schema_migrations (
|
|
id TEXT PRIMARY KEY,
|
|
applied_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
"""
|
|
)
|
|
conn.commit()
|
|
|
|
|
|
def _applied_migrations(conn: sql.Connection) -> List[str]:
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT id FROM schema_migrations ORDER BY id")
|
|
rows = cursor.fetchall()
|
|
return [r[0] for r in rows]
|
|
|
|
|
|
def _apply_sql_file(conn: sql.Connection, path: Path) -> None:
|
|
log.info(f"Applying migration {path.name}")
|
|
sql_text = path.read_text(encoding="utf-8")
|
|
cursor = conn.cursor()
|
|
cursor.executescript(sql_text)
|
|
cursor.execute(
|
|
"INSERT OR REPLACE INTO schema_migrations (id) VALUES (?)", (path.name,)
|
|
)
|
|
conn.commit()
|
|
|
|
|
|
def run_migrations(db_path: Path) -> None:
|
|
"""Run all unapplied migrations from the migrations directory against the database at db_path."""
|
|
if not MIGRATIONS_DIR.exists():
|
|
log.debug("Migrations directory does not exist, skipping migrations")
|
|
return
|
|
|
|
# Ensure database directory exists
|
|
db_dir = settings.database.path or Path(DATABASE_DIR)
|
|
if not db_dir.exists():
|
|
os.makedirs(db_dir, exist_ok=True)
|
|
|
|
conn = sql.connect(db_path)
|
|
try:
|
|
_ensure_migrations_table(conn)
|
|
applied = set(_applied_migrations(conn))
|
|
|
|
migration_files = sorted(
|
|
[p for p in MIGRATIONS_DIR.iterdir() if p.suffix in (".sql",)]
|
|
)
|
|
for m in migration_files:
|
|
if m.name in applied:
|
|
log.debug(f"Skipping already applied migration {m.name}")
|
|
continue
|
|
_apply_sql_file(conn, m)
|
|
finally:
|
|
conn.close()
|