#!/usr/bin/env python
"""This script looks through a file for a bipartition table that could be split
  over seprate lines (interleaved).
It uses simple regex pattern recognition, so there is a chance that it will
  misidentify text as a splits table -- check your results!
"""

import sys
import re
import os
script_name = os.path.split(sys.argv[0])[1]

verbose = False
def warn(msg):
    sys.stderr.write("%s: %s\n" % (script_name, msg))
def debug(msg):
    if verbose:
        sys.stderr.write("%s: %s\n" % (script_name, msg))

sources = []
if len(sys.argv) == 1:
    sources.append(sys.stdin)
else:
    for fn in sys.argv[1:]:
        try:
            sources.append(open(fn, 'rU'))
        except:
            sys.exit('Could not open the file "%s"' % fn)

before_table_pat = re.compile(r'^Bipartitions found')
table_header_pat = re.compile(r'^[0-9 ]+(Freq\s+%)?\s*$')
header_sep_pat = re.compile(r'[-]+\s*$')
partial_table_pat = re.compile(r'([.*]+\s*)')
full_table_pat = re.compile(r'([.*]+\s+\d+\s+[0-9.]+)%?')

def table_done(header, prev_table_rows, new_rows, line_num):
    if prev_table_rows:
        assert len(prev_table_rows) == len(new_rows), "Number of rows differs from previous pages of the splits table.' At line %d" % line_num
        for n, nl in enumerate(new_rows):
            prev_table_rows[n] += nl
    else:
        prev_table_rows = new_rows
    sep = '-'* len(header[-1])
    return header + [sep], prev_table_rows

def parse_splits_table(inp):
    table_list = []
    table_started = False
    reading_table = False
    reading_header = False
    curr_table = []
    curr_header = []
    working_header = None
    working_table = []
    for line_num, line in enumerate(inp):
        if (not reading_table) and (not reading_header):
            if before_table_pat.match(line):
                table_started = True
                reading_header = True
                reading_table = False
                debug("READING HEADER")
                next_pattern = table_header_pat
                pattern_after = header_sep_pat
                curr_table_page = 0
                curr_row = 0
                working_header = []
            elif False:
                sys.stderr.write("skipping %s\n" % line[:-1])
        else:
            s_line = line[:-1]
            if reading_header:
                if not s_line:
                    continue
                m = table_header_pat.match(s_line)
                if m:
                    working_header.append(s_line)
                    is_final_page = m.group(1) is not None
                    debug("group 1 of header pattern is %s " % str(m.group(1)))
                    debug("s_line was %s " % s_line)
                elif header_sep_pat.match(s_line):
                    assert working_header, "Expecting column number header before ----- separator at line number %d" % (1 + line_num)
                    if curr_header:
                        len_diff = len(curr_header) - len(working_header)
                        for n, h in enumerate(working_header):
                            curr_header[len_diff + n] += h
                    else:
                        curr_header = working_header
                    working_header = []
                    reading_header = False
                    reading_table = True
                    debug("READING TABLE")
                    table_pat = is_final_page and full_table_pat or partial_table_pat
                debug("header is %s final row" % ((not is_final_page) and "NOT" or ""))
            elif reading_table:
                if not s_line:
                    reading_table = False
                    if is_final_page:
                        h, tab = table_done(curr_header, curr_table, working_table, line_num)
                        table_list.append(tab)
                        reading_header = False
                        curr_table = []
                        curr_header = []
                    else:
                        curr_header, curr_table = table_done(curr_header, curr_table, working_table, line_num)
                        reading_header = True
                        debug("READING NEXT PAGE OF HEADER len(curr_table)= %d len(curr_table[0]) = %d " % (len(curr_table), len(curr_table[0]))) 
                    working_table = []
                else:
                    m = table_pat.match(s_line)
                    assert m, "Expecting at a splits table row at line number %d" % (1 + line_num)
                    working_table.append(m.group(1))
    if working_table:
        h, tab = table_done(curr_header, curr_table, working_table, line_num)
        table_list.append(tab)
    if (not table_list) and (not table_started):
        warn('A line starting with "Bipartitions found" was not found')
    return table_list




outstream = sys.stdout
for n, inp in enumerate(sources):
    t = parse_splits_table(inp)
    if not t:
        if len(sys.argv) > 1:
            sys.exit('No splits table found in %s' % sys.argv[1 + n])
        else:
            sys.exit('Splits table not found')
    else:
        for table in t:
            for line in table:
                outstream.write("%s\n" % line)
            outstream.write("\n")
