I want to persist all attributes of an object which is an instance of a dataclass. Then I want to load back that object from the files that I persisted.
Here it is an example that fullfills the task:
from dataclasses import dataclass import pickle @dataclass class Circle: radius: float centre: tuple def save(self, path: str): name = ".".join(("radius", "pkl")) with open("/".join((path, name)), "wb") as f: pickle.dump(self.radius, f) name = ".".join(("centre", "pkl")) with open("/".join((path, name)), "wb") as f: pickle.dump(self.centre, f) @classmethod def load(cls, path): my_model = {} name = "radius" file_name = ".".join((name, "pkl")) with open("\".join((path, file_name)), "rb") as f: my_model[name] = pickle.load(f) name = "centre" file_name = ".".join((name, "pkl")) with open("\".join((path, file_name)), "rb") as f: my_model[name] = pickle.load(f) return cls(**my_model)
>>> c = Circle(2, (0, 0)) >>> c.save(r".Circle") >>> c_loaded = Circle.load(r".Circle") >>> c_loaded == c True
As you can see I need to repeat the same code for every attribute, what is a better way to do it?
Advertisement
Answer
In the save method it use self.__dict__
. That contains all attribute names and values as a dictionary. Load is a classmethod so there is no __dict__
at that stage. However, cls.__annotations__
contains attribute names and types, still stored in a dictionary.
Here it is the end result:
from dataclasses import dataclass import pickle @dataclass class Circle: radius: float centre: tuple def save(self, path): for name, attribute in self.__dict__.items(): name = ".".join((name, "pkl")) with open("/".join((path, name)), "wb") as f: pickle.dump(attribute, f) @classmethod def load(cls, path): my_model = {} for name in cls.__annotations__: file_name = ".".join((name, "pkl")) with open("/".join((path, file_name)), "rb") as f: my_model[name] = pickle.load(f) return cls(**my_model)