#!/usr/bin/env python3
import anndata as ad
import pandas as pd
import csv, sys, argparse

def main():
	parser=argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,description="Fetch colors from h5ad")
	parser.add_argument('-i','--input_h5ad', help='input h5ad file')
	parser.add_argument('-o','--output_file', help='output filename in .tsv format')
	parser.add_argument('-c', '--color_names', help='tab-separated file of metadata keys and corresponding color values. Only needed if metadata field names do not match color values in uns.')
	args=parser.parse_args()

	# Print help message if no arguments are supplied
	if len(sys.argv) == 1:
		parser.print_help(sys.stderr)
		sys.exit(1)


	a=ad.read_h5ad(args.input_h5ad)
	if args.color_names:
		colornames={}
		with open(args.color_names) as colorfile:
			for line in colorfile:
				line=line.strip().split('\t')
				colornames[line[0]]=line[1]


	with open(args.output_file,"w") as csv_file:
		writer = csv.writer(csv_file, delimiter='\t')
		for metavalue in a.obs.keys():
			if str(metavalue+'_colors') in a.uns.keys():
				for key,value in {**dict(zip(a.obs[metavalue].cat.categories,a.uns[str(metavalue+'_colors')]))}.items():
					writer.writerow([key,value])
		if args.color_names:
			for key in colornames:
				for key,value in {**dict(zip(a.obs[key].cat.categories,a.uns[str(colornames[key])]))}.items():
					writer.writerow([key,value])

if __name__ == "__main__":
	main()
