#support for regression of a single gcm, with optional regridding 
'''
This code adapted for use in bivariate analysis but represents a style that was useful for other purposes namely large scale 
pattern scaling of climate models.
Mainly it returns all sorts of info about the regressions
JHR 
'''
'''
Code to support SI for 
Reconciling the signal and noise of atmospheric warming on decadal timescales
Roger N Jones* and James H Ricketts

Victoria Institute of Strategic Economic Studies, Victoria University, Melbourne, Victoria 8001, Australia
Correspondence to: Roger N. Jones (roger.jones@vu.edu.au)

And related publications

This code made available for informational purposes under the Creative Commons Attribution 3.0 License (enabling electronic and paper copies);
'''
import numpy as np
from scipy.stats import t
import numpy.ma as ma


class regressException(Exception):
  pass
#define a varable to control print diagnostics

printRegressDiags=False

statkeylist=["sse","ssb","ssa","sx" ,"sxx","sxy","sy" ,"syy" ,"ssx","ssy"]

def regress(data, xs, stats=[], asDict=False):
    '''
    This routine computes linear regressions on a gridded basis
    It returns slopes and then offsets
    The optional third parameter specifies some descriptive stats to return also returned in order specified.
    see http://en.wikipedia.org/wiki/Simple_linear_regression
    Values are any of 
        "sse" : sum squared error
        "ssb" : sum square on beta
        "ssa" : sum squared on alpha
        "sx"  : sum of x
        "sxx" : sum of squares of x
        "sxy" : sum of x * y
        "sy"  : sum of y
        "syy  : sum of squares of y
        "ssx" : sum squared of x
        "ssy" : sum squared of y
    '''
    
    if len(xs) != len(data):
        raise regressException("error: regress.regress. rank of xs and ys needs to agree")
    rank=len(np.shape(data))
    #print np.shape(data), len(xs)
    if rank == 1:
        n = np.shape(data)[0]
        sy = np.sum(data)
        syy = np.sum(data * data)
        sxy = np.sum(data * xs)
        sx = np.sum(xs)
        sxx = np.sum(xs * xs)
        ssx = sx * sx
    else:
        
        (n,x,y)=np.shape(data)
        xa=np.array(xs)
        sx = np.sum(xa)
        sxx= np.sum(xa * xa)
        #sxx=np.zeros((x,y))
        sxy=np.zeros((x,y))
        sy=np.zeros((x,y))
        syy=np.zeros((x,y))
        
        for i in range(n):
            dat=data[i][:]
            sy += dat
            syy += dat * dat
            sxy += dat * xa[i]
            #sx += xs[i]
            #sxx += xs[i] * xs[i]
        ssx = sx * sx
    
    beta = (n * sxy - sx * sy)/(n * sxx - ssx)
    alpha = (sy-beta*sx)/n
    
    if stats != []:
        if asDict:
            returns=dict()
        else:
            returns = []
        #then we set up to return extra statistics
        statset = set(stats)
        nullset = set([])
        if statset & set(["sse", "ssy", "ssb", "ssa"]) != nullset:
            ssy=sy*sy
        if statset & set(["sse", "ssb", "ssa"]) != nullset:
            if n > 2:
                sse= (1.0/(n * (n -2)))*(n * syy - ssy - beta * beta *(n * sxx - ssx))
            else:
                sse = np.NaN
        if statset & set(["ssb", "ssa"]) != nullset:
            ssb= (n * sse)/(n*sxx-ssx)
        if "ssa" in statset :
            ssa= ssb * sxx/n
        #print stats
        if not asDict:
            for k in stats:
                if k == "sse": returns.append((k,sse))
                if k == "ssb": returns.append((k,ssb))
                if k == "ssa": returns.append((k,ssa))
                if k == "sx":  returns.append((k,sx))
                if k == "sxx": returns.append((k,sxx))
                if k == "sxy": returns.append((k,sxy))
                if k == "sy" : returns.append((k,sy))
                if k == "syy": returns.append((k,syy))
                if k == "ssx": returns.append((k,ssx))
                if k == "ssy": returns.append((k,ssy))
                #print k, len(returns)
        else:
            for k in stats:
                if k == "sse": returns[k] =sse
                if k == "ssb": returns[k] =ssb
                if k == "ssa": returns[k] =ssa
                if k == "sx":  returns[k] =sx
                if k == "sxx": returns[k] =sxx
                if k == "sxy": returns[k] =sxy
                if k == "sy" : returns[k] =sy
                if k == "syy": returns[k] =syy
                if k == "ssx": returns[k] =ssx
                if k == "ssy": returns[k] =ssy
        
        return beta, alpha, returns
    else:
        return beta, alpha

def xs_matched_to_data( data, xs):
    xa = ma.array(np.zeros(np.shape(data)), mask=data.mask, keep_mask=True)
    for i in range(len(xs)):
        xa[i]=xs[i]
    xa = ma.array(xa, mask=data.mask, keep_mask=True)
    #print "SHAPE", np.shape(xa) , xa   
#    xa = ma.array(xs, copy=True, keep_mask=True)
    return xa
    
def masked_regress(data, xs, stats=[], asDict=False):
    '''
    This routine computes linear regressions on a gridded basis of Masked Arrays!@!
    It returns slopes and then offsets
    The optional third parameter specifies some descriptive stats to return also returned in order specified.
    see http://en.wikipedia.org/wiki/Simple_linear_regression
    Values are any of 
        "sse" : sum squared error
        "ssb" : sum square on beta
        "ssa" : sum squared on alpha
        "sx"  : sum of x
        "sxx" : sum of squares of x
        "sxy" : sum of x * y
        "sy"  : sum of y
        "syy  : sum of squares of y
        "ssx" : sum squared of x
        "ssy" : sum squared of y
    '''
    #data is expected to be a masked array enforce this 
    n=data.count(0) #count the non masked 
    
    xa=xs_matched_to_data( data, xs)
    
    rank=len(np.shape(data))
    #print np.shape(data), len(xs)
    if rank == 1:
        #n = ma.shape(data)[0]
        sy = ma.sum(data)
        syy = ma.sum(data * data)
        sxy = ma.sum(data * xa)
        sx = ma.sum(xa)
        sxx = ma.sum(xs * xa)
        ssx = sx * sx
    else:
        
        #(n,x,y)=ma.shape(data)  ###OOOPS needless and incorrect over write on n
        #print "masked_regress.shape(data)", ma.shape(data)
        #print "XA", xa.count(0)
        #xa=ma.array(xa)
        sx = ma.sum(xa, axis=0)
        sxx= ma.sum(xa * xa, axis=0)
        ssx = sx * sx
        sy = ma.sum(data, axis=0)
        syy = ma.sum(data * data, axis=0)
        sxy= ma.sum(data * xa, axis = 0)
        
#    print "shapes",ma.shape(sx),ma.shape(xa), ma.shape(sy), ma.shape(syy), ma.shape(ssx), ma.shape(sxy)
    beta = (n * sxy - sx * sy)/(n * sxx - ssx)
    alpha = (sy-beta*sx)/n
    if stats != []:
        if asDict:
            returns=dict()
        else:
            returns = []
        #then we set up to return extra statistics
        statset = set(stats)
        nullset = set([])
        if statset & set(["sse", "ssy", "ssb", "ssa"]) != nullset:
            ssy=sy*sy
        if statset & set(["sse", "ssb", "ssa"]) != nullset:
            sse= (1.0/(n * (n -2)))*(n * syy - ssy - beta * beta *(n * sxx - ssx))
        if statset & set(["ssb", "ssa"]) != nullset:
            ssb= (n * sse)/(n*sxx-ssx)
        if "ssa" in statset :
            ssa= ssb * sxx/n
        #print stats
        if not asDict:
            for k in stats:
                if k == "sse": returns.append((k,sse))
                if k == "ssb": returns.append((k,ssb))
                if k == "ssa": returns.append((k,ssa))
                if k == "sx":  returns.append((k,sx))
                if k == "sxx": returns.append((k,sxx))
                if k == "sxy": returns.append((k,sxy))
                if k == "sy" : returns.append((k,sy))
                if k == "syy": returns.append((k,syy))
                if k == "ssx": returns.append((k,ssx))
                if k == "ssy": returns.append((k,ssy))
                #print k, len(returns)
        else:
            for k in stats:
                if k == "sse": returns[k] =sse
                if k == "ssb": returns[k] =ssb
                if k == "ssa": returns[k] =ssa
                if k == "sx":  returns[k] =sx
                if k == "sxx": returns[k] =sxx
                if k == "sxy": returns[k] =sxy
                if k == "sy" : returns[k] =sy
                if k == "syy": returns[k] =syy
                if k == "ssx": returns[k] =ssx
                if k == "ssy": returns[k] =ssy
        
        return beta, alpha, returns
    else:
        return beta, alpha
 
