""" betavioplot.py

Alex Godfrey 
"""
import argparse
import os
import sys
import string

# utility functions
def _error(msg):
    sys.exit("[Error] %s" % msg)
def _error_param(param, param_val):
    _error("Invalid input for param {:}: {:}".format(param, param_val))
def _warn(msg):
    sys.stderr.write("[Warning] %s\n" % msg)
def _get_bool(b):
    """ Parse a bool represented as a string and return True/False.
    
    Interpreted as True:
    True, true, T, t, TRUE
    
    Interpreted as False:
    False, false, F, f, FALSE
    
    Raises ValueError if match not found.
    """
    b0 = b
    b = string.lower(b)
    if b == 'true' or b == 't':
        return True
    if b == 'false' or b == 'f':
        return False
    raise ValueError("could not conver string to bool: %s" % b0)

# check for and import numpy and matplotlib
libsnotfound = []
try:
    import numpy as np
except ImportError:
    libsnotfound.append('numpy')
try:
    import matplotlib.pyplot as plt
    import matplotlib.markers as mm
    import matplotlib.colors as mc
    import matplotlib.font_manager as fm
    from matplotlib import rcParams
except ImportError:
    libsnotfound.append('matplotlib')
if len(libsnotfound) > 0:
    msg = "Please download the following packages: "
    _error(msg + string.join(libsnotfound, ", "))


# Set default font
HOMEDIR = os.path.expanduser("~")
if "/Users/a" in HOMEDIR:
    TTFDIR = "{}/fonts".format(HOMEDIR)
else:
    TTFDIR = "/lab/page/alex/fonts"
HN_FILE = "{}/HelveticaNeue.ttf".format(TTFDIR)
if os.path.exists(HN_FILE):
    FONT_LG = fm.FontProperties(fname=HN_FILE, size='large')
    FONT_MD = fm.FontProperties(fname=HN_FILE, size='medium')
else:
    FONT_LG = fm.FontProperties(size='large')
    FONT_MD = fm.FontProperties(size='medium')

DEFAULTS = {
    'x_label' : '',
    'y_label' : '',
    'title' : '',
    'fig_height' : 5,
    'fig_width' : 7,
    'y_min' : None,
    'y_max' : None,
    'alpha' : 0.5,
    'marker' : 'o',
    'col_default' : 'k',
    'filled' : False,
    'vert_labs' : False,
    'show_median' : True,
    'med_color' : 'k',
    'med_stroke' : 0.75,
    'taper' : 14
}

PARAM_TYPES = {
    'datafile' : str,
    'outfile' : str,
    'x_label' : str,
    'y_label' : str,
    'title' : str,
    'fig_height' : float,
    'fig_width' : float,
    'y_min' : float,
    'y_max' : float,
    'alpha' : float,
    'marker' : str,
    'col_default' : str,
    'filled' : _get_bool,
    'vert_labs' : _get_bool,
    'show_median' : _get_bool,
    'med_color' : str,
    'med_stroke' : float,
    'taper' : float
}

# Color Palettes
DARK9 = ['#9E9E2C', '#C43434', '#1DB5A8', '#813B87', '#D10575', '#DD9D27', 
         '#8A5034', '#2871C2', '#666666']

# Color Brewer Dark2 Qualitative Palette (www.ColorBrewer.org)
CBDARK2_DICT = {
    'dark2_1' : '#1b9e77',
    'dark2_2' : '#d95f02',
    'dark2_3' : '#7570b3',
    'dark2_4' : '#e7298a',
    'dark2_5' : '#66a61e',
    'dark2_6' : '#e6ab02',
    'dark2_7' : '#a6761d',
    'dark2_8' : '#666666'
}
CBDARK2_LIST = ['#1b9e77', '#d95f02', '#7570b3', '#e7298a', 
                '#66a61e', '#e6ab02', '#a6761d', '#666666']

def _choose_ab(n, max_ab=14):
    """ Return the parameter of a beta dist given n points in a class.
    """
    ab = max_ab - int(n) / 10
    if ab < 1:
        return 1
    return ab

def betavioplot(x, ybins=80, alpha=0.5, marker='o', colors='k', pad=0.2,
                ax=None, labels=None, filled=False, vert_labels=True,
                show_median=True, taper=1., med_color='k', fig_height=5,
                fig_width=7, med_stroke=0.75, ylim=None):
    """ Make a beta-violin plot.
    
    Returns
    -------
    fig, ax
    """
    if labels is not None and len(labels) != len(x):
        raise ValueError("Length of labels must match number of data classes")
    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(111)
    else:
        fig = ax.get_figure()
    
    fig.set_figheight(fig_height)
    fig.set_figwidth(fig_width)
    
    xbin_w = 1.
    center = 0
    centers = []
    
    # Get min and max values over all data classes
    ymin = min(x[0])
    ymax = max(x[0])
    for i in xrange(1, len(x)):
        imin = min(x[i])
        imax = max(x[i])
        if imin < ymin:
            ymin = imin
        if imax > ymax:
            ymax = imax
    
    # set ybin height
    ybin_h = (ymax - ymin) / float(ybins)
    
    if type(colors) == str:
        colors = [colors]
    nc = len(colors)
    line_width = xbin_w / 4.
    
    for i, x_i in enumerate(x):
        x_i = np.array(x_i)
        
        st = ymin
        en = ymin + ybin_h
        xs = []  # will hold the xcoords of points in this class
        
        for j in xrange(ybins):
            if j == ybins - 1:
                dens = len(x_i) - len(xs)
            else:
                dens = np.sum((x_i >= st) & (x_i < en))
            ab = _choose_ab(dens, max_ab=taper)
            xs.extend(np.random.beta(a=ab, b=ab, size=dens) + center-(xbin_w / 2))
            if show_median:
                cmin = center-(line_width/2)
                cmax = center+(line_width/2)
                ax.plot([cmin, cmax], [np.median(x_i), np.median(x_i)], 
                        color=med_color, lw=med_stroke)
            
            st = en
            en += ybin_h
        
        # set marker colors
        if filled:
            facecolors=colors[i % nc]
            edgecolors='none'
        else:
            facecolors='none'
            edgecolors=colors[i % nc]
        
        # plot points
        ax.scatter(xs, x_i, marker=marker, edgecolors=edgecolors,
                   facecolors=facecolors, alpha=alpha)
        
        centers.append(center)
        center += xbin_w + pad
    
    lab_rot = 'vertical' if vert_labels else 'horizontal'
    ax.set_xticks(centers)
    if labels is not None:
        ax.set_xticklabels(labels, rotation=lab_rot)
    else:
        ax.set_xticklabels(xrange(1, len(x)+1))
    xlim = (centers[0] - xbin_w, centers[-1] + xbin_w)
    ax.set_xlim(xlim)
    
    if ylim is None:
        yrange = ymax - ymin
        ypad = 0.1
        ylim = (ymin - yrange*ypad, ymax + yrange*ypad)
    ax.set_ylim(ylim)
    
    ax = _set_axes_fonts(ax)
    
    return fig, ax


def _set_axes_fonts(ax):
    """ Set the fonts for all text labels on an Axes instance. 
    """
    ax.set_title(ax.get_title(), fontproperties=FONT_LG)
    ax.set_xlabel(ax.get_xlabel(), fontproperties=FONT_MD)
    ax.set_ylabel(ax.get_ylabel(), fontproperties=FONT_MD)
    for lab in ax.get_xticklabels():
        lab.set_fontproperties(FONT_MD)
    for lab in ax.get_yticklabels():
        lab.set_fontproperties(FONT_MD)
    return ax


