######################################################################################################################
# Copyright (C) 2017-2022 Spine project consortium
# Copyright Spine Database API contributors
# This file is part of Spine Database API.
# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser
# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your
# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################
from multiprocessing import RLock
from enum import Enum, unique, auto
from difflib import SequenceMatcher
from .temp_id import TempId, resolve
from .exception import SpineDBAPIError
from .helpers import Asterisk
# TODO: Implement MappedItem.pop() to do lookup?
@unique
[docs]class Status(Enum):
"""Mapped item status."""
committed = auto()
to_add = auto()
to_update = auto()
to_remove = auto()
added_and_removed = auto()
compromised = auto()
[docs]class DatabaseMappingBase:
"""An in-memory mapping of a DB, mapping item types (table names), to numeric ids, to items.
This class is not meant to be used directly. Instead, you should subclass it to fit your particular DB schema.
When subclassing, you need to implement :meth:`item_types`, :meth:`item_factory`, :meth:`_make_sq`,
and :meth:`_query_commit_count`.
"""
def __init__(self):
self.closed = False
self._mapped_tables = {}
self._fetched = {}
self._locks = {}
self._commit_count = None
item_types = self.item_types()
self._sorted_item_types = []
while item_types:
item_type = item_types.pop(0)
if self.item_factory(item_type).ref_types() & set(item_types):
item_types.append(item_type)
else:
self._sorted_item_types.append(item_type)
@staticmethod
def item_types():
"""Returns a list of public item types from the DB mapping schema (equivalent to the table names).
:meta private:
Returns:
list(str)
"""
raise NotImplementedError()
@staticmethod
def all_item_types():
"""Returns a list of all item types from the DB mapping schema (equivalent to the table names).
:meta private:
Returns:
list(str)
"""
raise NotImplementedError()
@staticmethod
def item_factory(item_type):
"""Returns a subclass of :class:`.MappedItemBase` to make items of given type.
:meta private:
Args:
item_type (str)
Returns:
function
"""
raise NotImplementedError()
def _make_query(self, item_type, **kwargs):
"""Returns a :class:`~spinedb_api.query.Query` object to fetch items of given type.
Args:
item_type (str)
**kwargs: query filters
Returns:
:class:`~spinedb_api.query.Query` or None if the mapping is closed.
"""
if self.closed:
return None
sq = self._make_sq(item_type)
qry = self.query(sq)
for key, value in kwargs.items():
if isinstance(value, tuple):
continue
value = resolve(value)
if hasattr(sq.c, key):
qry = qry.filter(getattr(sq.c, key) == value)
elif key in self.item_factory(item_type)._external_fields:
src_key, key = self.item_factory(item_type)._external_fields[key]
ref_type, ref_key = self.item_factory(item_type)._references[src_key]
ref_sq = self._make_sq(ref_type)
try:
qry = qry.filter(
getattr(sq.c, src_key) == getattr(ref_sq.c, ref_key), getattr(ref_sq.c, key) == value
)
except AttributeError:
pass
return qry
def _make_sq(self, item_type):
"""Returns a :class:`~sqlalchemy.sql.expression.Alias` object representing a subquery
to collect items of given type.
Args:
item_type (str)
Returns:
:class:`~sqlalchemy.sql.expression.Alias`
"""
raise NotImplementedError()
def _query_commit_count(self):
"""Returns the number of rows in the commit table in the DB.
Returns:
int
"""
raise NotImplementedError()
def make_item(self, item_type, **item):
factory = self.item_factory(item_type)
return factory(self, item_type, **item)
def dirty_ids(self, item_type):
return {
item["id"]
for item in self.mapped_table(item_type).valid_values()
if item.status in (Status.to_add, Status.to_update)
}
def _dirty_items(self):
"""Returns a list of tuples of the form (item_type, (to_add, to_update, to_remove)) corresponding to
items that have been modified but not yet committed.
Returns:
list
"""
real_commit_count = self._query_commit_count()
dirty_items = []
purged_item_types = {x for x in self.item_types() if self.mapped_table(x).purged}
self._add_descendants(purged_item_types)
for item_type in self._sorted_item_types:
self.do_fetch_all(item_type, commit_count=real_commit_count) # To fix conflicts in add_item_from_db
mapped_table = self.mapped_table(item_type)
to_add = []
to_update = []
to_remove = []
for item in mapped_table.valid_values():
if item.status == Status.to_add:
to_add.append(item)
elif item.status == Status.to_update:
to_update.append(item)
if item_type in purged_item_types:
to_remove.append(mapped_table.wildcard_item)
to_remove.extend(mapped_table.values())
else:
for item in mapped_table.values():
item.validate()
if item.status == Status.to_remove:
to_remove.append(item)
if to_add or to_update or to_remove:
dirty_items.append((item_type, (to_add, to_update, to_remove)))
return dirty_items
def _rollback(self):
"""Discards uncommitted changes.
Namely, removes all the added items, resets all the updated items, and restores all the removed items.
Returns:
bool: False if there is no uncommitted items, True if successful.
"""
dirty_items = self._dirty_items()
if not dirty_items:
return False
to_add_by_type = []
to_update_by_type = []
to_remove_by_type = []
for item_type, (to_add, to_update, to_remove) in reversed(dirty_items):
to_add_by_type.append((item_type, to_add))
to_update_by_type.append((item_type, to_update))
to_remove_by_type.append((item_type, to_remove))
for item_type, to_remove in to_remove_by_type:
mapped_table = self.mapped_table(item_type)
for item in to_remove:
mapped_table.restore_item(item["id"])
for item_type, to_update in to_update_by_type:
mapped_table = self.mapped_table(item_type)
for item in to_update:
mapped_table.update_item(item.backup)
for item_type, to_add in to_add_by_type:
mapped_table = self.mapped_table(item_type)
for item in to_add:
if mapped_table.remove_item(item) is not None:
item.invalidate_id()
return True
def _refresh(self):
"""Clears fetch progress, so the DB is queried again."""
if self._commit_count == self._query_commit_count():
return
self._fetched.clear()
for item_type in self.item_types():
mapped_table = self.mapped_table(item_type)
for item in mapped_table.values():
item.handle_refresh()
def _check_item_type(self, item_type):
if item_type not in self.all_item_types():
candidate = max(self.all_item_types(), key=lambda x: SequenceMatcher(None, item_type, x).ratio())
raise SpineDBAPIError(f"Invalid item type '{item_type}' - maybe you meant '{candidate}'?")
def mapped_table(self, item_type):
if item_type not in self._mapped_tables:
self._check_item_type(item_type)
self._mapped_tables[item_type] = _MappedTable(self, item_type)
return self._mapped_tables[item_type]
[docs] def reset(self, *item_types):
"""Resets the mapping for given item types as if nothing was fetched from the DB or modified in the mapping.
Any modifications in the mapping that aren't committed to the DB are lost after this.
"""
item_types = set(self.item_types()) if not item_types else set(item_types) & set(self.item_types())
self._add_descendants(item_types)
for item_type in item_types:
self._mapped_tables.pop(item_type, None)
self._fetched.pop(item_type, None)
[docs] def reset_purging(self):
"""Resets purging status for all item types.
Fetching items of an item type that has been purged will automatically mark those items removed.
Resetting the purge status lets fetched items to be added unmodified.
"""
for mapped_table in self._mapped_tables.values():
mapped_table.wildcard_item.status = Status.committed
def _add_descendants(self, item_types):
while True:
changed = False
for item_type in set(self.item_types()) - item_types:
if self.item_factory(item_type).ref_types() & item_types:
item_types.add(item_type)
changed = True
if not changed:
break
def get_mapped_item(self, item_type, id_, fetch=True):
mapped_table = self.mapped_table(item_type)
return mapped_table.find_item_by_id(id_, fetch=fetch) or {}
def _get_next_chunk(self, item_type, offset, limit, **kwargs):
"""Gets chunk of items from the DB.
Returns:
list(dict): list of dictionary items.
"""
qry = self._make_query(item_type, **kwargs)
if not qry:
return []
if not limit:
return [dict(x) for x in qry]
return [dict(x) for x in qry.limit(limit).offset(offset)]
[docs] def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs):
"""Fetches items from the DB and adds them to the mapping.
Args:
item_type (str)
Returns:
list(MappedItem): items fetched from the DB.
"""
chunk = self._get_next_chunk(item_type, offset, limit, **kwargs)
if not chunk:
return []
real_commit_count = self._query_commit_count()
is_db_dirty = self._get_commit_count() != real_commit_count
if is_db_dirty:
# We need to fetch the most recent references because their ids might have changed in the DB
for ref_type in self.item_factory(item_type).ref_types():
if ref_type != item_type:
self.do_fetch_all(ref_type, commit_count=real_commit_count)
mapped_table = self.mapped_table(item_type)
items = []
new_items = []
# Add items first
for x in chunk:
item, new = mapped_table.add_item_from_db(x, not is_db_dirty)
if new:
new_items.append(item)
else:
item.handle_refetch()
items.append(item)
# Once all items are added, add the unique key values
# Otherwise items that refer to other items that come later in the query will be seen as corrupted
for item in new_items:
mapped_table.add_unique(item)
return items
def _get_commit_count(self):
"""Returns current commit count.
Returns:
int
"""
if self._commit_count is None:
self._commit_count = self._query_commit_count()
return self._commit_count
[docs] def do_fetch_all(self, item_type, commit_count=None):
"""Fetches all items of given type, but only once for each commit_count.
In other words, the second time this method is called with the same commit_count, it does nothing.
If not specified, commit_count defaults to the result of self._get_commit_count().
Args:
item_type (str)
commit_count (int,optional)
"""
if commit_count is None:
commit_count = self._get_commit_count()
with self._locks.setdefault(item_type, RLock()):
if self._fetched.get(item_type, -1) < commit_count:
self._fetched[item_type] = commit_count
self.do_fetch_more(item_type, offset=0, limit=None)
class _MappedTable(dict):
def __init__(self, db_map, item_type, *args, **kwargs):
"""
Args:
db_map (DatabaseMappingBase): the DB mapping where this mapped table belongs.
item_type (str): the item type, equal to a table name
"""
super().__init__(*args, **kwargs)
self._db_map = db_map
self._item_type = item_type
self._ids_by_unique_key_value = {}
self._temp_id_lookup = {}
self.wildcard_item = MappedItemBase(self._db_map, self._item_type, id=Asterisk)
@property
def purged(self):
return self.wildcard_item.status == Status.to_remove
@purged.setter
def purged(self, purged):
self.wildcard_item.status = Status.to_remove if purged else Status.committed
def get(self, id_, default=None):
id_ = self._temp_id_lookup.get(id_, id_)
return super().get(id_, default)
def _new_id(self):
return TempId.new_unique(self._item_type, self._temp_id_lookup)
def _unique_key_value_to_id(self, key, value, fetch=True):
"""Returns the id that has the given value for the given unique key, or None.
Args:
key (tuple)
value (tuple)
fetch (bool): whether to fetch the DB until found.
Returns:
int or None
"""
value = tuple(tuple(x) if isinstance(x, list) else x for x in value)
ids = self._ids_by_unique_key_value.get(key, {}).get(value, [])
if not ids and fetch:
self._db_map.do_fetch_all(self._item_type)
ids = self._ids_by_unique_key_value.get(key, {}).get(value, [])
return None if not ids else ids[-1]
def _unique_key_value_to_item(self, key, value, fetch=True, valid_only=True):
id_ = self._unique_key_value_to_id(key, value, fetch=fetch)
mapped_item = self.get(id_)
if mapped_item is None:
return None
if valid_only and not mapped_item.is_valid():
return None
return mapped_item
def valid_values(self):
return (x for x in self.values() if x.is_valid())
def _make_item(self, item):
"""Returns a mapped item.
Args:
item (dict): the 'db item' to use as base
Returns:
MappedItem
"""
return self._db_map.make_item(self._item_type, **item)
def find_item(self, item, skip_keys=(), fetch=True):
"""Returns a MappedItemBase that matches the given dictionary-item.
Args:
item (dict)
Returns:
MappedItemBase or None
"""
id_ = item.get("id")
if id_ is not None:
return self.find_item_by_id(id_, fetch=fetch)
return self._find_item_by_unique_key(item, skip_keys=skip_keys, fetch=fetch)
def find_item_by_id(self, id_, fetch=True):
current_item = self.get(id_, {})
if not current_item and fetch:
self._db_map.do_fetch_all(self._item_type)
current_item = self.get(id_, {})
return current_item
def _find_item_by_unique_key(self, item, skip_keys=(), fetch=True, valid_only=True):
for key, value in self._db_map.item_factory(self._item_type).unique_values_for_item(item, skip_keys=skip_keys):
current_item = self._unique_key_value_to_item(key, value, fetch=fetch, valid_only=valid_only)
if current_item:
return current_item
# Maybe item is missing some key stuff, so try with a resolved and polished MappedItem too...
mapped_item = self._make_item(item)
error = mapped_item.resolve_internal_fields(skip_keys=item.keys())
if error:
return {}
error = mapped_item.polish()
if error:
return {}
for key, value in mapped_item.unique_key_values(skip_keys=skip_keys):
current_item = self._unique_key_value_to_item(key, value, fetch=fetch, valid_only=valid_only)
if current_item:
return current_item
return {}
def checked_item_and_error(self, item, for_update=False):
if for_update:
current_item = self.find_item(item)
if not current_item:
return None, f"no {self._item_type} matching {item} to update"
full_item, merge_error = current_item.merge(item)
if full_item is None:
return None, merge_error
else:
current_item = None
full_item, merge_error = item, None
candidate_item = self._make_item(full_item)
error = self._prepare_item(candidate_item, current_item, item)
if error:
return None, error
valid_types = (type(None),) if for_update else ()
self.check_fields(candidate_item._asdict(), valid_types=valid_types)
return candidate_item, merge_error
def _prepare_item(self, candidate_item, current_item, original_item):
"""Prepares item for insertion or update, returns any errors.
Args:
candidate_item (MappedItem)
current_item (MappedItem)
original_item (dict)
Returns:
str or None: errors if any.
"""
error = candidate_item.resolve_internal_fields(skip_keys=original_item.keys())
if error:
return error
error = candidate_item.check_mutability()
if error:
return error
error = candidate_item.polish()
if error:
return error
first_invalid_key = candidate_item.first_invalid_key()
if first_invalid_key:
return f"invalid {first_invalid_key} for {self._item_type}"
try:
for key, value in candidate_item.unique_key_values():
empty = {k for k, v in zip(key, value) if v == ""}
if empty:
return f"invalid empty keys {empty} for {self._item_type}"
unique_item = self._unique_key_value_to_item(key, value)
if unique_item not in (None, current_item) and unique_item.is_valid():
return f"there's already a {self._item_type} with {dict(zip(key, value))}"
except KeyError as e:
return f"missing {e} for {self._item_type}"
def item_to_remove_and_error(self, id_):
if id_ is Asterisk:
return self.wildcard_item, None
current_item = self.find_item({"id": id_})
if not current_item:
return None, None
return current_item, current_item.check_mutability()
def add_unique(self, item):
id_ = item["id"]
for key, value in item.unique_key_values():
self._ids_by_unique_key_value.setdefault(key, {}).setdefault(value, []).append(id_)
def remove_unique(self, item):
id_ = item["id"]
for key, value in item.unique_key_values():
ids = self._ids_by_unique_key_value.get(key, {}).get(value, [])
if id_ in ids:
ids.remove(id_)
def _make_and_add_item(self, item):
if not isinstance(item, MappedItemBase):
item = self._make_item(item)
item.polish()
db_id = item.pop("id", None) if item.has_valid_id else None
item["id"] = new_id = self._new_id()
if db_id is not None:
new_id.resolve(db_id)
self[new_id] = item
return item
def add_item_from_db(self, item, is_db_clean):
"""Adds an item fetched from the DB.
Args:
item (dict): item from the DB.
is_db_clean (Bool)
Returns:
tuple(MappedItem,bool): A mapped item and whether it needs to be added to the unique key values dict.
"""
mapped_item = self._find_item_by_unique_key(item, fetch=False, valid_only=False)
if mapped_item and (is_db_clean or self._same_item(mapped_item, item)):
mapped_item.force_id(item["id"])
return mapped_item, False
mapped_item = self.get(item["id"])
if mapped_item and (is_db_clean or self._same_item(mapped_item.db_equivalent(), item)):
return mapped_item, False
conflicting_item = self.get(item["id"])
if conflicting_item is not None:
conflicting_item.handle_id_steal()
mapped_item = self._make_and_add_item(item)
if self.purged:
# Lazy purge: instead of fetching all at purge time, we purge stuff as it comes.
mapped_item.cascade_remove()
return mapped_item, True
def _same_item(self, mapped_item, db_item):
"""Whether the two given items have the same unique keys.
Args:
mapped_item (MappedItemBase): an item in the in-memory mapping
db_item (dict): an item just fetched from the DB
"""
db_item = self._db_map.make_item(self._item_type, **db_item)
db_item.polish()
return dict(mapped_item.unique_key_values()) == dict(db_item.unique_key_values())
def check_fields(self, item, valid_types=()):
factory = self._db_map.item_factory(self._item_type)
def _error(key, value, valid_types):
if key in set(factory._internal_fields) | set(factory._external_fields) | factory._private_fields | {
"id",
"commit_id",
}:
# The user seems to know what they're doing
return
f_dict = factory.fields.get(key)
if f_dict is None:
valid_args = ", ".join(factory.fields)
return f"invalid keyword argument '{key}' for '{self._item_type}' - valid arguments are {valid_args}."
valid_types = valid_types + (f_dict["type"],)
if f_dict.get("optional", False):
valid_types = valid_types + (type(None),)
if not isinstance(value, valid_types):
return (
f"invalid type for '{key}' of '{self._item_type}' - "
f"got {type(value).__name__}, expected {f_dict['type'].__name__}."
)
errors = list(filter(lambda x: x is not None, (_error(key, value, valid_types) for key, value in item.items())))
if errors:
raise SpineDBAPIError("\n".join(errors))
def add_item(self, item):
item = self._make_and_add_item(item)
self.add_unique(item)
item.status = Status.to_add
return item
def update_item(self, item):
current_item = self.find_item(item)
current_item.cascade_remove_unique()
current_item.update(item)
current_item.cascade_add_unique()
current_item.cascade_update()
return current_item
def remove_item(self, item):
if not item:
return None
if item is self.wildcard_item:
self.purged = True
for current_item in self.valid_values():
current_item.cascade_remove()
return self.wildcard_item
item.cascade_remove()
return item
def restore_item(self, id_):
if id_ is Asterisk:
self.purged = False
for current_item in self.values():
current_item.cascade_restore()
return self.wildcard_item
current_item = self.find_item({"id": id_})
if current_item:
current_item.cascade_restore()
return current_item
[docs]class MappedItemBase(dict):
"""A dictionary that represents a db item."""
"""A dictionary mapping fields to a another dict mapping "type" to a Python type,
"value" to a description of the value for the key, and "optional" to a bool."""
_defaults = {}
"""A dictionary mapping fields to their default values"""
_unique_keys = ()
"""A tuple where each element is itself a tuple of fields corresponding to a unique key"""
_references = {}
"""A dictionary mapping source fields, to a tuple of reference item type and reference field.
Used to access external fields.
"""
_external_fields = {}
"""A dictionary mapping fields that are not in the original dictionary, to a tuple of source field
and target field.
When accessing fields in _external_fields, we first find the reference pointed at by the source field,
and then return the target field of that reference.
"""
_alt_references = {}
"""A dictionary mapping source fields, to a tuple of reference item type and reference fields.
Used only to resolve internal fields at item creation.
"""
_internal_fields = {}
"""A dictionary mapping fields that are not in the original dictionary, to a tuple of source field
and target field.
When resolving fields in _internal_fields, we first find the alt_reference pointed at by the source field,
and then use the target field of that reference.
"""
_private_fields = set()
"""A set with fields that should be ignored in validations."""
is_protected = False
def __init__(self, db_map, item_type, **kwargs):
"""
Args:
db_map (DatabaseMappingBase): the DB where this item belongs.
"""
super().__init__(**kwargs)
self._db_map = db_map
self._item_type = item_type
self._referrers = {}
self._weak_referrers = {}
self.restore_callbacks = set()
self.update_callbacks = set()
self.remove_callbacks = set()
self._has_valid_id = True
self._removed = False
self._valid = None
self._status = Status.committed
self._removal_source = None
self._status_when_removed = None
self._status_when_committed = None
self._backup = None
self.public_item = PublicItem(self._db_map, self)
[docs] def handle_refetch(self):
"""Called when an equivalent item is fetched from the DB.
1. If this item is compromised, then mark it as committed.
2. If this item is committed, then assume the one from the DB is newer and reset the state.
Otherwise assume *this* is newer and do nothing.
"""
if self.status == Status.compromised:
self.status = Status.committed
if self.is_committed():
self._removed = False
self._valid = None
[docs] def handle_refresh(self):
"""Called when the mapping is refreshed.
If this item is committed, then set it as compromised.
"""
if self.status == Status.committed:
self.status = Status.compromised
@classmethod
[docs] def ref_types(cls):
"""Returns a set of item types that this class refers.
Returns:
set(str)
"""
return set(ref_type for ref_type, _ref_key in cls._references.values())
@property
[docs] def status(self):
"""Returns the status of this item.
Returns:
Status
"""
return self._status
@status.setter
def status(self, status):
"""Sets the status of this item.
Args:
status (Status)
"""
self._status = status
@property
[docs] def backup(self):
"""Returns the committed version of this item.
Returns:
dict or None
"""
return self._backup
@property
[docs] def removed(self):
"""Returns whether or not this item has been removed.
Returns:
bool
"""
return self._removed
@property
[docs] def item_type(self):
"""Returns this item's type
Returns:
str
"""
return self._item_type
@property
[docs] def key(self):
"""Returns a tuple (item_type, id) for convenience, or None if this item doesn't yet have an id.
TODO: When does the latter happen?
Returns:
tuple(str,int) or None
"""
id_ = dict.get(self, "id")
if id_ is None:
return None
return (self._item_type, id_)
@property
def has_valid_id(self):
return self._has_valid_id
[docs] def invalidate_id(self):
"""Sets id as invalid."""
self._has_valid_id = False
def _extended(self):
"""Returns a dict from this item's original fields plus all the references resolved statically.
Returns:
dict
"""
d = self._asdict()
d.update({key: self[key] for key in self._external_fields})
return d
def _asdict(self):
"""Returns a dict from this item's original fields.
Returns:
dict
"""
return dict(self)
def resolve(self):
return {k: resolve(v) for k, v in self._asdict().items()}
[docs] def merge(self, other):
"""Merges this item with another and returns the merged item together with any errors.
Used for updating items.
Args:
other (dict): the item to merge into this.
Returns:
dict: merged item.
str: error description if any.
"""
if not self._something_to_update(other):
# Nothing to update, that's fine
return None, ""
merged = {**self._extended(), **other}
if not isinstance(merged["id"], int):
merged["id"] = self["id"]
return merged, ""
def _something_to_update(self, other):
def _convert(x):
if isinstance(x, list):
x = tuple(x)
return resolve(x)
return not all(
_convert(self.get(key)) == _convert(value)
for key, value in other.items()
if value is not None
or self.fields.get(key, {}).get("optional", False) # Ignore mandatory fields that are None
)
[docs] def db_equivalent(self):
"""The equivalent of this item in the DB.
Returns:
MappedItemBase
"""
if self.status == Status.to_update:
db_item = self._db_map.make_item(self._item_type, **self.backup)
db_item.polish()
return db_item
return self
[docs] def first_invalid_key(self):
"""Goes through the ``_references`` class attribute and returns the key of the first reference
that cannot be resolved.
Returns:
str or None: unresolved reference's key if any.
"""
return next((src_key for src_key, ref in self._resolve_refs() if not ref), None)
def _resolve_refs(self):
"""Goes through the ``_references`` class attribute and tries to resolve them.
If successful, replace source fields referring to db-ids with the reference's TempId.
Yields:
tuple(str,MappedItem or None): the source field and resolved ref.
"""
for src_key, (ref_type, ref_key) in self._references.items():
ref = self._get_full_ref(src_key, ref_type, ref_key)
if isinstance(ref, tuple):
for r in ref:
yield src_key, r
else:
yield src_key, ref
def _get_full_ref(self, src_key, ref_type, ref_key, strong=True):
try:
src_val = self[src_key]
except KeyError:
return {}
if isinstance(src_val, tuple):
ref = tuple(self._get_ref(ref_type, {ref_key: x}, strong=strong) for x in src_val)
if all(ref) and ref_key == "id":
self[src_key] = tuple(r["id"] for r in ref)
return ref
ref = self._get_ref(ref_type, {ref_key: src_val}, strong=strong)
if ref and ref_key == "id":
self[src_key] = ref["id"]
return ref
@classmethod
def unique_values_for_item(cls, item, skip_keys=()):
for key in cls._unique_keys:
if key not in skip_keys:
value = tuple(item.get(k) for k in key)
if None not in value:
yield key, value
[docs] def unique_key_values(self, skip_keys=()):
"""Yields tuples of unique keys and their values.
Args:
skip_keys: Don't yield these keys
Yields:
tuple(tuple,tuple): the first element is the unique key, the second is the values.
"""
yield from self.unique_values_for_item(self, skip_keys=skip_keys)
[docs] def resolve_internal_fields(self, skip_keys=()):
"""Goes through the ``_internal_fields`` class attribute and updates this item
by resolving those references.
Returns any error.
Args:
skip_keys (tuple): don't resolve references for these keys.
Returns:
str or None: error description if any.
"""
for key in self._internal_fields:
if key in skip_keys:
continue
error = self._do_resolve_internal_field(key)
if error:
return error
def _do_resolve_internal_field(self, key):
src_key, target_key = self._internal_fields[key]
src_val = tuple(dict.pop(self, k, None) or self.get(k) for k in src_key)
if None in src_val:
return
ref_type, ref_key = self._alt_references[src_key]
mapped_table = self._db_map.mapped_table(ref_type)
if all(isinstance(v, (tuple, list)) for v in src_val):
refs = []
for v in zip(*src_val):
ref = mapped_table.find_item(dict(zip(ref_key, v)))
if not ref:
return f"can't find {ref_type} with {dict(zip(ref_key, v))}"
refs.append(ref)
self[key] = tuple(ref[target_key] for ref in refs)
else:
ref = mapped_table.find_item(dict(zip(ref_key, src_val)))
if not ref:
return f"can't find {ref_type} with {dict(zip(ref_key, src_val))}"
self[key] = ref[target_key]
[docs] def polish(self):
"""Polishes this item once all it's references have been resolved. Returns any error.
The base implementation sets defaults but subclasses can do more work if needed.
Returns:
str or None: error description if any.
"""
for key, default_value in self._defaults.items():
self.setdefault(key, default_value)
return ""
[docs] def check_mutability(self):
"""Called before adding, updating, or removing this item. Returns any errors that prevent that.
Returns:
str or None: error description if any.
"""
return ""
def _get_ref(self, ref_type, key_val, strong=True):
"""Collects a reference from the in-memory mapping.
Adds this item to the reference's list of referrers if strong is True;
or weak referrers if strong is False.
Args:
ref_type (str): The reference's type
key_val (dict): The reference's key and value to match
strong (bool): True if the reference corresponds to a foreign key, False otherwise
Returns:
MappedItemBase or dict
"""
mapped_table = self._db_map.mapped_table(ref_type)
ref = mapped_table.find_item(key_val, fetch=True)
if not ref:
return {}
if strong:
ref.add_referrer(self)
else:
ref.add_weak_referrer(self)
if ref.removed:
return {}
return ref
def _invalidate_ref(self, ref_type, key_val):
"""Invalidates a reference previously collected from the in-memory mapping.
Args:
ref_type (str): The reference's type
key_val (dict): The reference's key and value to match
"""
mapped_table = self._db_map.mapped_table(ref_type)
ref = mapped_table.find_item(key_val)
ref.remove_referrer(self)
[docs] def is_valid(self):
"""Checks if this item has all its references.
Removes the item from the in-memory mapping if not valid by calling ``cascade_remove``.
Returns:
bool
"""
if self.status == Status.compromised:
return False
self.validate()
return self._valid
[docs] def validate(self):
"""Resolves all references and checks if the item is valid.
The item is valid if it's not removed, has all of its references, and none of them is removed."""
if self._valid is not None:
return
refs = [ref for _, ref in self._resolve_refs()]
self._valid = not self._removed and all(ref and not ref.removed for ref in refs)
if not self._valid:
self.cascade_remove()
[docs] def add_referrer(self, referrer):
"""Adds a strong referrer to this item. Strong referrers are removed, updated and restored
in cascade with this item.
Args:
referrer (MappedItemBase)
"""
key = referrer.key
if key is None:
return
self._referrers[key] = self._weak_referrers.pop(key, referrer)
[docs] def remove_referrer(self, referrer):
"""Removes a strong referrer.
Args:
referrer (MappedItemBase)
"""
key = referrer.key
if key is not None:
self._referrers.pop(key, None)
[docs] def add_weak_referrer(self, referrer):
"""Adds a weak referrer to this item.
Weak referrers' update callbacks are called whenever this item changes.
Args:
referrer (MappedItemBase)
"""
key = referrer.key
if key is None:
return
if key not in self._referrers:
self._weak_referrers[key] = referrer
def _update_weak_referrers(self):
for weak_referrer in self._weak_referrers.values():
weak_referrer.call_update_callbacks()
[docs] def cascade_restore(self, source=None):
"""Restores this item (if removed) and all its referrers in cascade.
Also, updates items' status and calls their restore callbacks.
"""
if not self._removed:
return
if source is not self._removal_source:
return
if self.status in (Status.added_and_removed, Status.to_remove):
self._status = self._status_when_removed
elif self.status == Status.committed:
self._status = Status.to_add
else:
raise RuntimeError("invalid status for item being restored")
self._removed = False
self._valid = None
# First restore this, then referrers
obsolete = set()
for callback in list(self.restore_callbacks):
if not callback(self):
obsolete.add(callback)
self.restore_callbacks -= obsolete
for referrer in self._referrers.values():
referrer.cascade_restore(source=self)
self._update_weak_referrers()
[docs] def cascade_remove(self, source=None):
"""Removes this item and all its referrers in cascade.
Also, updates items' status and calls their remove callbacks.
"""
if self._removed:
return
self._status_when_removed = self._status
if self._status == Status.to_add:
self._status = Status.added_and_removed
elif self._status in (Status.committed, Status.to_update):
self._status = Status.to_remove
else:
raise RuntimeError("invalid status for item being removed")
self._removal_source = source
self._removed = True
self._valid = None
# First remove referrers, then this
for referrer in self._referrers.values():
referrer.cascade_remove(source=self)
self._update_weak_referrers()
obsolete = set()
for callback in list(self.remove_callbacks):
if not callback(self):
obsolete.add(callback)
self.remove_callbacks -= obsolete
[docs] def cascade_update(self):
"""Updates this item and all its referrers in cascade.
Also, calls items' update callbacks.
"""
if self._removed:
return
self.call_update_callbacks()
for referrer in self._referrers.values():
referrer.cascade_update()
self._update_weak_referrers()
def call_update_callbacks(self):
obsolete = set()
for callback in list(self.update_callbacks):
if not callback(self):
obsolete.add(callback)
self.update_callbacks -= obsolete
[docs] def cascade_add_unique(self):
"""Adds item and all its referrers unique keys and ids in cascade."""
mapped_table = self._db_map.mapped_table(self._item_type)
mapped_table.add_unique(self)
for referrer in self._referrers.values():
referrer.cascade_add_unique()
[docs] def cascade_remove_unique(self):
"""Removes item and all its referrers unique keys and ids in cascade."""
mapped_table = self._db_map.mapped_table(self._item_type)
mapped_table.remove_unique(self)
for referrer in self._referrers.values():
referrer.cascade_remove_unique()
[docs] def is_committed(self):
"""Returns whether or not this item is committed to the DB.
Returns:
bool
"""
return self._status == Status.committed
[docs] def commit(self, commit_id):
"""Sets this item as committed with the given commit id."""
self._status_when_committed = self._status
self._status = Status.committed
if commit_id:
self["commit_id"] = commit_id
def __repr__(self):
"""Overridden to return a more verbose representation."""
return f"{self._item_type}{self._extended()}"
def __getattr__(self, name):
"""Overridden to return the dictionary key named after the attribute, or None if it doesn't exist."""
# FIXME: We should try and get rid of this one
return self.get(name)
def __getitem__(self, key):
"""Overridden to return references."""
source_and_target_key = self._external_fields.get(key)
if source_and_target_key:
source_key, target_key = source_and_target_key
ref_type, ref_key = self._references[source_key]
ref = self._get_full_ref(source_key, ref_type, ref_key)
if isinstance(ref, tuple):
return tuple(r.get(target_key) for r in ref)
return ref.get(target_key)
return super().__getitem__(key)
def __setitem__(self, key, value):
"""Sets id valid if key is 'id'."""
if key == "id":
self._has_valid_id = True
super().__setitem__(key, value)
[docs] def get(self, key, default=None):
"""Overridden to return references."""
try:
return self[key]
except KeyError:
return default
[docs] def update(self, other):
"""Overridden to update the item status and also to invalidate references that become obsolete."""
if self._status == Status.committed:
self._status = Status.to_update
self._backup = self._asdict()
elif self._status in (Status.to_remove, Status.added_and_removed):
raise RuntimeError("invalid status of item being updated")
for src_key, (ref_type, ref_key) in self._references.items():
src_val = self[src_key]
if src_key in other and other[src_key] != src_val:
# Invalidate references
if isinstance(src_val, tuple):
for x in src_val:
self._invalidate_ref(ref_type, {ref_key: x})
else:
self._invalidate_ref(ref_type, {ref_key: src_val})
id_ = self["id"]
super().update(other)
self["id"] = id_
if self._asdict() == self._backup:
self._status = Status.committed
[docs] def force_id(self, id_):
"""Makes sure this item's has the given id_, corresponding to the new id of the item
in the DB after some external changes.
Args:
id_ (int): The most recent id_ of the item as fetched from the DB.
"""
mapped_id = self["id"]
if mapped_id == id_:
return
# Resolve the TempId to the new db id (and commit the item if pending)
mapped_id.resolve(id_)
if self.status == Status.to_add:
self.status = Status.committed
[docs] def handle_id_steal(self):
"""Called when a new item is fetched from the DB with this item's id."""
self["id"].unresolve()
# TODO: Test if the below works...
if self.is_committed():
self._status = self._status_when_committed
if self._status == Status.to_update:
self._status = Status.to_add
elif self._status == Status.to_remove:
self._status = Status.committed
self._status_when_removed = Status.to_add
class PublicItem:
def __init__(self, db_map, mapped_item):
self._db_map = db_map
self._mapped_item = mapped_item
@property
def item_type(self):
return self._mapped_item.item_type
def __getitem__(self, key):
return self._mapped_item[key]
def __eq__(self, other):
if isinstance(other, dict):
return self._mapped_item == other
return super().__eq__(other)
def __repr__(self):
return repr(self._mapped_item)
def __str__(self):
return str(self._mapped_item)
def get(self, key, default=None):
return self._mapped_item.get(key, default)
def validate(self):
self._mapped_item.validate()
def is_valid(self):
return self._mapped_item.is_valid()
def is_committed(self):
return self._mapped_item.is_committed()
def _asdict(self):
return self._mapped_item._asdict()
def _extended(self):
return self._mapped_item._extended()
def update(self, **kwargs):
self._db_map.update_item(self.item_type, id=self["id"], **kwargs)
def remove(self):
return self._db_map.remove_item(self.item_type, self["id"])
def restore(self):
return self._db_map.restore_item(self.item_type, self["id"])
def add_update_callback(self, callback):
self._mapped_item.update_callbacks.add(callback)
def add_remove_callback(self, callback):
self._mapped_item.remove_callbacks.add(callback)
def add_restore_callback(self, callback):
self._mapped_item.restore_callbacks.add(callback)
def resolve(self):
return self._mapped_item.resolve()