Skip to content

Commit

Permalink
Merge pull request #2861 from bagerard/warning_dup_Document_class
Browse files Browse the repository at this point in the history
refactor _document_registry + log a warning when user register multip…
  • Loading branch information
bagerard authored Oct 4, 2024
2 parents 11943d9 + f0de61e commit e77daa6
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 61 deletions.
3 changes: 3 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ Development
- make sure to read https://www.mongodb.com/docs/manual/core/transactions-in-applications/#callback-api-vs-core-api
- run_in_transaction context manager relies on Pymongo coreAPI, it will retry automatically in case of `UnknownTransactionCommitResult` but not `TransientTransactionError` exceptions
- Using .count() in a transaction will always use Collection.count_document (as estimated_document_count is not supported in transactions)
- BREAKING CHANGE: wrap _document_registry (normally not used by end users) with _DocumentRegistry which acts as a singleton to access the registry
- Log a warning in case users creates multiple Document classes with the same name as it can lead to unexpected behavior #1778
- Fix use of $geoNear or $collStats in aggregate #2493
- BREAKING CHANGE: Further to the deprecation warning, remove ability to use an unpacked list to `Queryset.aggregate(*pipeline)`, a plain list must be provided instead `Queryset.aggregate(pipeline)`, as it's closer to pymongo interface
- BREAKING CHANGE: Further to the deprecation warning, remove `full_response` from `QuerySet.modify` as it wasn't supported with Pymongo 3+
Expand All @@ -21,6 +23,7 @@ Development
- BREAKING CHANGE: Remove LongField as it's equivalent to IntField since we drop support to Python2 long time ago (User should simply switch to IntField) #2309
- BugFix - Calling .clear on a ListField wasn't being marked as changed (and flushed to db upon .save()) #2858


