"""
file: 
    rmad.py
description: 
    implementation of reverse mode autodiff by building a computational graph
    and performing a reverse topological sort on it.
url:
    https://kyscg.github.io/2025/05/18/autodiffpython
author:
    kyscg
"""

import math

class Node:
    def __init__(self, value, name=None):
        self.value = value
        self.grad = 0.0
        self.children = []
        self.name = name
        self.op = None
        self.other = None

    def __add__(self, other):
        return add(self, other)
    __radd__ = __add__

    def __mul__(self, other):
        return mul(self, other)
    __rmul__ = __mul__

    def __pow__(self, other):
        return power(self, other)

    def __sin__(self):
        return sin(self)

    def __repr__(self):
        return f'Node(value={self.value}, grad={self.grad}, name={self.name})'


def add(node1, node2):
    if (isinstance(node2, Node)):
        out = Node(node1.value + node2.value)
        out.children.extend([node1, node2])
    else:
        out = Node(node1.value + node2)
        out.children.extend([node1])
    out.op = "add"
    return out


def mul(node1, node2):
    if (isinstance(node2, Node)):
        out = Node(node1.value * node2.value)
        out.children.extend([node1, node2])
    else:
        out = Node(node1.value * node2)
        out.children.extend([node1])
        out.other = node2
    out.op = "mul"
    return out


def power(node, exp):
    out = Node(node.value ** exp)
    out.children.extend([node])
    out.op = "power"
    out.other = exp
    return out

def sin(node):
    out = Node(math.sin(node.value))
    out.children.extend([node])
    out.op = "sin"
    return out


def backward(seed):
    seed.grad = 1.0
    visited = set()
    visitlist = []
    queue = [seed]

    while queue:
        node = queue.pop(0)
        if node not in visited:
            visited.add(node)
            visitlist.append(node)
            for child in node.children:
                queue.append(child)

    for node in visitlist:
        if node.op == "add":
            if (len(node.children) == 2):
                node.children[0].grad += node.grad * 1.0
                node.children[1].grad += node.grad * 1.0
            else:
                node.children[0].grad += node.grad * 1.0

        elif node.op == "mul":
            if (len(node.children) == 2):
                node.children[0].grad += node.grad * node.children[1].value 
                node.children[1].grad += node.grad * node.children[0].value 
            else:
                node.children[0].grad += node.grad * node.other

        elif node.op == "power":
            base = node.children[0]
            exponent = node.other
            base.grad += node.grad * exponent * (base.value ** (exponent - 1))

        elif node.op == "sin":
            node.children[0].grad += node.grad * math.cos(node.children[0].value)


x_val = 6.0
y_val = 1.0

v1 = Node(x_val, name='v1')
v2 = Node(y_val, name='v2')
v2i = 2 * v2
v2i.name = 'v2i'
v3 = v1 + v2i
v3.name = 'v3'
v4 = v1 * v2
v4.name = 'v4'
v5 = v3 ** 2
v5.name = 'v5'
v6 = sin(v4)
v6.name = 'v6'
v7 = v5 * v6
v7.name = 'v7'

backward(seed=v7)


print(f'Value of the function at ({x_val}, {y_val}): {v7.value}')
print(f'Partial derivative df/dx at ({x_val}, {y_val}): {v1.grad}')
print(f'Partial derivative df/dy at ({x_val}, {y_val}): {v2.grad}')

print('---')

print(v1)
print(v2)
print(v3)
print(v4)
print(v5)
print(v6)
print(v7)

print('---')

print(f'children of v1: {[ch.name for ch in v1.children]}')
print(f'children of v2: {[ch.name for ch in v2.children]}')
print(f'children of v3: {[ch.name for ch in v3.children]}')
print(f'children of v4: {[ch.name for ch in v4.children]}')
print(f'children of v5: {[ch.name for ch in v5.children]}')
print(f'children of v6: {[ch.name for ch in v6.children]}')
print(f'children of v7: {[ch.name for ch in v7.children]}')