Tech With Tim Logo
Go back

Overloading Methods

Overloading Methods

We often take for granted the fact that you can use operators like +, -, == on python builtin data types. However, in reality this functionality has actually been coded into the classes by python. This means that we can code this functionality into our own classes by creating some special methods.

Take for an example the following class and objects:

class Point():
    def __init__(self, x=0, y=0):
        self.x = x
        self.y = y
        self.coords = (self.x, self.y)

    def move(self, x, y):
        self.x += x
        self.y += y

p1 = Point(3, 4)
p2 = Point(3, 2)
p3 = Point(1, 3)
p4 = Point(0, 1)

If we would like to compare two points for equality we would have to do something like this:

isSame = p1.x == p2.x and p1.y == p2.y

This is far from elegant and is extremely inefficient. To solve this problem we can overload the default python method eq.

class Point():
    def __init__(self, x=0, y=0):
        self.x = x
        self.y = y
        self.coords = (self.x, self.y)

    def move(self, x, y):
        self.x += x
        self.y += y

    def __eq__(self, other):
        return self.x == other.x and self.y == other.y

p1 = Point(3, 4)
p2 = Point(3, 2)
p3 = Point(1, 3)
p4 = Point(0, 1)

# Now we can compare points using ==

isSame = p1 == p2
print(isSame)  # Prints False

There are tons of other python default methods that we can overload. Some of the most used are featured below.

class Point():
    def __init__(self, x=0, y=0):
        self.x = x
        self.y = y
        self.coords = (self.x, self.y)

    def move(self, x, y):
        self.x += x
        self.y += y

    def __add__(self, other):
        return Point(self.x + other.x, self.y + other.y)

    def __sub__(self, other):
        return Point(self.x + other.x, self.y + other.y)

    def __mul__(self, other):
        return self.x * other.x + self.y * other.y

    

p1 = Point(3, 4)
p2 = Point(3, 2)
p3 = Point(1, 3)
p4 = Point(0, 1)

p5 = p1 + p2  
p6 = p4 - p1
p7 = p2*p3

Now you may notice that when we try to print one of our point objects we get some cryptic text that looks like this

twt.png

This is because we have not defined how our point should be represented as a string. To do this we must overload the str method.

class Point():
    def __init__(self, x=0, y=0):
        self.x = x
        self.y = y
        self.coords = (self.x, self.y)

    def move(self, x, y):
        self.x += x
        self.y += y

    def __str__(self):
        return "Point(" + str(self.x) + ',' + str(self.y) + ")"

p1 = Point(3, 4)
print(p1)  # This prints Point(3, 4)

We can also overload the methods below to implement the comparison operators.

class Point():
    def __init__(self, x=0, y=0):
        self.x = x
        self.y = y
        self.coords = (self.x, self.y)

    def move(self, x, y):
        self.x += x
        self.y += y

    def length(self):
        import math
        
        return math.sqrt(self.x ** 2 + self.y**2)

    def __gt__(self, other):  # greater than
        return self.length() > other.length()

    def __ge__(self, other):  # greater than or equal to
        return self.length() >= other.length()
        
    def __lt__(self, other):  # less than
        return self.length() < other.length()

    def __le__(self, other):  # less than or equal to
        return self.length() <= other.length()


# We are going to compare points based on their lengths
    

p1 = Point(3, 4)
p2 = Point(3, 2)
p3 = Point(1, 3)
p4 = Point(0, 1)

isLess = p1 <= p2  # This is False
print(isLess)
Design & Development by Ibezio Logo