#!/usr/bin/env python3
#
# test-decode-pretbc - run the decoding toolchain against pre-generated TBC files
# Copyright (C) 2019-2022 Adam Sampson
# Copyright (C) 2022 Chad Page
# Copyright (C) 2025 Simon Inns
#
# This file is part of ld-decode.
#
# test-decode-pretbc is free software: you can redistribute it and/or
# modify it under the terms of the GNU General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

# This script is for testing the toolchain with pre-generated TBC files,
# skipping the ld-decode and ld-cut stages that are no longer part of this repo.

import argparse
import json
import numpy
import os
import shutil
import sqlite3
import subprocess
import sys

dry_run = False
source_dir = None
build_dir = None

def resolve_tool(*candidates):
    for path in candidates:
        if os.path.exists(path):
            return path
    return candidates[0]

def die(*args):
    """Print an error message and exit."""
    print(*args, file=sys.stderr)
    sys.exit(1)

def safe_unlink(filename):
    """Remove a file if it exists; if not, do nothing."""

    try:
        os.unlink(filename)
    except FileNotFoundError:
        pass

def clean(args, suffixes):
    """Remove output files, if they exist."""

    for suffix in suffixes:
        safe_unlink(args.output + suffix)

def run_command(cmd, **kwopts):
    """Run a command, as with subprocess.call.
    If it fails, exit with an error message."""

    print('\n>>>', ' '.join(cmd), file=sys.stderr)
    if dry_run:
        return

    # Flush both streams, in case we're in an environment where they're both buffered
    sys.stdout.flush()
    sys.stderr.flush()

    rc = subprocess.call(cmd, stderr=subprocess.STDOUT, **kwopts)
    if rc != 0:
        die(cmd[0], 'failed with exit code', rc)

def copy_pretbc_files(args):
    """Copy pre-generated TBC files to the working directory."""
    
    # Determine source base name from the input file path
    # e.g., test-data/ntsc/ve-snw-cut -> ve-snw-cut
    base_name = os.path.basename(args.pretbc_input)
    pretbc_dir = os.path.dirname(args.pretbc_input)
    
    # Files to copy: .tbc, .tbc.json, .tbc.db, .efm, .pcm, .log
    extensions = ['.tbc', '.tbc.json', '.tbc.db', '.efm', '.pcm', '.log']
    
    print(f'\nCopying pre-generated TBC files from {pretbc_dir}/', file=sys.stderr)
    
    for ext in extensions:
        src = os.path.join(pretbc_dir, base_name + ext)
        dst = args.output + ext
        
        if os.path.exists(src):
            print(f'  {os.path.basename(src)} -> {os.path.basename(dst)}', file=sys.stderr)
            if not dry_run:
                shutil.copy2(src, dst)
        elif ext in ['.tbc', '.tbc.db']:
            # These are required files (note: .tbc.json is optional as .tbc.db is used)
            die(f'Required file {src} does not exist')

