import cvxpy as cp
import copy
from numpy import linalg as LA
import os
import traceback
import glob
import numpy as np

from ..signal.engine     import *
from ..signal.interface  import *
from ..risk              import *
from ..domain            import *

# === Axioma specific covariance integration ===
from model.trading.process_axioma_data import process_axioma_specific_cov
AXIOMA_DATA_DIR = "./axioma_data"

## ==================================================================================================================

# Set pandas display options
pd.set_option('display.max_rows', None)         # Show all rows
pd.set_option('display.max_columns', None)      # Show all columns
pd.set_option('display.width', 1000)            # Set wider console width
pd.set_option('display.max_colwidth', None)     # Show full content in each column
pd.set_option('display.float_format', '{:.4f}'.format)  # Show floating point numbers with 4 decimals

def matrix_shrink(Q):
    """ Take Q and reapprox using subset of singular values """
    # shrink matrix
    U, S, V = LA.svd(Q) 
    num_sing_values = S.shape[0]
    U = U[:, :num_sing_values]

    # log some singular values
    PyLog.info(f"matrix_shrink sing values top5 {S[0:5]} bot5 {S[-5:]}")

    # use at least portion of the singular values, and always at least K but not more than the rank of the matrix
    # very cut down settings we have used in some tests:  portion= 1/8; K = 10
    # typical keep most settings portion= 7/8; K= 20
    portion = 7/8
    K = 20
    num_sing_vaule_to_use = max(int(num_sing_values*portion), K)
    r = min(num_sing_values, num_sing_vaule_to_use)
    PyLog.info(f"matrix_shrink got {num_sing_values} singular values. portion:{portion:.3f} K:{K} reconstructing with {r}. matrix shpe {Q.shape} first/last  {S[0]/S[4] :.2f}")
    S = np.diag(S)
    R = U[:, :r] @ S[:r, :r] @ V[:r, :]
    return R

class RebalConfig:

    dctDefaultOptParams = {
        'SCS': {'maxIter': 200_000,
                'cycleIter': 5_000,
                'targetTolerance': 1e-4,
                'toleranceRange': [7e-4, 1e-3, 2e-3, 3e-3, 4e-3, 5e-3]},
        'ECOS': {'maxIter': 20_000}
    }

    DEFAULT_CONFIG = {'configName': 'DEFAULT',
                      'stratName': 'strat_asia_v1',
                      'maxLeverage': +4.00,
                      'minNetExposure': -0.05,
                      'maxNetExposure': +0.40,
                      'minBetaAdjNet': -0.01,
                      'maxBetaAdjNet': +0.01,
                      'gamma': 1.8,
                      'adaptiveGamma': False,
                      'adaptiveGammaTargetRisk': 0.10,
                      'adaptiveGammaTolerance': 0.0025,
                      'tau': 4.0,
                      'maxRisk': 0.13,
                      'offDiagRisk': False,
                      'splitLongShort': False,
                      'splitLongShortStartDate': PyDate.asDate(20220615),
                      'bcostLookback': 8,
                      'maxShortUtilization': 0.8,
                      'pbList': None,
                      'bcostMultiplierName': None,
                      'HAUM': False,
                      'hypotheticalAUM': 100_000_000,
                      'univwtThreshold': 0.25,
                      'holdingThreshold': 0.00005,
                      'rmodelName': 'risk_model',
                      'tmodelName': 'tcost_model',
                      'impactCostScale': 1.0,
                      'alphaName': 'alpha',
                      'tradeUniverseName': None,
                      'themeBoundName': 'trading_alpha_theme_exposure_bounds',
                      'fcostName': 'trading_long_financing_cost',
                      'bcostName': 'trading_short_borrow_cost',
                      'defaultFCost': 0.005,
                      'defaultBCost': 0.15,
                      'longBoundName': 'trading_position_bounds_long',
                      'shortBoundName': 'trading_position_bounds_short',
                      'defaultLongBound': 0.000025,
                      'defaultShortBound': 0.000025,
                      'liquidityName': 'adv_usdm_blended',
                      'defaultLiquidityUSDM': 1.0,
                      'maxLiquidityBoundLong': None,
                      'maxLiquidityBoundShort': None,
                      'applyShortBan': True,
                      'applyExchangeHolidays': False,
                      'advName': 'adv_usdm_blended',
                      'countryBoundName': 'trading_country_bounds_frame',
                      'industryBoundName': 'trading_industry_bounds_frame',
                      'sizeBoundName': 'trading_size_bounds_frame',
                      'targetTolerance': 1e-4,
                      'toleranceRange': [7e-4, 1e-3, 2e-3, 3e-3, 4e-3, 5e-3],
                      'minShortAvailM': 0,
                      'shrinkfcov' : False,
                      'enable_gs_stability_score' : False,
                      'tightCountry' : False,}

    @classmethod
    def defaultConfig(cls, parameter):
        return cls.DEFAULT_CONFIG[parameter]
    
    def __init__(self,
                 configName              = None,
                 stratName               = None,
                 maxLeverage             = None,
                 minNetExposure          = None,
                 maxNetExposure          = None,
                 minBetaAdjNet           = None,
                 maxBetaAdjNet           = None,
                 gamma                   = None,
                 adaptiveGamma           = None,
                 adaptiveGammaTargetRisk = None,
                 adaptiveGammaTolerance  = None,
                 tau                     = None,
                 maxRisk                 = None,
                 offDiagRisk             = None,
                 splitLongShort          = None,
                 splitLongShortStartDate = None,
                 maxShortUtilization     = None,
                 pbList                  = None,
                 bcostMultiplierName     = None,
                 HAUM                    = None,
                 hypotheticalAUM         = None,
                 bcostLookback           = None,
                 univwtThreshold         = None,
                 holdingThreshold        = None,
                 rmodelName              = None,
                 tmodelName              = None,
                 impactCostScale         = None,
                 alphaName               = None,
                 tradeUniverseName       = None,
                 themeBoundName          = None,
                 fcostName               = None,
                 bcostName               = None,
                 defaultFCost            = None,
                 defaultBCost            = None,
                 longBoundName           = None,
                 shortBoundName          = None,
                 liquidityName           = None,
                 defaultLiquidityUSDM    = None,
                 maxLiquidityBoundLong   = None,
                 maxLiquidityBoundShort  = None,
                 defaultLongBound        = None,
                 defaultShortBound       = None,
                 applyShortBan           = None,
                 applyExchangeHolidays   = None,
                 advName                 = None,
                 maxAdvProp              = None,
                 countryBoundName        = None,
                 industryBoundName       = None,
                 sizeBoundName           = None,
                 targetTolerance         = None,
                 toleranceRange          = None,
                 minShortAvailM          = None,
                 shrinkfcov              = None,
                 enable_gs_stability_score = None,
                 tightCountry            = None,):

        self.configName        = self.defaultConfig('configName') if configName is None else configName
        self.stratName         = self.defaultConfig('stratName') if stratName is None else stratName
        self.maxLeverage       = self.defaultConfig('maxLeverage') if maxLeverage is None else maxLeverage
        self.minNetExposure    = self.defaultConfig('minNetExposure') if minNetExposure is None else minNetExposure
        self.maxNetExposure    = self.defaultConfig('maxNetExposure') if maxNetExposure is None else maxNetExposure
        self.minBetaAdjNet     = self.defaultConfig('minBetaAdjNet') if minBetaAdjNet is None else minBetaAdjNet
        self.maxBetaAdjNet     = self.defaultConfig('maxBetaAdjNet') if maxBetaAdjNet is None else maxBetaAdjNet
        self.gamma             = self.defaultConfig('gamma') if gamma is None else gamma
        self.adaptiveGamma     = self.defaultConfig('adaptiveGamma') if adaptiveGamma is None else adaptiveGamma
        self.adaptiveGammaTargetRisk = self.defaultConfig('adaptiveGammaTargetRisk')\
            if adaptiveGammaTargetRisk is None else adaptiveGammaTargetRisk
        self.adaptiveGammaTolerance  = self.defaultConfig('adaptiveGammaTolerance')\
            if adaptiveGammaTolerance is None else adaptiveGammaTolerance
        self.tau               = self.defaultConfig('tau') if tau is None else tau
        self.maxRisk           = self.defaultConfig('maxRisk') if maxRisk is None else maxRisk
        self.offDiagRisk       = self.defaultConfig('offDiagRisk') if offDiagRisk is None else offDiagRisk
        self.splitLongShort    = self.defaultConfig('splitLongShort') if splitLongShort is None else splitLongShort
        self.splitLongShortStartDate = self.defaultConfig('splitLongShortStartDate')\
            if splitLongShortStartDate is None else splitLongShortStartDate
        self.maxShortUtilization = self.defaultConfig('maxShortUtilization')\
            if maxShortUtilization is None else maxShortUtilization
        self.pbList = self.defaultConfig('pbList') if pbList is None else pbList
        self.bcostMultiplierName = self.defaultConfig('bcostMultiplierName')\
            if bcostMultiplierName is None else bcostMultiplierName
        self.HAUM = self.defaultConfig('HAUM') if HAUM is None else HAUM
        self.hypotheticalAUM = self.defaultConfig('hypotheticalAUM')\
            if hypotheticalAUM is None else hypotheticalAUM
        self.bcostLookback     = self.defaultConfig('bcostLookback') if bcostLookback is None else bcostLookback
        self.univwtThreshold   = self.defaultConfig('univwtThreshold') if univwtThreshold is None else univwtThreshold
        self.holdingThreshold  = self.defaultConfig('holdingThreshold') if holdingThreshold is None else holdingThreshold
        self.rmodelName        = self.defaultConfig('rmodelName') if rmodelName is None else rmodelName
        self.tmodelName        = self.defaultConfig('tmodelName') if tmodelName is None else tmodelName
        self.impactCostScale   = self.defaultConfig('impactCostScale') if impactCostScale is None else impactCostScale
        self.alphaName         = self.defaultConfig('alphaName') if alphaName is None else alphaName
        self.tradeUniverseName = \
            self.defaultConfig('tradeUniverseName') if tradeUniverseName is None else tradeUniverseName
        self.themeBoundName    = themeBoundName
        self.fcostName         = self.defaultConfig('fcostName') if fcostName is None else fcostName
        self.bcostName         = self.defaultConfig('bcostName') if bcostName is None else bcostName
        self.defaultFCost      = self.defaultConfig('defaultFCost') if defaultFCost is None else defaultFCost
        self.defaultBCost      = self.defaultConfig('defaultBCost') if defaultBCost is None else defaultBCost
        self.longBoundName     = self.defaultConfig('longBoundName') if longBoundName is None else longBoundName
        self.shortBoundName    = self.defaultConfig('shortBoundName') if shortBoundName is None else shortBoundName
        self.defaultLongBound  = \
            self.defaultConfig('defaultLongBound') if defaultLongBound is None else defaultLongBound
        self.defaultShortBound = \
            self.defaultConfig('defaultShortBound') if defaultShortBound is None else defaultShortBound
        self.liquidityName     = \
            self.defaultConfig('liquidityName') if liquidityName is None else liquidityName
        self.defaultLiquidityUSDM = \
            self.defaultConfig('defaultLiquidityUSDM') if defaultLiquidityUSDM is None else defaultLiquidityUSDM
        self.maxLiquidityBoundLong = \
            self.defaultConfig('maxLiquidityBoundLong') if maxLiquidityBoundLong is None else maxLiquidityBoundLong
        self.maxLiquidityBoundShort = \
            self.defaultConfig('maxLiquidityBoundShort') if maxLiquidityBoundShort is None else maxLiquidityBoundShort
        self.applyShortBan     = self.defaultConfig('applyShortBan') if applyShortBan is None else applyShortBan
        self.applyExchangeHolidays = \
            self.defaultConfig('applyExchangeHolidays') if applyExchangeHolidays is None else applyExchangeHolidays
        self.advName           = self.defaultConfig('advName') if advName is None else advName
        self.maxAdvProp        = maxAdvProp
        self.countryBoundName  = \
            self.defaultConfig('countryBoundName') if countryBoundName is None else countryBoundName
        self.industryBoundName = \
            self.defaultConfig('industryBoundName') if industryBoundName is None else industryBoundName
        self.sizeBoundName     = self.defaultConfig('sizeBoundName') if sizeBoundName is None else sizeBoundName
        self.targetTolerance   = self.defaultConfig('targetTolerance') if targetTolerance is None else targetTolerance
        self.toleranceRange    = self.defaultConfig('toleranceRange') if toleranceRange is None else toleranceRange
        self.minShortAvailM    = self.defaultConfig('minShortAvailM') if minShortAvailM is None else minShortAvailM
        self.shrinkfcov = self.defaultConfig('shrinkfcov') if shrinkfcov is None else shrinkfcov 
        self.tightCountry = self.defaultConfig('tightCountry') if tightCountry is None else tightCountry 
        self.enable_gs_stability_score     = self.defaultConfig('enable_gs_stability_score') if enable_gs_stability_score is None else enable_gs_stability_score

RebalConfig_DEFAULT = RebalConfig()

## ==================================================================================================================