Changes in 0.29.0
=================
- Fix weakref in EmbeddedDocumentListField (causing brief mem leak in certain circumstances) #2827
Expand Down
3 changes: 1 addition & 2 deletions mongoengine/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
__all__ = (
# common
"UPDATE_OPERATORS",
"_document_registry",
"get_document",
"_DocumentRegistry",
# datastructures
"BaseDict",
"BaseList",
Expand Down
77 changes: 54 additions & 23 deletions mongoengine/base/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import warnings

from mongoengine.errors import NotRegistered

__all__ = ("UPDATE_OPERATORS", "get_document", "_document_registry")
__all__ = ("UPDATE_OPERATORS", "_DocumentRegistry")


UPDATE_OPERATORS = {
Expand All @@ -25,28 +27,57 @@
_document_registry = {}


def get_document(name):
"""Get a registered Document class by name."""
doc = _document_registry.get(name, None)
if not doc:
# Possible old style name
single_end = name.split(".")[-1]
compound_end = ".%s" % single_end
possible_match = [
k for k in _document_registry if k.endswith(compound_end) or k == single_end
]
if len(possible_match) == 1:
doc = _document_registry.get(possible_match.pop(), None)
if not doc:
raise NotRegistered(
"""
`%s` has not been registered in the document registry.
Importing the document class automatically registers it, has it
been imported?
""".strip()
% name
)
return doc
class _DocumentRegistry:
"""Wrapper for the document registry (providing a singleton pattern).
This is part of MongoEngine's internals, not meant to be used directly by end-users
"""

@staticmethod
def get(name):
doc = _document_registry.get(name, None)
if not doc:
# Possible old style name
single_end = name.split(".")[-1]
compound_end = ".%s" % single_end
possible_match = [
k
for k in _document_registry
if k.endswith(compound_end) or k == single_end
]
if len(possible_match) == 1:
doc = _document_registry.get(possible_match.pop(), None)
if not doc:
raise NotRegistered(
"""
`%s` has not been registered in the document registry.
Importing the document class automatically registers it, has it
been imported?
""".strip()
% name
)
return doc

@staticmethod
def register(DocCls):
ExistingDocCls = _document_registry.get(DocCls._class_name)
if (
ExistingDocCls is not None
and ExistingDocCls.__module__ != DocCls.__module__
):
# A sign that a codebase may have named two different classes with the same name accidentally,
# this could cause issues with dereferencing because MongoEngine makes the assumption that a Document
# class name is unique.
warnings.warn(
f"Multiple Document classes named `{DocCls._class_name}` were registered, "
f"first from: `{ExistingDocCls.__module__}`, then from: `{DocCls.__module__}`. "
"this may lead to unexpected behavior during dereferencing.",
stacklevel=4,
)
_document_registry[DocCls._class_name] = DocCls

@staticmethod
def unregister(doc_cls_name):
_document_registry.pop(doc_cls_name)


def _get_documents_by_db(connection_alias, default_connection_alias):
Expand Down
6 changes: 3 additions & 3 deletions mongoengine/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from bson import SON, DBRef, ObjectId, json_util

from mongoengine import signals
from mongoengine.base.common import get_document
from mongoengine.base.common import _DocumentRegistry
from mongoengine.base.datastructures import (
BaseDict,
BaseList,
Expand Down Expand Up @@ -500,7 +500,7 @@ def __expand_dynamic_values(self, name, value):
# If the value is a dict with '_cls' in it, turn it into a document
is_dict = isinstance(value, dict)
if is_dict and "_cls" in value:
cls = get_document(value["_cls"])
cls = _DocumentRegistry.get(value["_cls"])
return cls(**value)

if is_dict:
Expand Down Expand Up @@ -802,7 +802,7 @@ def _from_son(cls, son, _auto_dereference=True, created=False):

# Return correct subclass for document type
if class_name != cls._class_name:
cls = get_document(class_name)
cls = _DocumentRegistry.get(class_name)

errors_dict = {}

Expand Down
4 changes: 2 additions & 2 deletions mongoengine/base/metaclasses.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import itertools
import warnings

from mongoengine.base.common import _document_registry
from mongoengine.base.common import _DocumentRegistry
from mongoengine.base.fields import (
BaseField,
ComplexBaseField,
Expand Down Expand Up @@ -169,7 +169,7 @@ def __new__(mcs, name, bases, attrs):
new_class._collection = None

# Add class to the _document_registry
_document_registry[new_class._class_name] = new_class
_DocumentRegistry.register(new_class)

# Handle delete rules
for field in new_class._fields.values():
Expand Down
20 changes: 10 additions & 10 deletions mongoengine/dereference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
BaseList,
EmbeddedDocumentList,
TopLevelDocumentMetaclass,
get_document,
_DocumentRegistry,
)
from mongoengine.base.datastructures import LazyReference
from mongoengine.connection import _get_session, get_db
Expand Down Expand Up @@ -131,9 +131,9 @@ def _find_references(self, items, depth=0):
elif isinstance(v, DBRef):
reference_map.setdefault(field.document_type, set()).add(v.id)
elif isinstance(v, (dict, SON)) and "_ref" in v:
reference_map.setdefault(get_document(v["_cls"]), set()).add(
v["_ref"].id
)
reference_map.setdefault(
_DocumentRegistry.get(v["_cls"]), set()
).add(v["_ref"].id)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
field_cls = getattr(
getattr(field, "field", None), "document_type", None
Expand All @@ -151,9 +151,9 @@ def _find_references(self, items, depth=0):
elif isinstance(item, DBRef):
reference_map.setdefault(item.collection, set()).add(item.id)
elif isinstance(item, (dict, SON)) and "_ref" in item:
reference_map.setdefault(get_document(item["_cls"]), set()).add(
item["_ref"].id
)
reference_map.setdefault(
_DocumentRegistry.get(item["_cls"]), set()
).add(item["_ref"].id)
elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth:
references = self._find_references(item, depth - 1)
for key, refs in references.items():
Expand Down Expand Up @@ -198,9 +198,9 @@ def _fetch_objects(self, doc_type=None):
)
for ref in references:
if "_cls" in ref:
doc = get_document(ref["_cls"])._from_son(ref)
doc = _DocumentRegistry.get(ref["_cls"])._from_son(ref)
elif doc_type is None:
doc = get_document(
doc = _DocumentRegistry.get(
"".join(x.capitalize() for x in collection.split("_"))
)._from_son(ref)
else:
Expand Down Expand Up @@ -235,7 +235,7 @@ def _attach_objects(self, items, depth=0, instance=None, name=None):
(items["_ref"].collection, items["_ref"].id), items
)
elif "_cls" in items:
doc = get_document(items["_cls"])._from_son(items)
doc = _DocumentRegistry.get(items["_cls"])._from_son(items)
_cls = doc._data.pop("_cls", None)
del items["_cls"]
doc._data = self._attach_objects(doc._data, depth, doc, None)
Expand Down
6 changes: 3 additions & 3 deletions mongoengine/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
DocumentMetaclass,
EmbeddedDocumentList,
TopLevelDocumentMetaclass,
get_document,
_DocumentRegistry,
)
from mongoengine.base.utils import NonOrderedList
from mongoengine.common import _import_class
Expand Down Expand Up @@ -851,12 +851,12 @@ def register_delete_rule(cls, document_cls, field_name, rule):
object.
"""
classes = [
get_document(class_name)
_DocumentRegistry.get(class_name)
for class_name in cls._subclasses
if class_name != cls.__name__
] + [cls]
documents = [
get_document(class_name)
_DocumentRegistry.get(class_name)
for class_name in document_cls._subclasses
if class_name != document_cls.__name__
] + [document_cls]
Expand Down
20 changes: 10 additions & 10 deletions mongoengine/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
GeoJsonBaseField,
LazyReference,
ObjectIdField,
get_document,
_DocumentRegistry,
)
from mongoengine.base.utils import LazyRegexCompiler
from mongoengine.common import _import_class
Expand Down Expand Up @@ -725,7 +725,7 @@ def document_type(self):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
resolved_document_type = self.owner_document
else:
resolved_document_type = get_document(self.document_type_obj)
resolved_document_type = _DocumentRegistry.get(self.document_type_obj)

if not issubclass(resolved_document_type, EmbeddedDocument):
# Due to the late resolution of the document_type
Expand Down Expand Up @@ -801,7 +801,7 @@ def prepare_query_value(self, op, value):

def to_python(self, value):
if isinstance(value, dict):
doc_cls = get_document(value["_cls"])
doc_cls = _DocumentRegistry.get(value["_cls"])
value = doc_cls._from_son(value)

return value
Expand Down Expand Up @@ -879,7 +879,7 @@ def to_mongo(self, value, use_db_field=True, fields=None):

def to_python(self, value):
if isinstance(value, dict) and "_cls" in value:
doc_cls = get_document(value["_cls"])
doc_cls = _DocumentRegistry.get(value["_cls"])
if "_ref" in value:
value = doc_cls._get_db().dereference(
value["_ref"], session=_get_session()
Expand Down Expand Up @@ -1171,7 +1171,7 @@ def document_type(self):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document
else:
self.document_type_obj = get_document(self.document_type_obj)
self.document_type_obj = _DocumentRegistry.get(self.document_type_obj)
return self.document_type_obj

@staticmethod
Expand All @@ -1195,7 +1195,7 @@ def __get__(self, instance, owner):
if auto_dereference and isinstance(ref_value, DBRef):
if hasattr(ref_value, "cls"):
# Dereference using the class type specified in the reference
cls = get_document(ref_value.cls)
cls = _DocumentRegistry.get(ref_value.cls)
else:
cls = self.document_type

Expand Down Expand Up @@ -1335,7 +1335,7 @@ def document_type(self):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document
else:
self.document_type_obj = get_document(self.document_type_obj)
self.document_type_obj = _DocumentRegistry.get(self.document_type_obj)
return self.document_type_obj

@staticmethod
Expand Down Expand Up @@ -1498,7 +1498,7 @@ def __get__(self, instance, owner):

auto_dereference = instance._fields[self.name]._auto_dereference
if auto_dereference and isinstance(value, dict):
doc_cls = get_document(value["_cls"])
doc_cls = _DocumentRegistry.get(value["_cls"])
instance._data[self.name] = self._lazy_load_ref(doc_cls, value["_ref"])

return super().__get__(instance, owner)
Expand Down Expand Up @@ -2443,7 +2443,7 @@ def document_type(self):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document
else:
self.document_type_obj = get_document(self.document_type_obj)
self.document_type_obj = _DocumentRegistry.get(self.document_type_obj)
return self.document_type_obj

def build_lazyref(self, value):
Expand Down Expand Up @@ -2584,7 +2584,7 @@ def build_lazyref(self, value):
elif value is not None:
if isinstance(value, (dict, SON)):
value = LazyReference(
get_document(value["_cls"]),
_DocumentRegistry.get(value["_cls"]),
value["_ref"].id,
passthrough=self.passthrough,
)
Expand Down
6 changes: 4 additions & 2 deletions mongoengine/queryset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pymongo.read_concern import ReadConcern

from mongoengine import signals
from mongoengine.base import get_document
from mongoengine.base import _DocumentRegistry
from mongoengine.common import _import_class
from mongoengine.connection import _get_session, get_db
from mongoengine.context_managers import (
Expand Down Expand Up @@ -1956,7 +1956,9 @@ def _fields_to_dbfields(self, fields):
"""Translate fields' paths to their db equivalents."""
subclasses = []
if self._document._meta["allow_inheritance"]:
subclasses = [get_document(x) for x in self._document._subclasses][1:]
subclasses = [_DocumentRegistry.get(x) for x in self._document._subclasses][
1:
]

db_field_paths = []
for field in fields:
Expand Down
8 changes: 4 additions & 4 deletions tests/document/test_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from mongoengine import *
from mongoengine import signals
from mongoengine.base import _document_registry, get_document
from mongoengine.base import _DocumentRegistry
from mongoengine.connection import get_db
from mongoengine.context_managers import query_counter, switch_db
from mongoengine.errors import (
Expand Down Expand Up @@ -392,7 +392,7 @@ class NicePlace(Place):

# Mimic Place and NicePlace definitions being in a different file
# and the NicePlace model not being imported in at query time.
del _document_registry["Place.NicePlace"]
_DocumentRegistry.unregister("Place.NicePlace")

with pytest.raises(NotRegistered):
list(Place.objects.all())
Expand All @@ -407,8 +407,8 @@ class Area(Location):

Location.drop_collection()

assert Area == get_document("Area")
assert Area == get_document("Location.Area")
assert Area == _DocumentRegistry.get("Area")
assert Area == _DocumentRegistry.get("Location.Area")

def test_creation(self):
"""Ensure that document may be created using keyword arguments."""
Expand Down
4 changes: 2 additions & 2 deletions tests/fields/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from mongoengine.base import (
BaseField,
EmbeddedDocumentList,
_document_registry,
_DocumentRegistry,
)
from mongoengine.base.fields import _no_dereference_for_fields
from mongoengine.errors import DeprecatedError
Expand Down Expand Up @@ -1678,7 +1678,7 @@ class User(Document):

# Mimic User and Link definitions being in a different file
# and the Link model not being imported in the User file.
del _document_registry["Link"]
_DocumentRegistry.unregister("Link")

user = User.objects.first()
try:
Expand Down

0 comments on commit e77daa6

Please sign in to comment.