import contextlib
import dataclasses
from dataclasses import fields, is_dataclass
import sys
from typing import Any, Callable, Optional, Type, Union, overload
from typing_extensions import dataclass_transform # type: ignore
from cornflakes.common import recursive_update
from cornflakes.decorator.dataclasses._add_dataclass_slots import add_slots
from cornflakes.decorator.dataclasses._enforce_types import enforce_types
from cornflakes.decorator.dataclasses._field import Field, field
from cornflakes.decorator.dataclasses._helper import dc_field_without_default
from cornflakes.decorator.dataclasses._helper import dict_factory as d_factory
from cornflakes.decorator.dataclasses._helper import is_index
from cornflakes.decorator.dataclasses._helper import tuple_factory as t_factory
from cornflakes.decorator.dataclasses._helper import value_factory as v_factory
from cornflakes.decorator.dataclasses._validate import check_dataclass_kwargs, validate_dataclass_kwargs
from cornflakes.types import _T, Constants, CornflakesDataclass, MappingWrapper
if sys.version_info >= (3, 10):
@dataclass_transform(field_specifiers=(field, Field))
@overload
def dataclass(
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
kw_only: bool = False,
slots: bool = False,
match_args: bool = True,
dict_factory: Optional[Callable] = None,
tuple_factory: Optional[Callable] = None,
value_factory: Optional[Callable] = None,
eval_env: bool = False,
validate: bool = False,
updatable: bool = False,
ignore_none: bool = False,
**kwargs: Any,
) -> Callable[[Type[_T]], Union[Type[CornflakesDataclass], MappingWrapper[_T]]]:
...
@dataclass_transform(field_specifiers=(field, Field))
@overload
def dataclass(
_cls: Type[_T],
/,
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
kw_only: bool = False,
slots: bool = False,
match_args: bool = True,
dict_factory: Optional[Callable] = None,
tuple_factory: Optional[Callable] = None,
value_factory: Optional[Callable] = None,
eval_env: bool = False,
validate: bool = False,
updatable: bool = False,
ignore_none: bool = False,
**kwargs: Any,
) -> Union[Type[CornflakesDataclass], MappingWrapper[_T]]:
...
else:
@dataclass_transform(field_specifiers=(field, Field))
@overload
def dataclass(
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
dict_factory: Optional[Callable] = None,
tuple_factory: Optional[Callable] = None,
value_factory: Optional[Callable] = None,
eval_env: bool = False,
validate: bool = False,
updatable: bool = False,
ignore_none: bool = False,
**kwargs: Any,
) -> Callable[[Type[_T]], Union[Type[CornflakesDataclass], MappingWrapper[_T]]]:
...
@dataclass_transform(field_specifiers=(field, Field))
@overload
def dataclass(
_cls: Type[_T], # type: ignore
/,
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
dict_factory: Optional[Callable] = None,
tuple_factory: Optional[Callable] = None,
value_factory: Optional[Callable] = None,
eval_env: bool = False,
validate: bool = False,
updatable: bool = False,
ignore_none: bool = False,
**kwargs: Any,
) -> Union[Type[CornflakesDataclass], MappingWrapper[_T]]:
...
# @dataclass_transform(field_specifiers=(field, Field))
def dataclass(
cls: Optional[Type[_T]] = None,
/,
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
kw_only: bool = False,
slots: bool = False,
match_args: bool = True,
dict_factory: Optional[Callable] = None,
tuple_factory: Optional[Callable] = None,
value_factory: Optional[Callable] = None,
eval_env: bool = False,
validate: bool = False,
updatable: bool = False,
ignore_none: bool = False,
**kwargs: Any,
) -> Union[
Callable[[Type[_T]], Union[Type[CornflakesDataclass], MappingWrapper[_T]]],
Type[CornflakesDataclass],
MappingWrapper[_T],
]:
"""Wrapper around built-in dataclasses dataclass."""
if sys.version_info >= (3, 10):
kwargs = dict(kw_only=kw_only, slots=slots, match_args=match_args)
else:
kwargs = {}
def create_dataclass(w_cls: Type[_T]) -> Union[Type[CornflakesDataclass], MappingWrapper[_T]]:
"""
Create a Cornflakes dataclass from a regular dataclass.
:param w_cls: The class to create the Cornflakes dataclass from.
:type w_cls: type
:returns: A Cornflakes dataclass.
:rtype: type
"""
if not init and slots:
# this is not supported by dataclasses
raise AttributeError("Cannot specify both init=False and slots=True")
dc_cls = _wrap_custom_dataclass(
w_cls,
init=init,
repr=repr,
eq=eq,
order=order,
unsafe_hash=unsafe_hash,
frozen=frozen,
dict_factory=dict_factory,
tuple_factory=tuple_factory,
value_factory=value_factory,
eval_env=eval_env,
**kwargs,
)
if slots and sys.version_info < (3, 10):
dc_cls = add_slots(dc_cls)
if updatable:
if kwargs.get("frozen", False):
raise AttributeError("Cannot set both frozen=True and updatable=True")
def _update(self, new, merge_lists=False):
current = {**self}
with contextlib.suppress(AttributeError):
recursive_update(current, new, merge_lists=merge_lists)
return type(self)(**current)
dc_cls.update = _update
if validate:
dc_cls = enforce_types(dc_cls)
dc_cls.__doc__ = w_cls.__doc__
dc_cls.__module__ = w_cls.__module__
dc_cls.__qualname__ = w_cls.__qualname__
dc_cls = _wrap_mapping(dc_cls, ignore_none)
return dc_cls
return create_dataclass(cls) if cls else create_dataclass # type: ignore
def _zero_copy_astuple_inner(obj, value_factory=None):
if is_dataclass(obj):
result = []
for f in fields(obj):
value = _zero_copy_astuple_inner(getattr(obj, f.name), v_factory(obj))
result.append(value)
return t_factory(obj)(result)
if is_index(obj):
type(obj).reset()
return obj
elif isinstance(obj, tuple) and hasattr(obj, "_fields"):
# obj is a namedtuple. Recurse into it, but the returned
# object is another namedtuple of the same type. This is
# similar to how other list- or tuple-derived classes are
# treated (see below), but we just need to create them
# differently because a namedtuple's __init__ needs to be
# called differently (see bpo-34363).
return type(obj)(*[_zero_copy_astuple_inner(v) for v in obj])
elif isinstance(obj, (list, tuple)):
# Assume we can create an object of this type by passing in a
# generator (which is not true for namedtuples, handled
# above).
return type(obj)(_zero_copy_astuple_inner(v) for v in obj)
elif isinstance(obj, dict):
return type(obj)((_zero_copy_astuple_inner(k), _zero_copy_astuple_inner(v)) for k, v in obj.items())
else:
return value_factory(obj) if value_factory else obj
[Doku]
def to_tuple(self) -> Any: # noqa: C901
"""Method to convert Dataclass with slots to dict."""
return _zero_copy_astuple_inner(self)
def _zero_copy_asdict_inner(obj, value_factory=None):
"""Patched version of dataclasses._asdict_inner that does not copy the dataclass values."""
if is_dataclass(obj):
result = []
for f in fields(obj):
value = _zero_copy_asdict_inner(getattr(obj, f.name), v_factory(obj))
result.append((f.name, value))
return d_factory(obj)(result)
if is_index(obj):
type(obj).reset()
return obj
elif isinstance(obj, tuple) and hasattr(obj, "_fields"):
# obj is a namedtuple. Recurse into it, but the returned
# object is another namedtuple of the same type. This is
# similar to how other list- or tuple-derived classes are
# treated (see below), but we just need to create them
# differently because a namedtuple's __init__ needs to be
# called differently (see bpo-34363).
# I'm not using namedtuple's _asdict()
# method, because:
# - it does not recurse in to the namedtuple fields and
# convert them to dicts (using dict_factory).
# - I don't actually want to return a dict here. The main
# use case here is json.dumps, and it handles converting
# namedtuples to lists. Admittedly we're losing some
# information here when we produce a json list instead of a
# dict. Note that if we returned dicts here instead of
# namedtuples, we could no longer call asdict() on a data
# structure where a namedtuple was used as a dict key.
return type(obj)(*[_zero_copy_asdict_inner(v) for v in obj])
elif isinstance(obj, (list, tuple)):
# Assume we can create an object of this type by passing in a
# generator (which is not true for namedtuples, handled
# above).
return type(obj)(_zero_copy_asdict_inner(v) for v in obj)
elif isinstance(obj, dict):
return type(obj)((_zero_copy_asdict_inner(k), _zero_copy_asdict_inner(v)) for k, v in obj.items())
else:
return value_factory(obj) if value_factory else obj
# @profile
[Doku]
def to_dict(self) -> dict:
"""Method to convert Dataclass with slots to dict."""
return _zero_copy_asdict_inner(self)
def _new_getattr_dict(self, key: str):
return _zero_copy_asdict_inner(getattr(self, key), v_factory(self))
def _new_getattr_tuple(self, index: int):
return _zero_copy_astuple_inner(getattr(self, self.keys()[index]), v_factory(self))
def _new_getattr(self, index):
if isinstance(index, int):
return _new_getattr_tuple(self, index)
return _new_getattr_dict(self, index)
def _wrap_custom_dataclass(
w_cls,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
dict_factory: Optional[Callable] = None,
tuple_factory: Optional[Callable] = None,
value_factory: Optional[Callable] = None,
eval_env: bool = False,
**kwargs: Any,
):
dc_cls = dataclasses.dataclass( # type: ignore[call-overload]
w_cls,
init=init,
repr=repr,
eq=eq,
order=order,
unsafe_hash=unsafe_hash,
frozen=frozen,
**kwargs,
)
dict_factory = staticmethod(dict_factory) if callable(dict_factory) else dict # type: ignore
tuple_factory = staticmethod(tuple_factory) if callable(tuple_factory) else tuple # type: ignore
value_factory = staticmethod(value_factory) if callable(value_factory) else None # type: ignore
dataclass_fields = {
obj_name: getattr(w_cls, obj_name)
for obj_name in dir(w_cls)
if isinstance(getattr(w_cls, obj_name), Field) and hasattr(getattr(w_cls, obj_name), "aliases")
}
dc_cls.__dataclass_fields__.update(dataclass_fields)
setattr(dc_cls, Constants.dataclass_decorator.EVAL_ENV, eval_env)
setattr(dc_cls, Constants.dataclass_decorator.DICT_FACTORY, dict_factory) # type: ignore
setattr(dc_cls, Constants.dataclass_decorator.TUPLE_FACTORY, tuple_factory) # type: ignore
setattr(dc_cls, Constants.dataclass_decorator.VALUE_FACTORY, value_factory) # type: ignore
# Non-comparable fields should be compared via repr, so they are stored for later use
setattr(
dc_cls,
Constants.dataclass_decorator.IGNORED_SLOTS,
[f.name for f in dataclasses.fields(dc_cls) if getattr(f, "ignore", False)],
)
# setattr(dc_cls, Constants.dataclass_decorator.IGNORE_NONE, ignore_none)
setattr(
dc_cls,
Constants.dataclass_decorator.VALIDATORS,
{
key: validator
for key, value in dc_cls.__dataclass_fields__.items()
if callable(validator := getattr(value, "validator", key))
},
)
setattr(
dc_cls,
Constants.dataclass_decorator.REQUIRED_KEYS,
[key for key, slot in dc_cls.__dataclass_fields__.items() if dc_field_without_default(slot)],
)
setattr(
dc_cls,
Constants.dataclass_decorator.INIT_EXCLUDE_KEYS,
[key for key, slot in dc_cls.__dataclass_fields__.items() if not getattr(slot, "init", True)],
)
dc_cls.to_dict = to_dict
dc_cls.to_tuple = to_tuple
dc_cls.validate_kwargs = classmethod(validate_dataclass_kwargs)
dc_cls.check_kwargs = classmethod(check_dataclass_kwargs)
return dc_cls
def _wrap_mapping(dc_cls, ignore_none):
"""Wrap a mapping class."""
dc_cls.__getitem__ = _new_getattr
static_keys = [f.name for f in dataclasses.fields(dc_cls) if not getattr(f, "ignore", False)]
if not ignore_none:
def keys(_):
return static_keys #
def _len(_):
return len(static_keys)
dc_cls.__len__ = classmethod(_len)
dc_cls.keys = classmethod(keys)
else:
def keys(self):
return [key for key in static_keys if getattr(self, key) is not None]
def _len(self):
return len(keys(self))
dc_cls.keys = keys # not classmethod
dc_cls.__len__ = _len
return dc_cls