class Rebalance:

    EMPTY_PORTFOLIO = pd.Series(1.0, index=['$USD'])

    @classmethod
    def saveSimScratch(cls, df, date=None, name_without_path=None):
        """ hack method for auditing and analysis
            drop the df in the sim scratch directory for post inspection 
        
            place a call whereever you want a dataframe for analysis, eg
                cls.saveSimScratch(dframe, tradeDate, "weights_early")
            args
            df: dataframe to save down
            date : if found appends to name 
            name_without_path : no path, no extension 
            """

        user = os.environ['USER']
        base = f"/home/{user}/scratch/r"
        os.makedirs(base, exist_ok=True)
        if name_without_path is None:
            name_without_path = "scratch_file"
        if date:
            date_str = date.strftime("%Y%m%d")
            filebase = f"{base}/{name_without_path}_{date_str}"
        else:
            filebase = f"{base}/{name_without_path}"
        df.to_parquet(f"{filebase}.pq")
        df.to_csv(f"{filebase}.csv")

    @classmethod
    def run(cls, rebalConfig=RebalConfig_DEFAULT, signalDate=PyDate.asDate(20191115),
            preOptWeights=EMPTY_PORTFOLIO, NAV=100000000, solver=cp.SCS, maxIter=200000, cycleIter=5000,
            verbose=True, checkDCP=True, optimal=False, tradeRestrictions=None):


        tradeDate = PyDate.nextWeekday(signalDate)
        PyLog.info('optimizing for SD:{} / TD:{}'.format(PyDate.asISO(signalDate), PyDate.asISO(tradeDate)))


        dctOptResult = dict()
        dctOptResult['rebalConfig'] = rebalConfig
        dctOptResult['signalDate']  = signalDate
        dctOptResult['preOptNAV']   = NAV

        if len(preOptWeights) == 0:
            dfWeights = pd.DataFrame({'assetKey': ['assetKey'], 'preOptWeights': [0.0]}).iloc[:0]
        else:
            dfWeights = preOptWeights.reset_index().rename(columns={'index': 'assetKey', 0: 'preOptWeights'})
            dfWeights = dfWeights[[Real.isNonZero(x) for x in dfWeights['preOptWeights']]]



        stratName = rebalConfig.stratName
        rmodel    = SignalMgr.get(rebalConfig.rmodelName, signalDate, stratName)
        tmodel    = SignalMgr.get(rebalConfig.tmodelName, signalDate, stratName).set_index('assetKey')
#        gmult     = Trading.getModelParameters('rrmult', signalDate)
        PyLog.info("Rebalance.py.run SET GMULT=1.0 --FROM YB TRADING GOOGLE SHEET")
        gmult=1.0

        PyLog.info("Rebalance.py.run SET TCAFMULT=1.0 --FROM YB TRADING GOOGLE SHEET")
        tcafmult =1.0
