chore(all): run formatting on repo, start work on porting webrequest over to api library

This commit is contained in:
2025-11-27 14:29:33 +01:00
parent 04010815a9
commit 539e1331a0
10 changed files with 925 additions and 233 deletions

View File

@@ -1,8 +1,9 @@
import re
import time
import xml.etree.ElementTree as ET
from collections.abc import Iterable
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
from typing import Any
import requests
from requests.adapters import HTTPAdapter
@@ -24,7 +25,7 @@ MARC = "http://www.loc.gov/MARC21/slim"
NS = {"zs": ZS, "marc": MARC}
def _text(elem: Optional[ET.Element]) -> str:
def _text(elem: ET.Element | None) -> str:
return (elem.text or "") if elem is not None else ""
@@ -36,32 +37,31 @@ def _req_text(parent: ET.Element, path: str) -> str:
def parse_marc_record(record_el: ET.Element) -> MarcRecord:
"""
record_el is the <marc:record> element (default ns MARC in your sample)
"""record_el is the <marc:record> element (default ns MARC in your sample)
"""
# leader
leader_text = _req_text(record_el, "marc:leader")
# controlfields
controlfields: List[ControlField] = []
controlfields: list[ControlField] = []
for cf in record_el.findall("marc:controlfield", NS):
tag = cf.get("tag", "").strip()
controlfields.append(ControlField(tag=tag, value=_text(cf)))
# datafields
datafields: List[DataField] = []
datafields: list[DataField] = []
for df in record_el.findall("marc:datafield", NS):
tag = df.get("tag", "").strip()
ind1 = df.get("ind1") or " "
ind2 = df.get("ind2") or " "
subfields: List[SubField] = []
subfields: list[SubField] = []
for sf in df.findall("marc:subfield", NS):
code = sf.get("code", "")
subfields.append(SubField(code=code, value=_text(sf)))
datafields.append(DataField(tag=tag, ind1=ind1, ind2=ind2, subfields=subfields))
return MarcRecord(
leader=leader_text, controlfields=controlfields, datafields=datafields
leader=leader_text, controlfields=controlfields, datafields=datafields,
)
@@ -92,7 +92,7 @@ def parse_record(zs_record_el: ET.Element) -> Record:
)
def parse_echoed_request(root: ET.Element) -> Optional[EchoedSearchRequest]:
def parse_echoed_request(root: ET.Element) -> EchoedSearchRequest | None:
el = root.find("zs:echoedSearchRetrieveRequest", NS)
if el is None:
return None
@@ -119,7 +119,7 @@ def parse_echoed_request(root: ET.Element) -> Optional[EchoedSearchRequest]:
def parse_search_retrieve_response(
xml_str: Union[str, bytes],
xml_str: str | bytes,
) -> SearchRetrieveResponse:
root = ET.fromstring(xml_str)
@@ -128,7 +128,7 @@ def parse_search_retrieve_response(
numberOfRecords = int(_req_text(root, "zs:numberOfRecords") or "0")
records_parent = root.find("zs:records", NS)
records: List[Record] = []
records: list[Record] = []
if records_parent is not None:
for r in records_parent.findall("zs:record", NS):
record = parse_record(r)
@@ -150,9 +150,9 @@ def parse_search_retrieve_response(
def iter_datafields(
rec: MarcRecord,
tag: Optional[str] = None,
ind1: Optional[str] = None,
ind2: Optional[str] = None,
tag: str | None = None,
ind1: str | None = None,
ind2: str | None = None,
) -> Iterable[DataField]:
"""Yield datafields, optionally filtered by tag/indicators."""
for df in rec.datafields:
@@ -170,11 +170,11 @@ def subfield_values(
tag: str,
code: str,
*,
ind1: Optional[str] = None,
ind2: Optional[str] = None,
) -> List[str]:
ind1: str | None = None,
ind2: str | None = None,
) -> list[str]:
"""All values for subfield `code` in every `tag` field (respecting indicators)."""
out: List[str] = []
out: list[str] = []
for df in iter_datafields(rec, tag, ind1, ind2):
out.extend(sf.value for sf in df.subfields if sf.code == code)
return out
@@ -185,10 +185,10 @@ def first_subfield_value(
tag: str,
code: str,
*,
ind1: Optional[str] = None,
ind2: Optional[str] = None,
default: Optional[str] = None,
) -> Optional[str]:
ind1: str | None = None,
ind2: str | None = None,
default: str | None = None,
) -> str | None:
"""First value for subfield `code` in `tag` (respecting indicators)."""
for df in iter_datafields(rec, tag, ind1, ind2):
for sf in df.subfields:
@@ -201,25 +201,24 @@ def find_datafields_with_subfields(
rec: MarcRecord,
tag: str,
*,
where_all: Optional[Dict[str, str]] = None,
where_any: Optional[Dict[str, str]] = None,
where_all: dict[str, str] | None = None,
where_any: dict[str, str] | None = None,
casefold: bool = False,
ind1: Optional[str] = None,
ind2: Optional[str] = None,
) -> List[DataField]:
"""
Return datafields of `tag` whose subfields match constraints:
ind1: str | None = None,
ind2: str | None = None,
) -> list[DataField]:
"""Return datafields of `tag` whose subfields match constraints:
- where_all: every (code -> exact value) must be present
- where_any: at least one (code -> exact value) present
Set `casefold=True` for case-insensitive comparison.
"""
where_all = where_all or {}
where_any = where_any or {}
matched: List[DataField] = []
matched: list[DataField] = []
for df in iter_datafields(rec, tag, ind1, ind2):
# Map code -> list of values (with optional casefold applied)
vals: Dict[str, List[str]] = {}
vals: dict[str, list[str]] = {}
for sf in df.subfields:
v = sf.value.casefold() if casefold else sf.value
vals.setdefault(sf.code, []).append(v)
@@ -246,8 +245,8 @@ def find_datafields_with_subfields(
def controlfield_value(
rec: MarcRecord, tag: str, default: Optional[str] = None
) -> Optional[str]:
rec: MarcRecord, tag: str, default: str | None = None,
) -> str | None:
"""Get the first controlfield value by tag (e.g., '001', '005')."""
for cf in rec.controlfields:
if cf.tag == tag:
@@ -256,8 +255,8 @@ def controlfield_value(
def datafields_value(
data: List[DataField], code: str, default: Optional[str] = None
) -> Optional[str]:
data: list[DataField], code: str, default: str | None = None,
) -> str | None:
"""Get the first value for a specific subfield code in a list of datafields."""
for df in data:
for sf in df.subfields:
@@ -267,8 +266,8 @@ def datafields_value(
def datafield_value(
df: DataField, code: str, default: Optional[str] = None
) -> Optional[str]:
df: DataField, code: str, default: str | None = None,
) -> str | None:
"""Get the first value for a specific subfield code in a datafield."""
for sf in df.subfields:
if sf.code == code:
@@ -276,9 +275,8 @@ def datafield_value(
return default
def _smart_join_title(a: str, b: Optional[str]) -> str:
"""
Join 245 $a and $b with MARC-style punctuation.
def _smart_join_title(a: str, b: str | None) -> str:
"""Join 245 $a and $b with MARC-style punctuation.
If $b is present, join with ' : ' unless either side already supplies punctuation.
"""
a = a.strip()
@@ -293,7 +291,7 @@ def _smart_join_title(a: str, b: Optional[str]) -> str:
def subfield_values_from_fields(
fields: Iterable[DataField],
code: str,
) -> List[str]:
) -> list[str]:
"""All subfield values with given `code` across a list of DataField."""
return [sf.value for df in fields for sf in df.subfields if sf.code == code]
@@ -301,8 +299,8 @@ def subfield_values_from_fields(
def first_subfield_value_from_fields(
fields: Iterable[DataField],
code: str,
default: Optional[str] = None,
) -> Optional[str]:
default: str | None = None,
) -> str | None:
"""First subfield value with given `code` across a list of DataField."""
for df in fields:
for sf in df.subfields:
@@ -314,12 +312,11 @@ def first_subfield_value_from_fields(
def subfield_value_pairs_from_fields(
fields: Iterable[DataField],
code: str,
) -> List[Tuple[DataField, str]]:
"""
Return (DataField, value) pairs for all subfields with `code`.
) -> list[tuple[DataField, str]]:
"""Return (DataField, value) pairs for all subfields with `code`.
Useful if you need to know which field a value came from.
"""
out: List[Tuple[DataField, str]] = []
out: list[tuple[DataField, str]] = []
for df in fields:
for sf in df.subfields:
if sf.code == code:
@@ -340,13 +337,13 @@ def book_from_marc(rec: MarcRecord, library_identifier: str) -> BookData:
# Signature = 924 where $9 == "Frei 129" → take that field's $g
frei_fields = find_datafields_with_subfields(
rec, "924", where_all={"9": "Frei 129"}
rec, "924", where_all={"9": "Frei 129"},
)
signature = first_subfield_value_from_fields(frei_fields, "g")
# Year = 264 $c (prefer ind2="1" publication; fallback to any 264)
year = first_subfield_value(rec, "264", "c", ind2="1") or first_subfield_value(
rec, "264", "c"
rec, "264", "c",
)
isbn = subfield_values(rec, "020", "a")
mediatype = first_subfield_value(rec, "338", "a")
@@ -378,10 +375,9 @@ RVK_ALLOWED = r"[A-Z0-9.\-\/]" # conservative char set typically seen in RVK no
def find_newer_edition(
swb_result: BookData, dnb_result: List[BookData]
) -> Optional[List[BookData]]:
"""
New edition if:
swb_result: BookData, dnb_result: list[BookData],
) -> list[BookData] | None:
"""New edition if:
- year > swb.year OR
- edition_number > swb.edition_number
@@ -393,7 +389,7 @@ def find_newer_edition(
edition_number desc, best-signature-match desc, has-signature desc).
"""
def norm_sig(s: Optional[str]) -> str:
def norm_sig(s: str | None) -> str:
if not s:
return ""
# normalize: lowercase, collapse whitespace, keep alnum + a few separators
@@ -427,7 +423,7 @@ def find_newer_edition(
swb_sig_norm = norm_sig(getattr(swb_result, "signature", None))
# 1) Filter to same-work AND newer
candidates: List[BookData] = []
candidates: list[BookData] = []
for b in dnb_result:
# Skip if both signatures exist and don't match (different work)
b_sig = getattr(b, "signature", None)
@@ -443,7 +439,7 @@ def find_newer_edition(
return None
# 2) Dedupe by PPN, preferring signature (and matching signature if possible)
by_ppn: dict[Optional[str], BookData] = {}
by_ppn: dict[str | None, BookData] = {}
for b in candidates:
key = getattr(b, "ppn", None)
prev = by_ppn.get(key)
@@ -477,7 +473,7 @@ def find_newer_edition(
class QueryTransformer:
def __init__(self, api_schema: Type[Enum], arguments: Union[Iterable[str], str]):
def __init__(self, api_schema: type[Enum], arguments: Iterable[str] | str):
self.api_schema = api_schema
if isinstance(arguments, str):
self.arguments = [arguments]
@@ -485,8 +481,8 @@ class QueryTransformer:
self.arguments = arguments
self.drop_empty = True
def transform(self) -> Dict[str, Any]:
arguments: List[str] = []
def transform(self) -> dict[str, Any]:
arguments: list[str] = []
schema = self.api_schema
for arg in self.arguments:
if "=" not in arg:
@@ -497,16 +493,16 @@ class QueryTransformer:
if hasattr(schema, key.upper()):
api_key = getattr(schema, key.upper()).value
if key.upper() == "AUTHOR" and hasattr(schema, "AUTHOR_SCHEMA"):
author_schema = getattr(schema, "AUTHOR_SCHEMA").value
author_schema = schema.AUTHOR_SCHEMA.value
if author_schema == "SpaceAfterComma":
value = value.replace(",", ", ")
elif author_schema == "NoSpaceAfterComma":
value = value.replace(", ", ",")
value = value.replace(" ", " ")
if key.upper() == "TITLE" and hasattr(
schema, "ENCLOSE_TITLE_IN_QUOTES"
schema, "ENCLOSE_TITLE_IN_QUOTES",
):
if getattr(schema, "ENCLOSE_TITLE_IN_QUOTES"):
if schema.ENCLOSE_TITLE_IN_QUOTES:
value = f'"{value}"'
arguments.append(f"{api_key}={value}")
@@ -519,10 +515,10 @@ class Api:
self,
site: str,
url: str,
prefix: Type[Enum],
prefix: type[Enum],
library_identifier: str,
notsupported_args: Optional[List[str]] = None,
replace: Optional[Dict[str, str]] = None,
notsupported_args: list[str] | None = None,
replace: dict[str, str] | None = None,
):
self.site = site
self.url = url
@@ -554,7 +550,7 @@ class Api:
# Best-effort cleanup
self.close()
def get(self, query_args: Union[Iterable[str], str]) -> List[Record]:
def get(self, query_args: Iterable[str] | str) -> list[Record]:
start_time = time.monotonic()
# if any query_arg ends with =, remove it
if isinstance(query_args, str):
@@ -566,7 +562,7 @@ class Api:
if not any(qa.startswith(na + "=") for na in self.notsupported_args)
]
query_args = QueryTransformer(
api_schema=self.prefix, arguments=query_args
api_schema=self.prefix, arguments=query_args,
).transform()
query = "+and+".join(query_args)
for old, new in self.replace.items():
@@ -579,12 +575,12 @@ class Api:
"Accept-Charset": "latin1,utf-8;q=0.7,*;q=0.3",
}
# Use persistent session, enforce 1 req/sec, and retry up to 5 times
last_error: Optional[Exception] = None
last_error: Exception | None = None
for attempt in range(1, self._max_retries + 1):
# Abort if overall timeout exceeded before starting attempt
if time.monotonic() - start_time > self._overall_timeout_seconds:
last_error = requests.exceptions.Timeout(
f"Overall timeout {self._overall_timeout_seconds}s exceeded before attempt {attempt}"
f"Overall timeout {self._overall_timeout_seconds}s exceeded before attempt {attempt}",
)
break
# Enforce rate limit relative to last request end
@@ -596,21 +592,20 @@ class Api:
try:
# Per-attempt read timeout capped at remaining overall budget (but at most 30s)
remaining = max(
0.0, self._overall_timeout_seconds - (time.monotonic() - start_time)
0.0, self._overall_timeout_seconds - (time.monotonic() - start_time),
)
read_timeout = min(30.0, remaining if remaining > 0 else 0.001)
resp = self._session.get(
url, headers=headers, timeout=(3.05, read_timeout)
url, headers=headers, timeout=(3.05, read_timeout),
)
self._last_request_time = time.monotonic()
if resp.status_code == 200:
# Parse using raw bytes (original behavior) to preserve encoding edge cases
sr = parse_search_retrieve_response(resp.content)
return sr.records
else:
last_error = Exception(
f"Error fetching data from {self.site}: HTTP {resp.status_code} (attempt {attempt}/{self._max_retries})"
)
last_error = Exception(
f"Error fetching data from {self.site}: HTTP {resp.status_code} (attempt {attempt}/{self._max_retries})",
)
except requests.exceptions.ReadTimeout as e:
last_error = e
except requests.exceptions.Timeout as e:
@@ -625,9 +620,9 @@ class Api:
# If we exit the loop, all attempts failed
raise last_error if last_error else Exception("Unknown request failure")
def getBooks(self, query_args: Union[Iterable[str], str]) -> List[BookData]:
def getBooks(self, query_args: Iterable[str] | str) -> list[BookData]:
try:
records: List[Record] = self.get(query_args)
records: list[Record] = self.get(query_args)
except requests.exceptions.ReadTimeout:
# Return a list with a single empty BookData object on read timeout
return [BookData()]
@@ -638,7 +633,7 @@ class Api:
# Propagate other errors (could also choose to return empty list)
raise
# Avoid printing on hot paths; rely on logger if needed
books: List[BookData] = []
books: list[BookData] = []
# extract title from query_args if present
title = None
for arg in query_args: