#!/usr/bin/env python
#
# A tool to create publication list as an html or tex file, based on a bib file (BIBTEX or YAML).
#
# Developed for Python 2.7
#
# Gokberk Cinbis, 2017

# A small note regarding pip on MAC: usually /usr/bin/python and /usr/local/bin/python 
# are independent. Make sure you install required packages using the right pip.

from jinja2 import Environment, FileSystemLoader
import argparse
import yaml     # pip install pyyaml
import bibtexparser # pip install bibtexparser (https://github.com/sciunto-org/python-bibtexparser)
import os
import sys
from numbers import Number
import copy

template_search_path = [os.getcwd(),os.path.dirname(os.path.abspath(__file__))]

class Bib2Doc:
    """
    Generic conversion class, shared across all publication list generators (html, tex, etc).
    This class shall be kept independent from the "args" structure used in specific convertion scripts.
    """

    def __init__(self,bibfile):
        ''' 
        Initialize by loading contents of a bibfile 
        @bibfile    A string or an array of strings
        '''
        self.bib = []
        self.num_processed = {} # num_processed[paper] gives the number of times the paper is processed
        if isinstance(bibfile,list):
            for f in bibfile:
                self.append_bib(f)
        elif isinstance(bibfile,str):
            self.append_bib(bibfile)
        else:
            raise Exception("bibfile argument should be a string or a set of string")

    @staticmethod
    def _mklist(val):
        """ Convert None to [], any non-list to [val], and, keep list as is. """
        if val == None:
            return []
        elif not isinstance(val,list):
            return [val]
        return val

    def list_year(self,default_year=None):
        """ Get list of unique years in decreasing order. If there are papers with no 'year' definition, default_year will be in the list. """
        # return reversed(sorted(list(set([self.getfieldwithdefaultNumeric(p,'year',default_year) for p in self.bib]))))
        # The above trick fails if year is not numeric.
        return reversed(sorted(list(set([self.getfieldwithdefault(p,'year',default_year) for p in self.bib]))))

    def filter(self,field_defaults,**kwargs):
        """ 
        Filter bib using a list of accepted field values.
        self.filter(field_defaults,field1,val1,field2,val2,...)
            * field_defaults is a dictionary of default values for fields. 
              When a field is missing in an entry, it is treated as field_defaults[field] (if defined) or None.
            * val can be a single value or a list of accepted values. 
            * Note that fields are case-sensitive
        """
        # Add only if needed: If "bib" is defined, uses "bib" instead of internal bibliography. 
        if kwargs is None:
            return self.bib
        bib = self.bib
        for fld, val in kwargs.items():
            val = self._mklist(val)
            if fld in field_defaults:
                default = field_defaults[fld]
            else:
                default = None
            # print(fld,val,default)
            bib = [p for p in bib if Bib2Doc.getfieldwithdefault(p,fld,default) in val]
        return bib

    @staticmethod
    def paper2hashablestring(paper):
        """
        Returns a unique identifier for the entry.
        """
        # No! paper is a dict, therefore, length is longer. assert len(paper)==1, 'len(paper)==1' 
        if not isinstance(paper,type({})):
            raise Exception("This is supposed to be a cell but it is not:\n" + str(paper))
        return str(paper) # this works!

    def processtrack_init(self):
        """ Reset the counter how many times each paper has been processed. Call this at the beginning of each comprehensive publication list page. """
        self.num_processed = {}
        return '' # to avoid nasty side-effects in Jekyll

    def processtrack_markpaper(self,paper):
        """ Mark a particular paper as processed. The passed paper should be the original record (eg, not the output of html_fixampersand). Use with processtrack_* """
        h = Bib2Doc.paper2hashablestring(paper)
        if h not in self.num_processed:
            self.num_processed[h] = 1
        else:
            self.num_processed[h] = self.num_processed[h] + 1
        return '' # to avoid nasty side-effects in Jekyll

    def processtrack_check(self,warnonly,custombib=None):
        """ Verify that we process each paper exactly once. Call this after the publications list is fully constructed. Use with processtrack_* """
        """ If a filtered result is being shown in the page, feed filter result to custombib """
        errstr_init = '\n' # more readable in this way.
        errstr = errstr_init
        if custombib is None:
            custombib = self.bib
        for p in custombib:
            h = Bib2Doc.paper2hashablestring(p)
            if h not in self.num_processed:
                errstr = errstr + ' * Not processed: ' + h + '\n\n'
            elif self.num_processed[h] != 1:
                errstr = errstr + ' * Processed ' + str(self.num_processed[h])  + ' times: ' + h + '\n\n'
        if errstr != errstr_init:
            if warnonly:
                print("WARNING")
                print(errstr)
            else:
                raise Exception(errstr)
        return '' # to avoid nasty side-effects in Jekyll

    def count_types(self,types):
        """ Count number of papers with type in types """
        count = 0
        for p in self.bib:
            try:
                if p['type'] in types:
                    count = count + 1
            except Exception:
                print(p)
            
        return count

    def count_preprint(self):
        """ Count number of papers with type = under_review or working_paper """
        return self.count_types(['under_review','working_paper'])

    @staticmethod
    def add_default_values(paper):
        """ Add default values for the missing or empty fields """
        #if ('national' not in paper) or (paper['national']==None) or (paper['national']==''):
        #    paper['national']=1
        # Design choice note:
        # * Do NOT make national a field. It makes processing very complicated, where one has to check
        #   whether a publication is national or not, for every single publication.
        return paper

    def append_paper(self,paper,convert_year_to_int):
        paper = Bib2Doc.add_default_values(paper)
        if convert_year_to_int and 'year' in paper:
            paper['year'] = int(paper['year'])
        self.bib.append(paper)

    def append_bib(self,bibfile,convert_year_to_int=True):
        """
        Read the publications database, and append it to the existing bib.
        convert_year_to_int: Recommended is True. This just ensures that per-year grouping works correctly.
        """
        with open(bibfile, "r") as f:
            _, ext = os.path.splitext(bibfile)
            if ext == '.yaml':
                collections = yaml.load_all(f) 
                for records in collections: # there is normally a single "records" in "collections"
                    for paper in records:
                        self.append_paper(paper,convert_year_to_int=convert_year_to_int)
            elif ext == '.bib':
                records = (bibtexparser.load(f)).entries
                for paper in records:
                    if not('type' in paper):
                        paper['type'] = paper['ENTRYTYPE']
                    self.append_paper(paper,convert_year_to_int=convert_year_to_int)
            else:
                raise Exception("Unrecognized file extension in " + bibfile)

    def remove_duplicates(self):
        """
        Remove duplicate entries from the dataset. Relies completely on titles right now, as minor differences shall be ignored.
        """
        hash_set = set()
        new_bib = []
        for p in self.bib:
            h_ = Bib2Doc.getfieldwithdefault(p,'title','')
            if h_ == "":
                new_bib.append(p)
            elif h_ not in hash_set:
                hash_set.add(h_)
                new_bib.append(p)
            else:
                print("-- Removing the following duplicate entry (possibly more than one ImageLab author has this in his/her bib which is OK) --")
                print(p)
        self.bib = new_bib


    def sort(self,newestFirst=True):
        """
        Sort the publications database according to (year, month, title)
        Uses the following default values: year=2100, month=12, title=''.
        month4sort field overrides month when defined for a publication.

        @newestFirst    If True, bib is sorted from newest to oldest.
        """

        DEFAULT_YEAR = 2100
        DEFAULT_MONTH = 12
        def p2priority(p):
            year = Bib2Doc.getfieldwithdefaultNumeric(p,'year',DEFAULT_YEAR) 
            if Bib2Doc.isdef(p,'month4sort'):
                month = Bib2Doc.getfieldwithdefaultNumeric(p,'month4sort',DEFAULT_MONTH) 
            else:
                month = Bib2Doc.getfieldwithdefaultNumeric(p,'month',DEFAULT_MONTH)
            title = Bib2Doc.getfieldwithdefault(p,'title','')
            return (year,month,title) # sort by year first, then month, then title

        # now sort
        prior = [p2priority(p) for p in self.bib]
        self.bib = [x for _, x in sorted(zip(prior,self.bib), key=lambda pair: pair[0])]

        if newestFirst:
            self.bib = self.bib[::-1]
            #self.bib = reversed(self.bib) --> returns an iterator

    @staticmethod
    def get_pdf_href(paper):
        """ Utility function that returns paper.pdf or '#' """
        return paper['pdf'] if Bib2Doc.isdef(paper,'pdf') else '#' 

    @staticmethod
    def get_venue(paper):
        """ Return vanue name (journal for article, booktitle otherwise) """
        if Bib2Doc.isdef(paper,'type') and (paper['type'] == 'article'):
            return Bib2Doc.getfieldwithdefault(paper,'journal','--',True)
        else:
            return Bib2Doc.getfieldwithdefault(paper,'booktitle','--',True)

    _month2name = {1: 'January', 2: 'February', 3: 'March', 4: 'April', 5: 'May', 6: 'June', 7: 'July', 
            8: 'August', 9: 'September', 10: 'October', 11: 'November', 12: 'December'}

    @classmethod
    def get_month_name(cls,paper):
        """
        Utility function to get month name for a paper.
        Return "" if paper.month is undefined or unrecognized
        """
        try:
            return cls._month2name[paper['month']]
        except Exception as e:
            return ""

    @staticmethod
    def isdef(paper,field):
        """ Return True iff paper[field] is defined and not empty """
        return True if ((field in paper) and (paper[field])) else False

    @staticmethod
    def getfieldwithdefault(paper,field,default,warnIfMissing=False):
        """ Return paper[field] (if defined) or default """
        try:
            return paper[field]
        except Exception:
            if warnIfMissing:
                print('=======')
                print('WARNING')
                print('=======')
                print('The field "' + field + '" is missing in the following entry') 
                for k,v in paper.items():
                    print('   ' + k + ": " + str(v))
            return default

    @staticmethod
    def getfieldwithdefaultNumeric(p,field,default):
        """ Get field of paper p. If the field is missing or not Numeric, return default """
        x = Bib2Doc.getfieldwithdefault(p,field,default)
        if not isinstance(x, Number):
            return default
        return x

    @staticmethod
    def get_vol_num_pp_loc_mon_year(p):
        """ Return a string according to the following template and the available information: 'volume (number), pp. pages, location, month year' """
        out = ''
        # Don't force "article" type -- if Bib2Doc.getfieldwithdefault(p,'type','')=='article' --
        if Bib2Doc.isdef(p,'volume'):
            if Bib2Doc.isdef(p,'number'):
                out += str(p['volume']) + '(' + str(p['number']) + '), '
            else:
                out += str(p['volume']) + ", "
        if Bib2Doc.isdef(p,'pages'):
            out += 'pp. ' + str(p['pages']) + ', '
        if Bib2Doc.isdef(p,'location'):
            out += p['location'] + ', '
        out += Bib2Doc.get_month_name(p) + ' ' # possibly ""
        if Bib2Doc.isdef(p,'year'):
            out += str(p['year'])
        return out

    ###### DEPRECATED ###### 

    # fix_html_ampersand: use html_fixampersand instead.
    # field2htmllink: use html_field2link instead.

    ###### HTML-SPECIFIC ###### 

    @staticmethod
    def html_field2link(paper,field,text=None,target="_blank",css_class=''):
        """ 
        Return HTML link with href=paper.field, if paper[field] is defined, and '' otherwise. 
        Text is field by default.
        """
        if not (css_class == ''):
            css_class = ' class="' + css_class + '"'
        return ("<a href=" + paper[field] + ' target="' + target + css_class + '">' + (text or field) + '</a>') if Bib2Doc.isdef(paper,field) else ""

    @staticmethod
    def html_fixampersand(paper):
        """
        Replace & with &amp; in all string-valued fields of paper except the following: {pdf,img}
        Returns a modified deep-copy of the input, to avoid unintended changes in the original bib records.
        """
        paper = copy.deepcopy(paper)
        for key in paper:
            if key in {'pdf','img'}:
                continue # do not alter these fields as they contain link/path.
            if type(paper[key])==str:
                paper[key] = paper[key].replace('&','&amp;') 
        return paper


    ###### TEX-SPECIFIC ###### 

    @staticmethod
    def tex_fixampersand(paper):
        """
        Replace & with \\& in all string-valued fields of paper except the following: {pdf,img}
        Returns a modified deep-copy of the input, to avoid unintended changes in the original bib records.
        """
        paper = copy.deepcopy(paper)
        for key in paper:
            if key in {'pdf','img'}:
                continue # do not alter these fields as they contain link/path.
            if type(paper[key])==str:
                paper[key] = paper[key].replace('&','\\&') 
        return paper


def getargs():
    parser = argparse.ArgumentParser(description='''
        Create a publication list document, based on a jinja2 template and a YAML bibliograph file.
        See https://github.com/gcinbis/bib2doc and http://jinja.pocoo.org/docs/templates 
        ''')
    parser.add_argument('template',help='Path of the jinja2 based HTML OR TEX template.')
    parser.add_argument('outputfile',help='Path of the output file.')
    parser.add_argument('--bibfile',action='append',help='Add YAML bib file path. (Can be used multiple times to utilize multiple bib files)',required=True)
    args = parser.parse_args()
    
    return args

if __name__ == "__main__":
    args = getargs()
    bib2doc = Bib2Doc(args.bibfile)
    bib2doc.remove_duplicates() # todo/maybe: make this optional.
    bib2doc.sort(newestFirst=True) # todo/maybe: make this optional.
    env = Environment(loader=FileSystemLoader(template_search_path),trim_blocks=True)
    template = env.get_template(args.template)
    template.stream(bib2doc=bib2doc).dump(args.outputfile)