#        tcafmult  = Trading.getModelParameters('tcafmult', signalDate)


        signals = {rebalConfig.alphaName      : 'alpha',
                   rebalConfig.fcostName      : 'fcost',
                   rebalConfig.bcostName      : 'bcost',
                   rebalConfig.longBoundName  : 'longBoundSoft',
                   rebalConfig.shortBoundName : 'shortBoundSoft',
                   'latest_mktcap_usdm'       : 'mktcap',
                   'univwt'                   : 'univwt',
                   'model_country'            : 'modelCountry'}



        dframe = SignalMgr.getFrame(list(signals.keys()), signalDate, stratName).\
            drop(columns=['signalDate']).rename(columns=signals).set_index('assetKey')


        tmult = SignalMgr.get('trading_tcost_amortization_factor_multiplier', signalDate, stratName).iloc[0]
        dct   = {'tcafmult': tcafmult, 'tmult': tmult}

        ## hard position bounds ---------------------------------------------------------------------------------
        dframe = dframe.assign(longBoundHard = 1.0)
        dframe = dframe.assign(shortBoundHard = 1.0)

        dframe = dframe[dframe['univwt'] >= rebalConfig.univwtThreshold]
        dframe = dframe[~dframe['alpha'].isnull()]

        if rebalConfig.tradeUniverseName is not None:
            PyLog.info(f"Getting rebalConfig.tradeUniverseName:{rebalConfig.tradeUniverseName} signalDate:{signalDate} stratName:{stratName}")
            universe = SignalMgr.get(rebalConfig.tradeUniverseName, signalDate, stratName)
            dframe_sz_was = len(dframe)
            dframe = dframe[dframe.index.isin(universe.index)]
            dframe_sz_is = len(dframe)
            PyLog.info(f"dframe filtered from {dframe_sz_was} to {dframe_sz_is}. universe size:{len(universe)}")

        ## drop assets that are excluded from both long and short bounds
        dframe = dframe[~dframe['longBoundSoft'].isnull() | ~dframe['shortBoundSoft'].isnull()]

        # assets = list((set(rmodel['assets']).intersection(set(dframe.index)).intersection(set(tmodel.index))).\
        #     union(set(dfWeights['assetKey'])))
        assets = set(rmodel['assets']).intersection(set(dframe.index))
        assets = assets.intersection(set(tmodel.index))
        assets = sorted(list(assets))
        dframe = dframe.reindex(index=assets).reset_index()

        dframe = dframe.assign(fcost  = dframe['fcost'].fillna(rebalConfig.defaultFCost))
        dframe = dframe.assign(bcost  = dframe['bcost'].fillna(rebalConfig.defaultBCost))

        ## position bounds -------------------------------------------------------------------------------------
        dframe = dframe.assign(longBoundSoft  = dframe['longBoundSoft'].fillna(rebalConfig.defaultLongBound))
        dframe = dframe.assign(shortBoundSoft = dframe['shortBoundSoft'].fillna(rebalConfig.defaultShortBound))



        ## in addition, we want to restrict short to 2% of market cap
        dframe = dframe.assign(shortAvail = dframe['mktcap'].fillna(10.0) * 0.02 * 1000000 / NAV)

        ## cut on min shortAvail in M;- i.e. zero out shortAvail if number is <X

        minShortAvailM = rebalConfig.minShortAvailM
        PyLog.info('DCH ZERO out shortAvail<(${}M)<    // zero shorts for ({}) names out of ({}) which have available shorts'.format(minShortAvailM, len(dframe.loc[dframe.shortAvail<minShortAvailM]) ,len(dframe.loc[dframe.shortAvail>0]) ))
        dframe.loc[dframe.shortAvail < minShortAvailM, 'shortAvail'] = 0

        # dframe = dframe.assign(shortBound = dframe[['shortBound', 'shortAvail']].min(axis=1))
        dframe = dframe.assign(shortBoundHard = dframe[['shortBoundHard', 'shortAvail']].min(axis=1))



        dframe = dframe.drop(columns='shortAvail')

        ## apply liquidity-based position bounds ---------------------------------------------------------------
        if (rebalConfig.maxLiquidityBoundLong is not None) or (rebalConfig.maxLiquidityBoundShort is not None):
            dfm = SignalMgr.getFrame(rebalConfig.liquidityName, signalDate, stratName)
            dframe = dframe.merge(dfm[['assetKey', rebalConfig.liquidityName]], how='left', on='assetKey')
            dframe[rebalConfig.liquidityName] = \
                dframe[rebalConfig.liquidityName].fillna(rebalConfig.defaultLiquidityUSDM)
            if rebalConfig.maxLiquidityBoundLong is not None:
                dframe = dframe.assign(
                    liquidityBoundLong = rebalConfig.maxLiquidityBoundLong * dframe[rebalConfig.liquidityName]
                                         * 1000000 / NAV)
                dframe = dframe.assign(longBoundSoft = dframe[['longBoundSoft', 'liquidityBoundLong']].min(axis=1))
                dframe = dframe.drop(columns=['liquidityBoundLong'])
            if rebalConfig.maxLiquidityBoundShort is not None:
                dframe = dframe.assign(
                    liquidityBoundShort = rebalConfig.maxLiquidityBoundShort * dframe[rebalConfig.liquidityName]
                                          * 1000000 / NAV)
                dframe = dframe.assign(shortBoundSoft = dframe[['shortBoundSoft', 'liquidityBoundShort']].min(axis=1))
                dframe = dframe.drop(columns=['liquidityBoundShort'])
            dframe = dframe.drop(columns=[rebalConfig.liquidityName])
        ## -----------------------------------------------------------------------------------------------------

        ## apply short bans ------------------------------------------------------------------------------------
        if rebalConfig.applyShortBan:
            if PyDate.ge(signalDate, PyMonth.firstWeekday(202003)):
                if PyDate.le(signalDate, PyMonth.firstWeekday(202105)):
                    countries = ['KR', 'ID']
                else:
                    countries = ['ID']
                dframe = dframe.assign(
                    preOptWeights = Filter.bound(-preOptWeights.reindex(dframe['assetKey']).fillna(0.0), lower=0.0))
                dframe.loc[dframe['modelCountry'].isin(countries), 'preOptWeights'] =\
                    dframe.loc[dframe['modelCountry'].isin(countries), 'preOptWeights']
                dframe.loc[~dframe['modelCountry'].isin(countries), 'preOptWeights'] = 1.0
                # dframe = dframe.assign(shortBound = dframe[['shortBound', 'preOptWeights']].min(axis=1))
                dframe = dframe.assign(shortBoundHard = dframe[['shortBoundHard', 'preOptWeights']].min(axis=1))
                dframe = dframe.drop(columns='preOptWeights')
        ## -----------------------------------------------------------------------------------------------------

        # === KR SHORTS ZERO-OUT (no KR shorts) =======================================
        #### KR shorting is not working in the optimizer, we have consistently lost money since the short 
        # ban was lifted in April 2025, so below code disable KR short selling. This is temporary till 
        # we figure out what is the core issue with KR shorts

        # Policy: immediately cover any KR short; forbid new KR shorts. KR longs unaffected.
        try:
            # Detect Korea by modelCountry; also accept RIC suffixes (.KS = KOSPI, .KQ = KOSDAQ) if present
            is_kr = dframe['modelCountry'].astype(str).str.upper().eq('KR')

            if 'ric' in dframe.columns:
                is_kr = is_kr | dframe['ric'].astype(str).str.upper().str.endswith(('.KS', '.KQ'))
            is_kr = is_kr.fillna(False).to_numpy()

            # Asset-level: set shortBoundHard to 0 (no shorts). Keep longBoundHard unchanged.
            sbh = dframe['shortBoundHard'].astype(float).fillna(0.0).to_numpy()
            sbh[is_kr] = 0.0
            dframe['shortBoundHard'] = sbh

            # Per-availability (splitLongShort=True): also clamp line-level short bounds to 0
            if 'assetKey' in dfAvail.columns and 'shortBoundHard' in dfAvail.columns:
                kr_keys = set(dframe.loc[is_kr, 'assetKey'])
                if len(kr_keys) > 0:
                    mask_av = dfAvail['assetKey'].isin(kr_keys)
                    dfAvail.loc[mask_av, 'shortBoundHard'] = 0.0

            PyLog.info(f"KR shorts disabled: set shortBoundHard=0 for {int(is_kr.sum())} KR names (asset + availability).")
        except Exception as e:
            PyLog.info(f"KR shorts disable not applied: {e}")
        # === end KR SHORTS ZERO-OUT ==================================================


        PyLog.info(f"dframe len:{len(dframe)} {dframe.columns.tolist()}")


       ########  Go impose TradeRestrictions pre-optimiser on longBoundHard/ShortBoundHard
        if(tradeRestrictions):
            from .TradeRestrictions  import TradeRestrictions

            #fn_ob = 'preOpt_preTR.{}.csv'.format(str(tradeDate).replace('-','')); sim_dir = '/mnt/signal/simulation/EG001'; fn_ob = os.path.join(f"{sim_dir}/{fn_ob}")
            #dframe.to_csv(fn_ob, index=False, float_format='%.10f')
            #PyLog.info('tr.applyTradeRestrictions PRE-OPT ...go')

            added_preOptWeights=False
            if not 'preOptWeights' in dframe.columns:
                preOptWeights_df = pd.DataFrame({'assetKey':preOptWeights.index, 'preOptWeights':preOptWeights.values})
                dframe = dframe.merge(preOptWeights_df, how='left', on='assetKey').fillna(0.0)
                added_preOptWeights=True

            dframe = tradeRestrictions.applyTradeRestrictions( tradeDate, dframe, NAV)


            #fn_ob = 'preOpt_postTR.{}.csv'.format(str(tradeDate).replace('-','')); sim_dir = '/mnt/signal/simulation/EG001'; fn_ob = os.path.join(f"{sim_dir}/{fn_ob}")
            #dframe.to_csv(fn_ob, index=False, float_format='%.10f')

            if added_preOptWeights and 'preOptWeights' in dframe.columns:
                dframe.drop(['preOptWeights'], axis=1, inplace=True)

    #        ###############  Done impose TradeRestrictions ######################################################

        ## dframe is what goes into optimization / dfWeights is for record keeping
        dframe = dframe.merge(dfWeights, how='left', on='assetKey')

        PyLog.info(f"tradeDate: {tradeDate} rebalConfig.splitLongShort:{rebalConfig.splitLongShort} rebalConfig.splitLongShortStartDate:{rebalConfig.splitLongShortStartDate}")
        ## Note, the time condition here is different from that in Simulation. We want the splitLongShort
        ## applied in optimization first before applying it in the portfolio accounting.
        if rebalConfig.splitLongShort and (tradeDate >= rebalConfig.splitLongShortStartDate):

            dfAvail, dfMarginal = cls.compileShortAvailability(rebalConfig, NAV, tradeDate, dframe)
            #this was allowing optimal starting positions to be constructed exceeding avail borrow. Commenting out.
            #if optimal:
            #    dfMarginal = dfAvail.groupby('assetKey').agg({'shortBoundHard': sum}).reset_index().\
            #        rename(columns={'shortBoundHard': 'availablePct'})
            dframe = dframe.merge(dfMarginal.rename(columns={'availablePct': 'shortBound'}), how='left', on='assetKey')
            dframe = dframe.assign(shortBound=Real.isNegative(dframe['preOptWeights']).astype(int) *
                                              abs(dframe['preOptWeights'].fillna(0.0))
                                              + dframe['shortBound'].fillna(0.0)
                                              + rebalConfig.holdingThreshold / 100.0)
            dframe = dframe.assign(shortBoundHard=dframe[['shortBoundHard', 'shortBound']].min(axis=1))

            """
            ## reflect exchange holidays 
            if rebalConfig.applyExchangeHolidays and (not optimal):
                lstHolidays = cls.getExchangeHolidays(tradeDate)
                dframe = dframe.assign(holiday=dframe['modelCountry'].isin(lstHolidays))
                dframe.loc[dframe['holiday'], 'longBoundHard'] \
                    = Filter.bound(dframe.loc[dframe['holiday'], 'preOptWeights'].fillna(0.0), lower=0.0) \
                      + rebalConfig.holdingThreshold / 100.0
                dframe.loc[dframe['holiday'], 'shortBoundHard'] \
                    = Filter.bound(-dframe.loc[dframe['holiday'], 'preOptWeights'].fillna(0.0), lower=0.0) \
                      + rebalConfig.holdingThreshold / 100.0
            ## reflect exchange holidays 
            """

            if 'holiday' not in dframe.columns:
                dframe = dframe.assign(holiday=False)

            ## NEW reflect exchange holidays 
            if rebalConfig.applyExchangeHolidays and (not optimal):

                try:
                    lstHolidays = cls.getExchangeHolidays(tradeDate)
                    PyLog.info(f"tradeDate = {tradeDate}, lstHolidays={lstHolidays}")

                    dframe = dframe.assign(holiday=dframe['modelCountry'].isin(lstHolidays))

                    cur = dframe['preOptWeights'].fillna(0.0)
                    cur_long  = Filter.bound(cur,  lower=0.0)
                    cur_short = Filter.bound(-cur, lower=0.0)

                    dframe.loc[dframe['holiday'], 'longBoundHard']  = cur_long.loc[dframe['holiday']]
                    dframe.loc[dframe['holiday'], 'shortBoundHard'] = cur_short.loc[dframe['holiday']]

                except Exception as e:
                    import traceback
                    PyLog.error(f"Exception inside holiday block: {type(e).__name__}: {e}")
                    PyLog.error(traceback.format_exc())
                    raise

            ## NEW reflect exchange holidays 

            ## NEW reflect exchange holidays in dfAvail also
            dframe = dframe.assign(idx = list(range(len(dframe))))
            dfAvail = dfAvail.merge(dframe[['assetKey', 'idx']], how='left', on='assetKey')

            if rebalConfig.applyExchangeHolidays and (not optimal):

                lk = dframe.set_index('assetKey')

                # update only the holiday rows in dfAvail
                hol_keys = set(dframe.loc[dframe['holiday'], 'assetKey'])

                mask_av_hol = dfAvail['assetKey'].isin(hol_keys)

                if 'longBoundHard' in dframe.columns:

                    dfAvail.loc[mask_av_hol, 'longBoundHard']  = dfAvail.loc[mask_av_hol, 'assetKey'].map(lk['longBoundHard'])

                if 'shortBoundHard' in dframe.columns:
                    dfAvail.loc[mask_av_hol, 'shortBoundHard'] = dfAvail.loc[mask_av_hol, 'assetKey'].map(lk['shortBoundHard'])

                PyLog.info(f"Unique modelCountry values: {dframe['modelCountry'].unique()[:10]}")
                PyLog.info(f"lstHolidays: {lstHolidays}")

                # Debugging: verify holiday logic and propagation 

                num_holidays = dframe['holiday'].sum() if 'holiday' in dframe.columns else 0
                PyLog.info(f"Detected {num_holidays} holiday assets out of {len(dframe)}")

                if num_holidays > 0:
                    PyLog.info("Sample holiday-pinned assets (up to 5 per country):")

                    holidays_df = dframe.loc[dframe['holiday']]

                    for ctry, sub in holidays_df.groupby('modelCountry'):
                        sample = sub[['assetKey']].head(5)
                        PyLog.info(f"\nCountry: {ctry} — {len(sub)} total holiday assets, showing up to 5 assetKeys:\n"
                                   + sample.to_string(index=False))

                    # Cross-check dfAvail bounds match for those same assets
                    merged_check = dfAvail.merge(
                        dframe.loc[dframe['holiday'], ['assetKey', 'longBoundHard', 'shortBoundHard']],
                        how='inner', on='assetKey', suffixes=('_avail', '_dframe')
                    )
                    diff_long = (merged_check['longBoundHard_avail'] - merged_check['longBoundHard_dframe']).abs().sum()
                    diff_short = (merged_check['shortBoundHard_avail'] - merged_check['shortBoundHard_dframe']).abs().sum()
                    PyLog.info(f"Holiday bounds sync check — long diff: {diff_long:.3e}, short diff: {diff_short:.3e}")

            ## NEW reflect exchange holidays in dfAvail also

            availAggr = np.zeros((len(dframe), len(dfAvail)))
            for n in range(len(dfAvail)):
                availAggr[dfAvail.iloc[n]['idx'], n] = 1

            factors = rmodel['factors']
            fload = rmodel['fload'].reindex(index=assets, columns=factors)

            mu = dframe[['alpha']].to_numpy()
            D = np.diag(rmodel['srisk'].reindex(index=assets) ** 2)

            """
            # === Axioma CCSC: try to replace specific covariance; fall back on local D if False/None ===
            try:
                # Build a minimal dframe for join: needs assetKey and ric
                if 'ric' in dframe.columns:
                    # Use the RIC already present in dframe
                    dframe_for_join = dframe[['assetKey', 'ric']]
                else:
                    # Try to get RICs from the model universe frame
                    try:
                        mfm = SignalMgr.getStatic('model_universe_frame')
                        mfm = mfm.rename(columns={'RkdDisplayRIC': 'ric'})
                        dframe_for_join = dframe[['assetKey']].merge(
                            mfm[['assetKey', 'ric']], how='left', on='assetKey'
                        )

                    except Exception as e:
                        PyLog.info(
                            f"Axioma integration: unable to build RICs from model_universe_frame "
                            f"({e}); using empty ric column."
                        )
                        dframe_for_join = dframe.assign(ric=None)[['assetKey', 'ric']]

                # existing specific covariance in optimizer order
                scov = rmodel['scov'].reindex(index=assets, columns=assets)

                # use the diagonal of scov as the existing specific variance
                spec_diag_vec = np.diag(scov.to_numpy())

                axioma_out = process_axioma_specific_cov(
                    data_dir=AXIOMA_DATA_DIR,
                    trade_date=tradeDate,
                    assets=assets,
                    dframe=dframe_for_join,
                    spec_diag_df=spec_diag_vec,
                    splitLongShort=True
                )

                if axioma_out is False:
                    PyLog.info("Axioma integration: function returned False — using local specific matrix D.")
                elif isinstance(axioma_out, dict) and axioma_out.get('new_spec_cov') is not None:
                    D = axioma_out['new_spec_cov']
                    PyLog.info(f"Axioma hybrid specific covariance applied. Shape={D.shape}")
                else:
                    PyLog.info("Axioma integration: no covariance provided — using local specific matrix D.")
            except Exception as e:
                PyLog.info(f"Axioma integration failed (using local D). Reason: {e}")


            # === end Axioma CCSC integration ===
            """

            ## optional off-diagonal risk 
            if rebalConfig.offDiagRisk:
                scov = rmodel['scov']
                scov = scov.merge(dframe[['assetKey', 'idx']].rename(columns={'assetKey': 'xKey', 'idx': 'xIdx'}),
                                  how='inner', on='xKey')
                scov = scov.merge(dframe[['assetKey', 'idx']].rename(columns={'assetKey': 'yKey', 'idx': 'yIdx'}),
                                  how='inner', on='yKey')
                for xIdx, yIdx, cov in zip(scov['xIdx'], scov['yIdx'], scov['cov']):
                    D[xIdx, yIdx] = D[yIdx, xIdx] = cov
            ## optional off-diagonal risk 

            S = rmodel['fcov'].reindex(index=factors, columns=factors).to_numpy()
            if rebalConfig.shrinkfcov:
                S = matrix_shrink(S)
            else:
                PyLog.info(f"rebalConfig.shrinkfcov:{rebalConfig.shrinkfcov} using fcov S as is")
            F = fload.to_numpy()

            w = cp.Variable((len(dfAvail), 1))
            aggrw = availAggr @ w
            wlong = cp.pos(w)
            wshort = cp.neg(w)

            # China Connect short-side stability penalty (only if file exists and scores > 4)
            if rebalConfig.enable_gs_stability_score == True:
                PyLog.info(f"Looking at GS stability files for tradeDate = {tradeDate}")
                china_short_penalty = build_china_connect_short_penalty(dfAvail, wshort, tradeDate)
            else:
                PyLog.info(f"GS Stability flag set to {rebalConfig.enable_gs_stability_score} for tradeDate = {tradeDate}")
                

            f = F.T @ aggrw

            rtau = copy.copy(rebalConfig.tau)
            otau = tmult * rtau
            atau = cp.Parameter(nonneg=True)
            atau.value = max(0, otau - 1)
            btau = cp.Parameter(nonneg=True)
            btau.value = min(1, otau)

            Lmax = cp.Parameter()
            Lmax.value = rebalConfig.maxLeverage
            ## Lmax = rebalConfig.maxLeverage

            ## long financing cost / short borrow cost
            fcost = dfAvail['fcost'].to_numpy() @ wlong
            bcost = dfAvail['bcost'].to_numpy() @ wshort

            ## expected return / risk
            netReturn = mu.T @ aggrw - fcost - bcost
            variance = cp.quad_form(f, S) + cp.quad_form(aggrw, D)

            ## transactions cost
            tmodel = tmodel.reindex(index=assets)
            trade = (aggrw - preOptWeights.reindex(index=assets).fillna(0.0).to_frame().to_numpy())
            tbuy = cp.pos(trade)
            tsell = cp.neg(trade)

            linearCost = tmodel['linearBuyCoeff'].to_numpy() @ tbuy + tmodel['linearSellCoeff'].to_numpy() @ tsell
            impactBuy = cp.multiply(cp.sqrt(NAV) * rebalConfig.impactCostScale,
                                    tmodel['impactCoeffTH'].to_numpy() @ cp.power(tbuy, 3 / 2))
            impactSell = cp.multiply(cp.sqrt(NAV) * rebalConfig.impactCostScale,
                                     tmodel['impactCoeffTH'].to_numpy() @ cp.power(tsell, 3 / 2))
            tcost = linearCost + impactBuy + impactSell

            avgLinearCost = tmodel['linearAverageCoeff'].to_numpy() @ (tbuy + tsell)
            avgTcost = avgLinearCost + impactBuy + impactSell

            ## constraints ------------------------------------------------------------------------------------
            constraints = []
            penalty = 0

            """
            ## Freeze holiday assets: no trades on exchange holidays
            if rebalConfig.applyExchangeHolidays and (not optimal):
                try:
                    if 'holiday' in dframe.columns and dframe['holiday'].any():
                        # preOptWeights in dframe order
                        cur_vec = dframe['preOptWeights'].fillna(0.0).to_numpy().reshape(-1, 1)

                        # indices in aggrw corresponding to holiday assets
                        hol_idx = dframe.loc[dframe['holiday'], 'idx'].to_numpy()

                        if len(hol_idx) > 0:
                            # aggrw[hol_idx] == current weights  →  zero trades for those names
                            constraints.append(aggrw[hol_idx] == cur_vec[hol_idx])
                            PyLog.info(f"Holiday freeze: pinned {len(hol_idx)} assets to preOptWeights.")
                except Exception as e:
                    import traceback
                    PyLog.error(f"Exception inside holiday-freeze block: {type(e).__name__}: {e}")
                    PyLog.error(traceback.format_exc())
                    raise
            """

            # Start with China Connect short penalty if present
            if rebalConfig.enable_gs_stability_score == True:
                PyLog.info("Adding china short penalty")
                if china_short_penalty is not None:
                    penalty = penalty + china_short_penalty

            ## max leverage
            # constraints.append(cp.norm(w, 1) <= Lmax)
            # pen = cp.pos(cp.norm(w, 1) - Lmax - 0.1)
            pen = cp.pos(cp.norm(w, 1) - Lmax)
            penalty = penalty + pen

            ## min leverage
            # if rebalConfig.minLeverage is not None:
            #     Lmin = cp.Parameter()
            #     Lmin.value = rebalConfig.minLeverage
            #     constraints.append(cp.norm(w, 1) >= Lmin)

            ## risk bound
            # constraints.append(variance <= (rebalConfig.maxRisk ** 2))
            pen = 5 * cp.pos(variance - (rebalConfig.maxRisk ** 2))
            penalty = penalty + pen

            ## net exposure bounds
            # constraints.append(cp.sum(w) >= rebalConfig.minNetExposure)
            # constraints.append(cp.sum(w) <= rebalConfig.maxNetExposure)
            pen = cp.pos(cp.sum(aggrw) - rebalConfig.maxNetExposure) \
                  + cp.pos(rebalConfig.minNetExposure - cp.sum(aggrw))
            penalty = penalty + pen

            ## individual position bounds - hard ---------------------------------------------------------
            constraints.append(cp.max(availAggr @ wlong - dframe[['longBoundHard']].to_numpy()) <= 0.0)
            constraints.append(cp.max(availAggr @ wshort - dframe[['shortBoundHard']].to_numpy()) <= 0.0)



            constraints.append(cp.max(wlong - dfAvail[['longBoundHard']].to_numpy()) <= 0.0)
            constraints.append(cp.max(wshort - dfAvail[['shortBoundHard']].to_numpy()) <= 0.0)


            ## individual position bounds - soft ---------------------------------------------------------
            pen = cp.sum(cp.pos(cp.pos(aggrw) - dframe[['longBoundSoft']].to_numpy())) \
                  + cp.sum(cp.pos(cp.neg(aggrw) - dframe[['shortBoundSoft']].to_numpy()))
            penalty = penalty + pen

            ## apply combined position bounds ===========================================================
            if rebalConfig.offDiagRisk:
                # if rebalConfig.applyExchangeHolidays and not 'HK' in lstHolidays:
                combAggr = np.zeros((len(scov), len(assets)))
                for n in range(len(scov)):
                    combAggr[n, scov['xIdx'].iloc[n]] = combAggr[n, scov['yIdx'].iloc[n]] = 1
                dfm = scov.merge(dframe[['assetKey', 'longBoundSoft', 'shortBoundSoft']].
                                 rename(
                    columns={'assetKey': 'xKey', 'longBoundSoft': 'xLong', 'shortBoundSoft': 'xShort'}),
                                 how='left', on='xKey')
                dfm = dfm.merge(dframe[['assetKey', 'longBoundSoft', 'shortBoundSoft']].
                                rename(
                    columns={'assetKey': 'yKey', 'longBoundSoft': 'yLong', 'shortBoundSoft': 'yShort'}),
                                how='left', on='yKey')
                dfm = dfm.assign(longBoundSoft=1.05 * dfm[['xLong', 'yLong']].max(axis=1))
                dfm = dfm.assign(shortBoundSoft=1.05 * dfm[['xShort', 'yShort']].max(axis=1))
                # dfm = dfm.assign(grossBound = 1.5 * dfm[['xLong', 'yLong', 'xShort', 'yShort']].max(axis=1))
                if len(dfm) > 0:
                    # constraints.append(cp.max(cp.pos(aggr @ w) - dfm[['longBound']].to_numpy()) <= 0.0)
                    # constraints.append(cp.max(cp.neg(aggr @ w) - dfm[['shortBound']].to_numpy()) <= 0.0)
                    pen = cp.sum(cp.pos(cp.pos(combAggr @ aggrw) - dfm[['longBoundSoft']].to_numpy())) \
                          + cp.sum(cp.pos(cp.neg(combAggr @ aggrw) - dfm[['shortBoundSoft']].to_numpy()))
                    penalty = penalty + pen
            ## ==========================================================================================

            ## country net/gross exposures
            if not Strategy.isSingleCountry(stratName):
                cBounds = SignalMgr.get(rebalConfig.countryBoundName, signalDate, stratName)
                # if rebalConfig.applyExchangeHolidays:
                #     cBounds = cBounds[~cBounds['cfactor'].isin(['country_' + x for x in lstHolidays])]
                cfactors = [x for x in factors if x.startswith('country_') and x in cBounds['cfactor'].tolist()]
                cBounds = cBounds.set_index('cfactor')
                cBounds = cBounds.reindex(index=cfactors)
                cBounds = cBounds.assign(maxGross=cBounds['maxGross'].fillna(0.0))
                cBounds = cBounds.assign(minNet=cBounds['minNet'].fillna(0.0))
                cBounds = cBounds.assign(maxNet=cBounds['maxNet'].fillna(0.0))
                if rebalConfig.tightCountry:
                    PyLog.info(f"tightening country bounds rebalConfig.tightCountry: {rebalConfig.tightCountry} stratName:{stratName}")
                    cBounds['minNet'] = -0.0
                    cBounds['maxNet'] = 0.0
                #PyLog.info(f"these cBounds are\n{cBounds} total max gross {cBounds.maxGross.sum() :.4f} ")
                # china hack. lower china leave room for JP, TH

                cBounds.loc["country_CN", "maxGross"] = 0.34
                cBounds.loc["country_JP", "maxGross"] = 0.40
                cBounds.loc["country_TH", "maxGross"] = 0.16

                #PyLog.info(f"after setting china these cBounds are\n{cBounds} total max gross {cBounds.maxGross.sum() :.4f} ")
                cNetExp = fload.reindex(index=assets, columns=cfactors).to_numpy().T @ aggrw
                maxNet = cBounds[['maxNet']].to_numpy().reshape(cNetExp.shape)
                minNet = cBounds[['minNet']].to_numpy().reshape(cNetExp.shape)
                #constraints.append(cp.max(cNetExp - maxNet) <= 0.0)
                #constraints.append(cp.min(cNetExp - minNet) >= 0.0)
                if rebalConfig.tightCountry:
                    K = 50
                    PyLog.info(f"tightening country bounds with high penalty {K} rebalConfig.tightCountry: {rebalConfig.tightCountry} stratName:{stratName}")
                    pen = K * cp.sum(cp.pos(cNetExp - maxNet)) + K * cp.sum(cp.pos(minNet - cNetExp))
                else:
                    pen = cp.sum(cp.pos(cNetExp - maxNet)) + cp.sum(cp.pos(minNet - cNetExp))
                penalty = penalty + pen
                cGrossExp = fload.reindex(index=assets, columns=cfactors).to_numpy().T @ cp.abs(aggrw)
                # maxGross = (rebalConfig.maxLeverage * 1.05 * cBounds[['maxGross']]).to_numpy().reshape(cGrossExp.shape)
                maxGross = (rebalConfig.maxLeverage * cBounds[['maxGross']]).to_numpy().reshape(cGrossExp.shape)
                # constraints.append(cp.max(cGrossExp - maxGross) <= 0.0)
                pen = cp.sum(cp.pos(cGrossExp - maxGross))
                penalty = penalty + pen

            ## industry net/gross exposures
            iBounds = SignalMgr.get(rebalConfig.industryBoundName, signalDate, stratName)
            #PyLog.info(f"iBounds are: {iBounds}")
            ifactors = [x for x in factors if x.startswith('ind_') if x in iBounds['ifactor'].tolist()]

            iBounds = iBounds.set_index('ifactor')
            iBounds = iBounds.reindex(index=ifactors)
            iBounds = iBounds.assign(maxGross=iBounds['maxGross'].fillna(0.0))
            iBounds = iBounds.assign(minNet=iBounds['minNet'].fillna(0.0))
            iBounds = iBounds.assign(maxNet=iBounds['maxNet'].fillna(0.0))

            iNetExp = fload.reindex(index=assets, columns=ifactors).to_numpy().T @ aggrw
            maxNet = iBounds[['maxNet']].to_numpy().reshape(iNetExp.shape)
            minNet = iBounds[['minNet']].to_numpy().reshape(iNetExp.shape)
            # constraints.append(cp.max(iNetExp - maxNet) <= 0.0)
            # constraints.append(cp.min(iNetExp - minNet) >= 0.0)
            pen = cp.sum(cp.pos(iNetExp - maxNet)) + cp.sum(cp.pos(minNet - iNetExp))
            penalty = penalty + pen

            iGrossExp = fload.reindex(index=assets, columns=ifactors).to_numpy().T @ cp.abs(aggrw)
            # maxGross = (rebalConfig.maxLeverage * 1.05 * iBounds[['maxGross']]).to_numpy().reshape(iGrossExp.shape)
            maxGross = (rebalConfig.maxLeverage * iBounds[['maxGross']]).to_numpy().reshape(iGrossExp.shape)
            # constraints.append(cp.max(iGrossExp - maxGross) <= 0.0)
            pen = cp.sum(cp.pos(iGrossExp - maxGross))
            penalty = penalty + pen


            ## size net/gross exposures
            sfactors = [x for x in factors if x.startswith('size_')]

            sBounds = SignalMgr.get(rebalConfig.sizeBoundName, signalDate, stratName).set_index('sfactor')
            sBounds = sBounds.reindex(index=sfactors)
            sBounds = sBounds.assign(maxGross=sBounds['maxGross'].fillna(0.0))
            sBounds = sBounds.assign(minNet=sBounds['minNet'].fillna(0.0))
            sBounds = sBounds.assign(maxNet=sBounds['maxNet'].fillna(0.0))

            sNetExp = fload.reindex(index=assets, columns=sfactors).to_numpy().T @ aggrw
            maxNet = sBounds[['maxNet']].to_numpy().reshape(sNetExp.shape)
            minNet = sBounds[['minNet']].to_numpy().reshape(sNetExp.shape)
            # constraints.append(cp.max(sNetExp - maxNet) <= 0.0)
            # constraints.append(cp.min(sNetExp - minNet) >= 0.0)
            pen = cp.sum(cp.pos(sNetExp - maxNet)) + cp.sum(cp.pos(minNet - sNetExp))
            penalty = penalty + pen

            sGrossExp = fload.reindex(index=assets, columns=sfactors).to_numpy().T @ cp.abs(aggrw)
            # maxGross = (rebalConfig.maxLeverage * 1.05 * sBounds[['maxGross']]).to_numpy().reshape(sGrossExp.shape)
            maxGross = (rebalConfig.maxLeverage * sBounds[['maxGross']]).to_numpy().reshape(sGrossExp.shape)
            sbounds_maxGross = sBounds[['maxGross']]
            PyLog.info(f"maxGross: {maxGross} = rebalConfig.maxLeverage {rebalConfig.maxLeverage} * {sbounds_maxGross}")
            # constraints.append(cp.max(sGrossExp - maxGross) <= 0.0)
            pen = cp.sum(cp.pos(sGrossExp - maxGross))
            penalty = penalty + pen

            ## net exposure to sbeta_market
            mNetExp = fload.reindex(index=assets, columns=['sbeta_market']).to_numpy().T @ aggrw
            constraints.append(cp.min(mNetExp) >= rebalConfig.minBetaAdjNet)
            constraints.append(cp.max(mNetExp) <= rebalConfig.maxBetaAdjNet)
            PyLog.info(f"mNetExp: {mNetExp} rebalConfig.minBetaAdjNet:{rebalConfig.minBetaAdjNet} rebalConfig.maxBetaAdjNet:{rebalConfig.maxBetaAdjNet}")
            ## phase out value_liq in 2017~ ============================================================
            # if PyDate.ge(signalDate, 20170101) and PyDate.le(signalDate, PyMonth.lastWeekday(202012)):
            # # if PyDate.ge(signalDate, 20170101):
            #     vliqExp = fload.reindex(index=assets, columns=['value_liq']).to_numpy().T @ w
            #     bound = 2.0 * Filter.bound(PyDate.span(signalDate, 20170331) / 90, lower=0.01)
            #     constraints.append(cp.max(vliqExp) <= bound)
            #     constraints.append(cp.min(vliqExp) >= -bound)
            ## ==========================================================================================