def run_ld_process_vbi(args):
    """Run ld-process-vbi."""

    clean(args, ['.tbc.json.bup', '.tbc.db.bup'])

    cmd = [resolve_tool(
        os.path.join(build_dir, 'bin', 'ld-process-vbi'),
        os.path.join(build_dir, 'src', 'ld-process-vbi', 'ld-process-vbi'),
    )]
    cmd += [args.output + '.tbc']
    run_command(cmd)

    if dry_run:
        return

    # Read the SQLite database output
    db_file = args.output + '.tbc.db'
    if not os.path.exists(db_file):
        die(db_file, 'does not exist')
    
    try:
        conn = sqlite3.connect(db_file)
        conn.row_factory = sqlite3.Row  # Enable column access by name
        cursor = conn.cursor()

        # Check black SNR
        if args.expect_bpsnr is not None:
            cursor.execute("SELECT b_psnr FROM vits_metrics")
            bpsnr_values = [row['b_psnr'] for row in cursor.fetchall() if row['b_psnr'] is not None]
            if bpsnr_values:
                bpsnr = numpy.median(bpsnr_values)
                if args.expect_bpsnr > bpsnr:
                    die(db_file, 'has median bPSNR', bpsnr, 'dB, expected',
                        args.expect_bpsnr, 'dB')

        # Print VBI data (useful for finding --expect_vbi values)
        if args.print_vbi:
            cursor.execute("SELECT vbi0, vbi1, vbi2 FROM vbi ORDER BY field_id")
            for row in cursor.fetchall():
                print("VBI data:", [row['vbi0'], row['vbi1'], row['vbi2']])

        # Check for a field with the expected VBI values
        if args.expect_vbi is not None:
            cursor.execute("SELECT vbi0, vbi1, vbi2 FROM vbi")  # Check all capture_ids
            found_vbi = False
            for row in cursor.fetchall():
                vbi_data = [row['vbi0'], row['vbi1'], row['vbi2']]
                if vbi_data == args.expect_vbi:
                    found_vbi = True
                    break
            
            if not found_vbi:
                die(db_file, 'did not contain a field with VBI values',
                    args.expect_vbi)

        # Print VITC data
        if args.print_vitc:
            cursor.execute("SELECT vitc0, vitc1, vitc2, vitc3, vitc4, vitc5, vitc6, vitc7 FROM vitc ORDER BY field_id")
            for row in cursor.fetchall():
                vitc_data = [row[f'vitc{i}'] for i in range(8)]
                print("VITC data:", vitc_data)

        # Check for a field with the expected VITC values
        if args.expect_vitc is not None:
            cursor.execute("SELECT vitc0, vitc1, vitc2, vitc3, vitc4, vitc5, vitc6, vitc7 FROM vitc")
            found_vitc = False
            for row in cursor.fetchall():
                vitc_data = [row[f'vitc{i}'] for i in range(8)]
                if vitc_data == args.expect_vitc:
                    found_vitc = True
                    break
            
            if not found_vitc:
                die(db_file, 'did not contain a field with VITC values',
                    args.expect_vitc)

    except sqlite3.Error as e:
        die(db_file, 'SQLite error:', str(e))
    finally:
        if conn:
            conn.close()

def run_ld_export_metadata(args):
    """Run ld-export-metadata."""

    clean(args, ['.vits.csv', '.vbi.csv', '.ffmetadata'])

    cmd = [resolve_tool(
        os.path.join(build_dir, 'bin', 'ld-export-metadata'),
        os.path.join(build_dir, 'src', 'ld-export-metadata', 'ld-export-metadata'),
    )]
    cmd += ['--vits-csv', args.output + '.vits.csv']
    cmd += ['--vbi-csv', args.output + '.vbi.csv']
    cmd += ['--ffmetadata', args.output + '.ffmetadata']
    cmd += [args.output + '.tbc.db']
    run_command(cmd)

