###############################################################################
#
# Rayleigh-Benard Convection via Dedalus
#
# Author: Aiden Huffman
# Date: March 1st, 2023
# 
# Description: A simple example of running a simulation in Dedalus
# with the Rayleigh-Benard instability as the example. This
# demonstration is based off of the tutorials that already exist
# for Dedalus at:
# 
#   https://dedalus-project.readthedocs.io/en/latest/pages/tutorials.html
#
###############################################################################

import numpy as np
import dedalus.public as de
from dedalus.extras import flow_tools
import time

# The easiest way to run simulations in parallel is to use mpiexec
# mpi4py.MPI will allow us to get information about the various
# processes.
from mpi4py import MPI

# argparse will allow us to pass arguments efficiently
import argparse

# Logging is a crucial component of writing verifiable code, in python
# this is best done using the logging module
import logging

# Create a logger which will log things
logger = logging.getLogger(__name__)

parser = argparse.ArgumentParser(description="Build a spectral method and simulate the formation of convective cells")

def fetch_args(parser):
    
    # Critical Rayleigh number for classical temperature driven instability is around 1708
    # with a wavenumber of around 3.12
    parser.add_argument('-Lx', '--Lx', type=np.float64, default=4.68, help="Length of x-axis")
    parser.add_argument('-Lz', '--Lz', type=np.float64, default=1, help="Length of z-axis")
    parser.add_argument('--Prandtl', '-P', type=np.float64, default=1, help='Prandtl number')
    parser.add_argument('--Rayleigh', '-Ra', type=np.float64, default=2500, help='Rayleigh number')
    parser.add_argument('-nx', '--nx', type=int, default=64, help='Number of spectral coefficients or grid points in x')
    parser.add_argument('-nz', '--nz', type=int, default=64, help='Number of spectral coefficients or grid points in z')
    parser.add_argument('--time', '-T', type=np.float64, default=100, help='Simulation stopping time')
    parser.add_argument('--filename', '-f' , type=str, default='simulation', help='Name of folder for output')
    parser.add_argument('-ll', '--LOG_LEVEL', type=str, default='INFO')
    
    args = parser.parse_args()
    
    print(vars(args))
    
    # Throw variables into dictionary for easy access
    par = {}
    par['Lx'] = args.Lx
    par['Lz'] = args.Lz
    par['P'] = args.Prandtl
    par['Ra'] = args.Rayleigh
    par['nx'] = args.nx
    par['nz'] = args.nz
    par['T'] = args.time
    par['f'] = args.filename
    par['LOG_LEVEL'] = args.LOG_LEVEL
    
    return par

# Fetch command line arguments
par = fetch_args(parser)

# To run the program with your own arguments, or multiple threads try:
#
# python dedalus_ex.py -Ra 1800 -Lx 9 -nx 128
# mpiexec -n 2 dedalus_ex.py
#
# The first runs it with a Rayleigh number of 1800, horizontal length of 9
# and 128 collocation points. The second runs the default configuration with
# 2 threads. You should always try to run simulations where nx and nz are a
# power of 2 to improve computational efficiency. This is because spectral
# methods make use of the fast Fourier transform which is fastest when the
# number of points is a power of 2.

# Set logging level
logger.setLevel(par['LOG_LEVEL'])

####################################################################
#
# Domain
#
####################################################################

domain_lengths = [par['Lx'], par['Lz']]
dim_res = [par['nx'], par['nz']]

# Note dealias does not have to be 1. This introduces additional
# spectral coefficients to help reduce aliasing, essentially
# increasing the number of "grid points"
dims = [de.Fourier('x', int(dim_res[0]),
        interval=(0,domain_lengths[0]), dealias=1),
        de.Chebyshev('z', int(dim_res[1]),
        interval=(0,domain_lengths[1]), dealias=1)]

domain = de.Domain(dims, grid_dtype = np.double)

problem = de.IVP(domain, variables=['T', 'Tz',
                                    'u_x', 'uz_x',
                                    'u_z', 'uz_z',
                                    'p'])

####################################################################
#
# Parameters
#
####################################################################

# Add parameters to problem
problem.parameters['Ra'] = par['Ra']
problem.parameters['Pr'] = par['P']

####################################################################
#
# Substitutions
#
####################################################################

