torchtree.core.utils#

Attributes#

Exceptions#

JSONParseError

Common base class for all non-exit exceptions.

AlternativeAttributeError

Custom exception for debugging conflicts between @property and

Classes#

TensorEncoder

Extensible JSON <http://json.org> encoder for Python data structures.

TensorDecoder

Simple JSON <http://json.org> decoder

SignalHandler

Functions#

as_tensor(dct[, dtype])

tensor_rand(distribution, shape[, dtype, device, ...])

Create a tensor with the given dtype and shape and initialize it using a

get_class(→ type)

process_objects(data, dic[, force_list, key])

process_object_with_key(key, data, dic[, default])

process_object(data, dic)

validate(data, rules)

remove_comments(obj)

Remove comments in dictionary representation of objects.

replace_wildcard_with_str(obj, wildcard, value)

replace_star_with_str(obj, value)

expand_plates(obj[, parent, idx])

update_parameters(→ None)

Recursively replace tensor in json_object with tensors present in

print_graph(→ None)

Print computation graph.

string_to_list_index(→ Union[int, slice])

package_contents(package_name)

register_class(_cls[, name])

Module Contents#

torchtree.core.utils.REGISTERED_CLASSES#
class torchtree.core.utils.TensorEncoder(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, sort_keys=False, indent=None, separators=None, default=None)[source]#

Bases: json.JSONEncoder

Extensible JSON <http://json.org> encoder for Python data structures.

Supports the following objects and types by default:

Python

JSON

dict

object

list, tuple

array

str

string

int, float

number

True

true

False

false

None

null

To extend this to recognize other objects, subclass and implement a .default() method with another method that returns a serializable object for o if possible, otherwise it should call the superclass implementation (to raise TypeError).

default(obj)[source]#

Implement this method in a subclass such that it returns a serializable object for o, or calls the base implementation (to raise a TypeError).

For example, to support arbitrary iterators, you could implement default like this:

def default(self, o):
    try:
        iterable = iter(o)
    except TypeError:
        pass
    else:
        return list(iterable)
    # Let the base class default method raise the TypeError
    return JSONEncoder.default(self, o)
class torchtree.core.utils.TensorDecoder(*args, **kwargs)[source]#

Bases: json.JSONDecoder

Simple JSON <http://json.org> decoder

Performs the following translations in decoding by default:

JSON

Python

object

dict

array

list

string

str

number (int)

int

number (real)

float

true

True

false

False

null

None

It also understands NaN, Infinity, and -Infinity as their corresponding float values, which is outside the JSON spec.

object_hook(dic)[source]#
torchtree.core.utils.as_tensor(dct, dtype=torch.float64)[source]#
torchtree.core.utils.tensor_rand(distribution, shape, dtype=None, device=None, requires_grad=False)[source]#

Create a tensor with the given dtype and shape and initialize it using a distribution.

Continuous distributions: normal, log_normal, uniform. Discrete distributions: random, bernoulli

Parameters:
  • distribution (str) – distribution as a string (e.g. ‘normal(1.0,2.0)’, ‘normal’, ‘normal()’).

  • shape (Sequence[int]) – shape of the tensor

  • dtype (torch.dtype) – dtype of the tensor

  • device (torch.device) – device of the tensor

Returns:

tensor

Return type:

torch.Tensor

Example:

>>> _ = torch.manual_seed(0)
>>> t1 = tensor_rand('normal(1.0, 2.0)', (1,2), dtype=torch.float64)
>>> t1
tensor([[4.0820, 0.4131]], dtype=torch.float64)
>>> _ = torch.manual_seed(0)
>>> t2 = tensor_rand('normal(0.0, 1.0)', (1,2), dtype=torch.float64)
>>> _ = torch.manual_seed(0)
>>> t3 = tensor_rand('normal()', (1,2), dtype=torch.float64)
>>> t2 == t3
tensor([[True, True]])
torchtree.core.utils.get_class(full_name: str) type[source]#
exception torchtree.core.utils.JSONParseError[source]#

Bases: Exception

Common base class for all non-exit exceptions.

torchtree.core.utils.process_objects(data, dic, force_list=False, key=None)[source]#
torchtree.core.utils.process_object_with_key(key, data, dic, default=None)[source]#
torchtree.core.utils.process_object(data, dic)[source]#
class torchtree.core.utils.SignalHandler[source]#
stop = False#
exit(signum, frame)[source]#
torchtree.core.utils.validate(data, rules)[source]#
torchtree.core.utils.remove_comments(obj)[source]#

Remove comments in dictionary representation of objects.

  • A key starting with an underscore results in the key/value pair to be removed.

  • A dictionary with key equal to ignore and value set to True results in its removal.

torchtree.core.utils.replace_wildcard_with_str(obj, wildcard, value)[source]#
torchtree.core.utils.replace_star_with_str(obj, value)[source]#
torchtree.core.utils.expand_plates(obj, parent=None, idx=None)[source]#
torchtree.core.utils.update_parameters(json_object, parameters) None[source]#

Recursively replace tensor in json_object with tensors present in parameters.

Parameters:
  • json_object (dict) – json object

  • parameters (list(Parameter)) – list of Parameters

torchtree.core.utils.print_graph(g: torch.Tensor, level: int = 0) None[source]#

Print computation graph.

Parameters:
  • g (torch.Tensor) – a tensor

  • level – indentation level

exception torchtree.core.utils.AlternativeAttributeError[source]#

Bases: Exception

Custom exception for debugging conflicts between @property and __getattr__

https://stackoverflow.com/questions/36575068/attributeerrors-undesired-interaction-between-property-and-getattr

classmethod wrapper(f)[source]#

Wraps a function to reraise an AttributeError as the alternate type.

torchtree.core.utils.string_to_list_index(index_str) int | slice[source]#
torchtree.core.utils.package_contents(package_name)[source]#
torchtree.core.utils.register_class(_cls, name=None)[source]#