#!python

import os
import sys
import datetime
import yaml
import calendar
import json
import time
import requests
import math
import pwd
from pushover import Client
from monzo_api import Monzo
from monzo_db import DB

PROVIDER = 'Monzo'

class PaymentFundsCheck:
    def __init__(self, account_name):
        self.account_name = account_name

        homedir = pwd.getpwuid(os.getuid()).pw_dir
        monzo_dir = f"{homedir}/.monzo"

        if not os.path.exists(monzo_dir):
            os.mkdir(monzo_dir, 0o755)

        config_file = f"{monzo_dir}/config.yaml"

        self.monzo_config = yaml.safe_load(open(config_file).read())

        account_config_file = f"{monzo_dir}/{account_name}.yaml"

        if not os.path.exists(account_config_file):
            sys.stderr.write(f"Cannot find account config file: {account_config_file}\n")
            sys.exit(1)

        try:
            self.config = yaml.safe_load(open(account_config_file).read())
        except Exception as e:
            sys.stderr.write(f"Cannot read or parse account config file {account_config_file}: {str(e)}\n")
            sys.exit(1)

        for required in ['payments','salary_description','salary_payment_day']:
            if required not in self.config or not self.config[required]:
                sys.stderr.write(f"Missing config key: {required}\n")
                sys.exit(1)

        if ('notify_shortfall' in self.config and self.config['notify_shortfall']) or ('notify_credit' in self.config and self.config['notify_credit']):
            for required in ['pushover_key','pushover_app']:
                if required not in self.config or not self.config[required]:
                    sys.stderr.write(f"Push is enabled but push config key is missing: {required}\n")
                    sys.exit(1)

        self.db = DB(self.monzo_config['db'])

        self.seen = []
        self.exchange_rates = {}

        self.provider = self.db.one("select * from provider where name = %s", [PROVIDER])
        self.account = self.db.one("select * from account where provider_id = %s and name = %s", [self.provider['id'], self.account_name])

        if not self.account:
            sys.stderr.write(f"account not found in ~/.monzo/config.yaml: {self.account_name}\n")
            sys.exit(1)


    def main(self):
        last_salary_date = self.get_last_salary_date()
        next_salary_date = self.get_next_salary_date(last_salary_date)

        due = 0
        total = 0
        next_month = 0
        next_month_bills_pot = 0

        now = datetime.datetime.now()

        flex_payments = []

        for payment_list in self.config['payments']:
            if payment_list['type'] == 'Flex':
                flex_payments += payment_list['payments']
                continue

            if not payment_list['payments']:
                continue

            for payment in payment_list['payments']:
                if payment_list['type'] == 'Amazon Payments':
                    status, num_paid, remaining, last_amount, last_date, due_date = self.get_amazon_payment(payment, last_salary_date, next_salary_date, payment_list['payment_day'])

                    self.display(status, 'Amazon Payments', payment['name'], last_amount, last_date, due_date, num_paid, payment['months'], remaining)

                    if status == 'DUE':
                        if last_amount >= remaining:
                            due += last_amount * 100
                            total += last_amount * 100
                            due_next_month = remaining - last_amount
                        else:
                            due += remaining * 100
                            total += remaining * 100
                            due_next_month = 0

                    elif status == 'PAID':
                        total += last_amount * 100

                        if remaining == 0:
                            due_next_month = 0
                        elif last_amount >= remaining:
                            due_next_month = remaining
                        else:
                            due_next_month = last_amount

                    next_month_bills_pot += due_next_month * 100
                    continue

                last_date, last_amount = self.get_last_payment_date(payment_list['type'], payment)

                if 'fixed' in payment and payment['fixed']:
                    last_amount = payment['amount']

                if str(payment['amount'])[0] == '$':
                    rate = self.get_exchange_rate('usd')

                    dollar_value = float(str(payment['amount'])[1:])

                    last_amount = dollar_value / rate

                if last_amount is None:
                    last_amount = payment['amount']

                if 'last_amount_overrides' in self.config and payment['name'] in self.config['last_amount_overrides'] and last_salary_date in self.config['last_amount_overrides'][payment['name']]:
                    last_amount = self.config['last_amount_overrides'][payment['name']][last_salary_date]

                if 'last_date_overrides' in self.config and payment['name'] in self.config['last_date_overrides'] and last_date in self.config['last_date_overrides'][payment['name']]:
                    last_date = self.config['last_date_overrides'][payment['name']][last_date]

                could_be_due = True

                due_date = last_date

                if 'yearly_month' in payment:
                    could_be_due = self.is_yearly_payment_due_this_month(payment, last_salary_date)

                    today = datetime.datetime.now()

                    while today.month != payment['yearly_month'] or today.day != payment['yearly_day']:
                        today += datetime.timedelta(days=1)

                    due_date = datetime.date(today.year, today.month, today.day)

                else:
                    if 'renew_date' in payment:
                        due_date = payment['renew_date']
                    elif due_date:
                        if due_date.month == 12:
                            due_date = datetime.date(due_date.year+1, 1, due_date.day)
                        else:
                            due_date = datetime.date(due_date.year, due_date.month+1, due_date.day)

                # payment period not started yet
                if 'start_date' in payment:
                    if payment['start_date'] >= next_salary_date:
                        could_be_due = False
                        due_date = payment['start_date']

                # excluded months (council tax)
                if 'exclude_months' in payment:
                    if last_date is None:
                        if 'start_date' in payment:
                            next_payment_month = payment['start_date'].month + 1
                        else:
                            next_payment_month = None
                    else:
                        next_payment_month = last_date.month + 1

                    if next_payment_month:
                        if next_payment_month >12:
                            next_payment_month = 1

                        if next_payment_month in payment['exclude_months']:
                            could_be_due = False

                if could_be_due and (last_date is None or last_date < last_salary_date):
                    status = 'DUE'
                    due += round(last_amount * 100)
                    total += round(last_amount * 100)
                else:
                    if could_be_due:
                        status = 'PAID'
                        total += round(last_amount * 100)
                    else:
                        status = 'SKIPPED'

                num_paid = None
                num_total = None
                remaining = None

                if payment_list['type'] == 'Finance':
                    if 'months' not in payment:
                        raise Exception('finance entries must have key: months')
                    if 'start_date' not in payment:
                        raise Exception('finance entries must have key: start_date')

                    num_paid = self.get_num_payments_made(payment)
                    num_total = payment['months']

                    total_paid = num_paid * payment['amount']
                    remaining = (num_total - num_paid) * payment['amount']

                self.display(status, payment_list['type'], payment['name'], last_amount, last_date, due_date, num_paid, num_total, remaining)

                # determine if due next month
                if 'yearly_month' in payment:
                    date_from = next_salary_date
                    date_to = self.get_next_salary_date(next_salary_date)

                    if due_date >= date_from and due_date <= date_to:
                        next_month += round(last_amount * 100)

                        if 'exclude_yearly_from_bills' not in self.config or self.config['exclude_yearly_from_bills'] is False:
                            next_month_bills_pot += round(last_amount * 100)
                else:
                    end_of_next_salary_period = self.get_next_salary_date(next_salary_date)

                    if 'start_date' not in payment or payment['start_date'] < end_of_next_salary_period:
                        if 'renew_date' not in payment or payment['renew_date'] < end_of_next_salary_period:
                            next_month += round(last_amount * 100)

                            if 'exclude_from_bills' not in self.config or ('exclude_from_bills' in self.config and payment['name'] not in self.config['exclude_from_bills']):
                                next_month_bills_pot += round(last_amount * 100)

        flex_this_month = 0
        flex_due = []
        flex_status = 'DONE'

        total_remaining = 0

        for payment in flex_payments:
            status, date, amount, payment_num = self.flex_payment_due(payment, next_salary_date)

            if status == 'DUE':
                due += round(amount * 100)
                total += round(amount * 100)
                flex_this_month += amount
                remaining = max([0, payment['amount'] - (payment_num * amount)])
                total_remaining += remaining

                flex_due.append({
                    'status': status,
                    'date': date,
                    'amount': amount,
                    'name': payment['name'],
                    'num_paid': payment_num,
                    'num_total': payment['months'],
                    'remaining': remaining
                })

                next_month_bills_pot += round(amount * 100)
                next_month += round(amount * 100)

                if flex_status in ['DONE','SKIPPED']:
                    flex_status = status

            elif status == 'SKIPPED':
                flex_this_month += amount
                remaining = max([0, payment['amount'] - (payment_num * amount)])
                total_remaining += remaining

                flex_due.append({
                    'status': status,
                    'date': date,
                    'amount': amount,
                    'name': payment['name'],
                    'num_paid': payment_num,
                    'num_total': payment['months'],
                    'remaining': remaining
                })
                next_month_bills_pot += round(amount * 100)
                next_month += round(amount * 100)

                if flex_status in ['DONE']:
                    flex_status = status

        if flex_status != 'DONE':
            self.display(flex_status, 'Flex', 'Flex payment', flex_this_month, None, date, None, None, total_remaining)

            for flex_payment in flex_due:
                self.display(
                    flex_payment['status'],
                    'Flex',
                    '- ' + flex_payment['name'],
                    flex_payment['amount'],
                    None,
                    flex_payment['date'],
                    flex_payment['num_paid'],
                    flex_payment['num_total'],
                    flex_payment['remaining']
                )

        if 'pot' in self.config:
            pot = self.db.one("select * from pot where account_id = %s and name = %s", [self.account['id'], self.config['pot']])
        else:
            pot = self.account

        shortfall = (due - (round(pot['balance'] * 100))) / 100

        sys.stdout.write("\n")

        sys.stdout.write(" " * 25)
        sys.stdout.write(" TOTAL THIS MONTH:".ljust(31))
        print("£%.2f" % (total / 100))

        if 'exclude_yearly_from_bills' in self.config and self.config['exclude_yearly_from_bills']:
            sys.stdout.write("\n")

        sys.stdout.write(" " * 25)
        sys.stdout.write(" TOTAL NEXT MONTH:".ljust(31))
        print("£%.2f" % (next_month / 100))

        if 'exclude_yearly_from_bills' in self.config and self.config['exclude_yearly_from_bills']:
            sys.stdout.write(" " * 25)
            sys.stdout.write("Bills pot payment:".ljust(31))
            print("£%.2f" % (next_month_bills_pot / 100))

        if round(shortfall * 100) >0:
            print("      due: £%.2f" % (due / 100))
            print("  balance: £%.2f" % (pot['balance']))
            print("shortfall: £%.2f" % (shortfall))

            if 'pot' in self.config:
                deposit = False
                notify = False

                if 'pot' in self.config and 'auto_deposit' in self.config and self.config['auto_deposit']:
                    deposit = True
                elif not sys.stdout.isatty():
                    if 'notify_shortfall' in self.config and self.config['notify_shortfall']:
                        notify = True
                else:
                    if 'pot' in self.config and 'prompt_deposit' in self.config and self.config['prompt_deposit']:
                        while 1:
                            sys.stdout.write("\ndeposit shortfall? [y/N] ")
                            sys.stdout.flush()

                            i = sys.stdin.readline().rstrip().lower()

                            if i in ['y','n']:
                                break

                        deposit = (i == 'y')
                    elif 'notify_shortfall' in self.config and self.config['notify_shortfall']:
                        notify = True

                if deposit:
                    m = Monzo()

                    result = m.deposit_to_pot(self.account['account_id'], pot, shortfall)

                    if 'notify_deposit' in self.config and self.config['notify_deposit']:
                        self.notify(
                            '%s - pot topped up' % (self.account['name']),
                            "£%.2f\n£%.2f due, £%.2f available" % (
                                shortfall,
                                due / 100,
                                pot['balance']
                            )
                        )
 
                elif notify:
                    self.notify(
                        '%s - shortfall' % (self.account['name']),
                        "£%.2f\n£%.2f due, £%.2f available" % (
                            shortfall,
                            due / 100,
                            pot['balance']
                        )
                    )

        else:
            credit = (round(pot['balance'] * 100) - due) / 100

            print("    due: £%.2f" % (due / 100))
            print("balance: £%.2f" % (pot['balance']))

            if round(credit * 100) == 0:
                credit = 0
            else:
                print(" credit: £%.2f" % (credit))

                withdraw = False
                notify = False

                if 'pot' in self.config and 'auto_withdraw' in self.config and self.config['auto_withdraw']:
                    withdraw = True
                elif not sys.stdout.isatty():
                    if 'notify_credit' in self.config and self.config['notify_credit']:
                        notify = True
                else:
                    if 'pot' in self.config and 'prompt_withdraw' in self.config and self.config['prompt_withdraw']:
                        while 1:
                            sys.stdout.write("\nwithdraw credit? [y/N] ")
                            sys.stdout.flush()

                            i = sys.stdin.readline().rstrip().lower()

                            if i in ['y','n']:
                                break

                        withdraw = (i == 'y')
                    elif 'notify_credit' in self.config and self.config['notify_credit']:
                        notify = True

                if withdraw:
                    m = Monzo()

                    if not m.withdraw_credit(self.account['account_id'], pot, credit):
                        sys.stderr.write("ERROR: failed to withdraw credit\n")

                    if 'notify_withdraw' in self.config and self.config['notify_withdraw']:
                        self.notify(
                            '%s - pot credit withdrawn' % (self.account['name']),
                            "£%.2f\n£%.2f due, £%.2f available" % (
                                credit,
                                due / 100,
                                pot['balance']
                            )
                        )
 
                elif notify:
                    self.notify(
                        '%s - credit' % (self.account['name']),
                        "£%.2f\n£%.2f due, £%.2f available" % (
                            credit,
                            due / 100,
                            pot['balance']
                        )
                    )


    def display(self, status, payment_type, payment_name, amount, last_date, due_date, num_paid=None, num_total=None, remaining=None):
        if num_paid is not None:
            suffix = '%d/%d' % (
                num_paid,
                num_total
            )
        else:
            suffix = ''

        if remaining is not None:
            remaining = '£%.2f' % (remaining)
        else:
            remaining = ''

        print("%s: %s %s %s %s %s %s %s" % (
            status.rjust(7),
            payment_type.ljust(15),
            payment_name.ljust(25),
            suffix.ljust(4),
            ('£%.2f' % (amount)).ljust(8),
            remaining.ljust(8),
            last_date.strftime('%Y-%m-%d').ljust(12) if last_date else ''.ljust(12),
            due_date.strftime('%Y-%m-%d').ljust(10) if due_date else ''
        ))


    def flex_payment_due(self, payment, next_salary_date):
        date = payment['start_date']
        today = datetime.datetime.now()
        today = datetime.date(today.year, today.month, today.day)

        due = False
        total_paid = 0
        for i in range(0, payment['months']):
            amount = int(math.ceil(payment['amount'] / payment['months']))

            if total_paid + amount > payment['amount']:
                amount = payment['amount'] - total_paid

            while date.day != self.config['flex_payment_date']:
                date += datetime.timedelta(days=1)

            if date >= today:
                due = True
                payment_num = i
                break

            date += datetime.timedelta(days=1)

            total_paid += amount

        if not due:
            status = 'DONE'
        elif date <= next_salary_date:
            status = 'DUE'
        else:
            status = 'SKIPPED'

        return status, date, amount, payment_num


    def notify(self, event, message):
        pushover = Client(self.config['pushover_key'], api_token=self.config['pushover_app'])
        pushover.send_message(message, title=event)


    def get_last_salary_date(self):
        if 'salary_account' in self.config and self.config['salary_account'] != self.account_name:
            account = self.db.one("select * from account where provider_id = %s and name = %s", [self.provider['id'], self.config['salary_account']])
        else:
            account = self.account

        now = datetime.datetime.now()
        _from = now.strftime('%Y-%m-01')
        _to = now.strftime('%Y-%m-30')

        sql = "select * from transaction where declined = 0 and account_id = %s and ("
        params = [account['id']]

        for i in range(0, len(self.config['salary_description'])):
            if i >0:
                sql += ' or '
            sql += 'description like %s'
            params.append('%' + self.config['salary_description'][i] + '%')

        sql += ')'

        last_salary_transaction = self.db.one(sql, params)

        if not last_salary_transaction:
            sys.stderr.write("failed to find last salary transaction.\n")
            sys.stderr.write(f"SQL: {sql}\n")
            sys.stderr.write(f"params: {json.dumps(params,indent=4)}\n")
            sys.exit(1)

        last_salary_date = last_salary_transaction['date']

        while last_salary_date.day <self.config['salary_payment_day']:
            last_salary_date = datetime.date(last_salary_date.year, last_salary_date.month, last_salary_date.day+1)

        return last_salary_date


    def get_next_salary_date(self, last_salary_date):
        next_salary_date = datetime.date(last_salary_date.year, last_salary_date.month, last_salary_date.day+1)

        while next_salary_date.day != 15:
            try:
                next_salary_date = datetime.date(next_salary_date.year, next_salary_date.month, next_salary_date.day+1)
            except:
                next_salary_date = datetime.date(next_salary_date.year, next_salary_date.month+1, 1)

        while next_salary_date.weekday() in [5,6]:
            next_salary_date = datetime.date(next_salary_date.year, next_salary_date.month, next_salary_date.day-1)

        return next_salary_date


    def get_last_payment_date(self, payment_type, payment):
        if payment_type != 'Standing Order' and 'fixed' in payment and payment['fixed']:
            sql = "select transaction.* from transaction where declined = 0 and money_out = %s and ( description like %s"
            params = [
                payment['amount'],
                '%' + str(payment['desc']) + '%'
            ]

            if 'old_desc' in payment:
                for old_desc in payment['old_desc']:
                    sql += " or description like %s"
                    params.append('%' + str(old_desc) + '%')

            sql += " ) order by transaction.id desc"

            rows = self.db.query(sql, params)
        else:
            sql = "select transaction.* from transaction where declined = 0 and money_out >0 and ( description like %s"
            params = ['%' + str(payment['desc']) + '%']

            if 'old_desc' in payment:
                for old_desc in payment['old_desc']:
                    sql += " or description like %s"
                    params.append('%' + str(old_desc) + '%')

            sql += " ) order by transaction.id desc"

            rows = self.db.query(sql, params)

        for row in rows:
            if row['id'] not in self.seen:
                self.seen.append(row['id'])

                return row['date'], row['money_out']

        return None, None


    def days_in_month(self, year, month):
        return calendar.monthrange(year, month)[1]


    def is_yearly_payment_due_this_month(self, payment, last_salary_date):
        date_from = last_salary_date.strftime('%Y-%m-%d')
        date = last_salary_date

        while date.day <= 15:
            date += datetime.timedelta(days=1)

        while date.day != 15:
            date += datetime.timedelta(days=1)

        date_to = date.strftime('%Y-%m-%d')

        due_date = str(last_salary_date.year) + '-' + (str(payment['yearly_month']).rjust(2,'0')) + '-' + (str(payment['yearly_day']).rjust(2,'0'))

        return due_date >= date_from and due_date <= date_to


    def get_exchange_rate(self, currency):
        if currency in self.exchange_rates:
            return self.exchange_rates[currency]

        if os.path.exists(self.monzo_config['exchange_rate_files'][currency]):
            mtime = os.stat(self.monzo_config['exchange_rate_files'][currency]).st_mtime

            dt = datetime.datetime.fromtimestamp(mtime)

            if dt.strftime('%Y%m%d') == datetime.datetime.now().strftime('%Y%m%d'):
                self.exchange_rates[currency] = float(open(self.monzo_config['exchange_rate_files'][currency]).read())

                return self.exchange_rates[currency]

        resp = requests.get('https://raw.githubusercontent.com/TheArmagan/currency/main/api/GBP-to-%s.txt' % (currency.upper()))
        rate = float(resp.text)

        with open(self.monzo_config['exchange_rate_files'][currency], 'w') as f:
            f.write(resp.text)

        self.exchange_rates[currency] = rate

        return rate


    def handle_amazon_payments(self, payment_day, payment, last_salary_date, next_salary_date):
        amazon_due = 0
        amazon_total = 0
        amazon_next_month_bills_pot = 0
        first = True

        status, num_paid, remaining, last_amount, last_date, due_date = self.get_amazon_payment(payment, last_salary_date, next_salary_date, payment_day)

        self.display(status, 'Amazon Payments', payment['name'], last_amount, last_date, due_date, num_paid, payment['months'], remaining)

        return amazon_due, amazon_total, amazon_next_month_bills_pot


    def get_amazon_payment(self, payment_spec, last_salary_date, next_salary_date, payment_day):
        date_from = payment_spec['start_date']

        payments = []

        initial_amount = int(math.ceil(int(payment_spec['amount'] * 100) / payment_spec['months'])) / 100
        final_amount = payment_spec['amount'] - (initial_amount * (payment_spec['months']-1))

        sql = "select * from transaction join account on transaction.account_id = account.id where provider_id = %s and declined = 0 and `date` >= %s and description like %s and (money_out = %s or money_out = %s) order by `date` asc"
        params = [
            self.provider['id'],
            payment_spec['start_date'],
            '%' + payment_spec['desc'] + '%',
            initial_amount,
            final_amount
        ]

        payments = self.db.query(sql, params)

        num_paid = len(payments)
        remaining = payment_spec['amount']

        for payment in payments:
            remaining -= payment['money_out']

        if len(payments) >0:
            last_amount = payments[-1]['money_out']
            last_date = payments[-1]['date']
        else:
            last_amount = initial_amount
            last_date = None

        if num_paid < payment_spec['months']:
            if last_date:
                due_date = last_date + datetime.timedelta(months=1)
            else:
                due_date = datetime.datetime.now() + datetime.timedelta(days=1)

                while due_date.day != payment_day:
                    due_date += datetime.timedelta(days=1)
        else:
            due_date = None

        # if paid since last salary then status is PAID
        if len(payments) >0 and payments[-1]['date'] >= last_salary_date:
            status = 'PAID'

        else:
            # determine if the payment date falls before the next salary date
            # if it does then this payment is DUE
            now = datetime.datetime.now() + datetime.timedelta(days=1)
            due = False

            while 1:
                if datetime.date(now.year, now.month, now.day) == next_salary_date:
                    break

                if now.day == payment_day:
                    due = True
                    break

                now += datetime.timedelta(days=1)

            if due:
                status = 'DUE'
            else:
                status = 'SKIPPED'

        return status, num_paid, remaining, last_amount, last_date, due_date


    def get_next_amazon_payment(self, i, payment_spec, payment, amount):
        date_from = self.add_month(payment['date']) - datetime.timedelta(days=7)
        date_to = self.add_month(payment['date']) + datetime.timedelta(days=7)

        payment = self.db.one("select * from transaction join account on transaction.account_id = account.id where provider_id = %s and declined = 0 and `date` >= %s and `date` <= %s and description like %s and money_out = %s", [
            self.provider['id'],
            date_from,
            date_to,
            '%' + payment_spec['desc'] + '%',
            amount
        ])

        return payment


    def add_month(self, date):
        if date.month == 12:
            day = date.day
            month = 1
            year = date.year + 1
        else:
            day = date.day
            month = date.month + 1
            year = date.year

        while 1:
            try:
                new_date = datetime.date(year, month, day)
                break
            except ValueError:
                day -= 1

        return new_date


    def get_num_payments_made(self, payment):
        sql = "select transaction.* from transaction where date >= %s and declined = 0 and money_out = %s and ( description like %s"
        params = [
            payment['start_date'],
            payment['amount'],
            '%' + str(payment['desc']) + '%'
        ]

        if 'old_desc' in payment:
            for old_desc in payment['old_desc']:
                sql += " or description like %s"
                params.append('%' + str(old_desc) + '%')

        sql += ')'

        return len(self.db.query(sql, params))


if len(sys.argv) <2:
    print("usage: %s <account_name>" % (sys.argv[0].split('/')[-1]))
    sys.exit(1)

p = PaymentFundsCheck(sys.argv[1])
p.main()