def run_ld_export_decode_metadata(args):
    """Run ld-export-decode-metadata."""

    json_file = args.output + '.tbc.export.json'
    db_file = args.output + '.tbc.db'

    # Check if SQLite DB exists
    if not dry_run:
        if not os.path.exists(db_file):
            die(db_file, 'does not exist')

    clean(args, ['.tbc.export.json'])

    cmd = [resolve_tool(
        os.path.join(build_dir, 'bin', 'ld-export-decode-metadata'),
        os.path.join(build_dir, 'src', 'ld-export-decode-metadata', 'ld-export-decode-metadata'),
    )]
    cmd += ['--input-sqlite', db_file]
    run_command(cmd)

    if dry_run:
        return

    if not os.path.exists(json_file):
        die(json_file, 'does not exist')

    try:
        with open(json_file) as fjson:
            json_data = json.load(fjson)
    except UnicodeDecodeError as e:
        die(json_file, 'Failed to decode Unicode:', str(e))
    except json.JSONDecodeError as e:
        die(json_file, 'Failed to decode JSON:', str(e))
    except:
        die(json_file, 'reading failed for unknown reason')

    try:
        conn = sqlite3.connect(db_file)
        conn.row_factory = sqlite3.Row  # Enable column access by name
        cursor = conn.cursor()

        video_params = [ ('video_sample_rate','sampleRate'), ('active_video_start','activeVideoStart'),
            ('active_video_end','activeVideoEnd'), ('field_width','fieldWidth'), ('field_height','fieldHeight'),
            ('number_of_sequential_fields', 'numberOfSequentialFields'), ('colour_burst_start','colourBurstStart'),
            ('colour_burst_end','colourBurstEnd') ]

        # check video parameters
        cursor.execute("SELECT " + ", ".join([vp[0] for vp in video_params]) + " FROM capture WHERE capture_id == 1")
        rows = cursor.fetchall()
        if len(rows) != 1:
            die(db_file, 'could not find capture with id 1')
        try:
            for vp in video_params:
                # float types have rounding issues when comparing, therfore this code
                if abs(rows[0][vp[0]] - json_data['videoParameters'][vp[1]]) > 0.05:
                    die(json_file, vp[0], 'incorrect', rows[0][vp[0]], 'vs.', json_data['videoParameters'][vp[1]])
        except TypeError as e:
            die(json_file, 'could not find index:', str(e))

        field_params = [ ('disk_loc','diskLoc'), ('field_phase_id','fieldPhaseID'), ('file_loc','fileLoc'),
            ('is_first_field','isFirstField'), ('sync_conf','syncConf') ]

        # check field params
        cursor.execute("SELECT field_id, " + ", ".join([fp[0] for fp in field_params]) + " FROM field_record WHERE capture_id == 1 ORDER BY field_id")
        rows = cursor.fetchall()

        if len(rows) != len(json_data['fields']):
            die(json_file, 'incorrect number of fields:', len(rows), 'vs.', len(json_data['fields']))
        try:
            for row in rows:
                for fp in field_params:
                    if row[fp[0]] != json_data['fields'][row['field_id']][fp[1]]:
                        die(json_file, 'field', row['field_id'], ':', fp[0], 'incorrect', row[fp[0]], 'vs.', json_data['fields'][row['field_id']][fp[1]])
        except TypeError as e:
            die(json_file, 'could not find index:', str(e))

        # check dropouts
        cursor.execute("SELECT field_id, field_line, startx, endx FROM drop_outs WHERE capture_id == 1 ORDER BY field_id")

        try:
            for row in cursor.fetchall():
                jdo = json_data['fields'][row['field_id']]['dropOuts']
                line_idx = [i for i in range(len(jdo['fieldLine'])) if jdo['fieldLine'][i] == row['field_line']]
                if len(line_idx) == 0:
                    die('could not find dropout for field', row['field_id'], 'in line', row['field_line'])
                found_dout = False
                for i in line_idx:
                    if jdo['startx'][i] == row['startx'] and jdo['endx'][i] == row['endx']:
                        found_dout = True
                        break
                if not found_dout:
                    die('could not find dropout for field', row['field_id'], 'in line', row['field_line'], 'for range', row['startx'], 'to', row['endx'])
        except TypeError as e:
            die(json_file, 'could not find index:', str(e))

        vits_metrics_params = [ ('b_psnr','bPSNR'), ('w_snr','wSNR') ]

        # check vits metrics
        cursor.execute("SELECT field_id, " + ", ".join([vmp[0] for vmp in vits_metrics_params]) + " FROM vits_metrics WHERE capture_id == 1 ORDER BY field_id")

        try:
            for row in cursor.fetchall():
                for vmp in vits_metrics_params:
                    if row[vmp[0]] != 0.0 and abs(row[vmp[0]] - json_data['fields'][row['field_id']]['vitsMetrics'][vmp[1]]) > 0.05:
                        die(json_file, 'field', row['field_id'], ':', vmp[0], 'incorrect', row[vmp[0]], 'vs.', json_data['fields'][row['field_id']]['vitsMetrics'][vmp[1]])
        except TypeError as e:
            die(json_file, 'could not find index:', str(e))

    except sqlite3.Error as e:
        die(db_file, 'SQLite error:', str(e))
    finally:
        if conn:
            conn.close()