def parse_ctl_file(ctlfile):
    """ Return a dict of parameters parsed from control file.
    
    Control file format (one parameter per line)
    <param_name> : <value>\n
    All whitespace will be removed.
    Any characters on a line after '#' will be considered a comment.
    
    Parameters
    ----------
    ctlfile (str) : path to control file
    
    Returns
    -------
    params (dict) : a dict mapping parameter names to values
    """
    params = {}
    with open(ctlfile) as f:
        for i, line in enumerate(f):
            # check for comments, emptiness, and unexpected formatting
            line = line.rstrip()
            if '#' in line:
                line = line[:line.index('#')]
            if len(line) == 0:
                continue
            if "'" in line:
                i1 = line.index("'")
                i2 = line.rindex("'")
                l0 = line[:i1].replace(' ', '')
                l1 = line[i1+1 : i2]
                l2 = line[i2+1:].replace(' ', '')
                line = l0 + l1 + l2
            elif '"' in line:
                i1 = line.index('"')
                i2 = line.rindex('"')
                l0 = line[:i1].replace(' ', '')
                l1 = line[i1+1 : i2]
                l2 = line[i2+1:].replace(' ', '')
                line = l0 + l1 + l2
            else:
                line = line.replace(' ', '')
            if line == '':
                continue
            cs = line.count(':')
            if cs == 0:
                _warn("Uncommented line {} of ctl file has no ".format(i+1) \
                      + "':' and will be skipped." )
                continue
            if cs > 1:
                _error("Line {} of ctl file has multiple ':' ".format(i+1) \
                       + "and cannot be parsed.")
            
            # parse
            par, val = line.split(':')
            if len(par) == 0:
                _error("Line {} of ctl file has no param name".format(i+1))
            if len(val) == 0:
                _error("Line {} of ctl file has no param value".format(i+1))
            
            if par in params:
                _error("Parameter '{}' was entered twice".format(par))
            params[par] = val
    
    # check validity and convert param types
    params = _check_convert_params(params)
    
    # add defaults for unset params
    for p in DEFAULTS:
        if not p in params:
            params[p] = DEFAULTS[p]
    
    return params
    
def _check_convert_params(params):
    """ Check user-input parameter values for validity and convert type.
    
    Return params with param values cast to usable type.  
    """
    # input/output files were given
    if not 'datafile' in params:
        _error("Please specify a file containing data")
    if not 'outfile' in params:
        _error("Please specify a file to write the plot to")
    
    # input file and output directory exist and are readable/writeable
    if not os.path.exists(params['datafile']):
        _error("Cannot find datafile %s" % params['datafile'])
    if not os.access(params['datafile'], os.R_OK):
        _error("Cannot read datafile %s" % params['datafile'])
    outdir = os.path.dirname(params['outfile'])
    if outdir == '':
        outdir = '.'
    if not os.path.exists(outdir):
        _error("Cannot find directory '%s' for writing" % outdir)
    if not os.access(outdir, os.W_OK):
        _error("Cannot write to directory '%s'" % outdir)
    
    # type and value validity of other parameters
    for p in params:
        # type conversion
        try:
            params[p] = PARAM_TYPES[p](params[p])
        except ValueError:
            _error("Invalid input for param {:}: {:}".format(p, params[p]))
        
        # check value of params
        if p == 'alpha':
            if not (0 < params[p] <= 1):
                _error_param(p, params[p])
        elif p == 'marker':
            if not params[p] in mm.MarkerStyle().markers:
                _error_param(p, params[p])
        elif p == 'col_default' or p == 'med_color':
            if not params[p] in ('dark2', 'dark9'):
                if params[p].startswith('hex_'):
                    params[p] = '#' + params[p][4:]
                try:
                    col = mc.colorConverter.to_rgb(params[p])
                except ValueError:
                    _error_param(p, params[p])
        elif p == 'taper':
            if params[p] < 1:
                _error_param(p, params[p])
        elif p in ('fig_height', 'fig_width'):
            if params[p] < 0:
                _error_param(p, params[p])
        elif p == 'med_stroke':
            if params[p] < 0:
                _error_param(p, params[p])
    
    return params 
        