#
            mfm = SignalMgr.getStatic('model_universe_frame')
            mfm = mfm.rename(columns={'RkdTicker': 'ticker', 'RkdDisplayRIC': 'ric', 'Country': 'quoteCountry'})
            dframe = mfm[['assetKey', 'ticker', 'quoteCountry', 'ric']].merge(dframe, how='right', on='assetKey')

                 
            df_lbs_sbs_ne0 = dframe[ (dframe['longBoundSoft'] !=0 ) | (dframe['shortBoundSoft'] !=0 ) ]
            univ_wt0 = dframe[ dframe['univwt']!=0 ]
            preOptWeights0 = dframe[ (dframe['preOptWeights']==0) | ( dframe['preOptWeights'].isnull() ) ]
            sb_lb_0 = dframe[ (dframe['longBoundHard'] !=0 ) | (dframe['shortBoundHard'] !=0 ) ]
#            sb_lb_0 = dframe

            user = os.environ['USER']
            sb_lb_0.to_csv(f"/home/{user}/scratch/nonZeroHardBounds.csv", index=False, columns=['assetKey','ticker','quoteCountry','ric']) 
            dframe.to_csv(f"/home/{user}/scratch/preOptUniv.csv", index=False, columns=['assetKey','ticker','quoteCountry','ric']) 
            

            ## apply alpha theme exposure bounds ========================================================
            if rebalConfig.themeBoundName is not None:
                themeBounds = SignalMgr.get(rebalConfig.themeBoundName, signalDate, stratName)
                themeBounds = FrameUtil.toSeries(themeBounds, keyCol='factor', valCol='bound') * rebalConfig.maxLeverage
                themes = list(themeBounds.index)
                themeExp = fload.reindex(index=assets, columns=themes).to_numpy().T @ aggrw
                # constraints.append(cp.max(themeExp - themeBounds.to_numpy().reshape(themeExp.shape)) <= 0)
                # constraints.append(cp.min(themeExp + themeBounds.to_numpy().reshape(themeExp.shape)) >= 0)
                pen = cp.sum(cp.pos(themeExp - themeBounds.to_numpy().reshape(themeExp.shape))) \
                      + cp.sum(cp.pos(- themeExp - themeBounds.to_numpy().reshape(themeExp.shape)))
                penalty = penalty + pen
            ## here we are constructing symmetric bounds around 0 but it is the upper bound that we are
            ## primarily concerned about
            ## ==========================================================================================

            ## penalty - apply max trade constraints ====================================================
            if (not optimal) and (rebalConfig.maxAdvProp is not None):
                dfm = SignalMgr.getFrame(rebalConfig.advName, signalDate, stratName)
                dfm = dframe[['assetKey']].merge(dfm[['assetKey', rebalConfig.advName]], how='left', on='assetKey')
                dfm = dfm.assign(
                    maxTradeWeights=rebalConfig.maxAdvProp * dfm[rebalConfig.advName].fillna(0.0) * 1000000 / NAV)
                ## may need to relax this condition for certain conditions (i.e. forced trading)
                # constraints.append(cp.max(cp.abs(trade) - dfm[['maxTradeWeights']].to_numpy()) <= 0.0)
                pen = (cp.abs(trade) - dfm[['maxTradeWeights']].to_numpy()) / dfm[['maxTradeWeights']].to_numpy()
                # pen = 4 * cp.max(cp.pos(pen))
                pen = 4 * cp.sum(cp.pos(pen))
                penalty = penalty + pen
            ## ==========================================================================================

            gamma = cp.Parameter(nonneg=True)

            # Uncomment this to print final dataframe before optimization starts
            #dframe.to_csv("dframe_final.csv")

            PyLog.info(f"constraints==== {len(constraints)}===")

            if rebalConfig.adaptiveGamma:

                prevGamma = 0.0
                currGamma = 0.0
                prevRisk = 0.0
                currRisk = 0.0
                offset = 1.0

                while Real.isZero(currRisk) or \
                        abs(currRisk - rebalConfig.adaptiveGammaTargetRisk) > rebalConfig.adaptiveGammaTolerance:

                    if Real.isZero(currRisk):
                        rgamma = rebalConfig.gamma
                    elif Real.isZero(prevRisk):
                        rgamma = currGamma + np.sign(currRisk - rebalConfig.adaptiveGammaTargetRisk) * offset
                    elif Real.isPositive(np.sign(currRisk - rebalConfig.adaptiveGammaTargetRisk)
                                         * np.sign(prevRisk - rebalConfig.adaptiveGammaTargetRisk)):
                        rgamma = currGamma + np.sign(currRisk - rebalConfig.adaptiveGammaTargetRisk)
                    else:
                        rgamma = (prevGamma + currGamma) / 2
                        offset = offset / 4

                    gamma.value = gmult * rgamma

                    ## optimization setup
                    if optimal:
                        problem = cp.Problem(cp.Maximize(netReturn - gamma * variance - penalty), constraints)
                    else:
                        problem = cp.Problem(cp.Maximize(netReturn - gamma * variance - btau * tcost - atau * avgTcost
                                                         - penalty), constraints)

                    numIters = cls.solveProblem(rebalConfig=rebalConfig, problem=problem, maxIter=maxIter,
                                                cycleIter=cycleIter, solver=solver, verbose=verbose, checkDCP=checkDCP,
                                                optimal=optimal)

                    if w.value is None:
                        PyLog.info("   Optimizer failed to converge")
                        w.value = dfAvail[['preOptWeights']].fillna(0.0).to_numpy()

                    weights = pd.Series(aggrw.value.flatten(), index=assets)

                    prevGamma = currGamma
                    prevRisk = currRisk

                    currGamma = rgamma
                    currRisk = RiskModel.computeRisk(rmodel, weights)

            else:

                rgamma = rebalConfig.gamma
                gamma.value = gmult * rgamma

                ## optimization setup
                if optimal:
                    problem = cp.Problem(cp.Maximize(netReturn - gamma * variance - penalty), constraints)
                else:
                    problem = cp.Problem(cp.Maximize(netReturn - gamma * variance - btau * tcost - atau * avgTcost
                                                     - penalty), constraints)

                ## problem.solve
                numIters = cls.solveProblem(rebalConfig=rebalConfig, problem=problem, maxIter=maxIter,
                                            cycleIter=cycleIter, solver=solver, verbose=verbose, checkDCP=checkDCP,
                                            optimal=optimal)

                if w.value is None:
                    PyLog.info("   Optimizer failed to converge")
                    w.value = dfAvail[['preOptWeights']].fillna(0.0).to_numpy()



            dfAvail = dfAvail.assign(optimalWeights = list(w.value.flatten()))
            dfAvail = dfAvail.assign(postOptWeights = dfAvail['optimalWeights'])
            # dfAvail.loc[abs(dfAvail['postOptWeights']) < rebalConfig.holdingThreshold, 'postOptWeights'] = 0.0

            dframe = dframe.assign(optimalWeights = availAggr @ dfAvail['optimalWeights'])
            dframe = dframe.assign(postOptWeights = availAggr @ dfAvail['postOptWeights'])

            dfWeights = dfWeights.merge(dframe[['assetKey', 'optimalWeights', 'postOptWeights']],
                                        how='outer', on='assetKey')
            dfWeights.loc[abs(dfWeights['postOptWeights'].fillna(0.0)) < rebalConfig.holdingThreshold, 'postOptWeights'] = 0.0
            dfWeights = dfWeights.assign(
                tradeWeights = dfWeights['postOptWeights'].fillna(0.0) - dfWeights['preOptWeights'].fillna(0.0))

            # --- Start China penalty post-solve logging ---

            if rebalConfig.enable_gs_stability_score == True:
                if china_short_penalty is not None and hasattr(china_short_penalty, "_china_penalty_info"):
                    try:
                        info = china_short_penalty._china_penalty_info
                        weights_arr = china_short_penalty._china_penalty_weights
                        scale = china_short_penalty._china_scale

                        PyLog.info("ChinaConnect: post-solve penalty details for penalized assets (top 50):")
                        for idx, assetKey, score, pw in info[:50]:
                            if wshort.value is None:
                                short_w = None
                                contrib = None
                            else:
                                short_w = float(wshort.value[idx])
                                contrib = scale * pw * short_w if short_w is not None else None

                            PyLog.info(
                                f"    assetKey={assetKey}, score={score}, penaltyWeight={pw}, "
                                f"shortWeight={(short_w if short_w is not None else 'None')}, "
                                f"penaltyContribution={(contrib if contrib is not None else 'None')}"
                            )

                        if len(info) > 50:
                            PyLog.info(f"    ...and {len(info) - 50} more penalized assets")

                        total_pen = float(china_short_penalty.value) if china_short_penalty.value is not None else None
                        PyLog.info(f"ChinaConnect: total China penalty = {total_pen}")

                    except Exception as e:
                        PyLog.error(f"ChinaConnect: error computing post-solve penalty log: {e}")
            # --- End China penalty post-solve logging ---

            # weights = pd.Series(w.value.flatten(), index=assets)


        else:
            PyLog.info(f"Entering split block. rebalConfig.splitLongShort:{rebalConfig.splitLongShort} rebalConfig.splitLongShortStartDate:{rebalConfig.splitLongShortStartDate}")
            PyLog.info(f"rebalConfig.applyExchangeHolidays:{rebalConfig.applyExchangeHolidays} optimal:{optimal}")

            """
            ## reflect exchange holidays -------------------------------------------------------------------
            if rebalConfig.applyExchangeHolidays and (not optimal):
                lstHolidays = cls.getExchangeHolidays(tradeDate)
                dframe = dframe.assign(holiday=dframe['modelCountry'].isin(lstHolidays))
                dframe.loc[dframe['holiday'], 'longBoundHard'] \
                    = Filter.bound(dframe.loc[dframe['holiday'], 'preOptWeights'].fillna(0.0), lower=0.0) \
                      + rebalConfig.holdingThreshold / 100.0
                dframe.loc[dframe['holiday'], 'shortBoundHard'] \
                    = Filter.bound(-dframe.loc[dframe['holiday'], 'preOptWeights'].fillna(0.0), lower=0.0) \
                      + rebalConfig.holdingThreshold / 100.0
            ## ---------------------------------------------------------------------------------------------
            """
            
            ## NEW reflect exchange holidays
            if rebalConfig.applyExchangeHolidays and (not optimal):
                lstHolidays = cls.getExchangeHolidays(tradeDate)
                dframe = dframe.assign(holiday=dframe['modelCountry'].isin(lstHolidays))

                cur = dframe['preOptWeights'].fillna(0.0)
                cur_long  = Filter.bound(cur,  lower=0.0)
                cur_short = Filter.bound(-cur, lower=0.0)

                dframe.loc[dframe['holiday'], 'longBoundHard']  = cur_long.loc[dframe['holiday']]
                dframe.loc[dframe['holiday'], 'shortBoundHard'] = cur_short.loc[dframe['holiday']]

            factors = rmodel['factors']
            fload  = rmodel['fload'].reindex(index=assets, columns=factors)
            ## NEW reflect exchange holidays

            mu = dframe[['alpha']].to_numpy()
            D  = np.diag(rmodel['srisk'].reindex(index=assets) ** 2)

            ## optional off-diagonal risk ------------------------------------------------------------------------------
            PyLog.info(f"rebalConfig.offDiagRisk:{rebalConfig.offDiagRisk}")
            if rebalConfig.offDiagRisk:
                scov = rmodel['scov']
                dfm = pd.DataFrame({'assetKey': assets, 'idx': range(len(assets))})
                scov = scov.merge(dfm.rename(columns={'assetKey': 'xKey', 'idx': 'xIdx'}), how='inner', on='xKey')
                scov = scov.merge(dfm.rename(columns={'assetKey': 'yKey', 'idx': 'yIdx'}), how='inner', on='yKey')
                for xIdx, yIdx, cov in zip(scov['xIdx'], scov['yIdx'], scov['cov']):
                    D[xIdx, yIdx] = D[yIdx, xIdx] = cov
            ## ---------------------------------------------------------------------------------------------------------

            S  = rmodel['fcov'].reindex(index=factors, columns=factors).to_numpy()
            if rebalConfig.shrinkfcov:
                S = matrix_shrink(S)
            else:
                PyLog.info(f"rebalConfig.shrinkfcov:{rebalConfig.shrinkfcov} using fcov S as is")
            F  = fload.to_numpy()

            w = cp.Variable((len(assets), 1))
            ## wlong  = cp.maximum( w, 0)
            ## wshort = cp.maximum(-w, 0)
            wlong  = cp.pos(w)
            wshort = cp.neg(w)
            
            f = F.T @ w

            rtau = rebalConfig.tau
            otau = tmult * rtau
            atau = cp.Parameter(nonneg=True)
            atau.value = max(0, otau - 1)
            btau = cp.Parameter(nonneg=True)
            btau.value = min(1, otau)

            Lmax = cp.Parameter()
            Lmax.value = rebalConfig.maxLeverage
            ## Lmax = rebalConfig.maxLeverage

            ## long financing cost / short borrow cost
            fcost = dframe['fcost'].to_numpy() @ wlong
            bcost = dframe['bcost'].to_numpy() @ wshort

            ## expected return / risk
            netReturn = mu.T @ w - fcost - bcost
            variance = cp.quad_form(f, S) + cp.quad_form(w, D)

            ## transactions cost
            tmodel = tmodel.reindex(index=assets)
            trade = (w - preOptWeights.reindex(index=assets).fillna(0.0).to_frame().to_numpy())
            tbuy  = cp.pos(trade)
            tsell = cp.neg(trade)

            linearCost = tmodel['linearBuyCoeff'].to_numpy() @ tbuy + tmodel['linearSellCoeff'].to_numpy() @ tsell
            ## impactCost = (tmodel['impactCoeffTH'].to_numpy() * ((NAV * (tbuy + tsell)) ** (3/2))) / NAV
            ## impactCost = tmodel['impactCoeffTH'].to_numpy() * cp.multiply(1/NAV, cp.power(cp.multiply(NAV, tbuy + tsell), 3/2))
            ## impactBuy  = cp.multiply(1 / NAV, tmodel['impactCoeffTH'].to_numpy() * cp.power(cp.multiply(NAV, tbuy),  3 / 2))
            ## impactSell = cp.multiply(1 / NAV, tmodel['impactCoeffTH'].to_numpy() * cp.power(cp.multiply(NAV, tsell), 3 / 2))
            impactBuy  = cp.multiply(cp.sqrt(NAV) * rebalConfig.impactCostScale,
                                     tmodel['impactCoeffTH'].to_numpy() @ cp.power(tbuy,  3 / 2))
            impactSell = cp.multiply(cp.sqrt(NAV) * rebalConfig.impactCostScale,
                                     tmodel['impactCoeffTH'].to_numpy() @ cp.power(tsell, 3 / 2))
            tcost = linearCost + impactBuy + impactSell

            avgLinearCost = tmodel['linearAverageCoeff'].to_numpy() @ (tbuy + tsell)
            avgTcost = avgLinearCost + impactBuy + impactSell


            ## constraints --------------------------------------------------------------------
            constraints = []
            penalty = 0

            ## max leverage
            # constraints.append(cp.norm(w, 1) <= Lmax)
            # pen = cp.pos(cp.norm(w, 1) - Lmax - 0.1)
            pen = cp.pos(cp.norm(w, 1) - Lmax)
            penalty = penalty + pen

            ## min leverage
            # if rebalConfig.minLeverage is not None:
            #     Lmin = cp.Parameter()
            #     Lmin.value = rebalConfig.minLeverage
            #     constraints.append(cp.norm(w, 1) >= Lmin)

            ## risk bound
            # constraints.append(variance <= (rebalConfig.maxRisk ** 2))
            pen = 5 * cp.pos(variance - (rebalConfig.maxRisk ** 2))
            penalty = penalty + pen

            ## net exposure bounds
            # constraints.append(cp.sum(w) >= rebalConfig.minNetExposure)
            # constraints.append(cp.sum(w) <= rebalConfig.maxNetExposure)
            pen = cp.pos(cp.sum(w) - rebalConfig.maxNetExposure) \
                  + cp.pos(rebalConfig.minNetExposure - cp.sum(w))
            penalty = penalty + pen

            ## individual position bounds - hard ---------------------------------------------------------
            constraints.append(cp.max(wlong  - dframe[['longBoundHard']].to_numpy())  <= 0.0)
            constraints.append(cp.max(wshort - dframe[['shortBoundHard']].to_numpy()) <= 0.0)

            ## individual position bounds - soft ---------------------------------------------------------
            pen = cp.sum(cp.pos(cp.pos(w) - dframe[['longBoundSoft']].to_numpy())) \
                  + cp.sum(cp.pos(cp.neg(w) - dframe[['shortBoundSoft']].to_numpy()))
            penalty = penalty + pen

            ## apply combined position bounds ===========================================================
            if rebalConfig.offDiagRisk:
                # if rebalConfig.applyExchangeHolidays and not 'HK' in lstHolidays:
                    aggr = np.zeros((len(scov), len(assets)))
                    for n in range(len(scov)):
                        aggr[n, scov['xIdx'].iloc[n]] = aggr[n, scov['yIdx'].iloc[n]] = 1
                    dfm = scov.merge(dframe[['assetKey', 'longBoundSoft', 'shortBoundSoft']].
                                     rename(columns={'assetKey': 'xKey', 'longBoundSoft': 'xLong', 'shortBoundSoft': 'xShort'}),
                                     how='left', on='xKey')
                    dfm = dfm.merge(dframe[['assetKey', 'longBoundSoft', 'shortBoundSoft']].
                                    rename(columns={'assetKey': 'yKey', 'longBoundSoft': 'yLong', 'shortBoundSoft': 'yShort'}),
                                    how='left', on='yKey')
                    dfm = dfm.assign(longBoundSoft = 1.05 * dfm[['xLong', 'yLong']].max(axis=1))
                    dfm = dfm.assign(shortBoundSoft = 1.05 * dfm[['xShort', 'yShort']].max(axis=1))
                    # dfm = dfm.assign(grossBound = 1.5 * dfm[['xLong', 'yLong', 'xShort', 'yShort']].max(axis=1))
                    if len(dfm) > 0:
                        # constraints.append(cp.max(cp.pos(aggr @ w) - dfm[['longBound']].to_numpy()) <= 0.0)
                        # constraints.append(cp.max(cp.neg(aggr @ w) - dfm[['shortBound']].to_numpy()) <= 0.0)
                        pen = cp.sum(cp.pos(cp.pos(aggr @ w) - dfm[['longBoundSoft']].to_numpy())) \
                              + cp.sum(cp.pos(cp.neg(aggr @ w) - dfm[['shortBoundSoft']].to_numpy()))
                        penalty = penalty + pen
            ## ==========================================================================================

            ## country net/gross exposures
            if not Strategy.isSingleCountry(stratName):
                cBounds  = SignalMgr.get(rebalConfig.countryBoundName, signalDate, stratName)
                # if rebalConfig.applyExchangeHolidays:
                #     cBounds = cBounds[~cBounds['cfactor'].isin(['country_' + x for x in lstHolidays])]
                cfactors = [x for x in factors if x.startswith('country_') and x in cBounds['cfactor'].tolist()]
                cBounds  = cBounds.set_index('cfactor')
                cBounds  = cBounds.reindex(index=cfactors)
                cBounds  = cBounds.assign(maxGross = cBounds['maxGross'].fillna(0.0))
                cBounds  = cBounds.assign(minNet = cBounds['minNet'].fillna(0.0))
                cBounds  = cBounds.assign(maxNet = cBounds['maxNet'].fillna(0.0))
                if rebalConfig.tightCountry:
                    PyLog.info(f"tightening country bounds rebalConfig.tightCountry: {rebalConfig.tightCountry} stratName:{stratName}")
                    cBounds['minNet'] = -0.0
                    cBounds['maxNet'] = 0.0
                # aggrw is  w
                cNetExp = fload.reindex(index=assets, columns=cfactors).to_numpy().T @ w 
                maxNet = cBounds[['maxNet']].to_numpy().reshape(cNetExp.shape)
                minNet = cBounds[['minNet']].to_numpy().reshape(cNetExp.shape)
                PyLog.info(f"here cBounds are\n{cBounds}")
                #constraints.append(cp.max(cNetExp - maxNet) <= 0.0)
                #constraints.append(cp.min(cNetExp - minNet) >= 0.0)
                if rebalConfig.tightCountry:
                    K = 50
                    PyLog.info(f"tightening country bounds with high penalty {K} rebalConfig.tightCountry: {rebalConfig.tightCountry} stratName:{stratName}")
                    pen = K * cp.sum(cp.pos(cNetExp - maxNet)) + K * cp.sum(cp.pos(minNet - cNetExp))
                else:
                    pen = cp.sum(cp.pos(cNetExp - maxNet)) + cp.sum(cp.pos(minNet - cNetExp))
                penalty = penalty + pen
                cGrossExp = fload.reindex(index=assets, columns=cfactors).to_numpy().T @ cp.abs(w)
                # maxGross = (rebalConfig.maxLeverage * 1.05 * cBounds[['maxGross']]).to_numpy().reshape(cGrossExp.shape)
                maxGross = (rebalConfig.maxLeverage * cBounds[['maxGross']]).to_numpy().reshape(cGrossExp.shape)
                # constraints.append(cp.max(cGrossExp - maxGross) <= 0.0)
                pen = cp.sum(cp.pos(cGrossExp - maxGross))
                penalty = penalty + pen

            ## industry net/gross exposures
            iBounds  = SignalMgr.get(rebalConfig.industryBoundName, signalDate, stratName)

            ifactors = [x for x in factors if x.startswith('ind_') if x in iBounds['ifactor'].tolist()]

            iBounds  = iBounds.set_index('ifactor')
            iBounds  = iBounds.reindex(index=ifactors)
            iBounds  = iBounds.assign(maxGross = iBounds['maxGross'].fillna(0.0))
            iBounds  = iBounds.assign(minNet = iBounds['minNet'].fillna(0.0))
            iBounds  = iBounds.assign(maxNet = iBounds['maxNet'].fillna(0.0))

            iNetExp  = fload.reindex(index=assets, columns=ifactors).to_numpy().T @ w
            maxNet = iBounds[['maxNet']].to_numpy().reshape(iNetExp.shape)
            minNet = iBounds[['minNet']].to_numpy().reshape(iNetExp.shape)
            # constraints.append(cp.max(iNetExp - maxNet) <= 0.0)
            # constraints.append(cp.min(iNetExp - minNet) >= 0.0)
            pen = cp.sum(cp.pos(iNetExp - maxNet)) + cp.sum(cp.pos(minNet - iNetExp))
            penalty = penalty + pen

            iGrossExp = fload.reindex(index=assets, columns=ifactors).to_numpy().T @ cp.abs(w)
            # maxGross = (rebalConfig.maxLeverage * 1.05 * iBounds[['maxGross']]).to_numpy().reshape(iGrossExp.shape)
            maxGross = (rebalConfig.maxLeverage * iBounds[['maxGross']]).to_numpy().reshape(iGrossExp.shape)
            # constraints.append(cp.max(iGrossExp - maxGross) <= 0.0)
            pen = cp.sum(cp.pos(iGrossExp - maxGross))
            penalty = penalty + pen

            ## size net/gross exposures
            sfactors = [x for x in factors if x.startswith('size_')]

            sBounds  = SignalMgr.get(rebalConfig.sizeBoundName, signalDate, stratName).set_index('sfactor')
            sBounds  = sBounds.reindex(index=sfactors)
            sBounds  = sBounds.assign(maxGross = sBounds['maxGross'].fillna(0.0))
            sBounds  = sBounds.assign(minNet = sBounds['minNet'].fillna(0.0))
            sBounds  = sBounds.assign(maxNet = sBounds['maxNet'].fillna(0.0))

            sNetExp  = fload.reindex(index=assets, columns=sfactors).to_numpy().T @ w
            maxNet = sBounds[['maxNet']].to_numpy().reshape(sNetExp.shape)
            minNet = sBounds[['minNet']].to_numpy().reshape(sNetExp.shape)
            # constraints.append(cp.max(sNetExp - maxNet) <= 0.0)
            # constraints.append(cp.min(sNetExp - minNet) >= 0.0)
            pen = cp.sum(cp.pos(sNetExp - maxNet)) + cp.sum(cp.pos(minNet - sNetExp))
            penalty = penalty + pen

            PyLog.info(f"sfactors: {sfactors}")
            PyLog.info(f"fload:\n{fload[sfactors].iloc[0]}")
            sGrossExp = fload.reindex(index=assets, columns=sfactors).to_numpy().T @ cp.abs(w)
            # maxGross = (rebalConfig.maxLeverage * 1.05 * sBounds[['maxGross']]).to_numpy().reshape(sGrossExp.shape)
            maxGross = (rebalConfig.maxLeverage * sBounds[['maxGross']]).to_numpy().reshape(sGrossExp.shape)
            PyLog.info(f"sBounds are\n{sBounds}")
            sbounds_maxGross = sBounds[['maxGross']]
            PyLog.info(f"maxGross: {maxGross} = rebalConfig.maxLeverage {rebalConfig.maxLeverage} * {sbounds_maxGross}")
            # constraints.append(cp.max(sGrossExp - maxGross) <= 0.0)
            pen = cp.sum(cp.pos(sGrossExp - maxGross))
            penalty = penalty + pen


            ## net exposure to sbeta_market
            mNetExp = fload.reindex(index=assets, columns=['sbeta_market']).to_numpy().T @ w
            constraints.append(cp.min(mNetExp) >= rebalConfig.minBetaAdjNet)
            constraints.append(cp.max(mNetExp) <= rebalConfig.maxBetaAdjNet)
            PyLog.info(f"mNetExp: {mNetExp} rebalConfig.minBetaAdjNet:{rebalConfig.minBetaAdjNet} rebalConfig.maxBetaAdjNet:{rebalConfig.maxBetaAdjNet}")


            ## phase out value_liq in 2017~ ============================================================
            # if PyDate.ge(signalDate, 20170101) and PyDate.le(signalDate, PyMonth.lastWeekday(202012)):
            # # if PyDate.ge(signalDate, 20170101):
            #     vliqExp = fload.reindex(index=assets, columns=['value_liq']).to_numpy().T @ w
            #     bound = 2.0 * Filter.bound(PyDate.span(signalDate, 20170331) / 90, lower=0.01)
            #     constraints.append(cp.max(vliqExp) <= bound)
            #     constraints.append(cp.min(vliqExp) >= -bound)
            ## ==========================================================================================

            ## apply alpha theme exposure bounds ========================================================
            if rebalConfig.themeBoundName is not None:
                themeBounds = SignalMgr.get(rebalConfig.themeBoundName, signalDate, stratName)
                themeBounds = FrameUtil.toSeries(themeBounds, keyCol='factor', valCol='bound') * rebalConfig.maxLeverage
                themes = list(themeBounds.index)
                themeExp = fload.reindex(index=assets, columns=themes).to_numpy().T @ w
                # constraints.append(cp.max(themeExp - themeBounds.to_numpy().reshape(themeExp.shape)) <= 0)
                # constraints.append(cp.min(themeExp + themeBounds.to_numpy().reshape(themeExp.shape)) >= 0)
                pen = cp.sum(cp.pos(themeExp - themeBounds.to_numpy().reshape(themeExp.shape))) \
                    + cp.sum(cp.pos(- themeExp - themeBounds.to_numpy().reshape(themeExp.shape)))
                penalty = penalty + pen
            ## here we are constructing symmetric bounds around 0 but it is the upper bound that we are
            ## primarily concerned about
            ## ==========================================================================================

            ## penalty - apply max trade constraints ====================================================
            if (not optimal) and (rebalConfig.maxAdvProp is not None):
                dfm = SignalMgr.getFrame(rebalConfig.advName, signalDate, stratName)
                dfm = dframe[['assetKey']].merge(dfm[['assetKey', rebalConfig.advName]], how='left', on='assetKey')
                dfm = dfm.assign(
                    maxTradeWeights = rebalConfig.maxAdvProp * dfm[rebalConfig.advName].fillna(0.0) * 1000000 / NAV)
                ## may need to relax this condition for certain conditions (i.e. forced trading)
                # constraints.append(cp.max(cp.abs(trade) - dfm[['maxTradeWeights']].to_numpy()) <= 0.0)
                pen = (cp.abs(trade) - dfm[['maxTradeWeights']].to_numpy()) / dfm[['maxTradeWeights']].to_numpy()
                # pen = 4 * cp.max(cp.pos(pen))
                pen = 4 * cp.sum(cp.pos(pen))
                penalty = penalty + pen
            ## ==========================================================================================

            gamma = cp.Parameter(nonneg=True)

            if rebalConfig.adaptiveGamma:

                prevGamma = 0.0
                currGamma = 0.0
                prevRisk = 0.0
                currRisk = 0.0
                offset = 1.0

                while Real.isZero(currRisk) or \
                        abs(currRisk - rebalConfig.adaptiveGammaTargetRisk) > rebalConfig.adaptiveGammaTolerance:

                    if Real.isZero(currRisk):
                        rgamma = rebalConfig.gamma
                    elif Real.isZero(prevRisk):
                        rgamma = currGamma + np.sign(currRisk - rebalConfig.adaptiveGammaTargetRisk) * offset
                    elif Real.isPositive(np.sign(currRisk - rebalConfig.adaptiveGammaTargetRisk)
                                         * np.sign(prevRisk - rebalConfig.adaptiveGammaTargetRisk)):
                        rgamma = currGamma + np.sign(currRisk - rebalConfig.adaptiveGammaTargetRisk)
                    else:
                        rgamma = (prevGamma + currGamma) / 2
                        offset = offset / 4

                    gamma.value = gmult * rgamma

                    ## optimization setup
                    if optimal:
                        problem = cp.Problem(cp.Maximize(netReturn - gamma * variance - penalty), constraints)
                    else:
                        problem = cp.Problem(cp.Maximize(netReturn - gamma * variance - btau * tcost - atau * avgTcost
                                                         - penalty), constraints)

                    numIters = cls.solveProblem(rebalConfig=rebalConfig, problem=problem, maxIter=maxIter,
                                                cycleIter=cycleIter, solver=solver, verbose=verbose, checkDCP=checkDCP,
                                                optimal=optimal)

                    if w.value is None:
                        PyLog.info("   Optimizer failed to converge")
                        w.value = dframe[['preOptWeights']].fillna(0.0).to_numpy()

                    weights = pd.Series(w.value.flatten(), index=assets)

                    prevGamma = currGamma
                    prevRisk = currRisk

                    currGamma = rgamma
                    currRisk = RiskModel.computeRisk(rmodel, weights)

            else:

                rgamma = rebalConfig.gamma
                gamma.value = gmult * rgamma

                ## optimization setup
                if optimal:
                    problem = cp.Problem(cp.Maximize(netReturn - gamma * variance - penalty), constraints)
                else:
                    problem = cp.Problem(cp.Maximize(netReturn - gamma * variance - btau * tcost - atau * avgTcost
                                                     - penalty), constraints)

                ## problem.solve
                numIters = cls.solveProblem(rebalConfig=rebalConfig, problem=problem, maxIter=maxIter,
                                            cycleIter=cycleIter, solver=solver, verbose=verbose, checkDCP=checkDCP,
                                            optimal=optimal)

                if w.value is None:
                    PyLog.info("   Optimizer failed to converge")
                    w.value = dframe[['preOptWeights']].fillna(0.0).to_numpy()

                weights = pd.Series(w.value.flatten(), index=assets)

            dfWeights = dfWeights.merge(weights.reset_index().rename(columns={'index': 'assetKey', 0: 'optimalWeights'}),
                                        how='outer', on='assetKey')

            dfWeights = dfWeights.assign(postOptWeights = dfWeights['optimalWeights'])
            dfWeights.loc[abs(dfWeights['postOptWeights'].fillna(0.0)) < rebalConfig.holdingThreshold, 'postOptWeights'] = 0.0
            # weights.loc[abs(weights) < rebalConfig.holdingThreshold] = 0.0
            # dfWeights = dfWeights.merge(weights.reset_index().rename(columns={'index': 'assetKey', 0: 'postOptWeights'}),
            #                             how='outer', on='assetKey')


        dctOptResult['numIters']  = numIters

        ## including cvxpy objects increased output file from 100K to 80MB
        dctOptResult['gamma']     = rgamma
        dctOptResult['netReturn'] = netReturn.value
        dctOptResult['variance']  = variance.value

        dfWeights = dfWeights.merge(dframe[['assetKey', 'longBoundHard', 'shortBoundHard']], how='left', on='assetKey')
        dfWeights = dfWeights.assign(longBoundHard = dfWeights['longBoundHard'].fillna(0.0))
        dfWeights = dfWeights.assign(shortBoundHard = dfWeights['shortBoundHard'].fillna(0.0))

        dfWeights = dfWeights.assign(preOptWeights  = dfWeights['preOptWeights'].fillna(0.0))
        dfWeights = dfWeights.assign(optimalWeights = dfWeights['optimalWeights'].fillna(0.0))
        dfWeights = dfWeights.assign(postOptWeights = dfWeights['postOptWeights'].fillna(0.0))
        dfWeights = dfWeights.assign(tradeWeights   = dfWeights['postOptWeights'] - dfWeights['preOptWeights'])
        dctOptResult['dfWeights'] = dfWeights

        dframe = dfWeights.merge(tmodel.reset_index(), how='left', on='assetKey')

        linearCost = NAV * ((Filter.bound(dframe['tradeWeights'], lower=0.0) *
                             dframe['linearBuyCoeff'].fillna(dframe['linearBuyCoeff'].max())) -
                            (Filter.bound(dframe['tradeWeights'], upper=0.0) *
                             dframe['linearSellCoeff'].fillna(dframe['linearSellCoeff'].max()))).sum()
        dctOptResult['linearCostUSD'] = linearCost

        impactCost = rebalConfig.impactCostScale * np.power(NAV, 3 / 2) * \
                     (np.power(abs(dframe['tradeWeights']), 3 / 2) *
                      dframe['impactCoeffTH'].fillna(dframe['impactCoeffTH'].max())).sum()
        dctOptResult['impactCostUSD'] = impactCost

        dctOptResult['tcostUSD'] = linearCost + impactCost

        dctOptResult['fcost'] = fcost.value[0]
        dctOptResult['bcost'] = bcost.value[0]
        dctOptResult['tradeWeight'] = abs(dfWeights['tradeWeights']).sum()
        dctOptResult['tradeValUSD'] = dctOptResult['tradeWeight'] * dctOptResult['preOptNAV']

        return dctOptResult

    @classmethod
    def getExchangeHolidays(cls, tradeDate):
        hframe = SignalMgr.getStatic('trading_exchange_holidays_master_frame')
        hframe = hframe[hframe['tradeDate'] == tradeDate]
        hframe = hframe[hframe['trading'] == 'No']
        lstHolidays = hframe['quoteCountry'].tolist()
        if 'HK' in hframe['quoteCountry']:
            lstHolidays = list(set(lstHolidays + ['CN', 'XH']))
        return lstHolidays

    @classmethod
    def solveProblem(cls, rebalConfig, problem, maxIter=200000, cycleIter=5000,
                     solver=cp.ECOS, verbose=False, checkDCP=False, optimal=False):

        if checkDCP:
            dcpFlag = problem.is_dcp()
            PyLog.info('Problem is DCP (Disciplined Convex Problem): {}'.format(dcpFlag))
            PyLog.assertion(dcpFlag, "Optimization problem is not DCP (Disciplined Convex Problem)")

        if solver == cp.SCS:

            targetTolerance = rebalConfig.targetTolerance
            toleranceRange = rebalConfig.toleranceRange

            nAttempt = 0
            maxAttempt = len(toleranceRange)
            success = False
            auxIter = 10
            numIters = 0

            while (not success) and (nAttempt < maxAttempt):

                maxCycles = int(np.ceil(maxIter / cycleIter))
                nCycles = 0
                cumIters = 0

                while (not success) and (nCycles < maxCycles):

                    tol = targetTolerance
                    try:
                        problem.solve(verbose=verbose, solver=solver, max_iters=cycleIter, eps=tol)
                    except:
                        pass

                    cumIters += problem.solver_stats.num_iters
                    numIters += problem.solver_stats.num_iters
                    nCycles += 1

                    PyLog.info('tau : {:.2f}  /  numIters : {}  / cumIters : {}  /  tol : {:e}'.format(
                        0.0 if optimal else rtau, problem.solver_stats.num_iters, cumIters, tol))

                    if (problem.solver_stats.num_iters < cycleIter):
                        success = True

                if (problem.solver_stats.num_iters < cycleIter):
                    success = True
                else:
                    tol = toleranceRange[nAttempt]
                    try:
                        problem.solve(verbose=verbose, solver=solver, max_iters=auxIter, eps=tol)
                    except:
                        pass

                    PyLog.info('tau : {:.2f}  /  numIters : {}  /  tol : {:e}'.format(
                        0.0 if optimal else rtau, problem.solver_stats.num_iters, tol))

                    numIters += problem.solver_stats.num_iters

                    if (problem.solver_stats.num_iters < auxIter):
                        success = True
                    else:
                        nAttempt += 1

        else:
            try:
                success = True
                PyLog.info(f"level one ECOS solve ")
                #problem.solve(verbose=verbose, solver=cp.ECOS, max_iters=maxIter, abstol=1e-6, reltol=1e-4, feastol=2e-6, abstol_inacc=2e-6, reltol_inacc=2e-4, feastol_inacc=2e-6)
                val1 = problem.solve(verbose=verbose, 
                                     solver=cp.ECOS, 
                                     max_iters=maxIter, 
                                     warm_start=True, 
                                     abstol=1e-5, 
                                     reltol=5e-4, 
                                     feastol=1e-5)

                #CLARABEL : cvxpy deprecated ECOS in May 2024
                #PyLog.info(f"level one CLARABEL solve")
                #problem.solve(verbose=verbose, solver=cp.CLARABEL)

            except Exception as ee:
                success = False
                PyLog.warning(f"Error: optimization failure first level tolerance. td: caught: {ee}")
                PyLog.warning(traceback.format_exc())

            if (not success) or (problem.status == 'infeasible'):
                try:
                    success = True
                    PyLog.info(f"level two ECOS solve ")
                    #problem.solve(verbose=verbose, solver=cp.ECOS, max_iters=maxIter,
                    #              abstol=1e-6, reltol=5e-4, feastol=2e-6,
                    #              abstol_inacc=2e-6, reltol_inacc=1e-3, feastol_inacc=2e-6)
                    # --- Level 2 (fallback, looser than Level 1) ---
                    val2 = problem.solve(
                        solver=cp.ECOS,
                        max_iters=int(maxIter * 2),  # give the fallback extra room
                        warm_start=True,
                        abstol=5e-5,
                        reltol=1e-3,
                        feastol=2e-5,
                        verbose=verbose
                    )

                except Exception as ee:
                    success = False
                    PyLog.warning(f"Error: ECOS failure second level tolerance. td: caught: {ee}")
                    PyLog.warning(traceback.format_exc())

            if (not success) or (problem.status == 'infeasible'):
                try:
                    PyLog.info(f"level three SCS solve ")
                    #problem.solve(verbose=verbose, solver=cp.ECOS, max_iters=maxIter,
                    #              abstol=1e-6, reltol=5e-3, feastol=2e-6,
                    #              abstol_inacc=2e-6, reltol_inacc=1e-2, feastol_inacc=2e-6)
                    problem.solve(verbose=verbose, solver=cp.SCS)
                except Exception as ee:
                    PyLog.warning(f"SCS failure third level tolerance. td: caught: {ee}")
                    PyLog.warning(traceback.format_exc())

            numIters = problem.solver_stats.num_iters

        return numIters


    @classmethod
    def runOptimal(cls, rebalConfig=RebalConfig_DEFAULT, signalDate=PyDate.asDate(20191115), NAV=100_000_000,
                   solver=cp.SCS, maxIter=200000, cycleIter=5000, verbose=True, checkDCP=True, tradeRestrictions=None):

        return cls.run(rebalConfig=rebalConfig, signalDate=signalDate, preOptWeights=cls.EMPTY_PORTFOLIO,
                       NAV=NAV, solver=solver, maxIter=maxIter, cycleIter=cycleIter,
                       verbose=verbose, checkDCP=checkDCP, optimal=True, tradeRestrictions=tradeRestrictions)

    @classmethod
    def generateRiskReport(cls, dctOptResult, weightCol='postOptWeights', maxReturnDate=PyDate.today()):
        rebalConfig = dctOptResult['rebalConfig']
        portWeights = FrameUtil.toSeries(dctOptResult['dfWeights'], keyCol='assetKey', valCol=weightCol)
        return RiskReport.runPortfolio(port       = portWeights,
                                       signalDate = dctOptResult['signalDate'],
                                       alphaName  = rebalConfig.alphaName,
                                       fcostName  = rebalConfig.fcostName,
                                       bcostName  = rebalConfig.bcostName,
                                       rmodelName = rebalConfig.rmodelName,
                                       envName    = Strategy.getModelName(rebalConfig.stratName),
                                       minWeight  = rebalConfig.holdingThreshold,
                                       maxReturnDate = maxReturnDate)

    @classmethod
    def printRiskReport(cls, dctOptResult, plotChart=False):
        dctReport = cls.generateRiskReport(dctOptResult)
        return RiskReport.print(dctReport, plotChart)



    @classmethod
    def compileShortAvailability(cls, rebalConfig, NAV, tradeDate, dframe):
        lookback = rebalConfig.bcostLookback
        endTD = tradeDate
        startTD = PyDate.minusWeekdays(endTD, lookback - 1)
        dfWgts = pd.DataFrame({'tradeDate': PyDate.sequenceWeekday(startTD, endTD, decreasing=True),
                               'timeWgts': Stats.expwts(lookback, lookback/2)})
        dfWgts = dfWgts.assign(timeScale = dfWgts['timeWgts'] / dfWgts['timeWgts'].iloc[0])

        dfRaw = ShortAvailability.getRange(startTD, endTD, pbList=rebalConfig.pbList)
        dfRaw = dfRaw[dfRaw['assetKey'].isin(dframe['assetKey'])]


        if rebalConfig.HAUM:
            AUM = rebalConfig.hypotheticalAUM
        else:
            AUM = NAV

        ########## THIS BLOCK IS NOT NEEDED SINCE WE DON'T HAVE JPM AS BROKER ############
        """
        dfRaw = dfRaw.assign(
            scale=rebalConfig.maxShortUtilization
                  + (1 - rebalConfig.maxShortUtilization)
                  * ((dfRaw['pbCode'] == 'JPM') & (dfRaw['category'] == 'connect')).astype(int))
        """
        ########## THIS BLOCK IS NOT NEEDED SINCE WE DON'T HAVE JPM AS BROKER ############

        # So max shorting depends only only this config variable.
        dfRaw = dfRaw.assign(scale = rebalConfig.maxShortUtilization)

        dfRaw = dfRaw.assign(vsHAUM=dfRaw['scale'] * dfRaw['notionalUSD'] / AUM)
        dfRaw = dfRaw.merge(dfWgts[['tradeDate', 'timeScale']], how='left', on='tradeDate')
        ## availableRate is float and it works poorly with group()
        dfRaw = dfRaw.assign(rateClass=['{:.4f}'.format(x) for x in dfRaw['availableRate']])


        dfMarginal = dfRaw[dfRaw['tradeDate'] == tradeDate].\
            groupby('assetKey').agg({'vsHAUM': sum}).reset_index().\
            rename(columns={'vsHAUM': 'availablePct'})

        ### THIS IS INCORRECTLY SUMMING UP ALL THE EXPONENTIAL DAYS ###
        #dfConsol = dfRaw.groupby(['assetKey', 'rateClass']). \
            #apply(lambda dfm: (dfm['vsHAUM'] * dfm['timeScale']).sum()). \
            #reset_index().rename(columns={0: 'availablePct'})
        
        dfConsol = dfRaw.groupby(['assetKey', 'rateClass']) \
                    .apply(lambda dfm: (dfm['vsHAUM'] * dfm['timeScale']).sum() / dfm['timeScale'].sum()) \
                    .reset_index().rename(columns={0: 'availablePct'})

        dfConsol = dfConsol.assign(availableRate=dfConsol['rateClass'].astype(float))
        dfConsol = dfConsol.sort_values(by=['assetKey', 'availableRate'], ascending=True). \
            groupby('assetKey').apply(lambda dfm: dfm.assign(cumHi=dfm['availablePct'].cumsum())).reset_index(drop=True)
        dfConsol = dfConsol.groupby('assetKey'). \
            apply(lambda dfm: dfm.assign(cumLo=[0.0] + dfm['cumHi'][:-1].tolist())).reset_index(drop=True)
        dfConsol = dfConsol.merge(dframe[['assetKey', 'preOptWeights']], how='left', on='assetKey')

        ## discard irrelevant availability
        softShortBoundBuffer = 0.01
        dfConsol = dfConsol.merge(dframe[['assetKey', 'shortBoundSoft', 'shortBoundHard']], how='left', on='assetKey')
        dfConsol = dfConsol.assign(bufferedShortBoundSoft=dfConsol['shortBoundSoft'] + softShortBoundBuffer)
        dfConsol = dfConsol[dfConsol['cumLo'] <= dfConsol['shortBoundHard']]
        dfConsol = dfConsol[dfConsol['cumLo'] <= dfConsol['bufferedShortBoundSoft']]

        ## technically, this is not strict but we are going to simplify the logic and keep track of
        ## all availability and marginal availability
        dfConsol = dfConsol.sort_values(by=['assetKey', 'availableRate'], ascending=True). \
            groupby('assetKey').apply(lambda dfm: dfm.assign(tier=range(len(dfm)))).reset_index(drop=True)

        columns = ['assetKey', 'tier', 'availableRate', 'availablePct']
        dfBaseTier = dfConsol[dfConsol['tier'] == 0][columns].merge(dframe, how='inner', on='assetKey')

        #dfBaseTier = dfBaseTier.assign(bcost=dfBaseTier['availableRate'] / 100)
        dfBaseTier = dfBaseTier.assign(bcost=dfBaseTier['availableRate'] )

        # dfBaseTier = dfBaseTier.assign(shortBoundHard = dfBaseTier[['shortBoundHard', 'availablePct']].min(axis=1))
        dfBaseTier = dfBaseTier.assign(shortBoundHard=dfBaseTier['availablePct'])
        dfBaseTier = dfBaseTier.drop(columns=['availableRate', 'availablePct'])

        dfHighTier = dfConsol[dfConsol['tier'] != 0][['assetKey', 'tier', 'availableRate', 'availablePct']]. \
            merge(dframe, how='inner', on='assetKey')
        dfHighTier = dfHighTier.assign(longBoundHard=0.0)
        dfHighTier = dfHighTier.assign(shortBoundHard=dfHighTier['availablePct'])
        
        #dfHighTier = dfHighTier.assign(bcost=dfHighTier['availableRate'] / 100)
        dfHighTier = dfHighTier.assign(bcost=dfHighTier['availableRate'] )

        dfHighTier = dfHighTier.assign(fcost=1.00)
        dfHighTier = dfHighTier.assign(preOptWeights=np.NaN)
        dfHighTier = dfHighTier.drop(columns=['availableRate', 'availablePct'])

        ## no short availability
        dfNoAvail = dframe[~dframe['assetKey'].isin(dfConsol['assetKey'])]
        dfNoAvail = dfNoAvail.assign(shortBoundHard=0.0)
        dfNoAvail = dfNoAvail.assign(bcost=1.0)
        dfNoAvail = dfNoAvail.assign(tier=-1)

        dfAvail = pd.concat([dfBaseTier, dfHighTier, dfNoAvail], ignore_index=True)
        dfAvail = dfAvail.assign(key=dfAvail['assetKey'] + ':' + dfAvail['tier'].astype(str))
        dfAvail = dfAvail.sort_values(by=['assetKey', 'tier'])


        ## bcostMultiplier
        if rebalConfig.bcostMultiplierName is not None:
            signalDate = PyDate.prevWeekday(tradeDate)
            signalObj = Signal.registryLookup(rebalConfig.bcostMultiplierName)
            dfm = SignalMgr.getFrame(signalObj.signalName, signalDate).\
                rename(columns={signalObj.signalName: 'multiplier'})
            dfAvail = dfAvail.merge(dfm[['assetKey', 'multiplier']], how='left', on='assetKey')
            dfAvail = dfAvail.assign(multiplier = dfAvail['multiplier'].fillna(signalObj.maxMultiplier))
            dfAvail = dfAvail.assign(bcost = dfAvail['bcost'] * dfAvail['multiplier'])
            dfAvail = dfAvail.drop(columns='multiplier')


        return dfAvail, dfMarginal


