#! /usr/bin/env python3

import sys,argparse,pathlib

# Set up script arguments
parser = argparse.ArgumentParser(
    formatter_class=argparse.RawDescriptionHelpFormatter,
    description="Can correct names in a column using a given file")
parser.add_argument('metaFile', type=str, help='metadata/cluster/etc file with values you want to correct')
parser.add_argument('replaceFile', type=str,
  help='two column file, with col1 being the old value and col2 being the new value')
parser.add_argument('fieldToReplace', type=str,
  help='field name from metadata header that contains the values to replace')
args = parser.parse_args()

# From https://stackoverflow.com/questions/541390/extracting-extension-from-filename-in-python
metaFtype = pathlib.Path(args.metaFile).suffix
repFtype = pathlib.Path(args.replaceFile).suffix

if metaFtype == ".csv":
    msep = ","
elif metaFtype == ".tsv":
    msep = "\t"
else:
    print("metaFile must have tsv or csv file extension")
    exit(1)
    
if repFtype == ".csv":
    rsep = ","
elif repFtype == ".tsv":
    rsep = "\t"
else:
    print("replaceFile must have tsv or csv file extension")
    exit(1)

# Build a dict of old -> new values from file
rfh=open(args.replaceFile, "r")
replacements = dict()
for line in rfh:
    #splitLine = line.rstrip().split("\t")
    splitLine = line.rstrip().split(rsep)
    replacements[splitLine[0]]=splitLine[1]

mfh=open(args.metaFile, "r")

# Get index of field that we want to replace values in
#header = mfh.readline().rstrip().split("\t")
header = mfh.readline().rstrip().split(msep)
n = header.index(args.fieldToReplace)
#print(*header, sep="\t")
print(*header, sep=msep)

# Go through input files and replace metadata values as needed
for line in mfh:
    #splitLine = line.rstrip().split("\t")
    splitLine = line.rstrip().split(msep)
    toReplace = splitLine[n]
    if toReplace in replacements.keys():
        splitLine[n] = replacements[toReplace]
        # From https://stackoverflow.com/questions/4048964/printing-tab-separated-values-of-a-list
        #print(*splitLine, sep="\t")
        print(*splitLine, sep=msep)
    else:
        #print(*splitLine, sep="\t")
        print(*splitLine, sep=msep)

