docs for
muutilsv0.6.21
muutils, stylized as “μutils” or “μutils”, is a collection
of miscellaneous python utilities, meant to be small and with no
dependencies outside of standard python.
PyPi: muutils
pip install muutils
Note that for using mlutils, tensor_utils,
nbutils.configure_notebook, or the array serialization
features of json_serialize, you will need to install with
optional array dependencies:
pip install muutils[array]
hosted html docs: https://miv.name/muutils
statcounteran extension of collections.Counter that provides
“smart” computation of stats (mean, variance, median, other percentiles)
from the counter object without using
Counter.elements()
dictmagichas utilities for working with dictionaries, like:
python >>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3}) {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}} >>> nested_dict_to_dotlist({'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}) {'a.b.c': 1, 'a.b.d': 2, 'a.e': 3}DefaulterDict which works like a
defaultdict but can generate the default value based on the
keycondense_tensor_dict takes a dict of dotlist-tensors
and gives a more human-readable summary:
python >>> model = MyGPT() >>> print(condense_tensor_dict(model.named_parameters(), 'yaml'))
yaml embed: W_E: (50257, 768) pos_embed: W_pos: (1024, 768) blocks: '[0-11]': attn: '[W_Q, W_K, W_V]': (12, 768, 64) W_O: (12, 64, 768) '[b_Q, b_K, b_V]': (12, 64) b_O: (768,) <...>kappaAnonymous gettitem, so you can do things like
>>> k = Kappa(lambda x: x**2)
>>> k[2]
4sysinfoutility for getting a bunch of system information. useful for logging.
misc:contains a few utilities: - stable_hash() uses
hashlib.sha256 to compute a hash of an object that is
stable across runs of python - list_join and
list_split which behave like str.join and
str.split but for lists - sanitize_fname and
dict_to_filename for simplifying the creation of unique
filename - shorten_numerical_to_str() and
str_to_numeric turns numbers like 123456789
into "123M" and back - freeze, which prevents
an object from being modified. Also see gelidum
nbutilscontains utilities for working with jupyter notebooks, such as:
json_serializea tool for serializing and loading arbitrary python objects into
json. plays nicely with ZANJ
tensor_utils]contains minor utilities for working with pytorch tensors and numpy arrays, mostly for making type conversions easier
group_equivgroups elements from a sequence according to a given equivalence relation, without assuming that the equivalence relation obeys the transitive property
jsonlinesan extremely simple utility for reading/writing jsonl
files
ZANJis a human-readable and simple format for ML models, datasets, and
arbitrary objects. It’s build around having a zip file with
json and npy files, and has been spun off into
its own project.
There are a couple work-in-progress utilities in _wip
that aren’t ready for anything, but nothing in this repo is suitable for
production. Use at your own risk!
json_serializeloggermiscnbutilsconsole_unicodedictmagicerrormodegroup_equivintervaljsonlineskappamlutilsparallelspinnerstatcountersysinfotensor_utilstimeit_fancyvalidate_typemuutils
muutils, stylized as “μutils” or “μutils”, is a collection
of miscellaneous python utilities, meant to be small and with no
dependencies outside of standard python.
PyPi: muutils
pip install muutils
Note that for using mlutils, tensor_utils,
nbutils.configure_notebook, or the array serialization
features of json_serialize, you will need to install with
optional array dependencies:
pip install muutils[array]
hosted html docs: https://miv.name/muutils
statcounteran extension of collections.Counter that provides
“smart” computation of stats (mean, variance, median, other percentiles)
from the counter object without using
Counter.elements()
dictmagichas utilities for working with dictionaries, like:
python >>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3}) {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}} >>> nested_dict_to_dotlist({'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}) {'a.b.c': 1, 'a.b.d': 2, 'a.e': 3}DefaulterDict which works like a
defaultdict but can generate the default value based on the
keycondense_tensor_dict takes a dict of dotlist-tensors
and gives a more human-readable summary:
python >>> model = MyGPT() >>> print(condense_tensor_dict(model.named_parameters(), 'yaml'))
yaml embed: W_E: (50257, 768) pos_embed: W_pos: (1024, 768) blocks: '[0-11]': attn: '[W_Q, W_K, W_V]': (12, 768, 64) W_O: (12, 64, 768) '[b_Q, b_K, b_V]': (12, 64) b_O: (768,) <...>kappaAnonymous gettitem, so you can do things like
>>> k = Kappa(lambda x: x**2)
>>> k[2]
4sysinfoutility for getting a bunch of system information. useful for logging.
misc:contains a few utilities: - stable_hash() uses
hashlib.sha256 to compute a hash of an object that is
stable across runs of python - list_join and
list_split which behave like str.join and
str.split but for lists - sanitize_fname and
dict_to_filename for simplifying the creation of unique
filename - shorten_numerical_to_str() and
str_to_numeric turns numbers like 123456789
into "123M" and back - freeze, which prevents
an object from being modified. Also see gelidum
nbutilscontains utilities for working with jupyter notebooks, such as:
json_serializea tool for serializing and loading arbitrary python objects into
json. plays nicely with ZANJ
tensor_utils]contains minor utilities for working with pytorch tensors and numpy arrays, mostly for making type conversions easier
group_equivgroups elements from a sequence according to a given equivalence relation, without assuming that the equivalence relation obeys the transitive property
jsonlinesan extremely simple utility for reading/writing jsonl
files
ZANJis a human-readable and simple format for ML models, datasets, and
arbitrary objects. It’s build around having a zip file with
json and npy files, and has been spun off into
its own project.
There are a couple work-in-progress utilities in _wip
that aren’t ready for anything, but nothing in this repo is suitable for
production. Use at your own risk!
docs for
muutilsv0.6.21
muutils.console_unicodedef get_console_safe_str(default: str, fallback: str) -> strDetermine a console-safe string based on the preferred encoding.
This function attempts to encode a given default string
using the system’s preferred encoding. If encoding is successful, it
returns the default string; otherwise, it returns a
fallback string.
default : str The primary string intended for use, to
be tested against the system’s preferred encoding.fallback : str The alternative string to be used if
default cannot be encoded in the system’s preferred
encoding.str Either default or
fallback based on whether default can be
encoded safely.>>> get_console_safe_str("café", "cafe")
"café" # This result may vary based on the system's preferred encoding.docs for
muutilsv0.6.21
making working with dictionaries easier
DefaulterDict: like a defaultdict, but default_factory
is passed the key as an argumentcondense_nested_dicts: condense a nested dict, by
condensing numeric or matching keys with matching values to rangescondense_tensor_dict: convert a dictionary of tensors
to a dictionary of shapeskwargs_to_nested_dict: given kwargs from fire, convert
them to a nested dictDefaulterDictdefaultdict_to_dict_recursivedotlist_to_nested_dictnested_dict_to_dotlistupdate_with_nested_dictkwargs_to_nested_dictis_numeric_consecutivecondense_nested_dicts_numeric_keyscondense_nested_dicts_matching_valuescondense_nested_dictstuple_dims_replaceTensorDictTensorIterableTensorDictFormatscondense_tensor_dictmuutils.dictmagicmaking working with dictionaries easier
DefaulterDict: like a defaultdict, but default_factory
is passed the key as an argumentcondense_nested_dicts: condense a nested dict, by
condensing numeric or matching keys with matching values to rangescondense_tensor_dict: convert a dictionary of tensors
to a dictionary of shapeskwargs_to_nested_dict: given kwargs from fire, convert
them to a nested dictclass DefaulterDict(typing.Dict[~_KT, ~_VT], typing.Generic[~_KT, ~_VT]):like a defaultdict, but default_factory is passed the key as an argument
default_factory: Callable[[~_KT], ~_VT]def defaultdict_to_dict_recursive(
dd: Union[collections.defaultdict, muutils.dictmagic.DefaulterDict]
) -> dictConvert a defaultdict or DefaulterDict to a normal dict, recursively
def dotlist_to_nested_dict(dot_dict: Dict[str, Any], sep: str = '.') -> Dict[str, Any]Convert a dict with dot-separated keys to a nested dict
Example:
>>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3})
{'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
def nested_dict_to_dotlist(
nested_dict: Dict[str, Any],
sep: str = '.',
allow_lists: bool = False
) -> dict[str, typing.Any]def update_with_nested_dict(
original: dict[str, typing.Any],
update: dict[str, typing.Any]
) -> dict[str, typing.Any]Update a dict with a nested dict
Example: >>> update_with_nested_dict({‘a’: {‘b’: 1}, “c”: -1}, {‘a’: {“b”: 2}}) {‘a’: {‘b’: 2}, ‘c’: -1}
original: dict[str, Any] the dict to update (will be
modified in-place)update: dict[str, Any] the dict to update withdict the updated dictdef kwargs_to_nested_dict(
kwargs_dict: dict[str, typing.Any],
sep: str = '.',
strip_prefix: Optional[str] = None,
when_unknown_prefix: muutils.errormode.ErrorMode = ErrorMode.Warn,
transform_key: Optional[Callable[[str], str]] = None
) -> dict[str, typing.Any]given kwargs from fire, convert them to a nested dict
if strip_prefix is not None, then all keys must start with the
prefix. by default, will warn if an unknown prefix is found, but can be
set to raise an error or ignore it:
when_unknown_prefix: ErrorMode
Example:
def main(**kwargs):
print(kwargs_to_nested_dict(kwargs))
fire.Fire(main)running the above script will give:
$ python test.py --a.b.c=1 --a.b.d=2 --a.e=3
{'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}kwargs_dict: dict[str, Any] the kwargs dict to
convertsep: str = "." the separator to use for nested
keysstrip_prefix: Optional[str] = None if not None, then
all keys must start with this prefixwhen_unknown_prefix: ErrorMode = ErrorMode.WARN what to
do when an unknown prefix is foundtransform_key: Callable[[str], str] | None = None a
function to apply to each key before adding it to the dict (applied
after stripping the prefix)def is_numeric_consecutive(lst: list[str]) -> boolCheck if the list of keys is numeric and consecutive.
def condense_nested_dicts_numeric_keys(data: dict[str, typing.Any]) -> dict[str, typing.Any]condense a nested dict, by condensing numeric keys with matching values to ranges
>>> condense_nested_dicts_numeric_keys({'1': 1, '2': 1, '3': 1, '4': 2, '5': 2, '6': 2})
{'[1-3]': 1, '[4-6]': 2}
>>> condense_nested_dicts_numeric_keys({'1': {'1': 'a', '2': 'a'}, '2': 'b'})
{"1": {"[1-2]": "a"}, "2": "b"}def condense_nested_dicts_matching_values(
data: dict[str, typing.Any],
val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None
) -> dict[str, typing.Any]condense a nested dict, by condensing keys with matching values
data : dict[str, Any] data to processval_condense_fallback_mapping : Callable[[Any], Hashable] | None
a function to apply to each value before adding it to the dict (if it’s
not hashable) (defaults to None)def condense_nested_dicts(
data: dict[str, typing.Any],
condense_numeric_keys: bool = True,
condense_matching_values: bool = True,
val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None
) -> dict[str, typing.Any]condense a nested dict, by condensing numeric or matching keys with matching values to ranges
combines the functionality of
condense_nested_dicts_numeric_keys() and
condense_nested_dicts_matching_values()
it’s not reversible because types are lost to make the printing pretty
data : dict[str, Any] data to processcondense_numeric_keys : bool whether to condense
numeric keys (e.g. “1”, “2”, “3”) to ranges (e.g. “[1-3]”) (defaults to
True)condense_matching_values : bool whether to condense
keys with matching values (defaults to True)val_condense_fallback_mapping : Callable[[Any], Hashable] | None
a function to apply to each value before adding it to the dict (if it’s
not hashable) (defaults to None)def tuple_dims_replace(
t: tuple[int, ...],
dims_names_map: Optional[dict[int, str]] = None
) -> tuple[typing.Union[int, str], ...]TensorDict = typing.Dict[str, ForwardRef('torch.Tensor|np.ndarray')]
TensorIterable = typing.Iterable[typing.Tuple[str, ForwardRef('torch.Tensor|np.ndarray')]]
TensorDictFormats = typing.Literal['dict', 'json', 'yaml', 'yml']
def condense_tensor_dict(
data: 'TensorDict | TensorIterable',
fmt: Literal['dict', 'json', 'yaml', 'yml'] = 'dict',
*args,
shapes_convert: Callable[[tuple], Any] = <function _default_shapes_convert>,
drop_batch_dims: int = 0,
sep: str = '.',
dims_names_map: Optional[dict[int, str]] = None,
condense_numeric_keys: bool = True,
condense_matching_values: bool = True,
val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None,
return_format: Optional[Literal['dict', 'json', 'yaml', 'yml']] = None
) -> Union[str, dict[str, str | tuple[int, ...]]]Convert a dictionary of tensors to a dictionary of shapes.
by default, values are converted to strings of their shapes (for nice
printing). If you want the actual shapes, set
shapes_convert = lambda x: x or
shapes_convert = None.
data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]]
a either a TensorDict dict from strings to tensors, or an
TensorIterable iterable of (key, tensor) pairs (like you
might get from a dict().items()) )fmt : TensorDictFormats format to return the result in
– either a dict, or dump to json/yaml directly for pretty printing. will
crash if yaml is not installed. (defaults to 'dict')shapes_convert : Callable[[tuple], Any] conversion of a
shape tuple to a string or other format (defaults to turning it into a
string and removing quotes) (defaults to
lambdax:str(x).replace('"', '').replace("'", ''))drop_batch_dims : int number of leading dimensions to
drop from the shape (defaults to 0)sep : str separator to use for nested keys (defaults to
'.')dims_names_map : dict[int, str] | None convert certain
dimension values in shape. not perfect, can be buggy (defaults to
None)condense_numeric_keys : bool whether to condense
numeric keys (e.g. “1”, “2”, “3”) to ranges (e.g. “[1-3]”), passed on to
condense_nested_dicts (defaults to True)condense_matching_values : bool whether to condense
keys with matching values, passed on to
condense_nested_dicts (defaults to True)val_condense_fallback_mapping : Callable[[Any], Hashable] | None
a function to apply to each value before adding it to the dict (if it’s
not hashable), passed on to condense_nested_dicts (defaults
to None)return_format : TensorDictFormats | None legacy alias
for fmt kwargstr|dict[str, str|tuple[int, ...]] dict if
return_format='dict', a string for json or
yaml output>>> model = transformer_lens.HookedTransformer.from_pretrained("gpt2")
>>> print(condense_tensor_dict(model.named_parameters(), return_format='yaml'))embed:
W_E: (50257, 768)
pos_embed:
W_pos: (1024, 768)
blocks:
'[0-11]':
attn:
'[W_Q, W_K, W_V]': (12, 768, 64)
W_O: (12, 64, 768)
'[b_Q, b_K, b_V]': (12, 64)
b_O: (768,)
mlp:
W_in: (768, 3072)
b_in: (3072,)
W_out: (3072, 768)
b_out: (768,)
unembed:
W_U: (768, 50257)
b_U: (50257,)ValueError : if return_format is not one
of ‘dict’, ‘json’, or ‘yaml’, or if you try to use ‘yaml’ output without
having PyYAML installeddocs for
muutilsv0.6.21
provides ErrorMode enum for handling errors
consistently
pass an error_mode: ErrorMode to a function to specify
how to handle a certain kind of exception. That function then instead of
raiseing or warnings.warning, calls
error_mode.process with the message and the exception.
you can also specify the exception class to raise, the warning class to use, and the source of the exception/warning.
WarningFuncLoggingFuncGLOBAL_WARN_FUNCGLOBAL_LOG_FUNCcustom_showwarningErrorModeERROR_MODE_ALIASESmuutils.errormodeprovides ErrorMode enum for handling errors
consistently
pass an error_mode: ErrorMode to a function to specify
how to handle a certain kind of exception. That function then instead of
raiseing or warnings.warning, calls
error_mode.process with the message and the exception.
you can also specify the exception class to raise, the warning class to use, and the source of the exception/warning.
class WarningFunc(typing.Protocol):Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto[T](Protocol):
def meth(self) -> T:
...
WarningFunc(*args, **kwargs)LoggingFunc = typing.Callable[[str], NoneType]def GLOBAL_WARN_FUNC(unknown)Issue a warning, or maybe ignore it or raise an exception.
message Text of the warning message. category The Warning category subclass. Defaults to UserWarning. stacklevel How far up the call stack to make this warning appear. A value of 2 for example attributes the warning to the caller of the code calling warn(). source If supplied, the destroyed object which emitted a ResourceWarning skip_file_prefixes An optional tuple of module filename prefixes indicating frames to skip during stacklevel computations for stack frame attribution.
def GLOBAL_LOG_FUNC(*args, sep=' ', end='\n', file=None, flush=False)Prints the values to a stream, or to sys.stdout by default.
sep string inserted between values, default a space. end string appended after the last value, default a newline. file a file-like object (stream); defaults to the current sys.stdout. flush whether to forcibly flush the stream.
def custom_showwarning(
message: Warning | str,
category: Optional[Type[Warning]] = None,
filename: str | None = None,
lineno: int | None = None,
file: Optional[TextIO] = None,
line: Optional[str] = None
) -> Noneclass ErrorMode(enum.Enum):Enum for handling errors consistently
pass one of the instances of this enum to a function to specify how to handle a certain kind of exception.
That function then instead of raiseing or
warnings.warning, calls error_mode.process
with the message and the exception.
EXCEPT = ErrorMode.Except
WARN = ErrorMode.Warn
LOG = ErrorMode.Log
IGNORE = ErrorMode.Ignore
def process(
self,
msg: str,
except_cls: Type[Exception] = <class 'ValueError'>,
warn_cls: Type[Warning] = <class 'UserWarning'>,
except_from: Optional[Exception] = None,
warn_func: muutils.errormode.WarningFunc | None = None,
log_func: Optional[Callable[[str], NoneType]] = None
)process an exception or warning according to the error mode
msg : str message to pass to except_cls or
warn_funcexcept_cls : typing.Type[Exception] exception class to
raise, must be a subclass of Exception (defaults to
ValueError)warn_cls : typing.Type[Warning] warning class to use,
must be a subclass of Warning (defaults to
UserWarning)except_from : typing.Optional[Exception] will
raise except_cls(msg) from except_from if not
None (defaults to None)warn_func : WarningFunc | None function to use for
warnings, must have the signature
warn_func(msg: str, category: typing.Type[Warning], source: typing.Any = None) -> None
(defaults to None)log_func : LoggingFunc | None function to use for
logging, must have the signature
log_func(msg: str) -> None (defaults to
None)except_cls : descriptionexcept_cls : descriptionValueError : descriptiondef from_any(
cls,
mode: str | muutils.errormode.ErrorMode,
allow_aliases: bool = True,
allow_prefix: bool = True
) -> muutils.errormode.ErrorModeinitialize an ErrorMode from a string or an
ErrorMode instance
def serialize(self) -> strdef load(cls, data: str) -> muutils.errormode.ErrorModeERROR_MODE_ALIASES: dict[str, muutils.errormode.ErrorMode] = {'except': ErrorMode.Except, 'warn': ErrorMode.Warn, 'log': ErrorMode.Log, 'ignore': ErrorMode.Ignore, 'e': ErrorMode.Except, 'error': ErrorMode.Except, 'err': ErrorMode.Except, 'raise': ErrorMode.Except, 'w': ErrorMode.Warn, 'warning': ErrorMode.Warn, 'l': ErrorMode.Log, 'print': ErrorMode.Log, 'output': ErrorMode.Log, 'show': ErrorMode.Log, 'display': ErrorMode.Log, 'i': ErrorMode.Ignore, 'silent': ErrorMode.Ignore, 'quiet': ErrorMode.Ignore, 'nothing': ErrorMode.Ignore}
map of string aliases to ErrorMode instances
docs for
muutilsv0.6.21
group items by assuming that eq_func defines an
equivalence relation
muutils.group_equivgroup items by assuming that eq_func defines an
equivalence relation
def group_by_equivalence(
items_in: Sequence[~T],
eq_func: Callable[[~T, ~T], bool]
) -> list[list[~T]]group items by assuming that eq_func implies an
equivalence relation but might not be transitive
so, if f(a,b) and f(b,c) then f(a,c) might be false, but we still want to put [a,b,c] in the same class
note that lists are used to avoid the need for hashable items, and to allow for duplicates
items_in: Sequence[T] the items to groupeq_func: Callable[[T, T], bool] a function that returns
true if two items are equivalent. need not be transitivedocs for
muutilsv0.6.21
represents a mathematical Interval over the real
numbers
muutils.intervalrepresents a mathematical Interval over the real
numbers
Number = typing.Union[float, int]class Interval:Represents a mathematical interval, open by default.
The Interval class can represent both open and closed intervals, as well as half-open intervals. It supports various initialization methods and provides containment checks.
Examples:
>>> i1 = Interval(1, 5) # Default open interval (1, 5)
>>> 3 in i1
True
>>> 1 in i1
False
>>> i2 = Interval([1, 5]) # Closed interval [1, 5]
>>> 1 in i2
True
>>> i3 = Interval(1, 5, closed_L=True) # Half-open interval [1, 5)
>>> str(i3)
'[1, 5)'
>>> i4 = ClosedInterval(1, 5) # Closed interval [1, 5]
>>> i5 = OpenInterval(1, 5) # Open interval (1, 5)
Interval(
*args: Union[Sequence[Union[float, int]], float, int],
is_closed: Optional[bool] = None,
closed_L: Optional[bool] = None,
closed_R: Optional[bool] = None
)lower: Union[float, int]
upper: Union[float, int]
closed_L: bool
closed_R: bool
singleton_set: Optional[set[Union[float, int]]]
is_closed: bool
is_open: boolis_half_open: boolis_singleton: boolis_empty: boolis_finite: boolsingleton: Union[float, int]def get_empty() -> muutils.interval.Intervaldef get_singleton(value: Union[float, int]) -> muutils.interval.Intervaldef numerical_contained(self, item: Union[float, int]) -> booldef interval_contained(self, item: muutils.interval.Interval) -> booldef from_str(cls, input_str: str) -> muutils.interval.Intervaldef copy(self) -> muutils.interval.Intervaldef size(self) -> floatReturns the size of the interval.
float the size of the intervaldef clamp(self, value: Union[int, float], epsilon: float = 1e-10) -> floatClamp the given value to the interval bounds.
For open bounds, the clamped value will be slightly inside the interval (by epsilon).
value : Union[int, float] the value to clamp.epsilon : float margin for open bounds (defaults to
_EPSILON)float the clamped valueValueError : If the input value is NaN.def intersection(
self,
other: muutils.interval.Interval
) -> Optional[muutils.interval.Interval]def union(self, other: muutils.interval.Interval) -> muutils.interval.Intervalclass ClosedInterval(Interval):Represents a mathematical interval, open by default.
The Interval class can represent both open and closed intervals, as well as half-open intervals. It supports various initialization methods and provides containment checks.
Examples:
>>> i1 = Interval(1, 5) # Default open interval (1, 5)
>>> 3 in i1
True
>>> 1 in i1
False
>>> i2 = Interval([1, 5]) # Closed interval [1, 5]
>>> 1 in i2
True
>>> i3 = Interval(1, 5, closed_L=True) # Half-open interval [1, 5)
>>> str(i3)
'[1, 5)'
>>> i4 = ClosedInterval(1, 5) # Closed interval [1, 5]
>>> i5 = OpenInterval(1, 5) # Open interval (1, 5)
ClosedInterval(*args: Union[Sequence[float], float], **kwargs: Any)lowerupperclosed_Lclosed_Rsingleton_setis_closedis_openis_half_openis_singletonis_emptyis_finitesingletonget_emptyget_singletonnumerical_containedinterval_containedfrom_strcopysizeclampintersectionunionclass OpenInterval(Interval):Represents a mathematical interval, open by default.
The Interval class can represent both open and closed intervals, as well as half-open intervals. It supports various initialization methods and provides containment checks.
Examples:
>>> i1 = Interval(1, 5) # Default open interval (1, 5)
>>> 3 in i1
True
>>> 1 in i1
False
>>> i2 = Interval([1, 5]) # Closed interval [1, 5]
>>> 1 in i2
True
>>> i3 = Interval(1, 5, closed_L=True) # Half-open interval [1, 5)
>>> str(i3)
'[1, 5)'
>>> i4 = ClosedInterval(1, 5) # Closed interval [1, 5]
>>> i5 = OpenInterval(1, 5) # Open interval (1, 5)
OpenInterval(*args: Union[Sequence[float], float], **kwargs: Any)lowerupperclosed_Lclosed_Rsingleton_setis_closedis_openis_half_openis_singletonis_emptyis_finitesingletonget_emptyget_singletonnumerical_containedinterval_containedfrom_strcopysizeclampintersectionuniondocs for
muutilsv0.6.21
submodule for serializing things to json in a recoverable way
you can throw any object into
muutils.json_serialize.json_serialize and it will return a
JSONitem, meaning a bool, int, float, str, None, list of
JSONitems, or a dict mappting to JSONitem.
The goal of this is if you want to just be able to store something as
relatively human-readable JSON, and don’t care as much about recovering
it, you can throw it into json_serialize and it will just
work. If you want to do so in a recoverable way, check out ZANJ.
it will do so by looking in DEFAULT_HANDLERS, which will
keep it as-is if its already valid, then try to find a
.serialize() method on the object, and then have a bunch of
special cases. You can add handlers by initializing a
JsonSerializer object and passing a sequence of them to
handlers_pre
additionally, SerializeableDataclass is a special kind
of dataclass where you specify how to serialize each field, and a
.serialize() method is automatically added to the class.
This is done by using the serializable_dataclass decorator,
inheriting from SerializeableDataclass, and
serializable_field in place of
dataclasses.field when defining non-standard fields.
This module plays nicely with and is a dependency of the ZANJ library,
which extends this to support saving things to disk in a more efficient
way than just plain json (arrays are saved as npy files, for example),
and automatically detecting how to load saved objects into their
original classes.
json_serializeserializable_dataclassserializable_fieldarr_metadataload_arrayBASE_HANDLERSJSONitemJsonSerializertry_catchdc_eqSerializableDataclassmuutils.json_serializesubmodule for serializing things to json in a recoverable way
you can throw any object into
<a href="json_serialize/json_serialize.html">muutils.json_serialize.json_serialize</a>
and it will return a JSONitem, meaning a bool, int, float,
str, None, list of JSONitems, or a dict mappting to
JSONitem.
The goal of this is if you want to just be able to store something as
relatively human-readable JSON, and don’t care as much about recovering
it, you can throw it into json_serialize and it will just
work. If you want to do so in a recoverable way, check out ZANJ.
it will do so by looking in DEFAULT_HANDLERS, which will
keep it as-is if its already valid, then try to find a
.serialize() method on the object, and then have a bunch of
special cases. You can add handlers by initializing a
JsonSerializer object and passing a sequence of them to
handlers_pre
additionally, SerializeableDataclass is a special kind
of dataclass where you specify how to serialize each field, and a
.serialize() method is automatically added to the class.
This is done by using the serializable_dataclass decorator,
inheriting from SerializeableDataclass, and
serializable_field in place of
dataclasses.field when defining non-standard fields.
This module plays nicely with and is a dependency of the ZANJ library,
which extends this to support saving things to disk in a more efficient
way than just plain json (arrays are saved as npy files, for example),
and automatically detecting how to load saved objects into their
original classes.
def json_serialize(
obj: Any,
path: tuple[typing.Union[str, int], ...] = ()
) -> Union[bool, int, float, str, list, Dict[str, Any], NoneType]serialize object to json-serializable object with default config
def serializable_dataclass(
_cls=None,
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
properties_to_serialize: Optional[list[str]] = None,
register_handler: bool = True,
on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except,
on_typecheck_mismatch: muutils.errormode.ErrorMode = ErrorMode.Warn,
**kwargs
)decorator to make a dataclass serializable. must also make it inherit
from SerializableDataclass
types will be validated (like pydantic) unless
on_typecheck_mismatch is set to
ErrorMode.IGNORE
behavior of most kwargs matches that of
dataclasses.dataclass, but with some additional kwargs
Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
Examines PEP 526 __annotations__ to determine
fields.
If init is true, an __init__() method is added to the
class. If repr is true, a __repr__() method is added. If
order is true, rich comparison dunder methods are added. If unsafe_hash
is true, a __hash__() method function is added. If frozen
is true, fields may not be assigned to after instance creation.
@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
a: int
b: str>>> Myclass(a=1, b="q").serialize()
{'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}_cls : _type_ class to decorate. don’t pass this arg,
just use this as a decorator (defaults to None)init : bool (defaults to True)repr : bool (defaults to True)order : bool (defaults to False)unsafe_hash : bool (defaults to
False)frozen : bool (defaults to False)properties_to_serialize : Optional[list[str]]
SerializableDataclass only: which properties to add to
the serialized data dict (defaults to None)register_handler : bool SerializableDataclass
only: if true, register the class with ZANJ for loading
(defaults to True)on_typecheck_error : ErrorMode
SerializableDataclass only: what to do if type checking
throws an exception (except, warn, ignore). If ignore and
an exception is thrown, type validation will still return falseon_typecheck_mismatch : ErrorMode
SerializableDataclass only: what to do if a type
mismatch is found (except, warn, ignore). If ignore, type
validation will return True_type_ the decorated classKWOnlyError : only raised if kw_only is
True and python version is <3.9, since
dataclasses.dataclass does not support thisNotSerializableFieldException : if a field is not a
SerializableFieldFieldSerializationError : if there is an error
serializing a fieldAttributeError : if a property is not found on the
classFieldLoadingError : if there is an error loading a
fielddef serializable_field(
*_args,
default: Union[Any, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
default_factory: Union[Any, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
metadata: Optional[mappingproxy] = None,
kw_only: Union[bool, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
serialize: bool = True,
serialization_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] = None,
assert_type: bool = True,
custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
**kwargs: Any
) -> AnyCreate a new SerializableField
default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
default_factory: Callable[[], Sfield_T]
| dataclasses._MISSING_TYPE = dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
metadata: types.MappingProxyType | None = None,
kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING,
### ----------------------------------------------------------------------
### new in `SerializableField`, not in `dataclasses.Field`
serialize: bool = True,
serialization_fn: Optional[Callable[[Any], Any]] = None,
loading_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] = None,
assert_type: bool = True,
custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
serialize: whether to serialize this field when
serializing the class’serialization_fn: function taking the instance of the
field and returning a serializable object. If not provided, will iterate
through the SerializerHandlers defined in
<a href="json_serialize/json_serialize.html">muutils.json_serialize.json_serialize</a>loading_fn: function taking the serialized object and
returning the instance of the field. If not provided, will take object
as-is.deserialize_fn: new alternative to
loading_fn. takes only the field’s value, not the whole
class. if both loading_fn and deserialize_fn
are provided, an error will be raised.assert_type: whether to assert the type of the field
when loading. if False, will not check the type of the
field.custom_typecheck_fn: function taking the type of the
field and returning whether the type itself is valid. if not provided,
will use the default type checking.loading_fn takes the dict of the
class, not the field. if you wanted a
loading_fn that does nothing, you’d write:class MyClass:
my_field: int = serializable_field(
serialization_fn=lambda x: str(x),
loading_fn=lambda x["my_field"]: int(x)
)using deserialize_fn instead:
class MyClass:
my_field: int = serializable_field(
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: int(x)
)In the above code, my_field is an int but will be
serialized as a string.
note that if not using ZANJ, and you have a class inside a container,
you MUST provide serialization_fn and
loading_fn to serialize and load the container. ZANJ will
automatically do this for you.
custom_value_check_fn: function taking the value of the
field and returning whether the value itself is valid. if not provided,
any value is valid as long as it passes the type testdef arr_metadata(arr) -> dict[str, list[int] | str | int]get metadata for a numpy array
def load_array(
arr: Union[bool, int, float, str, list, Dict[str, Any], NoneType],
array_mode: Optional[Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']] = None
) -> Anyload a json-serialized array, infer the mode if not specified
BASE_HANDLERS = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'))
JSONitem = typing.Union[bool, int, float, str, list, typing.Dict[str, typing.Any], NoneType]
class JsonSerializer:Json serialization class (holds configs)
array_mode : ArrayMode how to write arrays (defaults to
"array_list_meta")error_mode : ErrorMode what to do when we can’t
serialize an object (will use repr as fallback if “ignore” or “warn”)
(defaults to "except")handlers_pre : MonoTuple[SerializerHandler] handlers to
use before the default handlers (defaults to tuple())handlers_default : MonoTuple[SerializerHandler] default
handlers to use (defaults to DEFAULT_HANDLERS)write_only_format : bool changes
“format” keys in output to
“write_format” (when you want to serialize something in
a way that zanj won’t try to recover the object when loading) (defaults
to False)ValueError: on init, if args is not
emptySerializationException: on
json_serialize(), if any error occurs when trying to
serialize an object and error_mode is set to
ErrorMode.EXCEPT"JsonSerializer(
*args,
array_mode: Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim'] = 'array_list_meta',
error_mode: muutils.errormode.ErrorMode = ErrorMode.Except,
handlers_pre: None = (),
handlers_default: None = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function _serialize_override_serialize_func>, uid='.serialize override', desc='objects with .serialize method'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='namedtuple -> dict', desc='namedtuples as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dataclass -> dict', desc='dataclasses as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='path -> str', desc='Path objects as posix strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='obj -> str(obj)', desc='directly serialize objects in `SERIALIZE_DIRECT_AS_STR` to strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='numpy.ndarray', desc='numpy arrays'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='torch.Tensor', desc='pytorch tensors'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='pandas.DataFrame', desc='pandas DataFrames'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(set, list, tuple, Iterable) -> list', desc='sets, lists, tuples, and Iterables as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='fallback', desc='fallback handler -- serialize object attributes and special functions as strings')),
write_only_format: bool = False
)array_mode: Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']
error_mode: muutils.errormode.ErrorMode
write_only_format: bool
handlers: None
def json_serialize(
self,
obj: Any,
path: tuple[typing.Union[str, int], ...] = ()
) -> Union[bool, int, float, str, list, Dict[str, Any], NoneType]def hashify(
self,
obj: Any,
path: tuple[typing.Union[str, int], ...] = (),
force: bool = True
) -> Union[bool, int, float, str, tuple]try to turn any object into something hashable
def try_catch(func: Callable)wraps the function to catch exceptions, returns serialized error message on exception
returned func will return normal result on success, or error message on exception
def dc_eq(
dc1,
dc2,
except_when_class_mismatch: bool = False,
false_when_class_mismatch: bool = True,
except_when_field_mismatch: bool = False
) -> boolchecks if two dataclasses which (might) hold numpy arrays are equal
dc1: the first dataclassdc2: the second dataclassexcept_when_class_mismatch: bool if True,
will throw TypeError if the classes are different. if not,
will return false by default or attempt to compare the fields if
false_when_class_mismatch is False (default:
False)false_when_class_mismatch: bool only relevant if
except_when_class_mismatch is False. if
True, will return False if the classes are
different. if False, will attempt to compare the
fields.except_when_field_mismatch: bool only relevant if
except_when_class_mismatch is False and
false_when_class_mismatch is False. if
True, will throw TypeError if the fields are
different. (default: True)bool: True if the dataclasses are equal, False
otherwiseTypeError: if the dataclasses are of different
classesAttributeError: if the dataclasses have different
fields [START]
▼
┌───────────┐ ┌─────────┐
│dc1 is dc2?├─►│ classes │
└──┬────────┘No│ match? │
──── │ ├─────────┤
(True)◄──┘Yes │No │Yes
──── ▼ ▼
┌────────────────┐ ┌────────────┐
│ except when │ │ fields keys│
│ class mismatch?│ │ match? │
├───────────┬────┘ ├───────┬────┘
│Yes │No │No │Yes
▼ ▼ ▼ ▼
─────────── ┌──────────┐ ┌────────┐
{ raise } │ except │ │ field │
{ TypeError } │ when │ │ values │
─────────── │ field │ │ match? │
│ mismatch?│ ├────┬───┘
├───────┬──┘ │ │Yes
│Yes │No │No ▼
▼ ▼ │ ────
─────────────── ───── │ (True)
{ raise } (False)◄┘ ────
{ AttributeError} ─────
───────────────
class SerializableDataclass(abc.ABC):Base class for serializable dataclasses
only for linting and type checking, still need to call
serializable_dataclass decorator
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: strand then you can call my_obj.serialize() to get a dict
that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn’t too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
def serialize(self) -> dict[str, typing.Any]returns the class as a dict, implemented by using
@serializable_dataclass decorator
def load(cls: Type[~T], data: Union[dict[str, Any], ~T]) -> ~Ttakes in an appropriately structured dict and returns an instance of
the class, implemented by using @serializable_dataclass
decorator
def validate_fields_types(
self,
on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except
) -> boolvalidate the types of all the fields on a
SerializableDataclass. calls
SerializableDataclass__validate_field_type for each
field
def validate_field_type(
self,
field: muutils.json_serialize.serializable_field.SerializableField | str,
on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except
) -> boolgiven a dataclass, check the field matches the type hint
def diff(
self,
other: muutils.json_serialize.serializable_dataclass.SerializableDataclass,
of_serialized: bool = False
) -> dict[str, typing.Any]get a rich and recursive diff between two instances of a serializable dataclass
>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}other : SerializableDataclass other instance to compare
againstof_serialized : bool if true, compare serialized data
and not raw values (defaults to False)dict[str, Any]ValueError : if the instances are not of the same
typeValueError : if the instances are
dataclasses.dataclass but not
SerializableDataclassdef update_from_nested_dict(self, nested_dict: dict[str, typing.Any])update the instance from a nested dict, useful for configuration from command line args
- `nested_dict : dict[str, Any]`
nested dict to update the instance with
docs for
muutilsv0.6.21
this utilities module handles serialization and loading of numpy and torch arrays as json
array_list_meta is less efficient (arrays are stored as
nested lists), but preserves both metadata and human readability.array_b64_meta is the most efficient, but is not human
readable.external is mostly for use in ZANJmuutils.json_serialize.arraythis utilities module handles serialization and loading of numpy and torch arrays as json
array_list_meta is less efficient (arrays are stored as
nested lists), but preserves both metadata and human readability.array_b64_meta is the most efficient, but is not human
readable.external is mostly for use in ZANJArrayMode = typing.Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']def array_n_elements(arr) -> intget the number of elements in an array
def arr_metadata(arr) -> dict[str, list[int] | str | int]get metadata for a numpy array
def serialize_array(
jser: "'JsonSerializer'",
arr: numpy.ndarray,
path: Union[str, Sequence[str | int]],
array_mode: Optional[Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']] = None
) -> Union[bool, int, float, str, list, Dict[str, Any], NoneType]serialize a numpy or pytorch array in one of several modes
if the object is zero-dimensional, simply get the unique item
array_mode: ArrayMode can be one of: -
list: serialize as a list of values, no metadata
(equivalent to arr.tolist()) -
array_list_meta: serialize dict with metadata, actual list
under the key data - array_hex_meta: serialize
dict with metadata, actual hex string under the key data -
array_b64_meta: serialize dict with metadata, actual base64
string under the key data
for array_list_meta, array_hex_meta, and
array_b64_meta, the serialized object is:
{
"__format__": <array_list_meta|array_hex_meta>,
"shape": arr.shape,
"dtype": str(arr.dtype),
"data": <arr.tolist()|arr.tobytes().hex()|base64.b64encode(arr.tobytes()).decode()>,
}
arr : Any array to serializearray_mode : ArrayMode mode in which to serialize the
array (defaults to None and inheriting from
jser: JsonSerializer)JSONitem json serialized arrayKeyError : if the array mode is not validdef infer_array_mode(
arr: Union[bool, int, float, str, list, Dict[str, Any], NoneType]
) -> Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']given a serialized array, infer the mode
assumes the array was serialized via
serialize_array()
def load_array(
arr: Union[bool, int, float, str, list, Dict[str, Any], NoneType],
array_mode: Optional[Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']] = None
) -> Anyload a json-serialized array, infer the mode if not specified
docs for
muutilsv0.6.21
provides the basic framework for json serialization of objects
notably:
SerializerHandler defines how to serialize a specific
type of objectJsonSerializer handles configuration for which handlers
to usejson_serialize provides the default configuration if
you don’t care – call it on any object!SERIALIZER_SPECIAL_KEYSSERIALIZER_SPECIAL_FUNCSSERIALIZE_DIRECT_AS_STRObjectPathSerializerHandlerBASE_HANDLERSDEFAULT_HANDLERSJsonSerializerGLOBAL_JSON_SERIALIZERjson_serializemuutils.json_serialize.json_serializeprovides the basic framework for json serialization of objects
notably:
SerializerHandler defines how to serialize a specific
type of objectJsonSerializer handles configuration for which handlers
to usejson_serialize provides the default configuration if
you don’t care – call it on any object!SERIALIZER_SPECIAL_KEYS: None = ('__name__', '__doc__', '__module__', '__class__', '__dict__', '__annotations__')
SERIALIZER_SPECIAL_FUNCS: dict[str, typing.Callable] = {'str': <class 'str'>, 'dir': <built-in function dir>, 'type': <function <lambda>>, 'repr': <function <lambda>>, 'code': <function <lambda>>, 'sourcefile': <function <lambda>>}
SERIALIZE_DIRECT_AS_STR: Set[str] = {"<class 'torch.dtype'>", "<class 'torch.device'>"}
ObjectPath = tuple[typing.Union[str, int], ...]
class SerializerHandler:a handler for a specific type of object
- `check : Callable[[JsonSerializer, Any], bool]` takes a JsonSerializer and an object, returns whether to use this handler
- `serialize : Callable[[JsonSerializer, Any, ObjectPath], JSONitem]` takes a JsonSerializer, an object, and the current path, returns the serialized object
- `desc : str` description of the handler (optional)
SerializerHandler(
check: Callable[[muutils.json_serialize.json_serialize.JsonSerializer, Any, tuple[Union[str, int], ...]], bool],
serialize_func: Callable[[muutils.json_serialize.json_serialize.JsonSerializer, Any, tuple[Union[str, int], ...]], Union[bool, int, float, str, list, Dict[str, Any], NoneType]],
uid: str,
desc: str
)check: Callable[[muutils.json_serialize.json_serialize.JsonSerializer, Any, tuple[Union[str, int], ...]], bool]
serialize_func: Callable[[muutils.json_serialize.json_serialize.JsonSerializer, Any, tuple[Union[str, int], ...]], Union[bool, int, float, str, list, Dict[str, Any], NoneType]]
uid: str
desc: str
def serialize(self) -> dictserialize the handler info
BASE_HANDLERS: None = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'))
DEFAULT_HANDLERS: None = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function _serialize_override_serialize_func>, uid='.serialize override', desc='objects with .serialize method'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='namedtuple -> dict', desc='namedtuples as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dataclass -> dict', desc='dataclasses as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='path -> str', desc='Path objects as posix strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='obj -> str(obj)', desc='directly serialize objects inSERIALIZE_DIRECT_AS_STRto strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='numpy.ndarray', desc='numpy arrays'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='torch.Tensor', desc='pytorch tensors'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='pandas.DataFrame', desc='pandas DataFrames'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(set, list, tuple, Iterable) -> list', desc='sets, lists, tuples, and Iterables as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='fallback', desc='fallback handler -- serialize object attributes and special functions as strings'))
class JsonSerializer:Json serialization class (holds configs)
array_mode : ArrayMode how to write arrays (defaults to
"array_list_meta")error_mode : ErrorMode what to do when we can’t
serialize an object (will use repr as fallback if “ignore” or “warn”)
(defaults to "except")handlers_pre : MonoTuple[SerializerHandler] handlers to
use before the default handlers (defaults to tuple())handlers_default : MonoTuple[SerializerHandler] default
handlers to use (defaults to DEFAULT_HANDLERS)write_only_format : bool changes
“format” keys in output to
“write_format” (when you want to serialize something in
a way that zanj won’t try to recover the object when loading) (defaults
to False)ValueError: on init, if args is not
emptySerializationException: on
json_serialize(), if any error occurs when trying to
serialize an object and error_mode is set to
ErrorMode.EXCEPT"JsonSerializer(
*args,
array_mode: Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim'] = 'array_list_meta',
error_mode: muutils.errormode.ErrorMode = ErrorMode.Except,
handlers_pre: None = (),
handlers_default: None = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function _serialize_override_serialize_func>, uid='.serialize override', desc='objects with .serialize method'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='namedtuple -> dict', desc='namedtuples as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dataclass -> dict', desc='dataclasses as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='path -> str', desc='Path objects as posix strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='obj -> str(obj)', desc='directly serialize objects in `SERIALIZE_DIRECT_AS_STR` to strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='numpy.ndarray', desc='numpy arrays'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='torch.Tensor', desc='pytorch tensors'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='pandas.DataFrame', desc='pandas DataFrames'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(set, list, tuple, Iterable) -> list', desc='sets, lists, tuples, and Iterables as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='fallback', desc='fallback handler -- serialize object attributes and special functions as strings')),
write_only_format: bool = False
)array_mode: Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']
error_mode: muutils.errormode.ErrorMode
write_only_format: bool
handlers: None
def json_serialize(
self,
obj: Any,
path: tuple[typing.Union[str, int], ...] = ()
) -> Union[bool, int, float, str, list, Dict[str, Any], NoneType]def hashify(
self,
obj: Any,
path: tuple[typing.Union[str, int], ...] = (),
force: bool = True
) -> Union[bool, int, float, str, tuple]try to turn any object into something hashable
GLOBAL_JSON_SERIALIZER: muutils.json_serialize.json_serialize.JsonSerializer = <muutils.json_serialize.json_serialize.JsonSerializer object>def json_serialize(
obj: Any,
path: tuple[typing.Union[str, int], ...] = ()
) -> Union[bool, int, float, str, list, Dict[str, Any], NoneType]serialize object to json-serializable object with default config
docs for
muutilsv0.6.21
save and load objects to and from json or compatible formats in a recoverable way
d = dataclasses.asdict(my_obj) will give you a dict, but
if some fields are not json-serializable, you will get an error when you
call json.dumps(d). This module provides a way around
that.
Instead, you define your class:
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: strand then you can call my_obj.serialize() to get a dict
that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn’t too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
dataclass_transformCantGetTypeHintsWarningZanjMissingWarningzanj_register_loader_serializable_dataclassFieldIsNotInitOrSerializeWarningSerializableDataclass__validate_field_typeSerializableDataclass__validate_fields_types__dictSerializableDataclass__validate_fields_typesSerializableDataclassget_cls_type_hints_cachedget_cls_type_hintsKWOnlyErrorFieldErrorNotSerializableFieldExceptionFieldSerializationErrorFieldLoadingErrorFieldTypeMismatchErrorserializable_dataclassmuutils.json_serialize.serializable_dataclasssave and load objects to and from json or compatible formats in a recoverable way
d = dataclasses.asdict(my_obj) will give you a dict, but
if some fields are not json-serializable, you will get an error when you
call json.dumps(d). This module provides a way around
that.
Instead, you define your class:
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: strand then you can call my_obj.serialize() to get a dict
that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn’t too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
def dataclass_transform(
*,
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
frozen_default: bool = False,
field_specifiers: tuple[typing.Union[type[typing.Any], typing.Callable[..., typing.Any]], ...] = (),
**kwargs: Any
) -> <class '_IdentityCallable'>Decorator to mark an object as providing dataclass-like behaviour.
The decorator can be applied to a function, class, or metaclass.
Example usage with a decorator function::
@dataclass_transform()
def create_model[T](cls: type[T]) -> type[T]:
...
return cls
@create_model
class CustomerModel:
id: int
name: str
On a base class::
@dataclass_transform()
class ModelBase: ...
class CustomerModel(ModelBase):
id: int
name: str
On a metaclass::
@dataclass_transform()
class ModelMeta(type): ...
class ModelBase(metaclass=ModelMeta): ...
class CustomerModel(ModelBase):
id: int
name: str
The CustomerModel classes defined above will be treated
by type checkers similarly to classes created with
@dataclasses.dataclass. For example, type checkers will
assume these classes have __init__ methods that accept
id and name.
The arguments to this decorator can be used to customize this
behavior: - eq_default indicates whether the
eq parameter is assumed to be True or
False if it is omitted by the caller. -
order_default indicates whether the order
parameter is assumed to be True or False if it is omitted by the caller.
- kw_only_default indicates whether the
kw_only parameter is assumed to be True or False if it is
omitted by the caller. - frozen_default indicates whether
the frozen parameter is assumed to be True or False if it
is omitted by the caller. - field_specifiers specifies a
static list of supported classes or functions that describe fields,
similar to dataclasses.field(). - Arbitrary other keyword
arguments are accepted in order to allow for possible future
extensions.
At runtime, this decorator records its arguments in the
__dataclass_transform__ attribute on the decorated object.
It has no other runtime effect.
See PEP 681 for more details.
class CantGetTypeHintsWarning(builtins.UserWarning):special warning for when we can’t get type hints
class ZanjMissingWarning(builtins.UserWarning):special warning for when ZANJ is missing
– register_loader_serializable_dataclass will not work
def zanj_register_loader_serializable_dataclass(cls: Type[~T])Register a serializable dataclass with the ZANJ import
this allows ZANJ().read() to load the class and not just
return plain dicts
class FieldIsNotInitOrSerializeWarning(builtins.UserWarning):Base class for warnings generated by user code.
def SerializableDataclass__validate_field_type(
self: muutils.json_serialize.serializable_dataclass.SerializableDataclass,
field: muutils.json_serialize.serializable_field.SerializableField | str,
on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except
) -> boolgiven a dataclass, check the field matches the type hint
this function is written to
<a href="#SerializableDataclass.validate_field_type">SerializableDataclass.validate_field_type</a>
self : SerializableDataclass
SerializableDataclass instancefield : SerializableField | str field to validate, will
get from self.__dataclass_fields__ if an
stron_typecheck_error : ErrorMode what to do if type
checking throws an exception (except, warn, ignore). If
ignore and an exception is thrown, the function will return
False (defaults to
_DEFAULT_ON_TYPECHECK_ERROR)bool if the field type is correct. False
if the field type is incorrect or an exception is thrown and
on_typecheck_error is ignoredef SerializableDataclass__validate_fields_types__dict(
self: muutils.json_serialize.serializable_dataclass.SerializableDataclass,
on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except
) -> dict[str, bool]validate the types of all the fields on a
SerializableDataclass. calls
SerializableDataclass__validate_field_type for each
field
returns a dict of field names to bools, where the bool is if the field type is valid
def SerializableDataclass__validate_fields_types(
self: muutils.json_serialize.serializable_dataclass.SerializableDataclass,
on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except
) -> boolvalidate the types of all the fields on a
SerializableDataclass. calls
SerializableDataclass__validate_field_type for each
field
class SerializableDataclass(abc.ABC):Base class for serializable dataclasses
only for linting and type checking, still need to call
serializable_dataclass decorator
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: strand then you can call my_obj.serialize() to get a dict
that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn’t too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
def serialize(self) -> dict[str, typing.Any]returns the class as a dict, implemented by using
@serializable_dataclass decorator
def load(cls: Type[~T], data: Union[dict[str, Any], ~T]) -> ~Ttakes in an appropriately structured dict and returns an instance of
the class, implemented by using @serializable_dataclass
decorator
def validate_fields_types(
self,
on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except
) -> boolvalidate the types of all the fields on a
SerializableDataclass. calls
SerializableDataclass__validate_field_type for each
field
def validate_field_type(
self,
field: muutils.json_serialize.serializable_field.SerializableField | str,
on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except
) -> boolgiven a dataclass, check the field matches the type hint
def diff(
self,
other: muutils.json_serialize.serializable_dataclass.SerializableDataclass,
of_serialized: bool = False
) -> dict[str, typing.Any]get a rich and recursive diff between two instances of a serializable dataclass
>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}other : SerializableDataclass other instance to compare
againstof_serialized : bool if true, compare serialized data
and not raw values (defaults to False)dict[str, Any]ValueError : if the instances are not of the same
typeValueError : if the instances are
dataclasses.dataclass but not
SerializableDataclassdef update_from_nested_dict(self, nested_dict: dict[str, typing.Any])update the instance from a nested dict, useful for configuration from command line args
- `nested_dict : dict[str, Any]`
nested dict to update the instance with
def get_cls_type_hints_cached(cls: Type[~T]) -> dict[str, typing.Any]cached typing.get_type_hints for a class
def get_cls_type_hints(cls: Type[~T]) -> dict[str, typing.Any]helper function to get type hints for a class
class KWOnlyError(builtins.NotImplementedError):kw-only dataclasses are not supported in python <3.9
class FieldError(builtins.ValueError):base class for field errors
class NotSerializableFieldException(FieldError):field is not a SerializableField
class FieldSerializationError(FieldError):error while serializing a field
class FieldLoadingError(FieldError):error while loading a field
class FieldTypeMismatchError(FieldError, builtins.TypeError):error when a field type does not match the type hint
def serializable_dataclass(
_cls=None,
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
properties_to_serialize: Optional[list[str]] = None,
register_handler: bool = True,
on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except,
on_typecheck_mismatch: muutils.errormode.ErrorMode = ErrorMode.Warn,
**kwargs
)decorator to make a dataclass serializable. must also make it inherit
from SerializableDataclass
types will be validated (like pydantic) unless
on_typecheck_mismatch is set to
ErrorMode.IGNORE
behavior of most kwargs matches that of
dataclasses.dataclass, but with some additional kwargs
Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
Examines PEP 526 __annotations__ to determine
fields.
If init is true, an __init__() method is added to the
class. If repr is true, a __repr__() method is added. If
order is true, rich comparison dunder methods are added. If unsafe_hash
is true, a __hash__() method function is added. If frozen
is true, fields may not be assigned to after instance creation.
@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
a: int
b: str>>> Myclass(a=1, b="q").serialize()
{'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}_cls : _type_ class to decorate. don’t pass this arg,
just use this as a decorator (defaults to None)init : bool (defaults to True)repr : bool (defaults to True)order : bool (defaults to False)unsafe_hash : bool (defaults to
False)frozen : bool (defaults to False)properties_to_serialize : Optional[list[str]]
SerializableDataclass only: which properties to add to
the serialized data dict (defaults to None)register_handler : bool SerializableDataclass
only: if true, register the class with ZANJ for loading
(defaults to True)on_typecheck_error : ErrorMode
SerializableDataclass only: what to do if type checking
throws an exception (except, warn, ignore). If ignore and
an exception is thrown, type validation will still return falseon_typecheck_mismatch : ErrorMode
SerializableDataclass only: what to do if a type
mismatch is found (except, warn, ignore). If ignore, type
validation will return True_type_ the decorated classKWOnlyError : only raised if kw_only is
True and python version is <3.9, since
dataclasses.dataclass does not support thisNotSerializableFieldException : if a field is not a
SerializableFieldFieldSerializationError : if there is an error
serializing a fieldAttributeError : if a property is not found on the
classFieldLoadingError : if there is an error loading a
fielddocs for
muutilsv0.6.21
extends dataclasses.Field for use with
SerializableDataclass
In particular, instead of using dataclasses.field, use
serializable_field to define fields in a
SerializableDataclass. You provide information on how the
field should be serialized and loaded (as well as anything that goes
into dataclasses.field) when you define the field, and the
SerializableDataclass will automatically use those
functions.
muutils.json_serialize.serializable_fieldextends dataclasses.Field for use with
SerializableDataclass
In particular, instead of using dataclasses.field, use
serializable_field to define fields in a
SerializableDataclass. You provide information on how the
field should be serialized and loaded (as well as anything that goes
into dataclasses.field) when you define the field, and the
SerializableDataclass will automatically use those
functions.
class SerializableField(dataclasses.Field):extension of dataclasses.Field with additional
serialization properties
SerializableField(
default: Union[Any, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
default_factory: Union[Callable[[], Any], dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
metadata: Optional[mappingproxy] = None,
kw_only: Union[bool, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
serialize: bool = True,
serialization_fn: Optional[Callable[[Any], Any]] = None,
loading_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] = None,
assert_type: bool = True,
custom_typecheck_fn: Optional[Callable[[<member 'type' of 'SerializableField' objects>], bool]] = None
)serialize: bool
serialization_fn: Optional[Callable[[Any], Any]]
loading_fn: Optional[Callable[[Any], Any]]
deserialize_fn: Optional[Callable[[Any], Any]]
assert_type: bool
custom_typecheck_fn: Optional[Callable[[<member 'type' of 'SerializableField' objects>], bool]]
def from_Field(
cls,
field: dataclasses.Field
) -> muutils.json_serialize.serializable_field.SerializableFieldcopy all values from a dataclasses.Field to new
SerializableField
name
type
default
default_factory
init
repr
hash
compare
metadata
kw_only
def serializable_field(
*_args,
default: Union[Any, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
default_factory: Union[Any, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
metadata: Optional[mappingproxy] = None,
kw_only: Union[bool, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>,
serialize: bool = True,
serialization_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] = None,
assert_type: bool = True,
custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
**kwargs: Any
) -> AnyCreate a new SerializableField
default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
default_factory: Callable[[], Sfield_T]
| dataclasses._MISSING_TYPE = dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
metadata: types.MappingProxyType | None = None,
kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING,
### ----------------------------------------------------------------------
### new in `SerializableField`, not in `dataclasses.Field`
serialize: bool = True,
serialization_fn: Optional[Callable[[Any], Any]] = None,
loading_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] = None,
assert_type: bool = True,
custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
serialize: whether to serialize this field when
serializing the class’serialization_fn: function taking the instance of the
field and returning a serializable object. If not provided, will iterate
through the SerializerHandlers defined in
<a href="json_serialize.html">muutils.json_serialize.json_serialize</a>loading_fn: function taking the serialized object and
returning the instance of the field. If not provided, will take object
as-is.deserialize_fn: new alternative to
loading_fn. takes only the field’s value, not the whole
class. if both loading_fn and deserialize_fn
are provided, an error will be raised.assert_type: whether to assert the type of the field
when loading. if False, will not check the type of the
field.custom_typecheck_fn: function taking the type of the
field and returning whether the type itself is valid. if not provided,
will use the default type checking.loading_fn takes the dict of the
class, not the field. if you wanted a
loading_fn that does nothing, you’d write:class MyClass:
my_field: int = serializable_field(
serialization_fn=lambda x: str(x),
loading_fn=lambda x["my_field"]: int(x)
)using deserialize_fn instead:
class MyClass:
my_field: int = serializable_field(
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: int(x)
)In the above code, my_field is an int but will be
serialized as a string.
note that if not using ZANJ, and you have a class inside a container,
you MUST provide serialization_fn and
loading_fn to serialize and load the container. ZANJ will
automatically do this for you.
custom_value_check_fn: function taking the value of the
field and returning whether the value itself is valid. if not provided,
any value is valid as long as it passes the type testdocs for
muutilsv0.6.21
utilities for json_serialize
JSONitemJSONdictHashableitemUniversalContainerisinstance_namedtupletry_catchSerializationExceptionstring_as_linessafe_getsourcearray_safe_eqdc_eqMonoTuplemuutils.json_serialize.utilutilities for json_serialize
JSONitem = typing.Union[bool, int, float, str, list, typing.Dict[str, typing.Any], NoneType]
JSONdict = typing.Dict[str, typing.Union[bool, int, float, str, list, typing.Dict[str, typing.Any], NoneType]]
Hashableitem = typing.Union[bool, int, float, str, tuple]
class UniversalContainer:contains everything – x in UniversalContainer() is
always True
def isinstance_namedtuple(x: Any) -> boolchecks if x is a namedtuple
credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
def try_catch(func: Callable)wraps the function to catch exceptions, returns serialized error message on exception
returned func will return normal result on success, or error message on exception
class SerializationException(builtins.Exception):Common base class for all non-exit exceptions.
def string_as_lines(s: str | None) -> list[str]for easier reading of long strings in json, split up by newlines
sort of like how jupyter notebooks do it
def safe_getsource(func) -> list[str]def array_safe_eq(a: Any, b: Any) -> boolcheck if two objects are equal, account for if numpy arrays or torch tensors
def dc_eq(
dc1,
dc2,
except_when_class_mismatch: bool = False,
false_when_class_mismatch: bool = True,
except_when_field_mismatch: bool = False
) -> boolchecks if two dataclasses which (might) hold numpy arrays are equal
dc1: the first dataclassdc2: the second dataclassexcept_when_class_mismatch: bool if True,
will throw TypeError if the classes are different. if not,
will return false by default or attempt to compare the fields if
false_when_class_mismatch is False (default:
False)false_when_class_mismatch: bool only relevant if
except_when_class_mismatch is False. if
True, will return False if the classes are
different. if False, will attempt to compare the
fields.except_when_field_mismatch: bool only relevant if
except_when_class_mismatch is False and
false_when_class_mismatch is False. if
True, will throw TypeError if the fields are
different. (default: True)bool: True if the dataclasses are equal, False
otherwiseTypeError: if the dataclasses are of different
classesAttributeError: if the dataclasses have different
fields [START]
▼
┌───────────┐ ┌─────────┐
│dc1 is dc2?├─►│ classes │
└──┬────────┘No│ match? │
──── │ ├─────────┤
(True)◄──┘Yes │No │Yes
──── ▼ ▼
┌────────────────┐ ┌────────────┐
│ except when │ │ fields keys│
│ class mismatch?│ │ match? │
├───────────┬────┘ ├───────┬────┘
│Yes │No │No │Yes
▼ ▼ ▼ ▼
─────────── ┌──────────┐ ┌────────┐
{ raise } │ except │ │ field │
{ TypeError } │ when │ │ values │
─────────── │ field │ │ match? │
│ mismatch?│ ├────┬───┘
├───────┬──┘ │ │Yes
│Yes │No │No ▼
▼ ▼ │ ────
─────────────── ───── │ (True)
{ raise } (False)◄┘ ────
{ AttributeError} ─────
───────────────
class MonoTuple:tuple type hint, but for a tuple of any length with all the same type
docs for
muutilsv0.6.21
utilities for reading and writing jsonlines files, including gzip support
muutils.jsonlinesutilities for reading and writing jsonlines files, including gzip support
def jsonl_load(
path: str,
/,
*,
use_gzip: bool | None = None
) -> list[typing.Union[bool, int, float, str, list, typing.Dict[str, typing.Any], NoneType]]def jsonl_load_log(path: str, /, *, use_gzip: bool | None = None) -> list[dict]def jsonl_write(
path: str,
items: Sequence[Union[bool, int, float, str, list, Dict[str, Any], NoneType]],
use_gzip: bool | None = None,
gzip_compresslevel: int = 2
) -> Nonedocs for
muutilsv0.6.21
anonymous getitem class
util for constructing a class which has a getitem method which just calls a function
a lambda is an anonymous function: kappa is the letter
before lambda in the greek alphabet, hence the name of this class
muutils.kappaanonymous getitem class
util for constructing a class which has a getitem method which just calls a function
a lambda is an anonymous function: kappa is the letter
before lambda in the greek alphabet, hence the name of this class
class Kappa(typing.Mapping[~_kappa_K, ~_kappa_V]):A Mapping is a generic container for associating key/value pairs.
This class provides concrete generic implementations of all methods except for getitem, iter, and len.
Kappa(func_getitem: Callable[[~_kappa_K], ~_kappa_V])func_getitem
doc
docs for
muutilsv0.6.21
(deprecated) experimenting with logging utilities
muutils.logger(deprecated) experimenting with logging utilities
class Logger(muutils.logger.simplelogger.SimpleLogger):logger with more features, including log levels and streams
- `log_path : str | None`
default log file path
(defaults to `None`)
- `log_file : AnyIO | None`
default log io, should have a `.write()` method (pass only this or `log_path`, not both)
(defaults to `None`)
- `timestamp : bool`
whether to add timestamps to every log message (under the `_timestamp` key)
(defaults to `True`)
- `default_level : int`
default log level for streams/messages that don't specify a level
(defaults to `0`)
- `console_print_threshold : int`
log level at which to print to the console, anything greater will not be printed unless overridden by `console_print`
(defaults to `50`)
- `level_header : HeaderFunction`
function for formatting log messages when printing to console
(defaults to `HEADER_FUNCTIONS["md"]`)
keep_last_msg_time : bool whether to keep the last
message time (defaults to True) - `ValueError` : _description_
Logger(
log_path: str | None = None,
log_file: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None,
default_level: int = 0,
console_print_threshold: int = 50,
level_header: muutils.logger.headerfuncs.HeaderFunction = <function md_header_function>,
streams: Union[dict[str | None, muutils.logger.loggingstream.LoggingStream], Sequence[muutils.logger.loggingstream.LoggingStream]] = (),
keep_last_msg_time: bool = True,
timestamp: bool = True,
**kwargs
)def log(
self,
msg: Union[bool, int, float, str, list, Dict[str, Any], NoneType] = None,
lvl: int | None = None,
stream: str | None = None,
console_print: bool = False,
extra_indent: str = '',
**kwargs
)logging function
msg : JSONitem message (usually string or dict) to be
loggedlvl : int | None level of message (lower levels are
more important) (defaults to None)console_print : bool override
console_print_threshold setting (defaults to
False)stream : str | None whether to log to a stream
(defaults to None), which logs to the default
None stream (defaults to None)def log_elapsed_last(
self,
lvl: int | None = None,
stream: str | None = None,
console_print: bool = True,
**kwargs
) -> floatlogs the time elapsed since the last message was printed to the console (in any stream)
def flush_all(self)flush all streams
class LoggingStream:properties of a logging stream
name: str name of the streamaliases: set[str] aliases for the stream (calls to
these names will be redirected to this stream. duplicate alises will
result in errors) TODO: perhaps duplicate alises should result in
duplicate writes?file: str|bool|AnyIO|None file to write to - if
None, will write to standard log - if True,
will write to name + ".log" - if False will
“write” to NullIO (throw it away) - if a string, will write
to that file - if a fileIO type object, will write to that objectdefault_level: int|None default level for this
streamdefault_contents: dict[str, Callable[[], Any]] default
contents for this streamlast_msg: tuple[float, Any]|None last message written
to this stream (timestamp, message)LoggingStream(
name: str | None,
aliases: set[str | None] = <factory>,
file: Union[str, bool, TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None,
default_level: int | None = None,
default_contents: dict[str, typing.Callable[[], typing.Any]] = <factory>,
handler: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None
)name: str | None
aliases: set[str | None]
file: Union[str, bool, TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None
default_level: int | None = None
default_contents: dict[str, typing.Callable[[], typing.Any]]
handler: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None
def make_handler(self) -> Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType]class SimpleLogger:logs training data to a jsonl file
SimpleLogger(
log_path: str | None = None,
log_file: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None,
timestamp: bool = True
)def log(
self,
msg: Union[bool, int, float, str, list, Dict[str, Any], NoneType],
console_print: bool = False,
**kwargs
)log a message to the log file, and optionally to the console
class TimerContext:context manager for timing code
start_time: float
end_time: float
elapsed_time: float
docs for
muutilsv0.6.21
muutils.logger.exception_contextclass ExceptionContext:context manager which catches all exceptions happening while the
context is open, .write() the exception trace to the given
stream, and then raises the exception
for example:
errorfile = open('error.log', 'w')
with ExceptionContext(errorfile):
# do something that might throw an exception
# if it does, the exception trace will be written to errorfile
# and then the exception will be raisedExceptionContext(stream)streamdocs for
muutilsv0.6.21
muutils.logger.headerfuncsclass HeaderFunction(typing.Protocol):Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto[T](Protocol):
def meth(self) -> T:
...
HeaderFunction(*args, **kwargs)def md_header_function(
msg: Any,
lvl: int,
stream: str | None = None,
indent_lvl: str = ' ',
extra_indent: str = '',
**kwargs
) -> strstandard header function. will output
# {msg}
for levels in [0, 9]## {msg}
for levels in [10, 19], and so on[{stream}] # {msg}
for a non-`None` stream, with level headers as before!WARNING! [{stream}] {msg}
for level in [-9, -1]!!WARNING!! [{stream}] {msg}
for level in [-19, -10] and so onHEADER_FUNCTIONS: dict[str, muutils.logger.headerfuncs.HeaderFunction] = {'md': <function md_header_function>}
docs for
muutilsv0.6.21
muutils.logger.log_utildef get_any_from_stream(stream: list[dict], key: str) -> Noneget the first value of a key from a stream. errors if not found
def gather_log(file: str) -> dict[str, list[dict]]gathers and sorts all streams from a log
def gather_stream(file: str, stream: str) -> list[dict]gets all entries from a specific stream in a log file
def gather_val(
file: str,
stream: str,
keys: tuple[str],
allow_skip: bool = True
) -> list[list]gather specific keys from a specific stream in a log file
example: if “log.jsonl” has contents:
{"a": 1, "b": 2, "c": 3, "_stream": "s1"}
{"a": 4, "b": 5, "c": 6, "_stream": "s1"}
{"a": 7, "b": 8, "c": 9, "_stream": "s2"}
then gather_val("log.jsonl", "s1", ("a", "b")) will
return
[
[1, 2],
[4, 5]
]docs for
muutilsv0.6.21
logger with streams & levels, and a timer context manager
SimpleLogger is an extremely simple logger that can
write to both console and a fileLogger class handles levels in a slightly different way
than default python logging, and also has “streams” which
allow for different sorts of output in the same logger this was mostly
made with training models in mind and storing both metadata and
lossTimerContext is a context manager that can be used to
time the duration of a block of codemuutils.logger.loggerlogger with streams & levels, and a timer context manager
SimpleLogger is an extremely simple logger that can
write to both console and a fileLogger class handles levels in a slightly different way
than default python logging, and also has “streams” which
allow for different sorts of output in the same logger this was mostly
made with training models in mind and storing both metadata and
lossTimerContext is a context manager that can be used to
time the duration of a block of codedef decode_level(level: int) -> strclass Logger(muutils.logger.simplelogger.SimpleLogger):logger with more features, including log levels and streams
- `log_path : str | None`
default log file path
(defaults to `None`)
- `log_file : AnyIO | None`
default log io, should have a `.write()` method (pass only this or `log_path`, not both)
(defaults to `None`)
- `timestamp : bool`
whether to add timestamps to every log message (under the `_timestamp` key)
(defaults to `True`)
- `default_level : int`
default log level for streams/messages that don't specify a level
(defaults to `0`)
- `console_print_threshold : int`
log level at which to print to the console, anything greater will not be printed unless overridden by `console_print`
(defaults to `50`)
- `level_header : HeaderFunction`
function for formatting log messages when printing to console
(defaults to `HEADER_FUNCTIONS["md"]`)
keep_last_msg_time : bool whether to keep the last
message time (defaults to True) - `ValueError` : _description_
Logger(
log_path: str | None = None,
log_file: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None,
default_level: int = 0,
console_print_threshold: int = 50,
level_header: muutils.logger.headerfuncs.HeaderFunction = <function md_header_function>,
streams: Union[dict[str | None, muutils.logger.loggingstream.LoggingStream], Sequence[muutils.logger.loggingstream.LoggingStream]] = (),
keep_last_msg_time: bool = True,
timestamp: bool = True,
**kwargs
)def log(
self,
msg: Union[bool, int, float, str, list, Dict[str, Any], NoneType] = None,
lvl: int | None = None,
stream: str | None = None,
console_print: bool = False,
extra_indent: str = '',
**kwargs
)logging function
msg : JSONitem message (usually string or dict) to be
loggedlvl : int | None level of message (lower levels are
more important) (defaults to None)console_print : bool override
console_print_threshold setting (defaults to
False)stream : str | None whether to log to a stream
(defaults to None), which logs to the default
None stream (defaults to None)def log_elapsed_last(
self,
lvl: int | None = None,
stream: str | None = None,
console_print: bool = True,
**kwargs
) -> floatlogs the time elapsed since the last message was printed to the console (in any stream)
def flush_all(self)flush all streams
docs for
muutilsv0.6.21
muutils.logger.loggingstreamclass LoggingStream:properties of a logging stream
name: str name of the streamaliases: set[str] aliases for the stream (calls to
these names will be redirected to this stream. duplicate alises will
result in errors) TODO: perhaps duplicate alises should result in
duplicate writes?file: str|bool|AnyIO|None file to write to - if
None, will write to standard log - if True,
will write to name + ".log" - if False will
“write” to NullIO (throw it away) - if a string, will write
to that file - if a fileIO type object, will write to that objectdefault_level: int|None default level for this
streamdefault_contents: dict[str, Callable[[], Any]] default
contents for this streamlast_msg: tuple[float, Any]|None last message written
to this stream (timestamp, message)LoggingStream(
name: str | None,
aliases: set[str | None] = <factory>,
file: Union[str, bool, TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None,
default_level: int | None = None,
default_contents: dict[str, typing.Callable[[], typing.Any]] = <factory>,
handler: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None
)name: str | None
aliases: set[str | None]
file: Union[str, bool, TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None
default_level: int | None = None
default_contents: dict[str, typing.Callable[[], typing.Any]]
handler: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None
def make_handler(self) -> Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType]docs for
muutilsv0.6.21
muutils.logger.simpleloggerclass NullIO:null IO class
def write(self, msg: str) -> intwrite to nothing! this throws away the message
def flush(self) -> Noneflush nothing! this is a no-op
def close(self) -> Noneclose nothing! this is a no-op
AnyIO = typing.Union[typing.TextIO, muutils.logger.simplelogger.NullIO]class SimpleLogger:logs training data to a jsonl file
SimpleLogger(
log_path: str | None = None,
log_file: Union[TextIO, muutils.logger.simplelogger.NullIO, NoneType] = None,
timestamp: bool = True
)def log(
self,
msg: Union[bool, int, float, str, list, Dict[str, Any], NoneType],
console_print: bool = False,
**kwargs
)log a message to the log file, and optionally to the console
docs for
muutilsv0.6.21
muutils.logger.timingclass TimerContext:context manager for timing code
start_time: float
end_time: float
elapsed_time: float
def filter_time_str(time: str) -> strassuming format h:mm:ss, clips off the hours if its
0
class ProgressEstimator:estimates progress and can give a progress bar
ProgressEstimator(
n_total: int,
pbar_fill: str = '█',
pbar_empty: str = ' ',
pbar_bounds: tuple[str, str] = ('|', '|')
)n_total: int
starttime: float
pbar_fill: str
pbar_empty: str
pbar_bounds: tuple[str, str]
total_str_len: int
def get_timing_raw(self, i: int) -> dict[str, float]returns dict(elapsed, per_iter, remaining, percent)
def get_pbar(self, i: int, width: int = 30) -> strreturns a progress bar
def get_progress_default(self, i: int) -> strreturns a progress string
docs for
muutilsv0.6.21
miscellaneous utilities
stable_hash for hashing that is stable across runsmuutils.misc.sequence for sequence manipulation,
applying mappings, and string-like operations on listsmuutils.misc.string for sanitizing things for
filenames, adjusting docstrings, and converting dicts to filenamesmuutils.misc.numerical for turning numbers into nice
strings and backmuutils.misc.freezing for freezing thingsmuutils.misc.classes for some weird class
utilitiesstable_hashWhenMissingempty_sequence_if_attr_falseflattenlist_splitlist_joinapply_mappingapply_mapping_chainsanitize_namesanitize_fnamesanitize_identifierdict_to_filenamedynamic_docstringshorten_numerical_to_strstr_to_numeric_SHORTEN_MAPFrozenDictFrozenListfreezeis_abstractget_all_subclassesisinstance_by_type_nameIsDataclassget_hashable_eq_attrsdataclass_set_equalsmuutils.miscmiscellaneous utilities
stable_hash for hashing that is stable across runs<a href="misc/sequence.html">muutils.misc.sequence</a>
for sequence manipulation, applying mappings, and string-like operations
on lists<a href="misc/string.html">muutils.misc.string</a>
for sanitizing things for filenames, adjusting docstrings, and
converting dicts to filenames<a href="misc/numerical.html">muutils.misc.numerical</a>
for turning numbers into nice strings and back<a href="misc/freezing.html">muutils.misc.freezing</a>
for freezing things<a href="misc/classes.html">muutils.misc.classes</a>
for some weird class utilitiesdef stable_hash(s: str | bytes) -> intReturns a stable hash of the given string. not cryptographically secure, but stable between runs
WhenMissing = typing.Literal['except', 'skip', 'include']def empty_sequence_if_attr_false(itr: Iterable[Any], attr_owner: Any, attr_name: str) -> Iterable[Any]Returns itr if attr_owner has the attribute
attr_name and it boolean casts to True.
Returns an empty sequence otherwise.
Particularly useful for optionally inserting delimiters into a
sequence depending on an TokenizerElement attribute.
itr: Iterable[Any] The iterable to return if the
attribute is True.attr_owner: Any The object to check for the
attribute.attr_name: str The name of the attribute to check.itr: Iterable if attr_owner has the
attribute attr_name and it boolean casts to
True, otherwise an empty sequence.() an empty sequence if the attribute is
False or not present.def flatten(it: Iterable[Any], levels_to_flatten: int | None = None) -> GeneratorFlattens an arbitrarily nested iterable. Flattens all iterable data
types except for str and bytes.
Generator over the flattened sequence.
it: Any arbitrarily nested iterable.levels_to_flatten: Number of levels to flatten by,
starting at the outermost layer. If None, performs full
flattening.def list_split(lst: list, val: Any) -> list[list]split a list into sublists by val. similar to
“a_b_c”.split(“_“)
>>> list_split([1,2,3,0,4,5,0,6], 0)
[[1, 2, 3], [4, 5], [6]]
>>> list_split([0,1,2,3], 0)
[[], [1, 2, 3]]
>>> list_split([1,2,3], 0)
[[1, 2, 3]]
>>> list_split([], 0)
[[]]def list_join(lst: list, factory: Callable) -> listadd a new instance of factory() between each
element of lst
>>> list_join([1,2,3], lambda : 0)
[1,0,2,0,3]
>>> list_join([1,2,3], lambda: [time.sleep(0.1), time.time()][1])
[1, 1600000000.0, 2, 1600000000.1, 3]def apply_mapping(
mapping: Mapping[~_AM_K, ~_AM_V],
iter: Iterable[~_AM_K],
when_missing: Literal['except', 'skip', 'include'] = 'skip'
) -> list[typing.Union[~_AM_K, ~_AM_V]]Given an iterable and a mapping, apply the mapping to the iterable with certain options
Gotcha: if when_missing is invalid, this is totally fine
until a missing key is actually encountered.
Note: you can use this with
<a href="kappa.html#Kappa">muutils.kappa.Kappa</a>
if you want to pass a function instead of a dict
mapping : Mapping[_AM_K, _AM_V] must have
__contains__ and __getitem__, both of which
take _AM_K and the latter returns _AM_Viter : Iterable[_AM_K] the iterable to apply the
mapping towhen_missing : WhenMissing what to do when a key is
missing from the mapping – this is what distinguishes this function from
map you can choose from "skip",
"include" (without converting), and "except"
(defaults to "skip")return type is one of: - list[_AM_V] if
when_missing is "skip" or
"except" - list[Union[_AM_K, _AM_V]] if
when_missing is "include"
KeyError : if the item is missing from the mapping and
when_missing is "except"ValueError : if when_missing is
invaliddef apply_mapping_chain(
mapping: Mapping[~_AM_K, Iterable[~_AM_V]],
iter: Iterable[~_AM_K],
when_missing: Literal['except', 'skip', 'include'] = 'skip'
) -> list[typing.Union[~_AM_K, ~_AM_V]]Given an iterable and a mapping, chain the mappings together
Gotcha: if when_missing is invalid, this is totally fine
until a missing key is actually encountered.
Note: you can use this with
<a href="kappa.html#Kappa">muutils.kappa.Kappa</a>
if you want to pass a function instead of a dict
mapping : Mapping[_AM_K, Iterable[_AM_V]] must have
__contains__ and __getitem__, both of which
take _AM_K and the latter returns
Iterable[_AM_V]iter : Iterable[_AM_K] the iterable to apply the
mapping towhen_missing : WhenMissing what to do when a key is
missing from the mapping – this is what distinguishes this function from
map you can choose from "skip",
"include" (without converting), and "except"
(defaults to "skip")return type is one of: - list[_AM_V] if
when_missing is "skip" or
"except" - list[Union[_AM_K, _AM_V]] if
when_missing is "include"
KeyError : if the item is missing from the mapping and
when_missing is "except"ValueError : if when_missing is
invaliddef sanitize_name(
name: str | None,
additional_allowed_chars: str = '',
replace_invalid: str = '',
when_none: str | None = '_None_',
leading_digit_prefix: str = ''
) -> strsanitize a string, leaving only alphanumerics and
additional_allowed_chars
name : str | None input stringadditional_allowed_chars : str additional characters to
allow, none by default (defaults to "")replace_invalid : str character to replace invalid
characters with (defaults to "")when_none : str | None string to return if
name is None. if None, raises an
exception (defaults to "_None_")leading_digit_prefix : str character to prefix the
string with if it starts with a digit (defaults to "")str sanitized stringdef sanitize_fname(fname: str | None, **kwargs) -> strsanitize a filename to posix standards
_ (underscore), ‘-’ (dash)
and . (period)def sanitize_identifier(fname: str | None, **kwargs) -> strsanitize an identifier (variable or function name)
_ (underscore)_ if it starts with a digitdef dict_to_filename(
data: dict,
format_str: str = '{key}_{val}',
separator: str = '.',
max_length: int = 255
)def dynamic_docstring(**doc_params)def shorten_numerical_to_str(
num: int | float,
small_as_decimal: bool = True,
precision: int = 1
) -> strshorten a large numerical value to a string 1234 -> 1K
precision guaranteed to 1 in 10, but can be higher. reverse of
str_to_numeric
def str_to_numeric(
quantity: str,
mapping: None | bool | dict[str, int | float] = True
) -> int | floatConvert a string representing a quantity to a numeric value.
The string can represent an integer, python float, fraction, or
shortened via shorten_numerical_to_str.
>>> str_to_numeric("5")
5
>>> str_to_numeric("0.1")
0.1
>>> str_to_numeric("1/5")
0.2
>>> str_to_numeric("-1K")
-1000.0
>>> str_to_numeric("1.5M")
1500000.0
>>> str_to_numeric("1.2e2")
120.0
_SHORTEN_MAP = {1000.0: 'K', 1000000.0: 'M', 1000000000.0: 'B', 1000000000000.0: 't', 1000000000000000.0: 'q', 1e+18: 'Q'}class FrozenDict(builtins.dict):class FrozenList(builtins.list):Built-in mutable sequence.
If no argument is given, the constructor creates a new empty list. The argument must be an iterable if specified.
def append(self, value)Append object to the end of the list.
def extend(self, iterable)Extend list by appending elements from the iterable.
def insert(self, index, value)Insert object before index.
def remove(self, value)Remove first occurrence of value.
Raises ValueError if the value is not present.
def pop(self, index=-1)Remove and return item at index (default last).
Raises IndexError if list is empty or index is out of range.
def clear(self)Remove all items from list.
def freeze(instance: object) -> objectrecursively freeze an object in-place so that its attributes and elements cannot be changed
messy in the sense that sometimes the object is modified in place, but you can’t rely on that. always use the return value.
the gelidum package is a more complete implementation of this idea
def is_abstract(cls: type) -> boolReturns if a class is abstract.
def get_all_subclasses(class_: type, include_self=False) -> set[type]Returns a set containing all child classes in the subclass graph of
class_. I.e., includes subclasses of subclasses, etc.
include_self: Whether to include class_
itself in the returned setclass_: SuperclassSince most class hierarchies are small, the inefficiencies of the existing recursive implementation aren’t problematic. It might be valuable to refactor with memoization if the need arises to use this function on a very large class hierarchy.
def isinstance_by_type_name(o: object, type_name: str)Behaves like stdlib isinstance except it accepts a
string representation of the type rather than the type itself. This is a
hacky function intended to circumvent the need to import a type into a
module. It is susceptible to type name collisions.
o: Object (not the type itself) whose type to
interrogate type_name: The string returned by
type_.__name__. Generic types are not supported, only types
that would appear in type_.__mro__.
class IsDataclass(typing.Protocol):Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto[T](Protocol):
def meth(self) -> T:
...
IsDataclass(*args, **kwargs)def get_hashable_eq_attrs(dc: muutils.misc.classes.IsDataclass) -> tuple[typing.Any]Returns a tuple of all fields used for equality comparison, including the type of the dataclass itself. The type is included to preserve the unequal equality behavior of instances of different dataclasses whose fields are identical. Essentially used to generate a hashable dataclass representation for equality comparison even if it’s not frozen.
def dataclass_set_equals(
coll1: Iterable[muutils.misc.classes.IsDataclass],
coll2: Iterable[muutils.misc.classes.IsDataclass]
) -> boolCompares 2 collections of dataclass instances as if they were sets. Duplicates are ignored in the same manner as a set. Unfrozen dataclasses can’t be placed in sets since they’re not hashable. Collections of them may be compared using this function.
docs for
muutilsv0.6.21
is_abstractget_all_subclassesisinstance_by_type_nameIsDataclassget_hashable_eq_attrsdataclass_set_equalsmuutils.misc.classesdef is_abstract(cls: type) -> boolReturns if a class is abstract.
def get_all_subclasses(class_: type, include_self=False) -> set[type]Returns a set containing all child classes in the subclass graph of
class_. I.e., includes subclasses of subclasses, etc.
include_self: Whether to include class_
itself in the returned setclass_: SuperclassSince most class hierarchies are small, the inefficiencies of the existing recursive implementation aren’t problematic. It might be valuable to refactor with memoization if the need arises to use this function on a very large class hierarchy.
def isinstance_by_type_name(o: object, type_name: str)Behaves like stdlib isinstance except it accepts a
string representation of the type rather than the type itself. This is a
hacky function intended to circumvent the need to import a type into a
module. It is susceptible to type name collisions.
o: Object (not the type itself) whose type to
interrogate type_name: The string returned by
type_.__name__. Generic types are not supported, only types
that would appear in type_.__mro__.
class IsDataclass(typing.Protocol):Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto[T](Protocol):
def meth(self) -> T:
...
IsDataclass(*args, **kwargs)def get_hashable_eq_attrs(dc: muutils.misc.classes.IsDataclass) -> tuple[typing.Any]Returns a tuple of all fields used for equality comparison, including the type of the dataclass itself. The type is included to preserve the unequal equality behavior of instances of different dataclasses whose fields are identical. Essentially used to generate a hashable dataclass representation for equality comparison even if it’s not frozen.
def dataclass_set_equals(
coll1: Iterable[muutils.misc.classes.IsDataclass],
coll2: Iterable[muutils.misc.classes.IsDataclass]
) -> boolCompares 2 collections of dataclass instances as if they were sets. Duplicates are ignored in the same manner as a set. Unfrozen dataclasses can’t be placed in sets since they’re not hashable. Collections of them may be compared using this function.
docs for
muutilsv0.6.21
muutils.misc.freezingclass FrozenDict(builtins.dict):class FrozenList(builtins.list):Built-in mutable sequence.
If no argument is given, the constructor creates a new empty list. The argument must be an iterable if specified.
def append(self, value)Append object to the end of the list.
def extend(self, iterable)Extend list by appending elements from the iterable.
def insert(self, index, value)Insert object before index.
def remove(self, value)Remove first occurrence of value.
Raises ValueError if the value is not present.
def pop(self, index=-1)Remove and return item at index (default last).
Raises IndexError if list is empty or index is out of range.
def clear(self)Remove all items from list.
def freeze(instance: object) -> objectrecursively freeze an object in-place so that its attributes and elements cannot be changed
messy in the sense that sometimes the object is modified in place, but you can’t rely on that. always use the return value.
the gelidum package is a more complete implementation of this idea
docs for
muutilsv0.6.21
muutils.misc.hashingdef stable_hash(s: str | bytes) -> intReturns a stable hash of the given string. not cryptographically secure, but stable between runs
def base64_hash(s: str | bytes) -> strReturns a base64 representation of the hash of the given string. not cryptographically secure
docs for
muutilsv0.6.21
muutils.misc.numericaldef shorten_numerical_to_str(
num: int | float,
small_as_decimal: bool = True,
precision: int = 1
) -> strshorten a large numerical value to a string 1234 -> 1K
precision guaranteed to 1 in 10, but can be higher. reverse of
str_to_numeric
def str_to_numeric(
quantity: str,
mapping: None | bool | dict[str, int | float] = True
) -> int | floatConvert a string representing a quantity to a numeric value.
The string can represent an integer, python float, fraction, or
shortened via shorten_numerical_to_str.
>>> str_to_numeric("5")
5
>>> str_to_numeric("0.1")
0.1
>>> str_to_numeric("1/5")
0.2
>>> str_to_numeric("-1K")
-1000.0
>>> str_to_numeric("1.5M")
1500000.0
>>> str_to_numeric("1.2e2")
120.0
docs for
muutilsv0.6.21
WhenMissingempty_sequence_if_attr_falseflattenlist_splitlist_joinapply_mappingapply_mapping_chainmuutils.misc.sequenceWhenMissing = typing.Literal['except', 'skip', 'include']def empty_sequence_if_attr_false(itr: Iterable[Any], attr_owner: Any, attr_name: str) -> Iterable[Any]Returns itr if attr_owner has the attribute
attr_name and it boolean casts to True.
Returns an empty sequence otherwise.
Particularly useful for optionally inserting delimiters into a
sequence depending on an TokenizerElement attribute.
itr: Iterable[Any] The iterable to return if the
attribute is True.attr_owner: Any The object to check for the
attribute.attr_name: str The name of the attribute to check.itr: Iterable if attr_owner has the
attribute attr_name and it boolean casts to
True, otherwise an empty sequence.() an empty sequence if the attribute is
False or not present.def flatten(it: Iterable[Any], levels_to_flatten: int | None = None) -> GeneratorFlattens an arbitrarily nested iterable. Flattens all iterable data
types except for str and bytes.
Generator over the flattened sequence.
it: Any arbitrarily nested iterable.levels_to_flatten: Number of levels to flatten by,
starting at the outermost layer. If None, performs full
flattening.def list_split(lst: list, val: Any) -> list[list]split a list into sublists by val. similar to
“a_b_c”.split(“_“)
>>> list_split([1,2,3,0,4,5,0,6], 0)
[[1, 2, 3], [4, 5], [6]]
>>> list_split([0,1,2,3], 0)
[[], [1, 2, 3]]
>>> list_split([1,2,3], 0)
[[1, 2, 3]]
>>> list_split([], 0)
[[]]def list_join(lst: list, factory: Callable) -> listadd a new instance of factory() between each
element of lst
>>> list_join([1,2,3], lambda : 0)
[1,0,2,0,3]
>>> list_join([1,2,3], lambda: [time.sleep(0.1), time.time()][1])
[1, 1600000000.0, 2, 1600000000.1, 3]def apply_mapping(
mapping: Mapping[~_AM_K, ~_AM_V],
iter: Iterable[~_AM_K],
when_missing: Literal['except', 'skip', 'include'] = 'skip'
) -> list[typing.Union[~_AM_K, ~_AM_V]]Given an iterable and a mapping, apply the mapping to the iterable with certain options
Gotcha: if when_missing is invalid, this is totally fine
until a missing key is actually encountered.
Note: you can use this with
<a href="../kappa.html#Kappa">muutils.kappa.Kappa</a>
if you want to pass a function instead of a dict
mapping : Mapping[_AM_K, _AM_V] must have
__contains__ and __getitem__, both of which
take _AM_K and the latter returns _AM_Viter : Iterable[_AM_K] the iterable to apply the
mapping towhen_missing : WhenMissing what to do when a key is
missing from the mapping – this is what distinguishes this function from
map you can choose from "skip",
"include" (without converting), and "except"
(defaults to "skip")return type is one of: - list[_AM_V] if
when_missing is "skip" or
"except" - list[Union[_AM_K, _AM_V]] if
when_missing is "include"
KeyError : if the item is missing from the mapping and
when_missing is "except"ValueError : if when_missing is
invaliddef apply_mapping_chain(
mapping: Mapping[~_AM_K, Iterable[~_AM_V]],
iter: Iterable[~_AM_K],
when_missing: Literal['except', 'skip', 'include'] = 'skip'
) -> list[typing.Union[~_AM_K, ~_AM_V]]Given an iterable and a mapping, chain the mappings together
Gotcha: if when_missing is invalid, this is totally fine
until a missing key is actually encountered.
Note: you can use this with
<a href="../kappa.html#Kappa">muutils.kappa.Kappa</a>
if you want to pass a function instead of a dict
mapping : Mapping[_AM_K, Iterable[_AM_V]] must have
__contains__ and __getitem__, both of which
take _AM_K and the latter returns
Iterable[_AM_V]iter : Iterable[_AM_K] the iterable to apply the
mapping towhen_missing : WhenMissing what to do when a key is
missing from the mapping – this is what distinguishes this function from
map you can choose from "skip",
"include" (without converting), and "except"
(defaults to "skip")return type is one of: - list[_AM_V] if
when_missing is "skip" or
"except" - list[Union[_AM_K, _AM_V]] if
when_missing is "include"
KeyError : if the item is missing from the mapping and
when_missing is "except"ValueError : if when_missing is
invaliddocs for
muutilsv0.6.21
muutils.misc.stringdef sanitize_name(
name: str | None,
additional_allowed_chars: str = '',
replace_invalid: str = '',
when_none: str | None = '_None_',
leading_digit_prefix: str = ''
) -> strsanitize a string, leaving only alphanumerics and
additional_allowed_chars
name : str | None input stringadditional_allowed_chars : str additional characters to
allow, none by default (defaults to "")replace_invalid : str character to replace invalid
characters with (defaults to "")when_none : str | None string to return if
name is None. if None, raises an
exception (defaults to "_None_")leading_digit_prefix : str character to prefix the
string with if it starts with a digit (defaults to "")str sanitized stringdef sanitize_fname(fname: str | None, **kwargs) -> strsanitize a filename to posix standards
_ (underscore), ‘-’ (dash)
and . (period)def sanitize_identifier(fname: str | None, **kwargs) -> strsanitize an identifier (variable or function name)
_ (underscore)_ if it starts with a digitdef dict_to_filename(
data: dict,
format_str: str = '{key}_{val}',
separator: str = '.',
max_length: int = 255
)def dynamic_docstring(**doc_params)docs for
muutilsv0.6.21
miscellaneous utilities for ML pipelines
ARRAY_IMPORTSDEFAULT_SEEDGLOBAL_SEEDget_deviceset_reproducibilitychunksget_checkpoint_paths_for_runregister_methodpprint_summarymuutils.mlutilsmiscellaneous utilities for ML pipelines
ARRAY_IMPORTS: bool = True
DEFAULT_SEED: int = 42
GLOBAL_SEED: int = 42
def get_device(device: Union[str, torch.device, NoneType] = None) -> torch.deviceGet the torch.device instance on which torch.Tensors
should be allocated.
def set_reproducibility(seed: int = 42)Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information.
Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance.
def chunks(it, chunk_size)Yield successive chunks from an iterator.
def get_checkpoint_paths_for_run(
run_path: pathlib.Path,
extension: Literal['pt', 'zanj'],
checkpoints_format: str = 'checkpoints/model.iter_*.{extension}'
) -> list[tuple[int, pathlib.Path]]get checkpoints of the format from the run_path
note that checkpoints_format should contain a glob
pattern with: - unresolved “{extension}” format term for the extension -
a wildcard for the iteration number
def register_method(
method_dict: dict[str, typing.Callable[..., typing.Any]],
custom_name: Optional[str] = None
) -> Callable[[~F], ~F]Decorator to add a method to the method_dict
def pprint_summary(summary: dict)docs for
muutilsv0.6.21
utilities for working with notebooks
configure_notebookconvert_ipynb_to_scriptrun_notebook_testsmermaid,
print_texmuutils.nbutilsutilities for working with notebooks
configure_notebookconvert_ipynb_to_scriptrun_notebook_testsmermaid,
print_texdef mm(graph)for plotting mermaid.js diagrams
docs for
muutilsv0.6.21
shared utilities for setting up a notebook
PlotlyNotInstalledWarningPLOTLY_IMPORTEDPlottingModePLOT_MODECONVERSION_PLOTMODE_OVERRIDEFIG_COUNTERFIG_OUTPUT_FMTFIG_NUMBERED_FNAMEFIG_CONFIGFIG_BASEPATHCLOSE_AFTER_PLOTSHOWMATPLOTLIB_FORMATSTIKZPLOTLIB_FORMATSUnknownFigureFormatWarninguniversal_savefigsetup_plotsconfigure_notebookplotshowmuutils.nbutils.configure_notebookshared utilities for setting up a notebook
class PlotlyNotInstalledWarning(builtins.UserWarning):Base class for warnings generated by user code.
PLOTLY_IMPORTED: bool = True
PlottingMode = typing.Literal['ignore', 'inline', 'widget', 'save']
PLOT_MODE: Literal['ignore', 'inline', 'widget', 'save'] = 'inline'
CONVERSION_PLOTMODE_OVERRIDE: Optional[Literal['ignore', 'inline', 'widget', 'save']] = None
FIG_COUNTER: int = 0
FIG_OUTPUT_FMT: str | None = None
FIG_NUMBERED_FNAME: str = 'figure-{num}'
FIG_CONFIG: dict | None = None
FIG_BASEPATH: str | None = None
CLOSE_AFTER_PLOTSHOW: bool = False
MATPLOTLIB_FORMATS = ['pdf', 'png', 'jpg', 'jpeg', 'svg', 'eps', 'ps', 'tif', 'tiff']
TIKZPLOTLIB_FORMATS = ['tex', 'tikz']
class UnknownFigureFormatWarning(builtins.UserWarning):Base class for warnings generated by user code.
def universal_savefig(fname: str, fmt: str | None = None) -> Nonedef setup_plots(
plot_mode: Literal['ignore', 'inline', 'widget', 'save'] = 'inline',
fig_output_fmt: str | None = 'pdf',
fig_numbered_fname: str = 'figure-{num}',
fig_config: dict | None = None,
fig_basepath: str | None = None,
close_after_plotshow: bool = False
) -> NoneSet up plot saving/rendering options
def configure_notebook(
*args,
seed: int = 42,
device: Any = None,
dark_mode: bool = True,
plot_mode: Literal['ignore', 'inline', 'widget', 'save'] = 'inline',
fig_output_fmt: str | None = 'pdf',
fig_numbered_fname: str = 'figure-{num}',
fig_config: dict | None = None,
fig_basepath: str | None = None,
close_after_plotshow: bool = False
) -> torch.device | NoneShared Jupyter notebook setup steps
seed : int random seed across libraries including
torch, numpy, and random (defaults to 42) (defaults to
42)device : typing.Any pytorch device to use (defaults to
None)dark_mode : bool figures in dark mode (defaults to
True)plot_mode : PlottingMode how to display plots, one of
PlottingMode or
["ignore", "inline", "widget", "save"] (defaults to
"inline")fig_output_fmt : str | None format for saving figures
(defaults to "pdf")fig_numbered_fname : str format for saving figures with
numbers (if they aren’t named) (defaults to
"figure-{num}")fig_config : dict | None metadata to save with the
figures (defaults to None)fig_basepath : str | None base path for saving figures
(defaults to None)close_after_plotshow : bool close figures after showing
them (defaults to False)torch.device|None the device set, if torch is
installeddef plotshow(
fname: str | None = None,
plot_mode: Optional[Literal['ignore', 'inline', 'widget', 'save']] = None,
fmt: str | None = None
)Show the active plot, depending on global configs
docs for
muutilsv0.6.21
fast conversion of Jupyter Notebooks to scripts, with some basic and hacky filtering and formatting.
muutils.nbutils.convert_ipynb_to_scriptfast conversion of Jupyter Notebooks to scripts, with some basic and hacky filtering and formatting.
DISABLE_PLOTS: dict[str, list[str]] = {'matplotlib': ['\n# ------------------------------------------------------------\n# Disable matplotlib plots, done during processing byconvert_ipynb_to_script.py\nimport matplotlib.pyplot as plt\nplt.show = lambda: None\n# ------------------------------------------------------------\n'], 'circuitsvis': ['\n# ------------------------------------------------------------\n# Disable circuitsvis plots, done during processing byconvert_ipynb_to_script.py\nfrom circuitsvis.utils.convert_props import PythonProperty, convert_props\nfrom circuitsvis.utils.render import RenderedHTML, render, render_cdn, render_local\n\ndef new_render(\n react_element_name: str,\n **kwargs: PythonProperty\n) -> RenderedHTML:\n "return a visualization as raw HTML"\n local_src = render_local(react_element_name, **kwargs)\n cdn_src = render_cdn(react_element_name, **kwargs)\n # return as string instead of RenderedHTML for CI\n return str(RenderedHTML(local_src, cdn_src))\n\nrender = new_render\n# ------------------------------------------------------------\n'], 'muutils': ['import muutils.nbutils.configure_notebook as nb_conf\nnb_conf.CONVERSION_PLOTMODE_OVERRIDE = "ignore"\n']}
DISABLE_PLOTS_WARNING: list[str] = ["# ------------------------------------------------------------\n# WARNING: this script is auto-generated byconvert_ipynb_to_script.py\n# showing plots has been disabled, so this is presumably in a temp dict for CI or something\n# so don't modify this code, it will be overwritten!\n# ------------------------------------------------------------\n"]
def disable_plots_in_script(script_lines: list[str]) -> list[str]Disable plots in a script by adding cursed things after the import statements
def convert_ipynb(
notebook: dict,
strip_md_cells: bool = False,
header_comment: str = '#%%',
disable_plots: bool = False,
filter_out_lines: Union[str, Sequence[str]] = ('%', '!')
) -> strConvert Jupyter Notebook to a script, doing some basic filtering and formatting.
- `notebook: dict`: Jupyter Notebook loaded as json.
- `strip_md_cells: bool = False`: Remove markdown cells from the output script.
- `header_comment: str = r'#%%'`: Comment string to separate cells in the output script.
- `disable_plots: bool = False`: Disable plots in the output script.
- `filter_out_lines: str|typing.Sequence[str] = ('%', '!')`: comment out lines starting with these strings (in code blocks).
if a string is passed, it will be split by char and each char will be treated as a separate filter.
- `str`: Converted script.
def process_file(
in_file: str,
out_file: str | None = None,
strip_md_cells: bool = False,
header_comment: str = '#%%',
disable_plots: bool = False,
filter_out_lines: Union[str, Sequence[str]] = ('%', '!')
)def process_dir(
input_dir: str,
output_dir: str,
strip_md_cells: bool = False,
header_comment: str = '#%%',
disable_plots: bool = False,
filter_out_lines: Union[str, Sequence[str]] = ('%', '!')
)Convert all Jupyter Notebooks in a directory to scripts.
- `input_dir: str`: Input directory.
- `output_dir: str`: Output directory.
- `strip_md_cells: bool = False`: Remove markdown cells from the output script.
- `header_comment: str = r'#%%'`: Comment string to separate cells in the output script.
- `disable_plots: bool = False`: Disable plots in the output script.
- `filter_out_lines: str|typing.Sequence[str] = ('%', '!')`: comment out lines starting with these strings (in code blocks).
if a string is passed, it will be split by char and each char will be treated as a separate filter.
docs for
muutilsv0.6.21
display mermaid.js diagrams in jupyter notebooks by the
mermaid.ink/img service
muutils.nbutils.mermaiddisplay mermaid.js diagrams in jupyter notebooks by the
mermaid.ink/img service
def mm(graph)for plotting mermaid.js diagrams
docs for
muutilsv0.6.21
quickly print a sympy expression in latex
muutils.nbutils.print_texquickly print a sympy expression in latex
def print_tex(
expr: sympy.core.expr.Expr,
name: str | None = None,
plain: bool = False,
rendered: bool = True
)function for easily rendering a sympy expression in latex
docs for
muutilsv0.6.21
turn a folder of notebooks into scripts, run them, and make sure they work.
made to be called as
python -m muutils.nbutils.run_notebook_tests --notebooks-dir <notebooks_dir> --converted-notebooks-temp-dir <converted_notebooks_temp_dir>muutils.nbutils.run_notebook_teststurn a folder of notebooks into scripts, run them, and make sure they work.
made to be called as
python -m <a href="">muutils.nbutils.run_notebook_tests</a> --notebooks-dir <notebooks_dir> --converted-notebooks-temp-dir <converted_notebooks_temp_dir>class NotebookTestError(builtins.Exception):Common base class for all non-exit exceptions.
SUCCESS_STR: str = '[OK]'
FAILURE_STR: str = '[!!]'
def run_notebook_tests(
notebooks_dir: pathlib.Path,
converted_notebooks_temp_dir: pathlib.Path,
CI_output_suffix: str = '.CI-output.txt',
run_python_cmd: Optional[str] = None,
run_python_cmd_fmt: str = '{python_tool} run python',
python_tool: str = 'poetry',
exit_on_first_fail: bool = False
)Run converted Jupyter notebooks as Python scripts and verify they execute successfully.
Takes a directory of notebooks and their corresponding converted Python scripts, executes each script, and captures the output. Failures are collected and reported, with optional early exit on first failure.
notebooks_dir : Path Directory containing the original
.ipynb notebook filesconverted_notebooks_temp_dir : Path Directory
containing the corresponding converted .py filesCI_output_suffix : str Suffix to append to output files
capturing execution results (defaults to
".CI-output.txt")run_python_cmd : str | None Custom command to run
Python scripts. Overrides python_tool and run_python_cmd_fmt if provided
(defaults to None)run_python_cmd_fmt : str Format string for constructing
the Python run command (defaults to
"{python_tool} run python")python_tool : str Tool used to run Python (e.g. poetry,
uv) (defaults to "poetry")exit_on_first_fail : bool Whether to raise exception
immediately on first notebook failure (defaults to
False)NoneNotebookTestError: If any notebooks fail to execute, or
if input directories are invalidTypeError: If run_python_cmd is provided but not a
string>>> run_notebook_tests(
... notebooks_dir=Path("notebooks"),
... converted_notebooks_temp_dir=Path("temp/converted"),
... python_tool="poetry"
... )
### testing notebooks in 'notebooks'
### reading converted notebooks from 'temp/converted'
Running 1/2: temp/converted/notebook1.py
Output in temp/converted/notebook1.CI-output.txt
{SUCCESS_STR} Run completed with return code 0docs for
muutilsv0.6.21
ProgressBarFunctionProgressBarOptionDEFAULT_PBAR_FNspinner_fn_wrapno_progress_fn_wrapset_up_progress_bar_fnrun_maybe_parallelmuutils.parallelclass ProgressBarFunction(typing.Protocol):a protocol for a progress bar function
ProgressBarFunction(*args, **kwargs)ProgressBarOption = typing.Literal['tqdm', 'spinner', 'none', None]def DEFAULT_PBAR_FN(*_, **__)Decorate an iterable object, returning an iterator which acts exactly like the original iterable, but prints a dynamically updating progressbar every time a value is requested.
iterable : iterable, optional Iterable to decorate with a
progressbar. Leave blank to manually manage the updates. desc : str,
optional Prefix for the progressbar. total : int or float, optional The
number of expected iterations. If unspecified, len(iterable) is used if
possible. If float(“inf”) or as a last resort, only basic progress
statistics are displayed (no ETA, no progressbar). If gui
is True and this parameter needs subsequent updating, specify an initial
arbitrary large positive number, e.g. 9e9. leave : bool, optional If
[default: True], keeps all traces of the progressbar upon termination of
iteration. If None, will leave only if
position is 0. file :
io.TextIOWrapper or io.StringIO, optional
Specifies where to output the progress messages (default: sys.stderr).
Uses file.write(str) and file.flush() methods.
For encoding, see write_bytes. ncols : int, optional The
width of the entire output message. If specified, dynamically resizes
the progressbar to stay within this bound. If unspecified, attempts to
use environment width. The fallback is a meter width of 10 and no limit
for the counter and statistics. If 0, will not print any meter (only
stats). mininterval : float, optional Minimum progress display update
interval [default: 0.1] seconds. maxinterval : float, optional Maximum
progress display update interval [default: 10] seconds. Automatically
adjusts miniters to correspond to mininterval
after long display update lag. Only works if
dynamic_miniters or monitor thread is enabled. miniters :
int or float, optional Minimum progress display update interval, in
iterations. If 0 and dynamic_miniters, will automatically
adjust to equal mininterval (more CPU efficient, good for
tight loops). If > 0, will skip display of specified number of
iterations. Tweak this and mininterval to get very
efficient loops. If your progress is erratic with both fast and slow
iterations (network, skipping items, etc) you should set miniters=1.
ascii : bool or str, optional If unspecified or False, use unicode
(smooth blocks) to fill the meter. The fallback is to use ASCII
characters ” 123456789#“. disable : bool, optional Whether to disable
the entire progressbar wrapper [default: False]. If set to None, disable
on non-TTY. unit : str, optional String that will be used to define the
unit of each iteration [default: it]. unit_scale : bool or int or float,
optional If 1 or True, the number of iterations will be reduced/scaled
automatically and a metric prefix following the International System of
Units standard will be added (kilo, mega, etc.) [default: False]. If any
other non-zero number, will scale total and n.
dynamic_ncols : bool, optional If set, constantly alters
ncols and nrows to the environment (allowing
for window resizes) [default: False]. smoothing : float, optional
Exponential moving average smoothing factor for speed estimates (ignored
in GUI mode). Ranges from 0 (average speed) to 1 (current/instantaneous
speed) [default: 0.3]. bar_format : str, optional Specify a custom bar
string formatting. May impact performance. [default:
‘{l_bar}{bar}{r_bar}’], where l_bar=‘{desc}: {percentage:3.0f}%|’ and
r_bar=‘| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, ’
’{rate_fmt}{postfix}]’ Possible vars: l_bar, bar, r_bar, n, n_fmt,
total, total_fmt, percentage, elapsed, elapsed_s, ncols, nrows, desc,
unit, rate, rate_fmt, rate_noinv, rate_noinv_fmt, rate_inv,
rate_inv_fmt, postfix, unit_divisor, remaining, remaining_s, eta. Note
that a trailing”: ” is automatically removed after {desc} if the latter
is empty. initial : int or float, optional The initial counter value.
Useful when restarting a progress bar [default: 0]. If using float,
consider specifying {n:.3f} or similar in
bar_format, or specifying unit_scale. position
: int, optional Specify the line offset to print this bar (starting from
0) Automatic if unspecified. Useful to manage multiple bars at once (eg,
from threads). postfix : dict or *, optional Specify additional stats to
display at the end of the bar. Calls set_postfix(**postfix)
if possible (dict). unit_divisor : float, optional [default: 1000],
ignored unless unit_scale is True. write_bytes : bool,
optional Whether to write bytes. If (default: False) will write unicode.
lock_args : tuple, optional Passed to refresh for
intermediate output (initialisation, iterating, and updating). nrows :
int, optional The screen height. If specified, hides nested bars outside
this bound. If unspecified, attempts to use environment height. The
fallback is 20. colour : str, optional Bar colour (e.g. ‘green’,
‘#00ff00’). delay : float, optional Don’t display until [default: 0]
seconds have elapsed. gui : bool, optional WARNING: internal parameter -
do not use. Use tqdm.gui.tqdm(…) instead. If set, will attempt to use
matplotlib animations for a graphical output [default: False].
out : decorated iterator.
def spinner_fn_wrap(x: Iterable, **kwargs) -> Listdef no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterablefallback to no progress bar
def set_up_progress_bar_fn(
pbar: Union[muutils.parallel.ProgressBarFunction, Literal['tqdm', 'spinner', 'none', None]],
pbar_kwargs: Optional[Dict[str, Any]] = None,
**extra_kwargs
) -> muutils.parallel.ProgressBarFunctiondef run_maybe_parallel(
func: Callable[[~InputType], ~OutputType],
iterable: Iterable[~InputType],
parallel: Union[bool, int],
pbar_kwargs: Optional[Dict[str, Any]] = None,
chunksize: Optional[int] = None,
keep_ordered: bool = True,
use_multiprocess: bool = False,
pbar: Union[muutils.parallel.ProgressBarFunction, Literal['tqdm', 'spinner', 'none', None]] = <function tqdm>
) -> List[~OutputType]a function to make it easier to sometimes parallelize an operation
parallel is False, then the function
will run in serial, running map(func, iterable)parallel is True, then the function
will run in parallel, running in parallel with the maximum number of
processesparallel is an int, it must be greater
than 1, and the function will run in parallel with the number of
processes specified by parallelthe maximum number of processes is given by the
min(len(iterable), multiprocessing.cpu_count())
func : Callable[[InputType], OutputType] function
passed to either map or Pool.imapiterable : Iterable[InputType] iterable passed to
either map or Pool.imapparallel : bool | int descriptionpbar_kwargs : Dict[str, Any] descriptionList[OutputType] descriptionValueError : descriptiondocs for
muutilsv0.6.21
decorator spinner_decorator and context manager
SpinnerContext to display a spinner
using the base Spinner class while some code is
running.
DecoratedFunctionSpinnerConfigSpinnerConfigArgSPINNERSSpinnerNoOpContextManagerSpinnerContextspinner_decoratormuutils.spinnerdecorator spinner_decorator and context manager
SpinnerContext to display a spinner
using the base Spinner class while some code is
running.
DecoratedFunction = ~DecoratedFunctionDefine a generic type for the decorated function
class SpinnerConfig:SpinnerConfig(working: List[str] = <factory>, success: str = '✔️', fail: str = '❌')working: List[str]
success: str = '✔️'
fail: str = '❌'
def is_ascii(self) -> boolwhether all characters are ascii
def eq_lens(self) -> boolwhether all working characters are the same length
def is_valid(self) -> boolwhether the spinner config is valid
def from_any(
cls,
arg: Union[str, List[str], muutils.spinner.SpinnerConfig, dict]
) -> muutils.spinner.SpinnerConfigSpinnerConfigArg = typing.Union[str, typing.List[str], muutils.spinner.SpinnerConfig, dict]
SPINNERS: Dict[str, muutils.spinner.SpinnerConfig] = {'default': SpinnerConfig(working=['|', '/', '-', '\\'], success='#', fail='X'), 'dots': SpinnerConfig(working=['. ', '.. ', '...'], success='***', fail='xxx'), 'bars': SpinnerConfig(working=['| ', '|| ', '|||'], success='|||', fail='///'), 'arrows': SpinnerConfig(working=['<', '^', '>', 'v'], success='►', fail='✖'), 'arrows_2': SpinnerConfig(working=['←', '↖', '↑', '↗', '→', '↘', '↓', '↙'], success='→', fail='↯'), 'bouncing_bar': SpinnerConfig(working=['[ ]', '[= ]', '[== ]', '[=== ]', '[ ===]', '[ ==]', '[ =]'], success='[====]', fail='[XXXX]'), 'bar': SpinnerConfig(working=['[ ]', '[- ]', '[--]', '[ -]'], success='[==]', fail='[xx]'), 'bouncing_ball': SpinnerConfig(working=['( ● )', '( ● )', '( ● )', '( ● )', '( ●)', '( ● )', '( ● )', '( ● )', '( ● )', '(● )'], success='(●●●●●●)', fail='( ✖ )'), 'ooo': SpinnerConfig(working=['.', 'o', 'O', 'o'], success='O', fail='x'), 'braille': SpinnerConfig(working=['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏'], success='⣿', fail='X'), 'clock': SpinnerConfig(working=['🕛', '🕐', '🕑', '🕒', '🕓', '🕔', '🕕', '🕖', '🕗', '🕘', '🕙', '🕚'], success='✔️', fail='❌'), 'hourglass': SpinnerConfig(working=['⏳', '⌛'], success='✔️', fail='❌'), 'square_corners': SpinnerConfig(working=['◰', '◳', '◲', '◱'], success='◼', fail='✖'), 'triangle': SpinnerConfig(working=['◢', '◣', '◤', '◥'], success='◆', fail='✖'), 'square_dot': SpinnerConfig(working=['⣷', '⣯', '⣟', '⡿', '⢿', '⣻', '⣽', '⣾'], success='⣿', fail='❌'), 'box_bounce': SpinnerConfig(working=['▌', '▀', '▐', '▄'], success='■', fail='✖'), 'hamburger': SpinnerConfig(working=['☱', '☲', '☴'], success='☰', fail='✖'), 'earth': SpinnerConfig(working=['🌍', '🌎', '🌏'], success='✔️', fail='❌'), 'growing_dots': SpinnerConfig(working=['⣀', '⣄', '⣤', '⣦', '⣶', '⣷', '⣿'], success='⣿', fail='✖'), 'dice': SpinnerConfig(working=['⚀', '⚁', '⚂', '⚃', '⚄', '⚅'], success='🎲', fail='✖'), 'wifi': SpinnerConfig(working=['▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'], success='✔️', fail='❌'), 'bounce': SpinnerConfig(working=['⠁', '⠂', '⠄', '⠂'], success='⠿', fail='⢿'), 'arc': SpinnerConfig(working=['◜', '◠', '◝', '◞', '◡', '◟'], success='○', fail='✖'), 'toggle': SpinnerConfig(working=['⊶', '⊷'], success='⊷', fail='⊗'), 'toggle2': SpinnerConfig(working=['▫', '▪'], success='▪', fail='✖'), 'toggle3': SpinnerConfig(working=['□', '■'], success='■', fail='✖'), 'toggle4': SpinnerConfig(working=['■', '□', '▪', '▫'], success='■', fail='✖'), 'toggle5': SpinnerConfig(working=['▮', '▯'], success='▮', fail='✖'), 'toggle7': SpinnerConfig(working=['⦾', '⦿'], success='⦿', fail='✖'), 'toggle8': SpinnerConfig(working=['◍', '◌'], success='◍', fail='✖'), 'toggle9': SpinnerConfig(working=['◉', '◎'], success='◉', fail='✖'), 'arrow2': SpinnerConfig(working=['⬆️ ', '↗️ ', '➡️ ', '↘️ ', '⬇️ ', '↙️ ', '⬅️ ', '↖️ '], success='➡️', fail='❌'), 'point': SpinnerConfig(working=['∙∙∙', '●∙∙', '∙●∙', '∙∙●', '∙∙∙'], success='●●●', fail='xxx'), 'layer': SpinnerConfig(working=['-', '=', '≡'], success='≡', fail='✖'), 'speaker': SpinnerConfig(working=['🔈 ', '🔉 ', '🔊 ', '🔉 '], success='🔊', fail='🔇'), 'orangePulse': SpinnerConfig(working=['🔸 ', '🔶 ', '🟠 ', '🟠 ', '🔷 '], success='🟠', fail='❌'), 'bluePulse': SpinnerConfig(working=['🔹 ', '🔷 ', '🔵 ', '🔵 ', '🔷 '], success='🔵', fail='❌'), 'satellite_signal': SpinnerConfig(working=['📡 ', '📡· ', '📡·· ', '📡···', '📡 ··', '📡 ·'], success='📡 ✔️ ', fail='📡 ❌ '), 'rocket_orbit': SpinnerConfig(working=['🌍🚀 ', '🌏 🚀 ', '🌎 🚀'], success='🌍 ✨', fail='🌍 💥'), 'ogham': SpinnerConfig(working=['ᚁ ', 'ᚂ ', 'ᚃ ', 'ᚄ', 'ᚅ'], success='᚛᚜', fail='✖'), 'eth': SpinnerConfig(working=['᛫', '፡', '፥', '፤', '፧', '።', '፨'], success='፠', fail='✖')}
class Spinner:displays a spinner, and optionally elapsed time and a mutable value while a function is running.
update_interval : float how often to update the spinner
display in seconds (defaults to 0.1)initial_value : str initial value to display with the
spinner (defaults to "")message : str message to display with the spinner
(defaults to "")format_string : str string to format the spinner with.
must have "\r" prepended to clear the line. allowed keys
are spinner, elapsed_time,
message, and value (defaults to
"\r{spinner} ({elapsed_time:.2f}s) {message}{value}")output_stream : TextIO stream to write the spinner to
(defaults to sys.stdout)format_string_when_updated : Union[bool,str] whether to
use a different format string when the value is updated. if
True, use the default format string with a newline
appended. if a string, use that string. this is useful if you want
update_value to print to console and be preserved. (defaults to
False)spinner_chars : Union[str, Sequence[str]] sequence of
strings, or key to look up in SPINNER_CHARS, to use as the
spinner characters (defaults to "default")spinner_complete : str string to display when the
spinner is complete (defaults to looking up spinner_chars
in SPINNER_COMPLETE or "#")update_value(value: Any) -> None update the current
value displayed by the spinnerwith SpinnerContext() as sp:
for i in range(1):
time.sleep(0.1)
spinner.update_value(f"Step {i+1}")@spinner_decorator
def long_running_function():
for i in range(1):
time.sleep(0.1)
spinner.update_value(f"Step {i+1}")
return "Function completed"Spinner(
*args,
config: Union[str, List[str], muutils.spinner.SpinnerConfig, dict] = 'default',
update_interval: float = 0.1,
initial_value: str = '',
message: str = '',
format_string: str = '\r{spinner} ({elapsed_time:.2f}s) {message}{value}',
output_stream: <class 'TextIO'> = <_io.StringIO object>,
format_string_when_updated: Union[str, bool] = False,
spinner_chars: Union[str, Sequence[str], NoneType] = None,
spinner_complete: Optional[str] = None,
**kwargs: Any
)config: muutils.spinner.SpinnerConfig
format_string_when_updated: Optional[str]
format string to use when the value is updated
update_interval: float
message: str
current_value: Any
format_string: str
output_stream: <class 'TextIO'>
start_time: float
for measuring elapsed time
stop_spinner: threading.Eventto stop the spinner
spinner_thread: Optional[threading.Thread]the thread running the spinner
value_changed: boolwhether the value has been updated since the last display
term_width: intwidth of the terminal, for padding with spaces
state: Literal['initialized', 'running', 'success', 'fail']def spin(self) -> NoneFunction to run in a separate thread, displaying the spinner and optional information
def update_value(self, value: Any) -> NoneUpdate the current value displayed by the spinner
def start(self) -> NoneStart the spinner
def stop(self, failed: bool = False) -> NoneStop the spinner
class NoOpContextManager(typing.ContextManager):A context manager that does nothing.
NoOpContextManager(*args, **kwargs)class SpinnerContext(Spinner, typing.ContextManager):displays a spinner, and optionally elapsed time and a mutable value while a function is running.
update_interval : float how often to update the spinner
display in seconds (defaults to 0.1)initial_value : str initial value to display with the
spinner (defaults to "")message : str message to display with the spinner
(defaults to "")format_string : str string to format the spinner with.
must have "\r" prepended to clear the line. allowed keys
are spinner, elapsed_time,
message, and value (defaults to
"\r{spinner} ({elapsed_time:.2f}s) {message}{value}")output_stream : TextIO stream to write the spinner to
(defaults to sys.stdout)format_string_when_updated : Union[bool,str] whether to
use a different format string when the value is updated. if
True, use the default format string with a newline
appended. if a string, use that string. this is useful if you want
update_value to print to console and be preserved. (defaults to
False)spinner_chars : Union[str, Sequence[str]] sequence of
strings, or key to look up in SPINNER_CHARS, to use as the
spinner characters (defaults to "default")spinner_complete : str string to display when the
spinner is complete (defaults to looking up spinner_chars
in SPINNER_COMPLETE or "#")update_value(value: Any) -> None update the current
value displayed by the spinnerwith SpinnerContext() as sp:
for i in range(1):
time.sleep(0.1)
spinner.update_value(f"Step {i+1}")@spinner_decorator
def long_running_function():
for i in range(1):
time.sleep(0.1)
spinner.update_value(f"Step {i+1}")
return "Function completed"Spinnerconfigformat_string_when_updatedupdate_intervalmessagecurrent_valueformat_stringoutput_streamstart_timestop_spinnerspinner_threadvalue_changedterm_widthstatespinupdate_valuestartstopdef spinner_decorator(
*args,
config: Union[str, List[str], muutils.spinner.SpinnerConfig, dict] = 'default',
update_interval: float = 0.1,
initial_value: str = '',
message: str = '',
format_string: str = '{spinner} ({elapsed_time:.2f}s) {message}{value}',
output_stream: <class 'TextIO'> = <_io.StringIO object>,
mutable_kwarg_key: Optional[str] = None,
spinner_chars: Union[str, Sequence[str], NoneType] = None,
spinner_complete: Optional[str] = None,
**kwargs
) -> Callable[[~DecoratedFunction], ~DecoratedFunction]displays a spinner, and optionally elapsed time and a mutable value while a function is running.
update_interval : float how often to update the spinner
display in seconds (defaults to 0.1)initial_value : str initial value to display with the
spinner (defaults to "")message : str message to display with the spinner
(defaults to "")format_string : str string to format the spinner with.
must have "\r" prepended to clear the line. allowed keys
are spinner, elapsed_time,
message, and value (defaults to
"\r{spinner} ({elapsed_time:.2f}s) {message}{value}")output_stream : TextIO stream to write the spinner to
(defaults to sys.stdout)format_string_when_updated : Union[bool,str] whether to
use a different format string when the value is updated. if
True, use the default format string with a newline
appended. if a string, use that string. this is useful if you want
update_value to print to console and be preserved. (defaults to
False)spinner_chars : Union[str, Sequence[str]] sequence of
strings, or key to look up in SPINNER_CHARS, to use as the
spinner characters (defaults to "default")spinner_complete : str string to display when the
spinner is complete (defaults to looking up spinner_chars
in SPINNER_COMPLETE or "#")update_value(value: Any) -> None update the current
value displayed by the spinnerwith SpinnerContext() as sp:
for i in range(1):
time.sleep(0.1)
spinner.update_value(f"Step {i+1}")@spinner_decorator
def long_running_function():
for i in range(1):
time.sleep(0.1)
spinner.update_value(f"Step {i+1}")
return "Function completed"docs for
muutilsv0.6.21
StatCounter class for counting and calculating
statistics on numbers
cleaner and more efficient than just using a Counter or
array
muutils.statcounterStatCounter class for counting and calculating
statistics on numbers
cleaner and more efficient than just using a Counter or
array
NumericSequence = typing.Sequence[typing.Union[float, int, ForwardRef('NumericSequence')]]def universal_flatten(
arr: Union[Sequence[Union[float, int, Sequence[Union[float, int, ForwardRef('NumericSequence')]]]], float, int],
require_rectangular: bool = True
) -> Sequence[Union[float, int, ForwardRef('NumericSequence')]]flattens any iterable
class StatCounter(collections.Counter):Counter, but with some stat calculation methods which
assume the keys are numerical
works best when the keys are ints
def validate(self) -> boolvalidate the counter as being all floats or ints
def min(self)minimum value
def max(self)maximum value
def total(self)Sum of the counts
keys_sorted: listreturn the keys
def percentile(self, p: float)return the value at the given percentile
this could be log time if we did binary search, but that would be a lot of added complexity
def median(self) -> floatdef mean(self) -> floatreturn the mean of the values
def mode(self) -> floatdef std(self) -> floatreturn the standard deviation of the values
def summary(
self,
typecast: Callable = <function StatCounter.<lambda>>,
*,
extra_percentiles: Optional[list[float]] = None
) -> dict[str, typing.Union[float, int]]return a summary of the stats, without the raw data. human readable and small
def serialize(
self,
typecast: Callable = <function StatCounter.<lambda>>,
*,
extra_percentiles: Optional[list[float]] = None
) -> dictreturn a json-serializable version of the counter
includes both the output of summary and the raw
data:
{
"StatCounter": { <keys, values from raw data> },
"summary": self.summary(typecast, extra_percentiles=extra_percentiles),
}
### `def load` { #StatCounter.load }
```python
(cls, data: dict) -> muutils.statcounter.StatCounterload from a the output of
<a href="#StatCounter.serialize">StatCounter.serialize</a>
def from_list_arrays(
cls,
arr,
map_func: Callable = <class 'float'>
) -> muutils.statcounter.StatCountercalls map_func on each element of
universal_flatten(arr)
docs for
muutilsv0.6.21
utilities for getting information about the system, see
SysInfo class
muutils.sysinfoutilities for getting information about the system, see
SysInfo class
class SysInfo:getters for various information about the system
def python() -> dictdetails about python version
def pip() -> dictinstalled packages info
def pytorch() -> dictpytorch and cuda information
def platform() -> dictdef git_info(with_log: bool = False) -> dictdef get_all(
cls,
include: Optional[tuple[str, ...]] = None,
exclude: tuple[str, ...] = ()
) -> dictdocs for
muutilsv0.6.21
utilities for working with tensors and arrays.
notably:
TYPE_TO_JAX_DTYPE : a mapping from python, numpy, and
torch types to jaxtyping typesDTYPE_MAP mapping string representations of types to
their typeTORCH_DTYPE_MAP mapping string representations of types
to torch typescompare_state_dicts for comparing two state dicts and
giving a detailed error message on whether if was keys, shapes, or
values that didn’t matchTYPE_TO_JAX_DTYPEjaxtype_factoryATensorNDArraynumpy_to_torch_dtypeDTYPE_LISTDTYPE_MAPTORCH_DTYPE_MAPTORCH_OPTIMIZERS_MAPpad_tensorlpad_tensorrpad_tensorpad_arraylpad_arrayrpad_arrayget_dict_shapesstring_dict_shapesStateDictCompareErrorStateDictKeysErrorStateDictShapeErrorStateDictValueErrorcompare_state_dictsmuutils.tensor_utilsutilities for working with tensors and arrays.
notably:
TYPE_TO_JAX_DTYPE : a mapping from python, numpy, and
torch types to jaxtyping typesDTYPE_MAP mapping string representations of types to
their typeTORCH_DTYPE_MAP mapping string representations of types
to torch typescompare_state_dicts for comparing two state dicts and
giving a detailed error message on whether if was keys, shapes, or
values that didn’t matchTYPE_TO_JAX_DTYPE: dict = {<class 'float'>: <class 'jaxtyping.Float'>, <class 'int'>: <class 'jaxtyping.Int'>, <class 'jaxtyping.Float'>: <class 'jaxtyping.Float'>, <class 'jaxtyping.Int'>: <class 'jaxtyping.Int'>, <class 'bool'>: <class 'jaxtyping.Bool'>, <class 'jaxtyping.Bool'>: <class 'jaxtyping.Bool'>, <class 'numpy.bool_'>: <class 'jaxtyping.Bool'>, torch.bool: <class 'jaxtyping.Bool'>, <class 'numpy.float64'>: <class 'jaxtyping.Float'>, <class 'numpy.float16'>: <class 'jaxtyping.Float'>, <class 'numpy.float32'>: <class 'jaxtyping.Float'>, <class 'numpy.int32'>: <class 'jaxtyping.Int'>, <class 'numpy.int8'>: <class 'jaxtyping.Int'>, <class 'numpy.int16'>: <class 'jaxtyping.Int'>, <class 'numpy.int64'>: <class 'jaxtyping.Int'>, <class 'numpy.uint8'>: <class 'jaxtyping.Int'>, torch.float32: <class 'jaxtyping.Float'>, torch.float16: <class 'jaxtyping.Float'>, torch.float64: <class 'jaxtyping.Float'>, torch.bfloat16: <class 'jaxtyping.Float'>, torch.int32: <class 'jaxtyping.Int'>, torch.int8: <class 'jaxtyping.Int'>, torch.int16: <class 'jaxtyping.Int'>, torch.int64: <class 'jaxtyping.Int'>}dict mapping python, numpy, and torch types to jaxtyping
types
def jaxtype_factory(
name: str,
array_type: type,
default_jax_dtype=<class 'jaxtyping.Float'>,
legacy_mode: muutils.errormode.ErrorMode = ErrorMode.Warn
) -> typeusage:
ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
x: ATensor["dim1 dim2", np.float32]
ATensor = <class 'muutils.tensor_utils.jaxtype_factory.<locals>._BaseArray'>
NDArray = <class 'muutils.tensor_utils.jaxtype_factory.<locals>._BaseArray'>
def numpy_to_torch_dtype(dtype: Union[numpy.dtype, torch.dtype]) -> torch.dtypeconvert numpy dtype to torch dtype
DTYPE_LIST: list = [<class 'bool'>, <class 'int'>, <class 'float'>, torch.float32, torch.float32, torch.float64, torch.float16, torch.float64, torch.bfloat16, torch.complex64, torch.complex128, torch.int32, torch.int8, torch.int16, torch.int32, torch.int64, torch.int64, torch.int16, torch.uint8, torch.bool, <class 'numpy.float64'>, <class 'numpy.float16'>, <class 'numpy.float32'>, <class 'numpy.float64'>, <class 'numpy.float16'>, <class 'numpy.float32'>, <class 'numpy.float64'>, <class 'numpy.complex64'>, <class 'numpy.complex128'>, <class 'numpy.int8'>, <class 'numpy.int16'>, <class 'numpy.int32'>, <class 'numpy.int64'>, <class 'numpy.int32'>, <class 'numpy.int64'>, <class 'numpy.int16'>, <class 'numpy.uint8'>, <class 'numpy.bool_'>]list of all the python, numpy, and torch numerical types I could think of
DTYPE_MAP: dict = {"<class 'bool'>": <class 'bool'>, "<class 'int'>": <class 'int'>, "<class 'float'>": <class 'float'>, 'torch.float32': torch.float32, 'torch.float64': torch.float64, 'torch.float16': torch.float16, 'torch.bfloat16': torch.bfloat16, 'torch.complex64': torch.complex64, 'torch.complex128': torch.complex128, 'torch.int32': torch.int32, 'torch.int8': torch.int8, 'torch.int16': torch.int16, 'torch.int64': torch.int64, 'torch.uint8': torch.uint8, 'torch.bool': torch.bool, "<class 'numpy.float64'>": <class 'numpy.float64'>, "<class 'numpy.float16'>": <class 'numpy.float16'>, "<class 'numpy.float32'>": <class 'numpy.float32'>, "<class 'numpy.complex64'>": <class 'numpy.complex64'>, "<class 'numpy.complex128'>": <class 'numpy.complex128'>, "<class 'numpy.int8'>": <class 'numpy.int8'>, "<class 'numpy.int16'>": <class 'numpy.int16'>, "<class 'numpy.int32'>": <class 'numpy.int32'>, "<class 'numpy.int64'>": <class 'numpy.int64'>, "<class 'numpy.uint8'>": <class 'numpy.uint8'>, "<class 'numpy.bool_'>": <class 'numpy.bool_'>, 'float64': <class 'numpy.float64'>, 'float16': <class 'numpy.float16'>, 'float32': <class 'numpy.float32'>, 'complex64': <class 'numpy.complex64'>, 'complex128': <class 'numpy.complex128'>, 'int8': <class 'numpy.int8'>, 'int16': <class 'numpy.int16'>, 'int32': <class 'numpy.int32'>, 'int64': <class 'numpy.int64'>, 'uint8': <class 'numpy.uint8'>, 'bool_': <class 'numpy.bool_'>, 'bool': <class 'numpy.bool_'>}mapping from string representations of types to their type
TORCH_DTYPE_MAP: dict = {"<class 'bool'>": torch.bool, "<class 'int'>": torch.int32, "<class 'float'>": torch.float64, 'torch.float32': torch.float32, 'torch.float64': torch.float64, 'torch.float16': torch.float16, 'torch.bfloat16': torch.bfloat16, 'torch.complex64': torch.complex64, 'torch.complex128': torch.complex128, 'torch.int32': torch.int32, 'torch.int8': torch.int8, 'torch.int16': torch.int16, 'torch.int64': torch.int64, 'torch.uint8': torch.uint8, 'torch.bool': torch.bool, "<class 'numpy.float64'>": torch.float64, "<class 'numpy.float16'>": torch.float16, "<class 'numpy.float32'>": torch.float32, "<class 'numpy.complex64'>": torch.complex64, "<class 'numpy.complex128'>": torch.complex128, "<class 'numpy.int8'>": torch.int8, "<class 'numpy.int16'>": torch.int16, "<class 'numpy.int32'>": torch.int32, "<class 'numpy.int64'>": torch.int64, "<class 'numpy.uint8'>": torch.uint8, "<class 'numpy.bool_'>": torch.bool, 'float64': torch.float64, 'float16': torch.float16, 'float32': torch.float32, 'complex64': torch.complex64, 'complex128': torch.complex128, 'int8': torch.int8, 'int16': torch.int16, 'int32': torch.int32, 'int64': torch.int64, 'uint8': torch.uint8, 'bool_': torch.bool, 'bool': torch.bool}mapping from string representations of types to specifically torch types
TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.optimizer.Optimizer]] = {'Adagrad': <class 'torch.optim.adagrad.Adagrad'>, 'Adam': <class 'torch.optim.adam.Adam'>, 'AdamW': <class 'torch.optim.adamw.AdamW'>, 'SparseAdam': <class 'torch.optim.sparse_adam.SparseAdam'>, 'Adamax': <class 'torch.optim.adamax.Adamax'>, 'ASGD': <class 'torch.optim.asgd.ASGD'>, 'LBFGS': <class 'torch.optim.lbfgs.LBFGS'>, 'NAdam': <class 'torch.optim.nadam.NAdam'>, 'RAdam': <class 'torch.optim.radam.RAdam'>, 'RMSprop': <class 'torch.optim.rmsprop.RMSprop'>, 'Rprop': <class 'torch.optim.rprop.Rprop'>, 'SGD': <class 'torch.optim.sgd.SGD'>}def pad_tensor(
tensor: jaxtyping.Shaped[Tensor, 'dim1'],
padded_length: int,
pad_value: float = 0.0,
rpad: bool = False
) -> jaxtyping.Shaped[Tensor, 'padded_length']pad a 1-d tensor on the left with pad_value to length
padded_length
set rpad = True to pad on the right instead
def lpad_tensor(
tensor: torch.Tensor,
padded_length: int,
pad_value: float = 0.0
) -> torch.Tensorpad a 1-d tensor on the left with pad_value to length
padded_length
def rpad_tensor(
tensor: torch.Tensor,
pad_length: int,
pad_value: float = 0.0
) -> torch.Tensorpad a 1-d tensor on the right with pad_value to length
pad_length
def pad_array(
array: jaxtyping.Shaped[ndarray, 'dim1'],
padded_length: int,
pad_value: float = 0.0,
rpad: bool = False
) -> jaxtyping.Shaped[ndarray, 'padded_length']pad a 1-d array on the left with pad_value to length
padded_length
set rpad = True to pad on the right instead
def lpad_array(
array: numpy.ndarray,
padded_length: int,
pad_value: float = 0.0
) -> numpy.ndarraypad a 1-d array on the left with pad_value to length
padded_length
def rpad_array(
array: numpy.ndarray,
pad_length: int,
pad_value: float = 0.0
) -> numpy.ndarraypad a 1-d array on the right with pad_value to length
pad_length
def get_dict_shapes(d: dict[str, torch.Tensor]) -> dict[str, tuple[int, ...]]given a state dict or cache dict, compute the shapes and put them in a nested dict
def string_dict_shapes(d: dict[str, torch.Tensor]) -> strprintable version of get_dict_shapes
class StateDictCompareError(builtins.AssertionError):raised when state dicts don’t match
class StateDictKeysError(StateDictCompareError):raised when state dict keys don’t match
class StateDictShapeError(StateDictCompareError):raised when state dict shapes don’t match
class StateDictValueError(StateDictCompareError):raised when state dict values don’t match
def compare_state_dicts(
d1: dict,
d2: dict,
rtol: float = 1e-05,
atol: float = 1e-08,
verbose: bool = True
) -> Nonecompare two dicts of tensors
d1 : dictd2 : dictrtol : float (defaults to 1e-5)atol : float (defaults to 1e-8)verbose : bool (defaults to True)StateDictKeysError : keys don’t matchStateDictShapeError : shapes don’t match (but keys
do)StateDictValueError : values don’t match (but keys and
shapes do)docs for
muutilsv0.6.21
timeit_fancy is just a fancier version of timeit with
more options
muutils.timeit_fancytimeit_fancy is just a fancier version of timeit with
more options
class FancyTimeitResult(typing.NamedTuple):return type of timeit_fancy
FancyTimeitResult(
timings: ForwardRef('StatCounter'),
return_value: ForwardRef('T'),
profile: ForwardRef('Union[pstats.Stats, None]')
)Create new instance of FancyTimeitResult(timings, return_value, profile)
timings: muutils.statcounter.StatCounterAlias for field number 0
return_value: ~TAlias for field number 1
profile: Optional[pstats.Stats]Alias for field number 2
def timeit_fancy(
cmd: Callable[[], ~T],
setup: Union[str, Callable[[], Any]] = <function <lambda>>,
repeats: int = 5,
namespace: Optional[dict[str, Any]] = None,
get_return: bool = True,
do_profiling: bool = False
) -> muutils.timeit_fancy.FancyTimeitResultWrapper for timeit to get the fastest run of a callable
with more customization options.
Approximates the functionality of the %timeit magic or command line interface in a Python callable.
cmd: Callable[[], T] | str The callable to time. If a
string, it will be passed to timeit.Timer as the
stmt argument.setup: str The setup code to run before
cmd. If a string, it will be passed to
timeit.Timer as the setup argument.repeats: int The number of times to run
cmd to get a reliable measurement.namespace: dict[str, Any] Passed to
timeit.Timer constructor. If cmd or
setup use local or global variables, they must be passed
here. See timeit documentation for details.get_return: bool Whether to pass the value returned
from cmd. If True, the return value will be appended in a
tuple with execution time. This is for speed and convenience so that
cmd doesn’t need to be run again in the calling scope if
the return values are needed. (default: False)do_profiling: bool Whether to return a
pstats.Stats object in addition to the time and return
value. (default: False)FancyTimeitResult, which is a NamedTuple with the
following fields:
time: float The time in seconds it took to run
cmd the minimum number of times to get a reliable
measurement.return_value: T|None The return value of
cmd if get_return is True,
otherwise None.profile: pstats.Stats|None A pstats.Stats
object if do_profiling is True, otherwise
None.docs for
muutilsv0.6.21
experimental utility for validating types in python, see
validate_type
GenericAliasTypesIncorrectTypeExceptionTypeHintNotImplementedErrorInvalidGenericAliasErrorvalidate_typeget_fn_allowed_kwargsmuutils.validate_typeexperimental utility for validating types in python, see
validate_type
GenericAliasTypes: tuple = (<class 'types.GenericAlias'>, <class 'typing._GenericAlias'>, <class 'typing._UnionGenericAlias'>, <class 'typing._BaseGenericAlias'>)class IncorrectTypeException(builtins.TypeError):Inappropriate argument type.
class TypeHintNotImplementedError(builtins.NotImplementedError):Method or function hasn’t been implemented yet.
class InvalidGenericAliasError(builtins.TypeError):Inappropriate argument type.
def validate_type(value: Any, expected_type: Any, do_except: bool = False) -> boolValidate that a value is of the
expected_type
value: the value to check the type ofexpected_type: the type to check against. Not all types
are supporteddo_except: if True, raise an exception if
the type is incorrect (instead of returning False)
(default: False)bool: True if the value is of the expected
type, False otherwise.IncorrectTypeException(TypeError): if the type is
incorrect and do_except is TrueTypeHintNotImplementedError(NotImplementedError): if
the type hint is not implementedInvalidGenericAliasError(TypeError): if the generic
alias is invaliduse typeguard for a more robust solution:
https://github.com/agronholm/typeguard
def get_fn_allowed_kwargs(fn: Callable) -> Set[str]Get the allowed kwargs for a function, raising an exception if the signature cannot be determined.