#!python

import os
import sys
import time
import json
import re
import datetime
import yaml
import pwd
from monzo_db import DB

FIELD_MAP = {
    'date': 'created_at'
}

class Monzo:

    def __init__(self, args):
        homedir = pwd.getpwuid(os.getuid()).pw_dir
        monzo_dir = f"{homedir}/.monzo"

        if not os.path.exists(monzo_dir):
            os.mkdir(monzo_dir, 0o755)

        config_file = f"{monzo_dir}/config.yaml"

        self.config = yaml.safe_load(open(config_file).read())

        self.db = DB(self.config['db'])

        include_pots = False
        show_declined = False

        parsed_args = []

        for arg in args[1:]:
            if arg == '-h':
                self.help()
                continue

            if arg == '-p':
                include_pots = True
                continue
            if arg == '-d':
                show_declined = True
                continue

            parsed_args.append(arg)

        if len(parsed_args) >0:
            q = ' '.join(parsed_args)

            sql = "select transaction.*, account.name as account, pot.name as pot, transaction_metadata.value as mastercard_lifecycle_id from transaction join account on transaction.account_id = account.id left join transaction_metadata on transaction_metadata.transaction_id = transaction.id and transaction_metadata.`key` = %s left join pot on transaction.pot_id = pot.id where declined = %s and ((money_in >0 or money_out >0) and (description like %s"

            params = [
                'metadata_mastercard_lifecycle_id',
                1 if show_declined else 0,
                '%' + q + '%',
            ]

            if q.replace('.','').isdigit():
                if q.isdigit():
                    money_from = '%s.00' % (q)
                    money_to = '%s.99' % (q)

                    sql += " or (money_in >0 and money_in >= %s and money_in <= %s) or (money_out >0 and money_out >= %s and money_out <= %s)))"

                    params.append(money_from)
                    params.append(money_to)
                    params.append(money_from)
                    params.append(money_to)
                else:
                    if len(q.split('.')[1]) == 0:
                        q += '0'

                    sql += " or (money_in >0 and money_in = %s) or (money_out >0 and money_out = %s)))"

                    params.append(q)
                    params.append(q)
            else:
                sql += "))"

            if include_pots is False:
                sql += " and description not like %s"
                params.append('pot_0000%')

            sql += " group by id order by date, created_at"

        else:
            sql = "select transaction.*, account.name as account, pot.name as pot, transaction_metadata.value as mastercard_lifecycle_id from transaction join account on transaction.account_id = account.id left join pot on transaction.pot_id = pot.id left join transaction_metadata on transaction_metadata.transaction_id = transaction.id and transaction_metadata.`key` = %s where declined = %s and transaction.transaction_id is not null and (money_in >0 or money_out >0)"
            params = ['metadata_mastercard_lifecycle_id', 1 if show_declined else 0]

            if include_pots is False:
                sql += " and description not like %s"
                params.append('pot_0000%')

            sql += " order by date, created_at"

        rows = self.db.query(sql, params)

        rows = self.process_pending_refunds(rows)

        display_columns = ['account','pot','date','money_in','money_out','description']

        if show_declined:
            display_columns.append('decline_reason')

        self.display(rows, display_columns)


    def help(self):
        print("usage:\n")
        print("%s [-p] [-d] [search string]\n" % (sys.argv[0].split('/')[-1]))
        print("-p    :    include pot transactions")
        print("-d    :    show declined transactions")
        sys.exit()


    def display(self, data, columns):
        widths = {}

        for key in columns:
            widths[key] = len(key)

        last_date = None

        today = datetime.datetime.now()

        for i in range(0, len(data)):
            for key in columns:
                if key in FIELD_MAP:
                    if data[i][FIELD_MAP[key]]:
                        data[i][key] = data[i][FIELD_MAP[key]]

                if key == 'date':
                    data[i][key] = self.adjust_timestamp(data[i][key])

                if key == 'money_out' and data[i]['pending'] and type(data[i]['money_out']) != str:
                    data[i]['money_out'] = '*' + str(data[i]['money_out']) + '*'

                if key == 'date':
                    pattern = '%d/%m' if data[i][key].year == today.year else '%d/%m/%y'

                    this_date = data[i][key].strftime(pattern)

                    if this_date == last_date:
                        data[i][key] = data[i][key].strftime('%H:%M')

                    last_date = this_date

                if data[i][key] is None:
                    data[i][key] = ''

                if type(data[i][key]) == datetime.datetime:
                    pattern = '%d/%m %H:%M' if data[i][key].year == today.year else '%d/%m/%y %H:%M'
                    data[i][key] = data[i][key].strftime(pattern)

                elif type(data[i][key]) == datetime.date:
                    pattern = '%d/%m' if data[i][key].year == today.year else '%d/%m/%y'
                    data[i][key] = data[i][key].strftime(pattern)

                if len(str(data[i][key])) > widths[key]:
                    widths[key] = len(str(data[i][key]))

        for key in columns:
            sys.stdout.write(key.ljust(widths[key]+2))

        sys.stdout.write("\n")

        for key in columns:
            sys.stdout.write('-' * (widths[key]+2))
        sys.stdout.write("\n")

        for i in range(0, len(data)):
            for key in columns:
                if key == 'date':
                    sys.stdout.write(str(data[i][key]).rjust(widths[key]))
                    sys.stdout.write("  ")
                else:
                    sys.stdout.write(str(data[i][key]).ljust(widths[key]+2))

            sys.stdout.write("\n")


    def sanitise(self, string):
        return re.sub('[\s\t]+', ' ', string)


    def process_pending_refunds(self, rows):
        pending_returned = {}

        for row in rows:
            if row['money_in'] is not None and row['mastercard_lifecycle_id'] is not None:
                pending_returned[row['mastercard_lifecycle_id']] = row['money_in']

        for i in range(0, len(rows)):
            if rows[i]['pending'] and rows[i]['mastercard_lifecycle_id'] in pending_returned and pending_returned[rows[i]['mastercard_lifecycle_id']] == rows[i]['money_out']:
                rows[i]['money_out'] = str(rows[i]['money_out']) + ' R'

        return rows


    def is_within_bst(self, dt):
        """
        Check if the given datetime falls within the British Summer Time (BST) range.
        """
        if dt.month < 3 or dt.month > 10:
            return False
        if dt.month > 3 and dt.month < 10:
            return True
        if dt.month == 3:
            last_sunday_march = dt.replace(day=31)
            while last_sunday_march.weekday() != 6:  # find the last Sunday of March
                last_sunday_march -= datetime.timedelta(days=1)
            return dt >= last_sunday_march
        if dt.month == 10:
            last_sunday_october = dt.replace(day=31)
            while last_sunday_october.weekday() != 6:  # find the last Sunday of October
                last_sunday_october -= datetime.timedelta(days=1)
            return dt < last_sunday_october


    def adjust_timestamp(self, dt):
        """
        Adjust the datetime for BST if it falls within the BST range.
        """
        if self.is_within_bst(dt):
            return dt + datetime.timedelta(hours=1)
        return dt


Monzo(sys.argv)