def analysed_regress(data, xs):
    '''
    perform a regression and add in anything you ever wanted for later analysis
    '''
    #print type(data), type(ma.masked_array([1]))
    if isinstance(data, ma.masked_array):
        beta, alpha, stats = masked_regress(data, xs, statkeylist, True)
        n = data.count(0)
        #print "CALLED ",beta, alpha
#        print "masked"
    else:
        beta, alpha, stats = regress(data, xs, statkeylist, True)
        n = len(xs)
    #print beta, alpha        
    stats["n"]=n
    #print 'stats["n"]', stats["n"]
    Sxx = (n*stats["sxx"] - stats["sx"]*stats["sx"])/n
    #print "Sxx",Sxx
    Syy = (n*stats["syy"] - stats["sy"]*stats["sy"])/n
    #print 'stats["syy"]', np.shape(stats["syy"]), stats["syy"], 'stats["sy"]', stats["sy"]
    #print "Syy", Syy
    Sxy=  (n*stats["sxy"] - stats["sx"] * stats["sy"])/n
    SSR = beta*Sxy
    SSE= Syy - SSR
    stats["SSR"]=SSR
    stats["SSE"]=SSE
    sigma=np.sqrt(SSE/(n-2.))
    stats["sigma"]=sigma
    t95=t.ppf(0.975, n-2)
    stats["t95"]=t95
    sqrtsxx=np.sqrt(Sxx)
    beta_conf=t95 * sigma /sqrtsxx
    stats["beta_conf"]=beta_conf
    stats["beta"]=beta
    stats["alpha"]=alpha
    stats["t"]=sqrtsxx*(beta)/sigma
    stats["prob"]=t.pdf(stats["t"], n-2)
    y_conf=np.array([t95*sigma*np.sqrt(1./n + ((x-stats["sx"]/n) * (x-stats["sx"]/n))/Sxx) for x in xs])
    stats["y_conf"]=y_conf
    stats["mse"]= SSE/n
    stats["stderr"]=stats["mse"]/sqrtsxx
    stats["rsq"]=(Sxy * Sxy)/(Sxx * Syy)
    return stats
    
def residuals(data, xs, stats,dump=False):
    '''
    given an analyis compute residuals and central tendency, does not work for masked arrays.
    returns 
    '''
   #if data is masked then ensure that xs are masked
    if isinstance(data, np.ma.masked_array) and not isinstance(xs, np.ma.masked_array):
        yhat=xs_matched_to_data( data, xs)
        yhat=stats["beta"] * yhat  + stats["alpha"] 
        #print "HERE"
    else:
        #print "THRERE"
        yhat=np.array([stats["beta"] * x  + stats["alpha"] for x in xs])
    resid = data-yhat
    if dump:
        print "data"
        print data
        print "yhat"
        print yhat
        print "xs"
        print xs
        print "stats"
        print stats
    #    sys.exit()
    return yhat, resid
    

    
if __name__ == "__main__":
  print "Test cases go here"
