"""
file: 
    fmaddual.py
description: 
    implementation of forward mode autodiff by using dual numbers.
url:
    https://kyscg.github.io/2025/05/18/autodiffpython
author:
    kyscg
"""

import math

class DN:
    def __init__(self, real, dual):
        self.real = real
        self.dual = dual

    def __add__(self, other):
        if (isinstance(other, DN)):
            real = self.real + other.real
            dual = self.dual + other.dual
            return DN(real, dual)
        else:
            real = self.real + other
            dual = self.dual
            return DN(real, dual)
    __radd__ = __add__

    def __mul__(self, other):
        if (isinstance(other, DN)):
            real = self.real * other.real
            dual = self.real * other.dual + self.dual * other.real
            return DN(real, dual)
        else:
            real = self.real * other
            dual = self.dual * other
            return DN(real, dual)
    __rmul__ = __mul__

    def __pow__(self, power):
        real = self.real ** power
        dual = power * (self.real ** (power - 1)) * self.dual
        return DN(real, dual)

def sin_dual(d):
    real = math.sin(d.real)
    dual = math.cos(d.real) * d.dual
    return DN(real, dual)

def fmad(func, x, y):
    # seed = x
    x_dual = DN(x, 1.0)
    y_dual = DN(y, 0.0)
    df_dx = func(x_dual, y_dual).dual

    # seed = y
    x_dual = DN(x, 0.0)
    y_dual = DN(y, 1.0)
    df_dy = func(x_dual, y_dual).dual

    return df_dx, df_dy


def z_dual(x_dual, y_dual):
    v1 = x_dual
    v2 = y_dual
    v3 = v1 + 2 * v2
    v4 = v1 * v2
    v5 = v3 ** 2
    v6 = sin_dual(v4)
    v7 = v5 * v6

    return v7


def z(x, y):
    v1 = x
    v2 = y
    v3 = v1 + 2 * v2
    v4 = v1 * v2
    v5 = v3 ** 2
    v6 = math.sin(v4)
    v7 = v5 * v6

    return v7


x_val = 2.0
y_val = 3.0
df_dx, df_dy = fmad(z_dual, x_val, y_val)

print(f'Value of the function at ({x_val}, {y_val}): {z(x_val, y_val)}')
print(f'Partial derivative df/dx at ({x_val}, {y_val}): {df_dx}')
print(f'Partial derivative df/dy at ({x_val}, {y_val}): {df_dy}')