I have a class A, which has a lot of old complex methods, that I have re-written.
class A: def __init__(self): self.x = 1 def old_func(self): return 2 @property def b(self): if not hasattr(self, 'b'): self._b = B(self) return _b class B: def __init__(self, a: A): self.a = a def new_func(self): return 2
I want to gradually replace a.old_func
with `a.new_func. but first, I want to make sure that the new method always works the same way as the old one. So I’ve written a decorator to check that:
def refactor_factory(new_func): def refactor(old_func): def _wrapper(*args, **kwargs): old_return_value = old_func(*args, **kwargs) new_return_value = new_func(**kwargs) if old_return_value != new_return_value: raise Exception("Mismatch") # Add a complete log info return old_return_value return _wrapper return refactor
And I wanted to call it like this:
class A: def __init__(self): self.x = 1 @refactor_factory(self.b.new_func) def old_func(self): return 2 def new_func(self): return 2
The problem is that I can’t pass the new_func to the decorator. I know I have access to self
in the decorator, but when passing arguments, I don’t have access to it, and therefore I can’t pass its methods. Is there a way that I can achieve this?
p.s. I know that there are different designs to achieve what I want, like the one below, I just thought the first way was cleaner.
def refactor(old_func): def _wrapper(*args, **kwargs): self = args[0] if isinstance(old_func, self.old_func): new_func = self.b.new_func old_return_value = old_func(*args, **kwargs) new_return_value = new_func(**kwargs) if old_return_value != new_return_value: raise Exception("Mismatch") # Add a complete log info return old_return_value return _wrapper class A: def __init__(self): self.x = 1 @refactor_factory def old_func(self): return 2 def new_func(self): return 2
Advertisement
Answer
I think you can just pass A.new_func
. In the wrapper, self
will be one of the *args
, so it will be passed along correctly.
class A: def __init__(self): self.x = 1 @refactor_factory(A.new_func) def old_func(self): return 2 def new_func(self): return 2