def run_efm_decoder(args):
    """Run efm-decoder-f2, efm-decoder-d24, and efm-decoder-audio."""

    if args.no_efm:
        return

    clean(args, ['.digital.wav', '.f2', '.d24'])
    efm_file = args.output + '.efm'
    f2_file = args.output + '.f2'
    d24_file = args.output + '.d24'
    wav_file = args.output + '.digital.wav'

    if not dry_run:
        # Check if the input file exists and is not empty
        if not os.path.exists(efm_file):
            die(efm_file, 'does not exist')
        if os.stat(efm_file).st_size == 0:
            die(efm_file, 'is empty')

    # Run efm-decoder-f2 (EFM T-values to F2 Section)
    cmd = [resolve_tool(
        os.path.join(build_dir, 'bin', 'efm-decoder-f2'),
        os.path.join(build_dir, 'src', 'efm-decoder', 'tools', 'efm-decoder-f2', 'efm-decoder-f2'),
    )]
    if args.no_efm_timecodes:
        cmd += ['--no-timecodes']
    cmd += [efm_file, f2_file]
    run_command(cmd)

    # Run efm-decoder-d24 (F2 Section to Data24 Section)
    cmd = [resolve_tool(
        os.path.join(build_dir, 'bin', 'efm-decoder-d24'),
        os.path.join(build_dir, 'src', 'efm-decoder', 'tools', 'efm-decoder-d24', 'efm-decoder-d24'),
    )]
    cmd += [f2_file, d24_file]
    run_command(cmd)

    # Run efm-decoder-audio (Data24 to Audio WAV)
    cmd = [resolve_tool(
        os.path.join(build_dir, 'bin', 'efm-decoder-audio'),
        os.path.join(build_dir, 'src', 'efm-decoder', 'tools', 'efm-decoder-audio', 'efm-decoder-audio'),
    )]
    cmd += [d24_file, wav_file]
    run_command(cmd)

    # Check there are enough output samples in the WAV file
    if (args.expect_efm_samples is not None) and (not dry_run):
        if not os.path.exists(wav_file):
            die(wav_file, 'does not exist')
        # WAV file has a 44-byte header, then 16-bit stereo samples (4 bytes per sample pair)
        wav_size = os.stat(wav_file).st_size
        if wav_size < 44:
            die(wav_file, 'is too small to be a valid WAV file')
        wav_samples = (wav_size - 44) // 4  # 2 bytes per channel * 2 channels
        if wav_samples < args.expect_efm_samples:
            die(wav_file, 'contains', wav_samples,
                'samples; expected at least', args.expect_efm_samples)

def run_ld_dropout_correct(args):
    """Run ld-dropout-correct."""

    clean(args, ['.doc.tbc', '.doc.tbc.db'])

    cmd = [resolve_tool(
        os.path.join(build_dir, 'bin', 'ld-dropout-correct'),
        os.path.join(build_dir, 'src', 'ld-dropout-correct', 'ld-dropout-correct'),
    )]
    cmd += ['--overcorrect', args.output + '.tbc', args.output + '.doc.tbc']
    run_command(cmd)

def run_ld_chroma_decoder(args, decoder):
    """Run ld-chroma-decoder with a given decoder."""

    clean(args, ['.rgb'])
    rgb_file = args.output + '.rgb'

    cmd = [resolve_tool(
        os.path.join(build_dir, 'bin', 'ld-chroma-decoder'),
        os.path.join(build_dir, 'src', 'ld-chroma-decoder', 'ld-chroma-decoder'),
    )]
    if decoder is not None:
        cmd += ['--decoder', decoder]
    cmd += [args.output + '.doc.tbc', rgb_file]
    run_command(cmd)

    # Check there are enough output frames
    if (args.expect_frames is not None) and (not dry_run):
        if not os.path.exists(rgb_file):
            die(rgb_file, 'does not exist')
        if args.pal:
            frame_w, frame_h = 928, 576
        else:
            frame_w, frame_h = 760, 488
        rgb_frames = os.stat(rgb_file).st_size // (2 * 3 * frame_w * frame_h)
        if rgb_frames < args.expect_frames:
            die(rgb_file, 'contains', rgb_frames,
                'frames; expected at least', args.expect_frames)

def parse_list_arg(s, num):
    """Parse an argument with a comma-separated list of values."""
    values = s.split(",")
    if len(values) != num:
        raise ValueError(f"Argument must have {num} values")
    return [int(value) for value in values]

def parse_vbi_arg(s):
    return parse_list_arg(s, 3)

def parse_vitc_arg(s):
    return parse_list_arg(s, 8)

