###############################################################################
#
# Build VDS
#
# Author: Aiden Huffman
# Date: March 1st, 2023
# 
# Description: Constructs a virtual data set from the results of dedalus_ex.py
#
###############################################################################

import h5py
import argparse
import logging, os
from pathlib import Path
from dedalus.tools import post

# Setup logger
logger = logging.Logger(__name__)
formatter = logging.Formatter('%(levelname)s:%(module)s.%(funcName)s:%(asctime)s:%(message)s')
fileHandler = logging.FileHandler("vds_construction.log", mode='w')
fileHandler.setFormatter(formatter)
fileHandler.setLevel(logging.INFO)
logger.addHandler(fileHandler)

parser = argparse.ArgumentParser(description="Animate from Dedalus output")
parser.add_argument('-p', type=str, default = 'simulation', help="Path to folder generated by dedalus")

args = parser.parse_args()

filepath = Path(args.p)

# Check that folder exists and remove any virtual dataset already constructed
if not(filepath.exists()):
    raise FileNotFoundError(f"File {filepath} not found")

# Delete VDS if there's one already there
if (filepath / "VDS.h5").exists():
    os.remove(filepath / "VDS.h5")

# Merge output from processes simulation_s1_p<n> into single file simulation_s1
# This can generate large files and take a long time, be cognizant of what
# you are merging
post.merge_process_files(filepath, cleanup=True)

# Collect and sort h5 files
h5_files = sorted(filepath.glob('*.h5'), key = lambda x: int(x.stem.split('_')[-1][1:]))

logger.info(f"h5 files merged, total list: {h5_files}")

# Get shapes
with h5py.File(sorted(h5_files)[0]) as f:
    gen_sh = f['tasks']['T'].shape
with h5py.File(sorted(h5_files)[-1]) as f:
    f_timestamps = len(f['tasks']['T'].dims[0]['sim_time'])
    
assert len(gen_sh) == 3, "This only handles 2D data"

# Reaction layouts
layout_T = h5py.VirtualLayout(shape=(gen_sh[0]*(len(h5_files)-1)+f_timestamps,
                                      gen_sh[1], gen_sh[2]), dtype='f8')
layout_ux = h5py.VirtualLayout(shape=(gen_sh[0]*(len(h5_files)-1)+f_timestamps,
                                      gen_sh[1], gen_sh[2]), dtype='f8')
layout_uz = h5py.VirtualLayout(shape=(gen_sh[0]*(len(h5_files)-1)+f_timestamps,
                                      gen_sh[1], gen_sh[2]), dtype='f8')

# Scales layouts
layout_t = h5py.VirtualLayout(shape=(gen_sh[0]*(len(h5_files)-1)+f_timestamps), dtype='f8')
layout_x = h5py.VirtualLayout(shape=(gen_sh[1]), dtype='f8')
layout_z = h5py.VirtualLayout(shape=(gen_sh[2]), dtype='f8')

# Build virtual dataset
with h5py.File(filepath / "VDS.h5", 'w', libver='latest') as f:
    
    logger.info(f"Current file: {h5_files[0]}")
    
    # Set virtual sources to first file
    vsource_T = h5py.VirtualSource(h5_files[0], 'tasks/T', shape=gen_sh)
    vsource_ux = h5py.VirtualSource(h5_files[0], 'tasks/u_x', shape=gen_sh)
    vsource_uz = h5py.VirtualSource(h5_files[0], 'tasks/u_z', shape=gen_sh)
            
    vsource_t = h5py.VirtualSource(h5_files[0], 'scales/sim_time', shape=gen_sh[0])
    vsource_x = h5py.VirtualSource(h5_files[0], 'scales/x/1.0', shape=gen_sh[1])
    vsource_z = h5py.VirtualSource(h5_files[0], 'scales/z/1.0', shape=gen_sh[2])
    
    # Store data into virtual data source
    layout_T[0:gen_sh[0],:,:] = vsource_T
    layout_ux[0:gen_sh[0],:,:] = vsource_ux
    layout_uz[0:gen_sh[0],:,:] = vsource_uz

    layout_t[0:gen_sh[0]] = vsource_t
    layout_x[0:gen_sh[1]] = vsource_x
    layout_z[0:gen_sh[2]] = vsource_z
    
    for idx in range(1,len(h5_files)):
        
        # For the rest of the files
        logger.info(f"Current file: {h5_files[idx]}".format())
        
        # If we're not at the last one
        if idx != len(h5_files)-1:
            
            # Setup virtual sources for each task we recorded
            vsource_T = h5py.VirtualSource(h5_files[idx], 'tasks/T', shape=gen_sh)
            vsource_ux = h5py.VirtualSource(h5_files[idx], 'tasks/u_x', shape=gen_sh)
            vsource_uz = h5py.VirtualSource(h5_files[idx], 'tasks/u_z', shape=gen_sh)
            vsource_t = h5py.VirtualSource(h5_files[idx], 'scales/sim_time', shape=gen_sh[0])

            # Store sources
            layout_T[(idx*gen_sh[0]):((idx+1)*gen_sh[0]), :, :] = vsource_T
            layout_ux[(idx*gen_sh[0]):((idx+1)*gen_sh[0]), :, :] = vsource_ux
            layout_uz[(idx*gen_sh[0]):((idx+1)*gen_sh[0]), :, :] = vsource_uz
            layout_t[(idx*gen_sh[0]):((idx+1)*gen_sh[0])] = vsource_t
        
        # Adjust indexing for last file
        if idx == len(h5_files)-1:
            
            vsource_T = h5py.VirtualSource(h5_files[idx], 'tasks/c1',
                                            shape=(f_timestamps,
                                                    gen_sh[1],
                                                    gen_sh[2]))
            vsource_ux = h5py.VirtualSource(h5_files[idx], 'tasks/u_x',
                                            shape=(f_timestamps,
                                                   gen_sh[1],
                                                   gen_sh[2]))
            vsource_uz = h5py.VirtualSource(h5_files[idx], 'tasks/u_z',
                                            shape=(f_timestamps,
                                                   gen_sh[1],
                                                   gen_sh[2]))
            vsource_t = h5py.VirtualSource(h5_files[idx], 'scales/sim_time', shape=f_timestamps)

            layout_T[(idx*gen_sh[0]):(idx*gen_sh[0]+f_timestamps),:,:] = vsource_T
            layout_ux[(idx*gen_sh[0]):(idx*gen_sh[0]+f_timestamps),:,:] = vsource_ux
            layout_uz[(idx*gen_sh[0]):(idx*gen_sh[0]+f_timestamps),:,:] = vsource_uz
            layout_t[(idx*gen_sh[0]):(idx*gen_sh[0]+f_timestamps)] = vsource_t
    
    # Generate virtual data source
    logger.info("Creatng virtual dataset...")   
    f.create_virtual_dataset('T', layout_T, fillvalue=0)
    f.create_virtual_dataset('u_x', layout_ux, fillvalue=0)
    f.create_virtual_dataset('u_z', layout_uz, fillvalue=0)
    f.create_virtual_dataset('x', layout_x)
    f.create_virtual_dataset('z', layout_z)
    f.create_virtual_dataset('t', layout_t, fillvalue=0)
    
    logger.info("Completed VDS construction")