""" file: rmad.py description: implementation of reverse mode autodiff by building a computational graph and performing a reverse topological sort on it. url: https://kyscg.github.io/2025/05/18/autodiffpython author: kyscg """ import math class Node: def __init__(self, value, name=None): self.value = value self.grad = 0.0 self.children = [] self.name = name self.op = None self.other = None def __add__(self, other): return add(self, other) __radd__ = __add__ def __mul__(self, other): return mul(self, other) __rmul__ = __mul__ def __pow__(self, other): return power(self, other) def __sin__(self): return sin(self) def __repr__(self): return f'Node(value={self.value}, grad={self.grad}, name={self.name})' def add(node1, node2): if (isinstance(node2, Node)): out = Node(node1.value + node2.value) out.children.extend([node1, node2]) else: out = Node(node1.value + node2) out.children.extend([node1]) out.op = "add" return out def mul(node1, node2): if (isinstance(node2, Node)): out = Node(node1.value * node2.value) out.children.extend([node1, node2]) else: out = Node(node1.value * node2) out.children.extend([node1]) out.other = node2 out.op = "mul" return out def power(node, exp): out = Node(node.value ** exp) out.children.extend([node]) out.op = "power" out.other = exp return out def sin(node): out = Node(math.sin(node.value)) out.children.extend([node]) out.op = "sin" return out def backward(seed): seed.grad = 1.0 visited = set() visitlist = [] queue = [seed] while queue: node = queue.pop(0) if node not in visited: visited.add(node) visitlist.append(node) for child in node.children: queue.append(child) for node in visitlist: if node.op == "add": if (len(node.children) == 2): node.children[0].grad += node.grad * 1.0 node.children[1].grad += node.grad * 1.0 else: node.children[0].grad += node.grad * 1.0 elif node.op == "mul": if (len(node.children) == 2): node.children[0].grad += node.grad * node.children[1].value node.children[1].grad += node.grad * node.children[0].value else: node.children[0].grad += node.grad * node.other elif node.op == "power": base = node.children[0] exponent = node.other base.grad += node.grad * exponent * (base.value ** (exponent - 1)) elif node.op == "sin": node.children[0].grad += node.grad * math.cos(node.children[0].value) x_val = 6.0 y_val = 1.0 v1 = Node(x_val, name='v1') v2 = Node(y_val, name='v2') v2i = 2 * v2 v2i.name = 'v2i' v3 = v1 + v2i v3.name = 'v3' v4 = v1 * v2 v4.name = 'v4' v5 = v3 ** 2 v5.name = 'v5' v6 = sin(v4) v6.name = 'v6' v7 = v5 * v6 v7.name = 'v7' backward(seed=v7) print(f'Value of the function at ({x_val}, {y_val}): {v7.value}') print(f'Partial derivative df/dx at ({x_val}, {y_val}): {v1.grad}') print(f'Partial derivative df/dy at ({x_val}, {y_val}): {v2.grad}') print('---') print(v1) print(v2) print(v3) print(v4) print(v5) print(v6) print(v7) print('---') print(f'children of v1: {[ch.name for ch in v1.children]}') print(f'children of v2: {[ch.name for ch in v2.children]}') print(f'children of v3: {[ch.name for ch in v3.children]}') print(f'children of v4: {[ch.name for ch in v4.children]}') print(f'children of v5: {[ch.name for ch in v5.children]}') print(f'children of v6: {[ch.name for ch in v6.children]}') print(f'children of v7: {[ch.name for ch in v7.children]}')