###############################################################################
# Extract relevant fields from the fits files provided by Chang et al. 2015   #
# Merge everything into a single catalog and write to disk as csv file        #
# For input files see http://irfu.cea.fr/Pisp/yu-yen.chang/sw.html            #
#                                                                             #
# (c) R. Feldmann, 2017                                                       #
# no warranty, use at your own risk                                           #
# please cite: Feldmann et al. 2017, MNRAS Letters 470, L59, arXiv:1705.03014 #
###############################################################################
import pandas as pd
from astropy.io import fits

basepath = ('.')
cat1_filename = 'sw_input.fits'
cat2_filename = 'sw_output.fits'

dat1 = {}
aux = fits.getdata(basepath + '/' + cat1_filename, 1)
dat1['id'] = aux['id'].byteswap().newbyteorder()
dat1['z'] = aux['redshift'].byteswap().newbyteorder()
dat1_df = pd.DataFrame(data=dat1)
dat1 = dat1_df

dat2 = {}
aux = fits.getdata(basepath + '/' + cat2_filename, 1)
dat2['id'] = aux['id'].byteswap().newbyteorder()
dat2['lmass'] = aux['lmass50_all'].byteswap().newbyteorder()
dat2['vmax'] = aux['vmax'].byteswap().newbyteorder()
dat2['lsfr'] = aux['lsfr50_all'].byteswap().newbyteorder()
dat2['lsfr16'] = aux['lsfr16_all'].byteswap().newbyteorder()
dat2['lsfr84'] = aux['lsfr84_all'].byteswap().newbyteorder()
dat2['lssfr'] = aux['lssfr50_all'].byteswap().newbyteorder()
dat2['lssfr16'] = aux['lssfr16_all'].byteswap().newbyteorder()
dat2['lssfr84'] = aux['lssfr84_all'].byteswap().newbyteorder()

dat2['flag'] = aux['flag'].byteswap().newbyteorder()
dat2_df = pd.DataFrame(data=dat2)
dat2 = dat2_df

# combine catalogs
data = pd.merge(dat1, dat2, on=['id'])

# -- selection --
# only those with flag == 1, see Chang et al. 2015
sel = data.query('(flag == 1)')

# -- output as csv, put columns in sensible order --
sel.to_csv(path_or_buf='Chang2015.csv', na_rep='.', index=False,
           columns=['id', 'z', 'vmax', 'lmass', 'lsfr', 'lsfr16', 'lsfr84',
                    'lssfr', 'lssfr16', 'lssfr84'])