def load_china_connect_inventory_scores(input_date, max_lookback_days):
    """
    Try to read China Connect Inventory Score file for the given signal date.
    From what I have noticed, for tradeDate say 11/19, we get files for  11/18 and that 
    too at 9pm NY time but our optimizer runs at 3pm NY time so these 11/18 files come late.
    So we process only two days earlier files for 11/17. Hence I have put in logic for max_lookback_days
    as 3 so we will try to get trade_date(say 11/19) files, if not (11/18), if not (11/17)
    Returns dataframe ['assetKey','stabilityScore'] or None.
    """
    import PyUtil.PyRicX as RicX

    try:
        pbdata_dir = os.environ.get('PBDIR')
        if not pbdata_dir:
            PyLog.info("ChinaConnect: PBDIR env var not set; skipping China Connect short penalty.")
            return None

        base_dir = os.path.join(pbdata_dir, "GS_Stability_Files")
        if not os.path.isdir(base_dir):
            PyLog.info(f"ChinaConnect: directory {base_dir} does not exist; skipping China Connect short penalty.")
            return None

        # ----------------------------------------------------------
        # NEW: Loop over working days (signalDate, -1 WD, -2 WD, etc.)
        # ----------------------------------------------------------
        dates_to_try = []
        d = PyDate.asDate(input_date)

        for _ in range(max_lookback_days):
            dates_to_try.append(d)
            d = PyDate.prevWeekday(d)

        filename = None

        for d in dates_to_try:
            date_str = d.strftime("%Y%m%d")
            pattern_xls  = os.path.join(base_dir, f"*China_Connect*In*{date_str}.xls")
            pattern_xlsx = os.path.join(base_dir, f"*China_Connect*In*{date_str}.xlsx")

            files = glob.glob(pattern_xls) + glob.glob(pattern_xlsx)

            if files:
                files.sort(key=os.path.getmtime)
                filename = files[-1]  # newest file for that date
                PyLog.info(f"ChinaConnect: using inventory score file {filename} (date {date_str})")
                break
            else:
                PyLog.info(f"ChinaConnect: no file for date {date_str}, trying previous working day...")

        if not filename:
            PyLog.info(f"ChinaConnect: no China_Connect file found in last {max_lookback_days} working days; skipping penalty.")
            return None
        # ----------------------------------------------------------

        # Read without header; find the row where the table starts
        raw = pd.read_excel(filename, header=None)
        col0 = raw.iloc[:, 0].astype(str).str.strip()
        header_rows = raw.index[col0 == "RIC"].tolist()
        if not header_rows:
            PyLog.info(f"ChinaConnect: could not find header row with 'RIC' in {filename}; skipping penalty.")
            return None

        h = header_rows[0]
        df = raw.iloc[h:].copy()
        df.columns = df.iloc[0]
        df = df.iloc[1:]  # drop header row

        df = df.rename(columns={c: str(c).strip() for c in df.columns})
        expected_cols = {"RIC", "SCORE", "SEDOL", "ISIN"}
        if not expected_cols.issubset(df.columns):
            PyLog.info(f"ChinaConnect: expected columns {expected_cols} not found; skipping penalty.")
            return None

        df = df[["RIC", "SCORE", "SEDOL", "ISIN"]].copy()
        df["RIC"] = df["RIC"].astype(str).str.strip()
        df["SCORE"] = pd.to_numeric(df["SCORE"], errors="coerce")

        # Keep only ZK and SH
        mask = df["RIC"].str.endswith(".ZK") | df["RIC"].str.endswith(".SH")
        df = df[mask].copy()
        if df.empty:
            PyLog.info("ChinaConnect: no .ZK/.SH rows in file; nothing to do.")
            return None

        # Convert to internal_ric: ZK->SZ, SH->SS
        df["internal_ric"] = df["RIC"]
        df.loc[df["RIC"].str.endswith(".ZK"), "internal_ric"] = df["RIC"].str.replace(".ZK", ".SZ", regex=False)
        df.loc[df["RIC"].str.endswith(".SH"), "internal_ric"] = df["RIC"].str.replace(".SH", ".SS", regex=False)

        # Map to assetKey
        df = RicX.add_asset_code(df, "internal_ric")
        if "assetKey" not in df.columns:
            PyLog.info("ChinaConnect: add_asset_code failed; skipping penalty.")
            return None

        df = df[["assetKey", "SCORE"]].dropna()
        if df.empty:
            PyLog.info("ChinaConnect: no usable rows remain; skipping penalty.")
            return None

        # Aggregate duplicates by max score
        df = df.groupby("assetKey", as_index=False)["SCORE"].max()
        df = df.rename(columns={"SCORE": "stabilityScore"})

        PyLog.info(f"ChinaConnect: loaded {len(df)} inventory score rows with assetKey.")
        return df

    except Exception as e:
        PyLog.error(f"ChinaConnect: exception while loading inventory scores: {e}")
        PyLog.error(traceback.format_exc())
        return None

