chore(all): run formatting on repo, start work on porting webrequest over to api library
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
import re
|
||||
import time
|
||||
import xml.etree.ElementTree as ET
|
||||
from collections.abc import Iterable
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
@@ -24,7 +25,7 @@ MARC = "http://www.loc.gov/MARC21/slim"
|
||||
NS = {"zs": ZS, "marc": MARC}
|
||||
|
||||
|
||||
def _text(elem: Optional[ET.Element]) -> str:
|
||||
def _text(elem: ET.Element | None) -> str:
|
||||
return (elem.text or "") if elem is not None else ""
|
||||
|
||||
|
||||
@@ -36,32 +37,31 @@ def _req_text(parent: ET.Element, path: str) -> str:
|
||||
|
||||
|
||||
def parse_marc_record(record_el: ET.Element) -> MarcRecord:
|
||||
"""
|
||||
record_el is the <marc:record> element (default ns MARC in your sample)
|
||||
"""record_el is the <marc:record> element (default ns MARC in your sample)
|
||||
"""
|
||||
# leader
|
||||
leader_text = _req_text(record_el, "marc:leader")
|
||||
|
||||
# controlfields
|
||||
controlfields: List[ControlField] = []
|
||||
controlfields: list[ControlField] = []
|
||||
for cf in record_el.findall("marc:controlfield", NS):
|
||||
tag = cf.get("tag", "").strip()
|
||||
controlfields.append(ControlField(tag=tag, value=_text(cf)))
|
||||
|
||||
# datafields
|
||||
datafields: List[DataField] = []
|
||||
datafields: list[DataField] = []
|
||||
for df in record_el.findall("marc:datafield", NS):
|
||||
tag = df.get("tag", "").strip()
|
||||
ind1 = df.get("ind1") or " "
|
||||
ind2 = df.get("ind2") or " "
|
||||
subfields: List[SubField] = []
|
||||
subfields: list[SubField] = []
|
||||
for sf in df.findall("marc:subfield", NS):
|
||||
code = sf.get("code", "")
|
||||
subfields.append(SubField(code=code, value=_text(sf)))
|
||||
datafields.append(DataField(tag=tag, ind1=ind1, ind2=ind2, subfields=subfields))
|
||||
|
||||
return MarcRecord(
|
||||
leader=leader_text, controlfields=controlfields, datafields=datafields
|
||||
leader=leader_text, controlfields=controlfields, datafields=datafields,
|
||||
)
|
||||
|
||||
|
||||
@@ -92,7 +92,7 @@ def parse_record(zs_record_el: ET.Element) -> Record:
|
||||
)
|
||||
|
||||
|
||||
def parse_echoed_request(root: ET.Element) -> Optional[EchoedSearchRequest]:
|
||||
def parse_echoed_request(root: ET.Element) -> EchoedSearchRequest | None:
|
||||
el = root.find("zs:echoedSearchRetrieveRequest", NS)
|
||||
if el is None:
|
||||
return None
|
||||
@@ -119,7 +119,7 @@ def parse_echoed_request(root: ET.Element) -> Optional[EchoedSearchRequest]:
|
||||
|
||||
|
||||
def parse_search_retrieve_response(
|
||||
xml_str: Union[str, bytes],
|
||||
xml_str: str | bytes,
|
||||
) -> SearchRetrieveResponse:
|
||||
root = ET.fromstring(xml_str)
|
||||
|
||||
@@ -128,7 +128,7 @@ def parse_search_retrieve_response(
|
||||
numberOfRecords = int(_req_text(root, "zs:numberOfRecords") or "0")
|
||||
|
||||
records_parent = root.find("zs:records", NS)
|
||||
records: List[Record] = []
|
||||
records: list[Record] = []
|
||||
if records_parent is not None:
|
||||
for r in records_parent.findall("zs:record", NS):
|
||||
record = parse_record(r)
|
||||
@@ -150,9 +150,9 @@ def parse_search_retrieve_response(
|
||||
|
||||
def iter_datafields(
|
||||
rec: MarcRecord,
|
||||
tag: Optional[str] = None,
|
||||
ind1: Optional[str] = None,
|
||||
ind2: Optional[str] = None,
|
||||
tag: str | None = None,
|
||||
ind1: str | None = None,
|
||||
ind2: str | None = None,
|
||||
) -> Iterable[DataField]:
|
||||
"""Yield datafields, optionally filtered by tag/indicators."""
|
||||
for df in rec.datafields:
|
||||
@@ -170,11 +170,11 @@ def subfield_values(
|
||||
tag: str,
|
||||
code: str,
|
||||
*,
|
||||
ind1: Optional[str] = None,
|
||||
ind2: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
ind1: str | None = None,
|
||||
ind2: str | None = None,
|
||||
) -> list[str]:
|
||||
"""All values for subfield `code` in every `tag` field (respecting indicators)."""
|
||||
out: List[str] = []
|
||||
out: list[str] = []
|
||||
for df in iter_datafields(rec, tag, ind1, ind2):
|
||||
out.extend(sf.value for sf in df.subfields if sf.code == code)
|
||||
return out
|
||||
@@ -185,10 +185,10 @@ def first_subfield_value(
|
||||
tag: str,
|
||||
code: str,
|
||||
*,
|
||||
ind1: Optional[str] = None,
|
||||
ind2: Optional[str] = None,
|
||||
default: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
ind1: str | None = None,
|
||||
ind2: str | None = None,
|
||||
default: str | None = None,
|
||||
) -> str | None:
|
||||
"""First value for subfield `code` in `tag` (respecting indicators)."""
|
||||
for df in iter_datafields(rec, tag, ind1, ind2):
|
||||
for sf in df.subfields:
|
||||
@@ -201,25 +201,24 @@ def find_datafields_with_subfields(
|
||||
rec: MarcRecord,
|
||||
tag: str,
|
||||
*,
|
||||
where_all: Optional[Dict[str, str]] = None,
|
||||
where_any: Optional[Dict[str, str]] = None,
|
||||
where_all: dict[str, str] | None = None,
|
||||
where_any: dict[str, str] | None = None,
|
||||
casefold: bool = False,
|
||||
ind1: Optional[str] = None,
|
||||
ind2: Optional[str] = None,
|
||||
) -> List[DataField]:
|
||||
"""
|
||||
Return datafields of `tag` whose subfields match constraints:
|
||||
ind1: str | None = None,
|
||||
ind2: str | None = None,
|
||||
) -> list[DataField]:
|
||||
"""Return datafields of `tag` whose subfields match constraints:
|
||||
- where_all: every (code -> exact value) must be present
|
||||
- where_any: at least one (code -> exact value) present
|
||||
Set `casefold=True` for case-insensitive comparison.
|
||||
"""
|
||||
where_all = where_all or {}
|
||||
where_any = where_any or {}
|
||||
matched: List[DataField] = []
|
||||
matched: list[DataField] = []
|
||||
|
||||
for df in iter_datafields(rec, tag, ind1, ind2):
|
||||
# Map code -> list of values (with optional casefold applied)
|
||||
vals: Dict[str, List[str]] = {}
|
||||
vals: dict[str, list[str]] = {}
|
||||
for sf in df.subfields:
|
||||
v = sf.value.casefold() if casefold else sf.value
|
||||
vals.setdefault(sf.code, []).append(v)
|
||||
@@ -246,8 +245,8 @@ def find_datafields_with_subfields(
|
||||
|
||||
|
||||
def controlfield_value(
|
||||
rec: MarcRecord, tag: str, default: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
rec: MarcRecord, tag: str, default: str | None = None,
|
||||
) -> str | None:
|
||||
"""Get the first controlfield value by tag (e.g., '001', '005')."""
|
||||
for cf in rec.controlfields:
|
||||
if cf.tag == tag:
|
||||
@@ -256,8 +255,8 @@ def controlfield_value(
|
||||
|
||||
|
||||
def datafields_value(
|
||||
data: List[DataField], code: str, default: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
data: list[DataField], code: str, default: str | None = None,
|
||||
) -> str | None:
|
||||
"""Get the first value for a specific subfield code in a list of datafields."""
|
||||
for df in data:
|
||||
for sf in df.subfields:
|
||||
@@ -267,8 +266,8 @@ def datafields_value(
|
||||
|
||||
|
||||
def datafield_value(
|
||||
df: DataField, code: str, default: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
df: DataField, code: str, default: str | None = None,
|
||||
) -> str | None:
|
||||
"""Get the first value for a specific subfield code in a datafield."""
|
||||
for sf in df.subfields:
|
||||
if sf.code == code:
|
||||
@@ -276,9 +275,8 @@ def datafield_value(
|
||||
return default
|
||||
|
||||
|
||||
def _smart_join_title(a: str, b: Optional[str]) -> str:
|
||||
"""
|
||||
Join 245 $a and $b with MARC-style punctuation.
|
||||
def _smart_join_title(a: str, b: str | None) -> str:
|
||||
"""Join 245 $a and $b with MARC-style punctuation.
|
||||
If $b is present, join with ' : ' unless either side already supplies punctuation.
|
||||
"""
|
||||
a = a.strip()
|
||||
@@ -293,7 +291,7 @@ def _smart_join_title(a: str, b: Optional[str]) -> str:
|
||||
def subfield_values_from_fields(
|
||||
fields: Iterable[DataField],
|
||||
code: str,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""All subfield values with given `code` across a list of DataField."""
|
||||
return [sf.value for df in fields for sf in df.subfields if sf.code == code]
|
||||
|
||||
@@ -301,8 +299,8 @@ def subfield_values_from_fields(
|
||||
def first_subfield_value_from_fields(
|
||||
fields: Iterable[DataField],
|
||||
code: str,
|
||||
default: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
default: str | None = None,
|
||||
) -> str | None:
|
||||
"""First subfield value with given `code` across a list of DataField."""
|
||||
for df in fields:
|
||||
for sf in df.subfields:
|
||||
@@ -314,12 +312,11 @@ def first_subfield_value_from_fields(
|
||||
def subfield_value_pairs_from_fields(
|
||||
fields: Iterable[DataField],
|
||||
code: str,
|
||||
) -> List[Tuple[DataField, str]]:
|
||||
"""
|
||||
Return (DataField, value) pairs for all subfields with `code`.
|
||||
) -> list[tuple[DataField, str]]:
|
||||
"""Return (DataField, value) pairs for all subfields with `code`.
|
||||
Useful if you need to know which field a value came from.
|
||||
"""
|
||||
out: List[Tuple[DataField, str]] = []
|
||||
out: list[tuple[DataField, str]] = []
|
||||
for df in fields:
|
||||
for sf in df.subfields:
|
||||
if sf.code == code:
|
||||
@@ -340,13 +337,13 @@ def book_from_marc(rec: MarcRecord, library_identifier: str) -> BookData:
|
||||
|
||||
# Signature = 924 where $9 == "Frei 129" → take that field's $g
|
||||
frei_fields = find_datafields_with_subfields(
|
||||
rec, "924", where_all={"9": "Frei 129"}
|
||||
rec, "924", where_all={"9": "Frei 129"},
|
||||
)
|
||||
signature = first_subfield_value_from_fields(frei_fields, "g")
|
||||
|
||||
# Year = 264 $c (prefer ind2="1" publication; fallback to any 264)
|
||||
year = first_subfield_value(rec, "264", "c", ind2="1") or first_subfield_value(
|
||||
rec, "264", "c"
|
||||
rec, "264", "c",
|
||||
)
|
||||
isbn = subfield_values(rec, "020", "a")
|
||||
mediatype = first_subfield_value(rec, "338", "a")
|
||||
@@ -378,10 +375,9 @@ RVK_ALLOWED = r"[A-Z0-9.\-\/]" # conservative char set typically seen in RVK no
|
||||
|
||||
|
||||
def find_newer_edition(
|
||||
swb_result: BookData, dnb_result: List[BookData]
|
||||
) -> Optional[List[BookData]]:
|
||||
"""
|
||||
New edition if:
|
||||
swb_result: BookData, dnb_result: list[BookData],
|
||||
) -> list[BookData] | None:
|
||||
"""New edition if:
|
||||
- year > swb.year OR
|
||||
- edition_number > swb.edition_number
|
||||
|
||||
@@ -393,7 +389,7 @@ def find_newer_edition(
|
||||
edition_number desc, best-signature-match desc, has-signature desc).
|
||||
"""
|
||||
|
||||
def norm_sig(s: Optional[str]) -> str:
|
||||
def norm_sig(s: str | None) -> str:
|
||||
if not s:
|
||||
return ""
|
||||
# normalize: lowercase, collapse whitespace, keep alnum + a few separators
|
||||
@@ -427,7 +423,7 @@ def find_newer_edition(
|
||||
swb_sig_norm = norm_sig(getattr(swb_result, "signature", None))
|
||||
|
||||
# 1) Filter to same-work AND newer
|
||||
candidates: List[BookData] = []
|
||||
candidates: list[BookData] = []
|
||||
for b in dnb_result:
|
||||
# Skip if both signatures exist and don't match (different work)
|
||||
b_sig = getattr(b, "signature", None)
|
||||
@@ -443,7 +439,7 @@ def find_newer_edition(
|
||||
return None
|
||||
|
||||
# 2) Dedupe by PPN, preferring signature (and matching signature if possible)
|
||||
by_ppn: dict[Optional[str], BookData] = {}
|
||||
by_ppn: dict[str | None, BookData] = {}
|
||||
for b in candidates:
|
||||
key = getattr(b, "ppn", None)
|
||||
prev = by_ppn.get(key)
|
||||
@@ -477,7 +473,7 @@ def find_newer_edition(
|
||||
|
||||
|
||||
class QueryTransformer:
|
||||
def __init__(self, api_schema: Type[Enum], arguments: Union[Iterable[str], str]):
|
||||
def __init__(self, api_schema: type[Enum], arguments: Iterable[str] | str):
|
||||
self.api_schema = api_schema
|
||||
if isinstance(arguments, str):
|
||||
self.arguments = [arguments]
|
||||
@@ -485,8 +481,8 @@ class QueryTransformer:
|
||||
self.arguments = arguments
|
||||
self.drop_empty = True
|
||||
|
||||
def transform(self) -> Dict[str, Any]:
|
||||
arguments: List[str] = []
|
||||
def transform(self) -> dict[str, Any]:
|
||||
arguments: list[str] = []
|
||||
schema = self.api_schema
|
||||
for arg in self.arguments:
|
||||
if "=" not in arg:
|
||||
@@ -497,16 +493,16 @@ class QueryTransformer:
|
||||
if hasattr(schema, key.upper()):
|
||||
api_key = getattr(schema, key.upper()).value
|
||||
if key.upper() == "AUTHOR" and hasattr(schema, "AUTHOR_SCHEMA"):
|
||||
author_schema = getattr(schema, "AUTHOR_SCHEMA").value
|
||||
author_schema = schema.AUTHOR_SCHEMA.value
|
||||
if author_schema == "SpaceAfterComma":
|
||||
value = value.replace(",", ", ")
|
||||
elif author_schema == "NoSpaceAfterComma":
|
||||
value = value.replace(", ", ",")
|
||||
value = value.replace(" ", " ")
|
||||
if key.upper() == "TITLE" and hasattr(
|
||||
schema, "ENCLOSE_TITLE_IN_QUOTES"
|
||||
schema, "ENCLOSE_TITLE_IN_QUOTES",
|
||||
):
|
||||
if getattr(schema, "ENCLOSE_TITLE_IN_QUOTES"):
|
||||
if schema.ENCLOSE_TITLE_IN_QUOTES:
|
||||
value = f'"{value}"'
|
||||
|
||||
arguments.append(f"{api_key}={value}")
|
||||
@@ -519,10 +515,10 @@ class Api:
|
||||
self,
|
||||
site: str,
|
||||
url: str,
|
||||
prefix: Type[Enum],
|
||||
prefix: type[Enum],
|
||||
library_identifier: str,
|
||||
notsupported_args: Optional[List[str]] = None,
|
||||
replace: Optional[Dict[str, str]] = None,
|
||||
notsupported_args: list[str] | None = None,
|
||||
replace: dict[str, str] | None = None,
|
||||
):
|
||||
self.site = site
|
||||
self.url = url
|
||||
@@ -554,7 +550,7 @@ class Api:
|
||||
# Best-effort cleanup
|
||||
self.close()
|
||||
|
||||
def get(self, query_args: Union[Iterable[str], str]) -> List[Record]:
|
||||
def get(self, query_args: Iterable[str] | str) -> list[Record]:
|
||||
start_time = time.monotonic()
|
||||
# if any query_arg ends with =, remove it
|
||||
if isinstance(query_args, str):
|
||||
@@ -566,7 +562,7 @@ class Api:
|
||||
if not any(qa.startswith(na + "=") for na in self.notsupported_args)
|
||||
]
|
||||
query_args = QueryTransformer(
|
||||
api_schema=self.prefix, arguments=query_args
|
||||
api_schema=self.prefix, arguments=query_args,
|
||||
).transform()
|
||||
query = "+and+".join(query_args)
|
||||
for old, new in self.replace.items():
|
||||
@@ -579,12 +575,12 @@ class Api:
|
||||
"Accept-Charset": "latin1,utf-8;q=0.7,*;q=0.3",
|
||||
}
|
||||
# Use persistent session, enforce 1 req/sec, and retry up to 5 times
|
||||
last_error: Optional[Exception] = None
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(1, self._max_retries + 1):
|
||||
# Abort if overall timeout exceeded before starting attempt
|
||||
if time.monotonic() - start_time > self._overall_timeout_seconds:
|
||||
last_error = requests.exceptions.Timeout(
|
||||
f"Overall timeout {self._overall_timeout_seconds}s exceeded before attempt {attempt}"
|
||||
f"Overall timeout {self._overall_timeout_seconds}s exceeded before attempt {attempt}",
|
||||
)
|
||||
break
|
||||
# Enforce rate limit relative to last request end
|
||||
@@ -596,21 +592,20 @@ class Api:
|
||||
try:
|
||||
# Per-attempt read timeout capped at remaining overall budget (but at most 30s)
|
||||
remaining = max(
|
||||
0.0, self._overall_timeout_seconds - (time.monotonic() - start_time)
|
||||
0.0, self._overall_timeout_seconds - (time.monotonic() - start_time),
|
||||
)
|
||||
read_timeout = min(30.0, remaining if remaining > 0 else 0.001)
|
||||
resp = self._session.get(
|
||||
url, headers=headers, timeout=(3.05, read_timeout)
|
||||
url, headers=headers, timeout=(3.05, read_timeout),
|
||||
)
|
||||
self._last_request_time = time.monotonic()
|
||||
if resp.status_code == 200:
|
||||
# Parse using raw bytes (original behavior) to preserve encoding edge cases
|
||||
sr = parse_search_retrieve_response(resp.content)
|
||||
return sr.records
|
||||
else:
|
||||
last_error = Exception(
|
||||
f"Error fetching data from {self.site}: HTTP {resp.status_code} (attempt {attempt}/{self._max_retries})"
|
||||
)
|
||||
last_error = Exception(
|
||||
f"Error fetching data from {self.site}: HTTP {resp.status_code} (attempt {attempt}/{self._max_retries})",
|
||||
)
|
||||
except requests.exceptions.ReadTimeout as e:
|
||||
last_error = e
|
||||
except requests.exceptions.Timeout as e:
|
||||
@@ -625,9 +620,9 @@ class Api:
|
||||
# If we exit the loop, all attempts failed
|
||||
raise last_error if last_error else Exception("Unknown request failure")
|
||||
|
||||
def getBooks(self, query_args: Union[Iterable[str], str]) -> List[BookData]:
|
||||
def getBooks(self, query_args: Iterable[str] | str) -> list[BookData]:
|
||||
try:
|
||||
records: List[Record] = self.get(query_args)
|
||||
records: list[Record] = self.get(query_args)
|
||||
except requests.exceptions.ReadTimeout:
|
||||
# Return a list with a single empty BookData object on read timeout
|
||||
return [BookData()]
|
||||
@@ -638,7 +633,7 @@ class Api:
|
||||
# Propagate other errors (could also choose to return empty list)
|
||||
raise
|
||||
# Avoid printing on hot paths; rely on logger if needed
|
||||
books: List[BookData] = []
|
||||
books: list[BookData] = []
|
||||
# extract title from query_args if present
|
||||
title = None
|
||||
for arg in query_args:
|
||||
|
||||
Reference in New Issue
Block a user