def main():
    parser = argparse.ArgumentParser(description='Run the decoding toolchain against pre-generated TBC files')
    group = parser.add_argument_group("Decoding")
    group.add_argument('pretbc_input', metavar='pretbc-input',
                       help='Pre-generated TBC base name (e.g., test-data/ntsc/ve-snw-cut)')
    group.add_argument('output', metavar='output', nargs='?', default='testout/test',
                       help='base name for output files (default testout/test)')
    group.add_argument('-n', '--dry-run', action='store_true',
                       help='show commands, rather than running them')
    group.add_argument('--build', metavar='DIR',
                       help='build tree to test (default current directory)')
    group.add_argument('--pal', action='store_true',
                       help='source is PAL (default NTSC)')
    group.add_argument('--no-efm', action='store_true', dest='no_efm',
                       help='source has no EFM')
    group.add_argument('--no-efm-timecodes', action='store_true', dest='no_efm_timecodes',
                       help='source EFM has no timecodes')
    group.add_argument('--decoder', metavar='decoder', action='append',
                       dest='decoders', default=[],
                       help='use specific ld-chroma-decoder decoder '
                            '(use more than once to test multiple decoders)')
    group = parser.add_argument_group("Sanity checks")
    group.add_argument('--expect-frames', metavar='N', type=int,
                       help='expect at least N frames of video output')
    group.add_argument('--expect-bpsnr', metavar='DB', type=float,
                       help='expect median bPSNR of at least DB')
    group.add_argument('--expect-vbi', metavar='N,N,N', type=parse_vbi_arg,
                       help='expect at least one field with VBI values N,N,N')
    group.add_argument('--print-vbi', action='store_true',
                       help='Print VBI values (to find out what to pass into expect-vbi)')
    group.add_argument('--expect-vitc', metavar='N,N,N,N,N,N,N,N', type=parse_vitc_arg,
                       help='expect at least one field with VITC values N,N,N,N,N,N,N,N')
    group.add_argument('--print-vitc', action='store_true',
                       help='Print VITC values (to find out what to pass into expect-vitc)')
    group.add_argument('--expect-efm-samples', metavar='N', type=int,
                       help='expect at least N stereo pairs of samples in EFM output')
    args = parser.parse_args()

    global dry_run
    dry_run = args.dry_run
    if args.decoders == []:
        args.decoders = [None]

    if not os.path.exists(args.pretbc_input + '.tbc'):
        die('Pre-generated TBC file', args.pretbc_input + '.tbc', 'does not exist')

    # Find the build directory
    global source_dir, build_dir
    build_dir = os.getcwd()
    if args.build:
        build_dir = args.build
    source_dir = build_dir

    required_tools = [
        os.path.join(build_dir, 'bin', 'ld-process-vbi'),
        os.path.join(build_dir, 'bin', 'ld-export-metadata'),
        os.path.join(build_dir, 'bin', 'ld-export-decode-metadata'),
        os.path.join(build_dir, 'bin', 'ld-dropout-correct'),
        os.path.join(build_dir, 'bin', 'ld-chroma-decoder'),
    ]
    if not args.no_efm:
        required_tools += [
            os.path.join(build_dir, 'bin', 'efm-decoder-f2'),
            os.path.join(build_dir, 'bin', 'efm-decoder-d24'),
            os.path.join(build_dir, 'bin', 'efm-decoder-audio'),
        ]
    missing = [path for path in required_tools if not os.path.exists(path)]
    if missing:
        print('SKIP: required tool(s) not built:', ', '.join(missing), file=sys.stderr)
        sys.exit(77)

    print('Processing', args.pretbc_input, 'using tools from', build_dir, file=sys.stderr)

    # Remove display environment variables, as the decoding tools shouldn't
    # depend on having a display
    for var in ('DISPLAY', 'WAYLAND_DISPLAY'):
        if var in os.environ:
            del os.environ[var]

    # Ensure the directory containing output files exists
    output_dir = os.path.dirname(args.output)
    if output_dir != '':
        os.makedirs(output_dir, exist_ok=True)

    # Copy pre-generated TBC files to working directory
    copy_pretbc_files(args)

    # Run the stages of the decoding toolchain (skipping ld-decode/ld-cut)
    run_ld_process_vbi(args)
    run_ld_export_metadata(args)
    run_ld_export_decode_metadata(args)
    run_efm_decoder(args)
    run_ld_dropout_correct(args)
    for decoder in args.decoders:
        run_ld_chroma_decoder(args, decoder)

    print('\nProcessing', args.pretbc_input, 'completed successfully')

if __name__ == '__main__':
    main()
