#!/usr/bin/env python

from __future__ import print_function

import sys
import os
import subprocess

from pyLAD.Base import Base
from pyLAD.Parser import Parser
from pyLAD.Mapping import Fastq, Preprocessor
from pyLAD.Segmenter import Segmenter


class LADetector(Base):

    def __init__(self):
        parser = Parser()
        self.args = parser.parse_args()
        if self.args is None:
            return None
        super(LADetector, self).__init__(verbose=self.args.VERBOSE,
                                          log_name='LADetector')
        self.load_segmenter()
        self.map_fastq()
        self.load_counts()
        self.bin_data()
        self.partition_data()
        self.write_results()

    def load_segmenter(self):
        # If needed, make the output directory
        subprocess.Popen(['mkdir', '-p', self.args.NAME])

        self.segmenter = Segmenter(verbose=self.args.VERBOSE)
        # Try loading a saved segmenter file
        if os.path.exists("%s/%s.npz" % (self.args.NAME, self.args.NAME)):
            self.segmenter.load("%s/%s.npz" % (self.args.NAME, self.args.NAME))

        # Load RE data
        if (self.segmenter.data is None
                and (self.args.RE is not None
                     or self.args.INDEX is not None)):
            self.segmenter.load_RE(fname=self.args.RE, index=self.args.INDEX,
                                   focus=self.args.FOCUS)

        # RE data wasn't loaded, stop execution
        if self.segmenter.data is None:
            sys.exit(1)
        if self.args.UNALIGN is not None:
            self.segmenter.load_unalignable(fname=self.args.UNALIGN)

    def map_fastq(self):
        if self.args.FASTQ is None:
            return
        self.args.BAM = []
        for cond in self.args.FASTQ:
            self.args.BAM.append([])
            for fname in cond:
                bam_fname = ("%s/%s.bam" % (self.args.NAME,
                             fname.split('/')[-1].split('.fastq')[0]))
                if (len(self.args.SPLIT) == 0
                        and len(self.args.ADAPTERS) == 0
                        and not self.args.QTRIM):
                    preproc_fname = None
                else:
                    preproc_fname = ("%s/%s.preprocessed.fastq" % (
                        self.args.NAME,
                        fname.split('/')[-1].split('.fastq')[0]))
                if (
                        os.path.exists(bam_fname)
                        and (not self.args.MULTIMAPPING or (
                            os.path.exists(bam_fname.replace('.bam',
                                                             '.1.bam'))
                            and os.path.exists(bam_fname.replace('.bam',
                                                                 '.2.bam'))
                        ))):
                    self.logger.info("The bam file(s) %s already exists. "
                                     "Skipping preprocessing and alignment" %
                                     bam_fname)
                else:
                    fastq_fname = preproc_fname
                    if (preproc_fname is not None
                            and not os.path.exists(preproc_fname)):
                        Preprocessor(fastq=fname,
                                     outdir=self.args.NAME,
                                     threads=self.args.THREADS,
                                     quality_trim=self.args.QTRIM,
                                     quality_cutoff=self.args.QCUTOFF,
                                     split_by=self.args.SPLIT,
                                     adapters=self.args.ADAPTERS,
                                     minsize=self.args.MINREAD,
                                     verbose=self.args.VERBOSE)
                    elif preproc_fname is not None:
                        self.logger.info("The preprocessed file %s already "
                                         "exists. Skipping preprocessing" %
                                         preproc_fname)
                    else:
                        self.logger.info("No preprocessing needed. "
                                         "Skipping preprocessing")
                        fastq_fname = fname
                    Fastq(fastq=fastq_fname,
                          index=self.args.INDEX,
                          threads=self.args.THREADS,
                          outdir=self.args.NAME,
                          multimapping=self.args.MULTIMAPPING,
                          trim_size=self.args.TRIMSIZE,
                          seed=self.args.SEED,
                          verbose=self.args.VERBOSE)
                    if not self.args.KEEPPRE and preproc_fname is not None:
                        os.remove(preproc_fname)
                    for x in ['un', 'un1', 'un2', 'mult', 'mult1']:
                        fname = fastq_fname.replace('.fastq',
                                                    '.%s.fastq' % x)
                        if os.path.exists(fname):
                            os.remove(fname)
                self.args.BAM[-1].append(bam_fname)
                if self.args.MULTIMAPPING:
                    self.args.BAM[-1].append(bam_fname.replace('.bam',
                                                               '.1.bam'))
                    self.args.BAM[-1].append(bam_fname.replace('.bam',
                                                               '.2.bam'))

    def load_counts(self):
        if self.args.BAM is None and self.args.COUNTS is None:
            return
        conditions = ['treatment', 'control']
        for i in range(2):
            if self.args.BAM is not None:
                for fname in self.args.BAM[i]:
                    self.segmenter.load_reads(fname=fname,
                                              condition=conditions[i],
                                              count_overlaps=self.args.OVERLAP,
                                              maxdist=self.args.MAXDIST)
            else:
                for fname in self.args.COUNTS[i]:
                    self.segmenter.load_counts(fname=fname,
                                               condition=conditions[i])

    def bin_data(self):
        if self.args.BINSIZE > 0:
            self.segmenter.bin_data(binsize=self.args.BINSIZE)

    def partition_data(self):
        self.segmenter.segment_data(threads=self.args.THREADS)
        self.segmenter.find_LADs_from_segmentation(mindip=self.args.MINDIP,
                                                   maxdip=self.args.MAXDIP)

    def write_results(self):
        self.segmenter.save("%s/%s.npz" % (self.args.NAME, self.args.NAME))
        if self.args.BEDGRAPH:
            for data in ['score', 'treatment', 'control']:
                self.segmenter.write_bedgraph(
                    target=data,
                    fname="%s/%s_%s.bg" %
                    (self.args.NAME, self.args.NAME, data))
        for data in ['LADs', 'DIPs']:
            self.segmenter.write_bed(
                target=data,
                fname="%s/%s_%s.bed" % (self.args.NAME, self.args.NAME, data))
        if self.args.PLOT:
            self.segmenter.plot_scores(
                fname="%s/%s.pdf" % (self.args.NAME, self.args.NAME),
                binsize=self.args.PBINSIZE,
                show_lads=self.args.LADS,
                show_dips=self.args.DIPS,
                show_means=self.args.MEANS,
                show_partitions=self.args.SEGMENTS)


LADetector()
sys.exit(0)