def parse_datafile(datafile):
    """ Return a lists of data, labels, and colors parsed from file.
    
    Datafile format:
    (one line per data class)
    <label[,color]>\tx1,x2,x3,x4,...,xn\n
    
      Example:
      gene1,r   1.1,2,3.4,6
      gene2,b   1,1,2,5,6,10,-1
     
    Parameters
    ----------
    datafile (str) : path to file containing data
    
    Returns
    -------
    data (list) : a list of lists; one internal list for each class
    labels (list) : the labels for each class; len(labels)==len(data)
    colors (list) : a list of colors for each class; a list of None 
        values of no colors are given
    """
    data = []
    labels = []
    colors = []
    with open(datafile) as f:
        for i, line in enumerate(f):
            line = line.rstrip()
            if line == '':
                continue
            try:
                lab_col, datalist = line.split('\t')
            except ValueError:
                _error("Line {} of datafile has unexpected format".format(i+1))
            if ',' in lab_col:
                try:
                    lab, col = lab_col.split(',')
                    if col == '':
                        col = None
                    else:
                        if col.startswith('hex_'):
                            col = '#' + col[4:]
                        try:
                            testcol = mc.colorConverter.to_rgb(col)
                        except ValueError:
                            _error("Unrecognized color in datafile: " \
                                   + "'{}'".format(col))
                except ValueError:
                    _error("Label in line {} of datafile".format(i+1) \
                           + " has unexpected format")
            else:
                lab = lab_col
                col = None
            datalist = datalist.strip(',')
            try:
                d = map(float, datalist.split(','))
            except ValueError:
                _error("Data in line {} of datafile".format(i+1) \
                       + " has unexpected format")
            data.append(d)
            labels.append(lab)
            colors.append(col)
    return data, labels, colors
                
        
def parse_cmd_line():
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--ctlfile', help="path to control file")
    parser.add_argument('-d', '--datafile', help="path to file containing "\
                        + "data")
    parser.add_argument('-o', '--outfile', help="name of output plot file, "\
                        + "should end with '.pdf' or '.png'; if none given "\
                        + "and a data file is passed from the command line "\
                        + "with '-d', the plot will be saved in "\
                        + "betavioplot.pdf in the working directory.")
    
    args = parser.parse_args()
    
    if args.ctlfile is None and args.datafile is None:
        _error("Must provide either control file with '-c' or datafile " \
               + "with '-d'.")
    
    if args.ctlfile is not None and args.datafile is not None:
        _error("If using a control file, specify the data file location "\
               + "in the control file.")
    
    if args.ctlfile is not None:
        cfile = args.ctlfile
    else:
        cfile = args.datafile
    if not os.path.exists(cfile):
        _error("Cannot find file: '{}'".format(cfile))
    if not os.access(cfile, os.R_OK):
        _error("Cannot read file: '{}'".format(cfile))
    
    return args
    
def main():
    args = parse_cmd_line()
    
    if args.ctlfile is None:
        params = DEFAULTS
        params['datafile'] = args.datafile
        if args.outfile is None:
            params['outfile'] = 'betavioplot.pdf'
        else:
            params['outfile'] = args.outfile
    else:
        params = parse_ctl_file(args.ctlfile)
    
    data, labels, colors = parse_datafile(params['datafile'])
    if params['col_default'] == 'dark9':
        colors = DARK9
    elif params['col_default'] == 'dark2':
        colors = CBDARK2_LIST
    else:
        for i in xrange(len(colors)):
            if colors[i] is None:
                colors[i] = params['col_default']
    
    fig = plt.figure()
    ax = fig.add_subplot(111)
    
    fig, ax = betavioplot(data, alpha=params['alpha'], marker=params['marker'],
                          colors=colors, ax=ax, labels=labels,
                          filled=params['filled'],
                          vert_labels=params['vert_labs'], 
                          show_median=params['show_median'],
                          taper=params['taper'], med_color=params['med_color'],
                          fig_height=params['fig_height'],
                          fig_width=params['fig_width'],
                          med_stroke=params['med_stroke'],
                          ylim=(params['y_min'], params['y_max']))
    ax.set_xlabel(params['x_label'])
    ax.set_ylabel(params['y_label'])
    ax.set_title(params['title'])
    
    fig.savefig(params['outfile'])
    
    plt.close(fig)
    
if __name__ == '__main__':
    main()