#!/usr/bin/env python
# Copyright 2019-2022 DADoES, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License in the root directory in the "LICENSE" file or at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import sys
import time
import anatools
import subprocess
from anatools.lib.channel import Channel, find_channelfile


def update_credentials(client, volumes):
    data = client.mount_volumes(volumes=volumes)
    if data == False:
        print_color('There was an error retrieving mount credential, please contact Rendered.ai for support.', 'ff0000')
        sys.exit(1)
    awsprofiles = {}
    if os.path.isdir('/home/anadev/.aws') and os.path.isfile('/home/anadev/.aws/credentials'):
        with open('/home/anadev/.aws/credentials', 'r') as awscredfile:
            lines = awscredfile.readlines()
            profile = '[default]'
            awsprofiles[profile] = []
            for line in lines:
                line = line.rstrip()
                if line.startswith('[') and line.endswith(']'):
                    profile = line
                    awsprofiles[profile] = []
                else: awsprofiles[profile].append(line)
    awsprofiles['[volumes]'] = [
        f"aws_access_key_id={data['credentials']['accesskeyid']}\n",
        f"aws_secret_access_key={data['credentials']['accesskey']}\n",
        f"aws_session_token={data['credentials']['sessiontoken']}]\n" ]
    if not os.path.isdir('/home/anadev/.aws'): os.mkdir('/home/anadev/.aws')
    with open('/home/anadev/.aws/credentials', 'w+') as awscredfile:
        for profile in awsprofiles.keys():
            if len(awsprofiles[profile]):
                awscredfile.write(profile+'\n')
                awscredfile.writelines(awsprofiles[profile])
    return data

def mount_volumes(data, volumes, verbose):
    ws_dataroot = os.path.join('/workspaces', os.listdir("/workspaces")[0], "data")
    ws_volroot = os.path.join(ws_dataroot, "volumes")
    if not os.path.isdir('/home/anadev/data'): os.mkdir('/home/anadev/data')
    if not os.path.isdir(ws_dataroot): os.mkdir(ws_dataroot)
    if not os.path.isdir('/home/anadev/data/volumes'): os.mkdir('/home/anadev/data/volumes')
    if not os.path.isdir(ws_volroot): os.mkdir(ws_volroot)
    for i,volumeId in enumerate(volumes):
        print(f'Mounting volume {volumeId}...', end='')
        if not os.path.exists(f'/home/anadev/data/volumes/{volumeId}'): os.mkdir(f'/home/anadev/data/volumes/{volumeId}')
        rw = ''
        if data['rw'][i] == 'r': rw = '-o allow_other -o umask=0002'
        command = f's3fs {data["keys"][i]} /home/anadev/data/volumes/{volumeId} -o profile=volumes -o endpoint=us-west-2 -o url="https://s3-us-west-2.amazonaws.com" -o use_cache=/tmp/s3fs/{volumeId} {rw} -f -d'
        if verbose: subprocess.Popen(command, shell=True, preexec_fn=os.setsid)
        else:       subprocess.Popen(command, stdout=subprocess.DEVNULL, shell=True, preexec_fn=os.setsid)
        ws_voldir = os.path.join(ws_volroot, volumeId)
        if os.path.exists(ws_voldir): os.unlink(ws_voldir)
        os.symlink(f'/home/anadev/data/volumes/{volumeId}', ws_voldir)
        print('complete!')


def unmount_volumes(volumes):
    ws_dataroot = os.path.join('/workspaces', os.listdir("/workspaces")[0], "data")
    ws_volroot = os.path.join(ws_dataroot, "volumes")
    try:
        print(f'Unmounting volumes...', end='')
        try:
            pids = map(int, subprocess.check_output(["pidof", "s3fs"]).split()); 
            for pid in pids:
                try:
                    ump = subprocess.Popen(f'kill -9 {pid}',stdout=subprocess.PIPE, shell=True); ump.wait()
                except: pass
        except: pass
        for volumeId in volumes:
            try:
                ump = subprocess.Popen(f'sudo umount -f /home/anadev/data/volumes/{volumeId}', shell=True); ump.wait()
                ump = subprocess.Popen(f'sudo rm -rf /home/anadev/data/volumes/{volumeId}', shell=True); ump.wait()
                ws_voldir = os.path.join(ws_volroot, volumeId)
                os.unlink(ws_voldir)
            except: pass
        print('complete.')
    except: pass


def mount_loop(client, volumes, verbose=0):
    while True:
        try:
            data = update_credentials(client, volumes)
            mount_volumes(data, volumes, verbose)
            for i in range(350):
                seconds = 3500-(i*10)
                print(f'Remounting volumes in {seconds}s...', end='\r')
                time.sleep(10)
            update_credentials(client, volumes)
            unmount_volumes(volumes)
        except KeyboardInterrupt:
                unmount_volumes(volumes)
                sys.exit()


def print_color(text, color):
    r,g,b = tuple(int(color[i:i+2], 16) for i in (0, 2, 4))
    coloresc = '\033[{};2;{};{};{}m'.format(38, r, g, b)
    resetesc = '\033[0m'
    print(coloresc + text + resetesc)



parser = argparse.ArgumentParser()
parser.add_argument('--channel', type=str, default=None)
parser.add_argument('--volumes', type=str, default=None)
parser.add_argument('--environment', type=str, choices=['prod', 'test', 'dev', 'infra'], default='prod')
parser.add_argument('--organization', type=str, default=None)
parser.add_argument('--email', type=str, default=None)
parser.add_argument('--password', type=str, default=None)
parser.add_argument('--local', type=bool, default=False)
parser.add_argument('--verbose', type=bool, default=False)
parser.add_argument('--unmount', type=bool, default=False)
args = parser.parse_args()

volumes = []
if args.volumes:
    try:    volumes.append( args.volumes.replace('[','').replace(']','').split(',') )
    except: print('Failed to parse --volumes input, expecting a list of volumeIds.');sys.exit(1)
if args.channel is None: args.channel = find_channelfile()
if args.channel:
    channel = Channel(args.channel)
    for package in channel.packages.keys():
        if 'volumes' in channel.packages[package]:
            for volume in channel.packages[package]['volumes'].keys():
                if channel.packages[package]['volumes'][volume] != 'local':
                    volumes.append(channel.packages[package]['volumes'][volume])
volumes = list(set(volumes))
if len(volumes):
    if args.unmount: unmount_volumes(volumes); sys.exit(1)
    client = anatools.client(
        environment=args.environment,
        email=args.email, 
        password=args.password,
        interactive=False,
        local=args.local, 
        verbose=args.verbose)
    if len(client.organizations) == 0: sys.exit(1)
    if args.organization: client.set_organization(args.organization)
    for volume in volumes: 
        if volume not in client.volumes: 
            print_color(f'Warning: Unable to mount volume {volume}, permission denied.', 'ff0000')
            volumes.remove(volume)
    print_color(f'This process to mount Volumes was successfully enabled and will remain open to refresh Volumes. Killing this process will unmount Volumes. Press CTRL+C or close the terminal to kill the process.', 'ffff00')
    mount_loop(client, volumes, args.verbose)

else: print('No volumes specified.'); sys.exit(1)