# Substitutions for Diffusive Terms
problem.substitutions['LT'] = "dx(dx(T))+dz(Tz)"
problem.substitutions['Lu_x'] = "dx(dx(u_x))+dz(uz_x)"
problem.substitutions['Lu_z'] = "dx(dx(u_z))+dz(uz_z)"

# Substitutions for Advection Terms
problem.substitutions['AT'] = "-u_x * dx(T) - u_z * Tz"
problem.substitutions['Au_x'] = "-u_x * dx(u_x) - u_z * uz_x"
problem.substitutions['Au_z'] = "-u_x * dx(u_z) - u_z * uz_z"

####################################################################
#
# Equations
#
####################################################################

# Derivatives
problem.add_equation("Tz - dz(T) = 0")
problem.add_equation("uz_x - dz(u_x) = 0")
problem.add_equation("uz_z - dz(u_z) = 0")

# Temperature
problem.add_equation("dt(T) - LT = AT") # Temperature

# Flow
problem.add_equation("dx(u_x) + uz_z = 0") # Incompressibility
problem.add_equation("dt(u_x) + dx(p) - Lu_x = Au_x") # Velocity field in x
problem.add_equation("dt(u_z) + dz(p) - Lu_z = Au_z + Ra*(T-1)") # Velocity field in z

####################################################################
#
# Boundary Conditions
#
####################################################################

# Full Schnakenberg reaction
# at bottom boundary, no flux at top
problem.add_bc("left(T) = 1")
problem.add_bc("right(T) = 0")

# No slip with "gauge" pressure
problem.add_bc("left(u_x) = 0")
problem.add_bc("right(u_x) = 0")

# Note that the condition left(p) at nx == 0 is neceeaey to make 
# the solution for the pressure field unique by setting a
# "gauge pressure"
problem.add_bc("left(u_z) = 0", condition=('nx != 0'))
problem.add_bc("left(p) = 0", condition=('nx == 0'))
problem.add_bc("right(u_z) = 0")

# Construct solver:
solver = problem.build_solver(de.timesteppers.SBDF3)

# Introduce stopping conditions
solver.stop_sim_time = par['T']
solver.stop_wall_time = np.inf
solver.stop_iteration = np.inf

# Get information from dedalus and mpi4py so we know which 
# grid points or spectral coefficients ended up with each thread
g_shape = domain.dist.grid_layout.global_shape(scales=1)
slices = domain.dist.grid_layout.slices(scales=1)
z = domain.grid(1)

comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()

temp = solver.state['T']

# Fetch some "random" noise
rand = np.random.RandomState(seed=42)
noise = rand.standard_normal(g_shape)[slices]

# Perturb the temperature profile
temp['g'] = temp['g'] + 0.01*noise
temp['g'] += par['Lz'] - z
ux = solver.state['u_x']

# Build a backup which saves every %10

backup = solver.evaluator.add_file_handler("{}_backup".format(par['f']),
                                           sim_dt=par['T']/10, max_writes=100)
backup.add_task(solver.state['T'], layout='g')
backup.add_task(solver.state['u_x'], layout='g')
backup.add_task(solver.state['u_z'], layout='g')
backup.add_task(solver.state['p'], layout='g')

# Build an output file, this saves every tenth of a second
backup = solver.evaluator.add_file_handler("{}".format(par['f']),
                                           sim_dt=0.1, max_writes=1000)
backup.add_task(solver.state['T'], layout='g')
backup.add_task(solver.state['u_x'], layout='g')
backup.add_task(solver.state['u_z'], layout='g')

# CFL condition for adaptive timestepping
dt = 1e-3
CFL = flow_tools.CFL(solver, initial_dt=dt, cadence=10, safety=0.5,
                     max_change=1.5, min_change=0.5, max_dt=0.125, threshold=0.05)
CFL.add_velocities(('u_x', 'u_z'))

while solver.proceed:
    start_time = time.time()
    dt = CFL.compute_dt()
    solver.step(dt)
    end_time = time.time()
    
    step_time = end_time - start_time
    steps_remaining = (par['T']-solver.sim_time)/dt
    
    int_z = de.operators.integrate(solver.state['T'], 'z').evaluate()['g']
    amp = (np.amax(int_z)-np.amin(int_z))/2
    logger.info(f"Percent Complete: {solver.sim_time/par['T']*100:.0f}{4*' '}Sim Time: {solver.sim_time:.4}{4*' '}dt: {dt:.4}{4*' '}Amp: {amp:.4}")