#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon May 13 08:37:19 2024

@author: jdaniel
"""

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from get_weibull import *
from get_consumption import *
from scipy.stats import qmc, norm, truncnorm, lognorm
from outputPresent import *
from altMarket import *
import pickle
import copy

class otherParams:
    def __init__(self):
        self.prodemiss = 1-1/1.05   #0.05
        self.prodemisssig = 1.
        self.solvent = .1#0.093
        self.solventsig = .5
        self.iyears = 200   #number of years to track values after production
        
class outputInfo:
    def __init__(self, nmarkets,iyears):
        self.prodemiss = np.zeros([nmarkets,iyears])
        self.totalemiss = np.zeros([nmarkets,iyears])
        self.cprodemiss = np.zeros([nmarkets,iyears])
        self.solventemiss = np.zeros([nmarkets,iyears])
        self.csolventemiss = np.zeros([nmarkets,iyears])
        self.installemiss = np.zeros([nmarkets,iyears])
        self.cinstallemiss = np.zeros([nmarkets,iyears])
        self.activebank = np.zeros([nmarkets,iyears])
        self.totalbank = np.zeros([nmarkets,iyears])
        self.activeemiss = np.zeros([nmarkets,iyears])
        self.cactiveemiss = np.zeros([nmarkets,iyears])
        self.activedecom = np.zeros([nmarkets,iyears])
        self.decomemiss = np.zeros([nmarkets,iyears])
        self.decomemissalt = np.zeros([nmarkets,iyears])
        self.cdecomemiss = np.zeros([nmarkets,iyears])
        self.cdecomemissalt = np.zeros([nmarkets,iyears])
        self.inactivebank = np.zeros([nmarkets,iyears])
        self.inactivebankalt = np.zeros([nmarkets,iyears])
        self.inactiveemiss = np.zeros([nmarkets,iyears])
        self.cinactiveemiss = np.zeros([nmarkets,iyears])
        self.inactiveemissalt = np.zeros([nmarkets,iyears])
        self.cinactiveemissalt = np.zeros([nmarkets,iyears])

def getReleaseParams(filename,sheet):
    fileinfo = pd.read_excel(filename, sheet_name=sheet)
    markets = fileinfo.columns[1::]
    nummarkets = len(markets)
    tmp = fileinfo.iloc[0::,1::]
    vals = np.zeros(tmp.shape)
    vals[:] = fileinfo.iloc[0::,1::]
    return(nummarkets,markets,vals)


def outputKernels(i,output,params,emissvals):
    iflag = 0
    output.prodemiss[i,0] = 1./(1.-params.prodemiss) - 1.
    output.solventemiss[i,0] = params.solvent/2. 
    output.solventemiss[i,1] = params.solvent/2.
    output.installemiss[i,0] = emissvals[0]*(1.-params.solvent)
    installed = (1.-params.solvent)*(1-emissvals[0])
    weibullactive, weibullemiss = getweibull(i,params,emissvals)
    output.activebank[i,:] = weibullactive*installed
    output.activebank[i,0] = 0   #for conservation
    output.activeemiss[i,:] = weibullemiss*installed
    output.activedecom[i,0] = installed - output.activebank[i,1] - output.activeemiss[i,0]
    for j in range(1,params.iyears-1):
        output.activedecom[i,j] = output.activebank[i,j] - output.activebank[i,j+1] - output.activeemiss[i,j]
    output.decomemiss[i,:] = output.activedecom[i,:] * emissvals[3]
    output.inactivebank[i,0] = 0#output.activedecom[i,0] - output.decomemiss[i,0]
    output.decomemissalt[i,:] = output.activedecom[i,:] * emissvals[3] *0.
    output.inactivebankalt[i,0] = 0#output.activedecom[i,0] - output.decomemiss[i,0]
    for j in range(1,params.iyears):
        output.inactivebank[i,j] = output.inactivebank[i,j-1] * (1.-emissvals[2]) + output.activedecom[i,j-1] -\
            output.decomemiss[i,j-1]
        output.inactiveemiss[i,j-1] = output.inactivebank[i,j-1] * emissvals[2]
        output.inactivebankalt[i,j] = output.inactivebank[i,j-1] * (1.-emissvals[2]) * 0
        output.inactiveemissalt[i,j-1] = output.inactivebank[i,j-1] * emissvals[2] * 0.
    if (np.min(output.inactivebank[i,:]) < 0): iflag = 1
    for j in range(0,params.iyears-1):
        output.cprodemiss[i,j] = np.sum(output.prodemiss[i,0:j+1])
        output.csolventemiss[i,j] = np.sum(output.solventemiss[i,0:j+1])
        output.cinstallemiss[i,j] = np.sum(output.installemiss[i,0:j+1])
        output.cactiveemiss[i,j] = np.sum(output.activeemiss[i,0:j+1])
        output.cdecomemiss[i,j] = np.sum(output.decomemiss[i,0:j+1])
        output.cinactiveemiss[i,j] = np.sum(output.inactiveemiss[i,0:j+1])
        output.cdecomemissalt[i,j] = np.sum(output.decomemissalt[i,0:j+1])
        output.cinactiveemissalt[i,j] = np.sum(output.inactiveemissalt[i,0:j+1])
    return(output,weibullactive,weibullemiss,iflag)

def runmodel(years,nummarkets,params,con,kernel,emissvals):
    ieurope = [2]#[0,1,2,3,4,5,6,7,8,9,10]#2   # index for Europe for capturing decomposition
    results = outputInfo(nummarkets,params.iyears)
    year0 = years[0]
    for i in range(0,params.iyears):
        year1 = year0 + float(i)
        for j in range(0,i+1):
            year = year0 + float(j)
            if (year <= years[-1]):
                results.activebank[:,i] = results.activebank[:,i] + con[j,:]*kernel.activebank[:,i-j]
                results.inactivebank[:,i] = results.inactivebank[:,i] + con[j,:]*kernel.inactivebank[:,i-j]
                results.csolventemiss[:,i] = results.csolventemiss[:,i] + con[j,:]*kernel.csolventemiss[:,i-j]
                results.cprodemiss[:,i] = results.cprodemiss[:,i] + con[j,:]*kernel.cprodemiss[:,i-j]
                results.cinstallemiss[:,i] = results.cinstallemiss[:,i] + con[j,:]*kernel.cinstallemiss[:,i-j]
                results.cactiveemiss[:,i] = results.cactiveemiss[:,i] + con[j,:]*kernel.cactiveemiss[:,i-j]
                results.cdecomemiss[:,i] = results.cdecomemiss[:,i] + con[j,:]*kernel.cdecomemiss[:,i-j]
                results.cinactiveemiss[:,i] = results.cinactiveemiss[:,i] + con[j,:]*kernel.cinactiveemiss[:,i-j]

                results.prodemiss[:,i] = results.prodemiss[:,i] + con[j,:]*kernel.prodemiss[:,i-j]
                results.solventemiss[:,i] = results.solventemiss[:,i] + con[j,:]*kernel.solventemiss[:,i-j]
                results.installemiss[:,i] = results.installemiss[:,i] + con[j,:]*kernel.installemiss[:,i-j]
                results.activeemiss[:,i] = results.activeemiss[:,i] + con[j,:]*kernel.activeemiss[:,i-j]
                results.decomemiss[:,i] = results.decomemiss[:,i] + con[j,:]*kernel.decomemiss[:,i-j]
                results.inactiveemiss[:,i] = results.inactiveemiss[:,i] + con[j,:]*kernel.inactiveemiss[:,i-j]
        if (year1 >= 2002):   #year1 >= 2002
            for k in ieurope:
                results.decomemiss[k,i] = 0
                results.cdecomemiss[k,i] = results.cdecomemiss[k,i-1]
                results.inactiveemiss[k,i] = 0
                results.cinactiveemiss[k,i] = results.inactiveemiss[k,i-1]
                results.inactivebank[k,i] = results.inactivebank[k,i-1]*(1-emissvals[2,k])# - con[j,:] * (kernel.activedecom[:,i-j]-kernel.decomemiss[:,i-j])
            
        results.totalbank[:,i] = results.activebank[:,i] + results.inactivebank[:,i]
        results.totalemiss[:,i] = results.inactiveemiss[:,i] + results.decomemiss[:,i]+\
                        results.activeemiss[:,i]+results.installemiss[:,i]+results.solventemiss[:,i]+\
                        results.prodemiss[:,i]
    return(results)

def getAltParams(sample_num,params):
    normaldistr = 0   # 1 for normal, 0 for uniform
    paramsmc = [otherParams() for i in range(sample_num)]
    cliplow, cliphi = 0.0000001, 1000
    dimension = 2
    a,b = np.zeros(dimension), np.zeros(dimension)
    mean = [params.prodemiss,params.solvent]
    std = [params.prodemiss*params.prodemisssig, params.solvent*params.solventsig]
    if (normaldistr == 1):
        for i in range(0, dimension):
            a[i] = (cliplow-mean[i])/std[i]
            b[i] = (cliphi-mean[i])/std[i]
            lhd = qmc.LatinHypercube(d=dimension, optimization="random-cd").random(n=sample_num)
            sample = truncnorm(a,b, loc=mean, scale=std).ppf(lhd)
    elif (normaldistr == 0):
        l_bounds, u_bounds = np.zeros(2), np.zeros(2)
        l_bounds[0] = mean[0] - std[0]
        l_bounds[1] = mean[1] - std[1]
        u_bounds = [2*mean[0],2*mean[1]]
        u_bounds = [mean[0]+std[0],mean[1]+std[1]]
        sampler = qmc.LatinHypercube(d=dimension)
        sample_noscale = sampler.random(n=sample_num)
        sample = qmc.scale(sample_noscale, l_bounds, u_bounds)
    for i in range(1,sample_num):
        paramsmc[i].prodemiss = sample[i,0]
        paramsmc[i].solvent = sample[i,1]
   
    return(paramsmc)

def getAltReleaseParams(sample_num,emissvals):
    isize = np.shape(emissvals)
    cliplow, cliphi = 0.0000001, 1#1000
    dimension = int(isize[0] * isize[1] / 2)
    a,b = np.zeros((dimension)), np.zeros((dimension))
    mean = np.zeros(dimension) 
    std = np.zeros(dimension)
    print('izize',isize,emissvals[0,1])
    for ii in range(0,int(emissvals.shape[0]/2)):
        if (ii <= 3):
            cliplow, cliphi = .0000001, 1
        else:
            cliplow, cliphi = .0000001, 1000
        for j in range(0,isize[1]):
            i = ii*isize[1] + j
            mean[i] = emissvals[ii,j]
            std[i] = emissvals[ii,j] * emissvals[ii+int(emissvals.shape[0]/2),j]
            print(ii,j,mean[i],std[i])
            a[i] = (cliplow-mean[i])/std[i]
            b[i] = (cliphi-mean[i])/std[i]
    lhd = qmc.LatinHypercube(d=dimension, optimization="random-cd").random(n=sample_num)
    sval = np.ndarray.flatten(emissvals[6::,:])
    svalln = sval[0:48]
    svalln = np.sqrt(np.log(svalln*svalln+1))
    std = sval[48::]
    sampleln = lognorm.ppf(lhd[:,0:48], s=svalln)
    sampleln = sampleln / np.exp(svalln**2/2)
    samplenorm = norm.ppf(lhd[:,48::], loc=mean[48::], scale=std)
    sampleln, samplenorm = np.array(sampleln), np.array(samplenorm)
    sampleout = np.concatenate((sampleln,samplenorm),axis=1)
    sampleout = np.reshape(sampleout,(sample_num,int(isize[0]/2),isize[1]))
    skeep = np.copy(sampleout)
    sampleout = sampleout * emissvals[0:6,:]
    sampleout[:,4::,:] = skeep[:,4::,:]
    sampleout[0,:,:] = emissvals[0:int(isize[0]/2),:]
    for i in range(0,sample_num):
        iover0 = np.where(sampleout[i,0,:] > 1)
        iover3 = np.where(sampleout[i,3,:] > 1)
        ish0 = np.shape(iover0)
        ish3 = np.shape(iover3)
        if (ish0[1] > 0): 
            sampleout[i,0,iover0] = 1
        if (ish3[1] > 0): 
            sampleout[i,3,iover3] = 1
    
    return(sampleout)
 
#------------------------------------------------------------------------
#          MAIN PROGRAM   #for paper submitted on 1/16/2025
#------------------------------------------------------------------------

paramsin = otherParams()
check = 0
sample_num = 1000  #number of Monte Carlo samples
tmp1, tmp2 = np.zeros(sample_num), np.zeros(sample_num)
emissFile = 'emiss_params_alt.xlsx'    # this file determines the markets and life cycle parameters for each one
emissSheet = 'testall'
segfilename = 'segmentation_allyears.xlsx' # this file has information to determine market fraction in each region and defines regions
                                #must be same order of regions as production files used in get_consumption
sheetname1 = 'use-2001'    #primary segmentation information
sheetname2 = 'use-2005'  #using fraction of market that is filled by -141b
sheetname3 = 'use-2008'
nummarkets, markets, emissvalsin = getReleaseParams(emissFile,emissSheet) # get number of markets and their parameters

marketSegRefrig2001, marketSegNoRefrig2001 = getSegmentation(segfilename,sheetname1)     #first index is region, second is market
marketSegRefrig2005, marketSegNoRefrig2005 = getSegmentation(segfilename,sheetname2)     #first index is region, second is market
marketSegRefrig2008, marketSegNoRefrig2008 = getSegmentation(segfilename,sheetname3)     #first index is region, second is market

years, production, consumptionin, feedstock, regions = get_consumption()     #get scaled consumption from production files; determines length of production time
s = np.shape(consumptionin)
consump_bymarket_mc = np.zeros((sample_num,s[0],s[1],nummarkets))
consumptionin = consumptionin*1.0

paramvalsmc = getAltParams(sample_num,paramsin)
emissvalsmc = getAltReleaseParams(sample_num,emissvalsin)
nummarkuse = 12
varmarket = 0.1

marketrefrigmc2001 = AltMarket_FractionMethod(sample_num,nummarkuse,varmarket,marketSegRefrig2001)
marketrefrigmc2005 = AltMarket_FractionMethod(sample_num,nummarkuse,varmarket,marketSegRefrig2005)
marketrefrigmc2008 = AltMarket_FractionMethod(sample_num,nummarkuse,varmarket,marketSegRefrig2008)
marketSegRefrigmc2001 = marketmc(marketrefrigmc2001,marketSegRefrig2001)
marketSegRefrigmc2005 = marketmc(marketrefrigmc2005,marketSegRefrig2005)
marketSegRefrigmc2008 = marketmc(marketrefrigmc2008,marketSegRefrig2008)
nummarkuse = 12
marketnorefrigmc2001 = AltMarket_FractionMethod(sample_num,nummarkuse,varmarket,marketSegNoRefrig2001)
marketnorefrigmc2005 = AltMarket_FractionMethod(sample_num,nummarkuse,varmarket,marketSegNoRefrig2005)
marketnorefrigmc2008 = AltMarket_FractionMethod(sample_num,nummarkuse,varmarket,marketSegNoRefrig2008)
marketSegNoRefrigmc2001 = marketmc(marketnorefrigmc2001,marketSegNoRefrig2001)
marketSegNoRefrigmc2005 = marketmc(marketnorefrigmc2005,marketSegNoRefrig2005)
marketSegNoRefrigmc2008 = marketmc(marketnorefrigmc2008,marketSegNoRefrig2008)

results = outputInfo(nummarkets,paramvalsmc[0].iyears)
resultsmc = [results for i in range(sample_num)]

weibull, weibullkeep = np.zeros([nummarkets,200]),np.zeros([nummarkets,200])
weibullem, weibullemkeep = np.zeros([nummarkets,200]),np.zeros([nummarkets,200])
for ii in range(0,sample_num):
    print(ii)
    params = copy.deepcopy(paramvalsmc[ii])
    emissvals = copy.deepcopy(emissvalsmc[ii,:,:])
    marketSegRefrig2001, marketSegRefrig2005, marketSegRefrig2008, marketSegNoRefrig2001,\
        marketSegNoRefrig2005, marketSegNoRefrig2008 =\
            marketSegRefrigmc2001[ii,:,:],marketSegRefrigmc2005[ii,:,:],marketSegRefrigmc2008[ii,:,:],\
            marketSegNoRefrigmc2001[ii,:,:],marketSegNoRefrigmc2005[ii,:,:],marketSegNoRefrigmc2008[ii,:,:]

    consump_bymarket = getSegmentedValsMult(years,consumptionin,marketSegRefrig2001,marketSegRefrig2005,\
                    marketSegRefrig2008,marketSegNoRefrig2001,marketSegNoRefrig2005,marketSegNoRefrig2008)    #calculate consumption in each market by region
    consump_bymarket_mc[ii,:,:,:] = consump_bymarket
    consumption = np.sum(consump_bymarket,axis=1)  #sums over all regions leaving array of (years,markets)

    outputKernel = outputInfo(nummarkets,params.iyears) 
    if (ii == 0):
        consump_bymarket_keep = consump_bymarket
        outputKernelKeep = outputKernel
    for i in range(0,nummarkets):
        outputKernel,wa,we,iflag = outputKernels(i,outputKernel,params,emissvals[:,i])    #calculate kernels for unit emission in a given year
        if (i == 0 and ii==0):
            makefig1(outputKernel)
        weibull[i,:] = wa
        weibullem[i,:] = we
        if (iflag == 1): print('iflag triggered');input()
        if (ii == 0):
            weibullkeep[i,:] = wa
            weibullemkeep[i,:] = we
        if (check): print(markets[i]); getcheck(i,outputKernel)   #does budget checking to ensure full time series == 1
    
    results = runmodel(years,nummarkets,params,consumption,outputKernel,emissvals)
    if (np.min(results.inactivebank) < 0 ):
        print('uh oh')
        input()
    else:
        print(i)
        print(np.min(results.inactivebank))
    resultsmc[ii] = results

ibestfit = outputPresent(years,markets,regions,consump_bymarket_keep,consump_bymarket_mc,results,resultsmc,weibullkeep,weibullemkeep)

with open('pickled.pkl', 'wb') as f:
    pickle.dump(years, f)
    pickle.dump(markets, f)
    pickle.dump(regions, f)
    pickle.dump(consump_bymarket_keep, f)
    pickle.dump(consump_bymarket_mc, f)
    pickle.dump(results, f)
    pickle.dump(resultsmc, f)
    pickle.dump(weibullkeep, f)
    pickle.dump(weibullemkeep, f)
    pickle.dump(ibestfit, f)
    pickle.dump(paramvalsmc, f)
    pickle.dump(emissvals, f)
    pickle.dump(emissvalsmc, f)
    pickle.dump(marketSegRefrigmc2001, f)
    pickle.dump(marketSegRefrigmc2005, f)
    pickle.dump(marketSegRefrigmc2008, f)
    pickle.dump(marketSegNoRefrigmc2001, f)
    pickle.dump(marketSegNoRefrigmc2005, f)
    pickle.dump(marketSegNoRefrigmc2008, f)
    pickle.dump(consumptionin, f)
    pickle.dump(nummarkets, f)
    pickle.dump(weibull, f)
    pickle.dump(weibullem, f)
    pickle.dump(check, f)
    pickle.dump(sample_num, f)
f.close()