from moonshot import Moonshot
from moonshot.commission import PerShareCommission
from quantrocket import get_prices
class USStockCommission(PerShareCommission):
    BROKER_COMMISSION_PER_SHARE = 0.005
class FirstHalfHourPredictsLastHalfHour(Moonshot):
    """
    Intraday strategy that buys (sells) if the market is up (down) during the first
    and penultimate half-hour.
    """
    CODE = 'first-last'
    DB = 'usstock-1min'
    DB_TIMES = ['10:00:00', '15:00:00', '15:30:00', '15:59:00']
    DB_FIELDS = ['Open','Close']
    SIDS = ["FIBBG000BDTBL9"]
    COMMISSION_CLASS = USStockCommission
    SLIPPAGE_BPS = 0.5
    MIN_VIX = None
    BENCHMARK = "FIBBG000BDTBL9"
    BENCHMARK_TIME = "15:59:00"
    def prices_to_signals(self, prices):
        closes = prices.loc["Close"]
        opens = prices.loc["Open"]
        
        prior_closes = closes.xs('15:59:00', level="Time").shift()
        ten_oclock_prices = opens.xs('10:00:00', level="Time")
        first_half_hour_returns = (ten_oclock_prices - prior_closes) / prior_closes
        
        fifteen_oclock_prices = opens.xs('15:00:00', level="Time")
        fifteen_thirty_prices = opens.xs('15:30:00', level="Time")
        penultimate_half_hour_returns = (fifteen_thirty_prices - fifteen_oclock_prices) / fifteen_oclock_prices
        
        long_signals = (first_half_hour_returns > 0) & (penultimate_half_hour_returns > 0)
        short_signals = (first_half_hour_returns < 0) & (penultimate_half_hour_returns < 0)
        
        signals = long_signals.astype(int).where(long_signals, -short_signals.astype(int))
        
        if self.MIN_VIX:
            
            vix = get_prices("vix-30min",
                             fields="Close",
                             start_date=signals.index.min(),
                             end_date=signals.index.max(),
                             times="14:00:00")
            
            vix = vix.loc["Close"].xs("14:00:00", level="Time").squeeze()
            
            vix = signals.apply(lambda x: vix)
            signals = signals.where(vix >= self.MIN_VIX, 0)
        return signals
    def signals_to_target_weights(self, signals, prices):
        
        target_weights = signals.copy()
        return target_weights
    def target_weights_to_positions(self, target_weights, prices):
        
        positions = target_weights.copy()
        return positions
    def positions_to_gross_returns(self, positions, prices):
        opens = prices.loc["Open"]
        closes = prices.loc["Close"]
        
        entry_prices = opens.xs("15:30:00", level="Time")
        session_closes = closes.xs("15:59:00", level="Time")
        pct_changes = (session_closes - entry_prices) / entry_prices
        gross_returns = pct_changes * positions
        return gross_returns