Skip to content
Advertisement

Passing a method of self to decorator as an argument

I have a class A, which has a lot of old complex methods, that I have re-written.

class A:
    def __init__(self):
        self.x = 1

    def old_func(self):
        return 2

    @property
    def b(self):
        if not hasattr(self, 'b'):
            self._b = B(self)
        return _b

class B:
    def __init__(self, a: A):
         self.a = a

    def new_func(self):
        return 2

I want to gradually replace a.old_func with `a.new_func. but first, I want to make sure that the new method always works the same way as the old one. So I’ve written a decorator to check that:

def refactor_factory(new_func):
    def refactor(old_func):
        def _wrapper(*args, **kwargs):
            old_return_value = old_func(*args, **kwargs)
            new_return_value = new_func(**kwargs)

            if old_return_value != new_return_value:
                raise Exception("Mismatch")  # Add a complete log info

            return old_return_value

        return _wrapper

    return refactor

And I wanted to call it like this:

class A:
    def __init__(self):
        self.x = 1
    
    @refactor_factory(self.b.new_func)
    def old_func(self):
        return 2
    
    def new_func(self):
        return 2

The problem is that I can’t pass the new_func to the decorator. I know I have access to self in the decorator, but when passing arguments, I don’t have access to it, and therefore I can’t pass its methods. Is there a way that I can achieve this?

p.s. I know that there are different designs to achieve what I want, like the one below, I just thought the first way was cleaner.

def refactor(old_func):
    def _wrapper(*args, **kwargs):
        self = args[0]
        if isinstance(old_func, self.old_func):
            new_func = self.b.new_func
            
        old_return_value = old_func(*args, **kwargs)
        new_return_value = new_func(**kwargs)

        if old_return_value != new_return_value:
            raise Exception("Mismatch")  # Add a complete log info

        return old_return_value

    return _wrapper


class A:
    def __init__(self):
        self.x = 1

    @refactor_factory
    def old_func(self):
        return 2

    def new_func(self):
        return 2

Advertisement

Answer

I think you can just pass A.new_func. In the wrapper, self will be one of the *args, so it will be passed along correctly.

class A:
    def __init__(self):
        self.x = 1
    
    @refactor_factory(A.new_func)
    def old_func(self):
        return 2
    
    def new_func(self):
        return 2
User contributions licensed under: CC BY-SA
7 People found this is helpful
Advertisement