def build_china_connect_short_penalty(dfAvail, wshort, input_date):
    """
    Build a cvxpy penalty term for short positions based on China Connect
    stability scores.

    - Only applies to assets in dfAvail (optimizer universe in split block).
    - Only applies on the short side (wshort).
    - No penalty for score <= 4.
    - For scores 5–8, penalty increases linearly with (score - 4).
    - Logs which assets are penalized (pre-solve) and stores metadata for
      post-solve debug logging.

    Returns:
        cp.Expression or None
    """

    # For now we will hardcode max_lookback_days here, ideally it should be in rebalance.config
    max_lookback_days = 3

    # Load scores from file

    df_scores = load_china_connect_inventory_scores(input_date, max_lookback_days)

    if df_scores is None or df_scores.empty:
        return None

    # Map assetKey -> stabilityScore
    df_scores = df_scores.set_index("assetKey")["stabilityScore"]

    if "assetKey" not in dfAvail.columns:
        PyLog.info("ChinaConnect: dfAvail has no 'assetKey' column; cannot apply stability penalty.")
        return None

    n = len(dfAvail)
    penalties = np.zeros(n)
    penalized_assets = []  # (idx, assetKey, score, penaltyWeight)

    # Build penalty weights per dfAvail row
    for idx, a in enumerate(dfAvail["assetKey"]):
        if a in df_scores.index:
            s = df_scores.loc[a]
            if pd.notna(s) and s > 4:
                # 5→1, 6→2, 7→3, 8→4 (can be tuned)
                pw = max(0.0, float(s) - 4.0)
                penalties[idx] = pw
                penalized_assets.append((idx, a, float(s), pw))

    if not penalized_assets:
        PyLog.info("ChinaConnect: no dfAvail assets with stabilityScore > 4; no penalty applied.")
        return None

    # Log penalized assets BEFORE solve (no short weights yet)
    PyLog.info(f"ChinaConnect: {len(penalized_assets)} dfAvail assets will have short-side stability penalty.")
    PyLog.info("ChinaConnect: penalized assets (before solve, top 50):")
    for idx, a, s, pw in penalized_assets[:50]:
        PyLog.info(f"    idx={idx}, assetKey={a}, score={s}, penaltyWeight={pw}")
    if len(penalized_assets) > 50:
        PyLog.info(f"    ...and {len(penalized_assets) - 50} more penalized assets")

    # Optional tuning knob (strength of this penalty)
    china_scale = 0.1  # adjust as needed or make configurable

    # penalties is constant data; cvxpy treats it as coefficients
    penalties_vec = penalties.reshape(-1, 1)
    penalty_expr = china_scale * cp.sum(cp.multiply(penalties_vec, wshort))

    # Attach metadata so we can inspect after solve
    penalty_expr._china_penalty_info = penalized_assets
    penalty_expr._china_penalty_weights = penalties_vec
    penalty_expr._china_scale = china_scale

    PyLog.info("ChinaConnect: short-side stability penalty term added to objective.")
    return penalty_expr

