"""
Copyright 2019 Goldman Sachs.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
  http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied.  See the License for the
specific language governing permissions and limitations
under the License.
"""
import asyncio
import datetime as dt
import logging
import queue
import sys
import weakref
from abc import ABCMeta
from concurrent.futures import ThreadPoolExecutor
from inspect import signature
from itertools import zip_longest, takewhile
from typing import Optional, Union
from gs_quant.base import InstrumentBase, RiskKey, Scenario, get_enum_value
from gs_quant.common import PricingLocation, RiskMeasure
from gs_quant.context_base import ContextBaseWithDefault
from gs_quant.datetime.date import business_day_offset, today
from gs_quant.risk import CompositeScenario, DataFrameWithInfo, ErrorValue, FloatWithInfo, MarketDataScenario, \
    StringWithInfo
from gs_quant.risk.results import PricingFuture
from gs_quant.session import GsSession
from gs_quant.target.common import PricingDateAndMarketDataAsOf
from gs_quant.target.risk import RiskPosition, RiskRequest, RiskRequestParameters
from gs_quant.tracing import Tracer
from tqdm import tqdm
from .markets import CloseMarket, LiveMarket, Market, close_market_date, OverlayMarket, RelativeMarket
_logger = logging.getLogger(__name__)
CacheResult = Union[DataFrameWithInfo, FloatWithInfo, StringWithInfo]
class PricingCache(metaclass=ABCMeta):
    """
    Weakref cache for instrument calcs
    """
    __cache = weakref.WeakKeyDictionary()
    @classmethod
    def clear(cls):
        __cache = weakref.WeakKeyDictionary()
    @classmethod
    def get(cls, risk_key: RiskKey, instrument: InstrumentBase) -> Optional[CacheResult]:
        return cls.__cache.get(instrument, {}).get(risk_key)
    @classmethod
    def put(cls, risk_key: RiskKey, instrument: InstrumentBase, result: CacheResult):
        if not isinstance(result, ErrorValue) and not isinstance(risk_key.market, LiveMarket):
            cls.__cache.setdefault(instrument, {})[risk_key] = result
    @classmethod
    def drop(cls, instrument: InstrumentBase):
        if instrument in cls.__cache:
            cls.__cache.pop(instrument)
[docs]class PricingContext(ContextBaseWithDefault):
    """
    A context for controlling pricing and market data behaviour
    """
[docs]    def __init__(self,
                 pricing_date: Optional[dt.date] = None,
                 market_data_location: Optional[Union[PricingLocation, str]] = None,
                 is_async: bool = None,
                 is_batch: bool = None,
                 use_cache: bool = None,
                 visible_to_gs: Optional[bool] = None,
                 request_priority: Optional[int] = None,
                 csa_term: Optional[str] = None,
                 timeout: Optional[int] = None,
                 market: Optional[Market] = None,
                 show_progress: Optional[bool] = None,
                 use_server_cache: Optional[bool] = None,
                 market_behaviour: Optional[str] = 'ContraintsBased',
                 set_parameters_only: bool = False):
        """
        The methods on this class should not be called directly. Instead, use the methods on the instruments,
        as per the examples
        :param pricing_date: the date for pricing calculations. Default is today
        :param market_data_location: the location for sourcing market data ('NYC', 'LDN' or 'HKG' (defaults to LDN)
        :param is_async: if True, return (a future) immediately. If False, block (defaults to False)
        :param is_batch: use for calculations expected to run longer than 3 mins, to avoid timeouts.
            It can be used with is_async=True|False (defaults to False)
        :param use_cache: store results in the pricing cache (defaults to False)
        :param visible_to_gs: are the contents of risk requests visible to GS (defaults to False)
        :param request_priority: the priority of risk requests
        :param csa_term: the csa under which the calculations are made. Default is local ccy ois index
        :param timeout: the timeout for batch operations
        :param market: a Market object
        :param show_progress: add a progress bar (tqdm)
        :param use_server_cache: cache query results on the GS servers
        :param market_behaviour: the behaviour to build the curve for pricing ('ContraintsBased' or 'Calibrated'
            (defaults to ContraintsBased))
        :param set_parameters_only: if true don't stop embedded pricing contexts submitting their jobs.
        **Examples**
        To change the market data location of the default context:
        >>> from gs_quant.markets import PricingContext
        >>>
        >>> PricingContext.current = PricingContext(market_data_location='LDN')
        For a blocking, synchronous request:
        >>> from gs_quant.instrument import IRCap
        >>> cap = IRCap('5y', 'GBP')
        >>>
        >>> with PricingContext():
        >>>     price_f = cap.dollar_price()
        >>>
        >>> price = price_f.result()
        For an asynchronous request:
        >>> with PricingContext(is_async=True):
        >>>     price_f = cap.dollar_price()
        >>>
        >>> while not price_f.done:
        >>>     ...
        """
        super().__init__()
        if market and market_data_location and market.location is not \
                
get_enum_value(PricingLocation, market_data_location):
            raise ValueError('market.location and market_data_location cannot be different')
        if not market and pricing_date and pricing_date > dt.date.today() + dt.timedelta(5):
            # We allow a small tolerance to rolling over weekends/holidays
            # We should use a calendar but not everyone has access
            raise ValueError(
                'The PricingContext does not support a pricing_date in the future. Please use the RollFwd Scenario '
                'to roll the pricing_date to a future date')
        if market:
            market_date = None
            if isinstance(market, OverlayMarket) or isinstance(market, CloseMarket):
                market_date = getattr(market, 'date', None) or getattr(market.market, 'date', None)
            if isinstance(market, RelativeMarket):
                market_date = market.market.from_market.date if market.market.from_market.date > dt.date.today() \
                    
else market.market.to_market.date
            if market_date:
                if market_date > dt.date.today():
                    raise ValueError(
                        'The PricingContext does not support a market dated in the future. Please use the RollFwd '
                        'Scenario to roll the pricing_date to a future date')
        if not market_data_location:
            if market:
                market_data_location = market.location
        market_data_location = get_enum_value(PricingLocation, market_data_location)
        self.__pricing_date = pricing_date
        self.__csa_term = csa_term
        self.__market_behaviour = market_behaviour
        self.__is_async = is_async
        self.__is_batch = is_batch
        self.__timeout = timeout
        self.__use_cache = use_cache
        self.__visible_to_gs = visible_to_gs
        self.__request_priority = request_priority
        self.__market_data_location = market_data_location
        self.__market = market
        self.__show_progress = show_progress
        self.__use_server_cache = use_server_cache
        self.__max_per_batch = None
        self.__max_concurrent = None
        self.__set_parameters_only = set_parameters_only
        self.__pending = {}
        self._group_by_date = True
        self.__attrs_on_entry = {} 
    def __save_attrs_to(self, attr_dict):
        attr_dict['pricing_date'] = self.__pricing_date
        attr_dict['csa_term'] = self.__csa_term
        attr_dict['market_behaviour'] = self.__market_behaviour
        attr_dict['is_batch'] = self.__is_batch
        attr_dict['is_async'] = self.__is_async
        attr_dict['timeout'] = self.__timeout
        attr_dict['use_cache'] = self.__use_cache
        attr_dict['visible_to_gs'] = self.__visible_to_gs
        attr_dict['request_priority'] = self.__request_priority
        attr_dict['market_data_location'] = self.__market_data_location
        attr_dict['market'] = self.__market
        attr_dict['show_progress'] = self.__show_progress
        attr_dict['use_server_cache'] = self.__use_server_cache
        attr_dict['_max_concurrent'] = self.__max_concurrent
        attr_dict['_max_per_batch'] = self.__max_per_batch
    def _inherited_val(self, parameter, default=None, from_active=False):
        if from_active:
            # some properties are inherited from the active context
            if self != self.active_context and getattr(self.active_context, parameter) is not None:
                return getattr(self.active_context, parameter)
        if not self.is_entered and (not PricingContext.has_prior or self is not PricingContext.prior):
            # if not yet entered, get property from current (would-be prior) so that getters still display correctly
            if PricingContext.current is not self and PricingContext.current and getattr(PricingContext.current,
                                                                                         parameter) is not None:
                return getattr(PricingContext.current, parameter)
        else:
            # if entered, inherit from the prior
            if PricingContext.has_prior and PricingContext.prior is not self and getattr(PricingContext.prior,
                                                                                         parameter) is not None:
                return getattr(PricingContext.prior, parameter)
        # default if nothing to inherit
        return default
    def _on_enter(self):
        self.__save_attrs_to(self.__attrs_on_entry)
        self.__market_data_location = self.market_data_location
        self.__pricing_date = self.pricing_date
        self.__market = self.market
        self.__csa_term = self.csa_term
        self.__market_behaviour = self.market_behaviour
        self.__is_async = self.is_async
        self.__is_batch = self.is_batch
        self.__timeout = self.timeout
        self.__use_cache = self.use_cache
        self.__visible_to_gs = self.visible_to_gs
        self.__request_priority = self.request_priority
        self.__show_progress = self.show_progress
        self.__use_server_cache = self.use_server_cache
        self.__max_concurrent = self._max_concurrent
        self.__max_per_batch = self._max_per_batch
    def __reset_atts(self):
        self.__pricing_date = self.__attrs_on_entry.get('pricing_date')
        self.__csa_term = self.__attrs_on_entry.get('csa_term')
        self.__market_behaviour = self.__attrs_on_entry.get('market_behaviour')
        self.__is_async = self.__attrs_on_entry.get('is_async')
        self.__is_batch = self.__attrs_on_entry.get('is_batch')
        self.__timeout = self.__attrs_on_entry.get('timeout')
        self.__use_cache = self.__attrs_on_entry.get('use_cache')
        self.__visible_to_gs = self.__attrs_on_entry.get('visible_to_gs')
        self.__request_priority = self.__attrs_on_entry.get('request_priority')
        self.__market_data_location = self.__attrs_on_entry.get('market_data_location')
        self.__market = self.__attrs_on_entry.get('market')
        self.__show_progress = self.__attrs_on_entry.get('show_progress')
        self.__use_server_cache = self.__attrs_on_entry.get('use_server_cache')
        self.__max_concurrent = self.__attrs_on_entry.get('_max_concurrent')
        self.__max_per_batch = self.__attrs_on_entry.get('_max_per_batch')
        self.__attrs_on_entry = {}
    def _on_exit(self, exc_type, exc_val, exc_tb):
        try:
            if exc_val:
                raise exc_val
            else:
                self.__calc()
        finally:
            self.__reset_atts()
    def __calc(self):
        def run_requests(requests_: list, provider_, create_event_loop: bool, pc_attrs: dict, span):
            if create_event_loop:
                asyncio.set_event_loop(asyncio.new_event_loop())
            results = queue.Queue()
            done = False
            try:
                with session:
                    provider_.run(requests_, results, pc_attrs['_max_concurrent'], progress_bar,
                                  timeout=pc_attrs['timeout'], span=span)
            except Exception as e:
                provider_.enqueue(results, ((k, e) for k in self.__pending.keys()))
            while self.__pending and not done:
                done, chunk_results = provider_.drain_queue(results)
                for (risk_key_, priceable_), result in chunk_results:
                    future = self.__pending.pop((risk_key_, priceable_), None)
                    if future is not None:
                        future.set_result(result)
                        if pc_attrs['use_cache']:
                            PricingCache.put(risk_key_, priceable_, result)
            if not pc_attrs['is_async']:
                # In async mode we can't tell if we've completed, we could be re-used
                while self.__pending:
                    (risk_key_, _), future = self.__pending.popitem()
                    future.set_result(ErrorValue(risk_key_, 'No result returned'))
        # Group requests optimally
        requests_by_provider = {}
        for (key, instrument) in self.__pending.keys():
            dates_markets, measures = requests_by_provider.setdefault(key.provider, {}) \
                
.setdefault((key.params, key.scenario), {}) \
                
.setdefault(instrument, (set(), set()))
            dates_markets.add((key.date, key.market))
            measures.add(key.risk_measure)
        requests_for_provider = {}
        if requests_by_provider:
            session = GsSession.current
            request_visible_to_gs = session.is_internal() if self.__visible_to_gs is None else self.__visible_to_gs
            for provider, by_params_scenario in requests_by_provider.items():
                grouped_requests = {}
                for (params, scenario), positions_by_dates_markets_measures in by_params_scenario.items():
                    for instrument, (dates_markets, risk_measures) in positions_by_dates_markets_measures.items():
                        grouped_requests.setdefault((params, scenario, tuple(sorted(dates_markets)),
                                                     tuple(sorted(risk_measures))),
                                                    []).append(instrument)
                requests = []
                # Restrict to 1,000 instruments and 1 date in a batch, until server side changes are made
                for (params, scenario, dates_markets, risk_measures), instruments in grouped_requests.items():
                    for insts_chunk in [tuple(filter(None, i)) for i in
                                        zip_longest(*[iter(instruments)] * self._max_per_batch)]:
                        for dates_chunk in [tuple(filter(None, i)) for i in
                                            zip_longest(*[iter(dates_markets)] * (
                                                    1 if self._group_by_date else self._max_per_batch))]:
                            requests.append(RiskRequest(
                                tuple(RiskPosition(instrument=i, quantity=i.instrument_quantity,
                                                   instrument_name=i.name) for i in insts_chunk),
                                risk_measures,
                                parameters=params,
                                wait_for_results=not self.__is_batch,
                                scenario=scenario,
                                pricing_and_market_data_as_of=tuple(
                                    PricingDateAndMarketDataAsOf(pricing_date=d, market=m)
                                    for d, m in dates_chunk),
                                request_visible_to_gs=request_visible_to_gs,
                                use_cache=self.__use_server_cache,
                                priority=self.__request_priority
                            ))
                requests_for_provider[provider] = requests
            show_status = self.__show_progress and \
                          
(len(requests_for_provider) > 1 or len(next(iter(requests_for_provider.values()))) > 1)
            request_pool = ThreadPoolExecutor(len(requests_for_provider)) \
                
if len(requests_for_provider) > 1 or self.__is_async else None
            progress_bar = tqdm(total=len(self.__pending), position=0, maxinterval=1,
                                file=sys.stdout) if show_status else None
            completion_futures = []
            # Requests might get dispatched asynchronously and the PricingContext gets cleaned up on exit.
            # We should use a saved state of the object when dispatching async requests, except for self.__pending
            # All attributes are immutable, so a shared dictionary is sufficient. __pending remains shared.
            attrs_for_request = {}
            self.__save_attrs_to(attrs_for_request)
            span = Tracer.get_instance().active_span
            for provider, requests in requests_for_provider.items():
                if request_pool:
                    completion_future = request_pool.submit(run_requests, requests, provider, True,
                                                            attrs_for_request, span)
                    if not self.__is_async:
                        completion_futures.append(completion_future)
                else:
                    run_requests(requests, provider, False, attrs_for_request, span)
            # Wait on results if not async, so exceptions are surfaced
            if request_pool:
                request_pool.shutdown(False)
                all(f.result() for f in completion_futures)
    def __risk_key(self, risk_measure: RiskMeasure, provider: type) -> RiskKey:
        return RiskKey(provider, self.__pricing_date, self.__market, self._parameters, self._scenario, risk_measure)
    @property
    def _parameters(self) -> RiskRequestParameters:
        return RiskRequestParameters(csa_term=self.__csa_term, raw_results=True,
                                     market_behaviour=self.__market_behaviour)
    @property
    def _scenario(self) -> Optional[MarketDataScenario]:
        scenarios = Scenario.path
        if not scenarios:
            return None
        return MarketDataScenario(scenario=scenarios[0] if len(scenarios) == 1 else
        CompositeScenario(scenarios=tuple(reversed(scenarios))))
    @property
    def active_context(self):
        # active context cannot be below self on the stack - this also prevents infinite recursion when inheriting
        path = takewhile(lambda x: x != self, reversed(PricingContext.path))
        return next((c for c in path if c.is_entered and not c.set_parameters_only), self)
    @property
    def is_current(self) -> bool:
        return self == PricingContext.current
    @property
    def _max_concurrent(self) -> int:
        return self.__max_concurrent if self.__max_concurrent else self._inherited_val('_max_concurrent', default=1000)
    @_max_concurrent.setter
    def _max_concurrent(self, value):
        self.__max_concurrent = value
    @property
    def _max_per_batch(self) -> int:
        return self.__max_per_batch if self.__max_per_batch else self._inherited_val('_max_per_batch', default=1000)
    @_max_per_batch.setter
    def _max_per_batch(self, value):
        self.__max_per_batch = value
    @property
    def is_async(self) -> bool:
        if self.__is_async is not None:
            return self.__is_async
        return self._inherited_val('is_async', default=False)
    @property
    def is_batch(self) -> bool:
        return self.__is_batch if self.__is_batch else self._inherited_val('is_batch', default=False)
    @property
    def market(self) -> Market:
        return self.__market if self.__market else CloseMarket(
            date=close_market_date(self.market_data_location, self.pricing_date),
            location=self.market_data_location)
    @property
    def market_data_location(self) -> PricingLocation:
        return self.__market_data_location if self.__market_data_location else self._inherited_val(
            'market_data_location', from_active=True, default=PricingLocation.LDN)
    @property
    def csa_term(self) -> str:
        return self.__csa_term if self.__csa_term else self._inherited_val('csa_term')
    @property
    def show_progress(self) -> bool:
        return self.__show_progress if self.__show_progress else self._inherited_val('show_progress', default=False)
    @property
    def timeout(self) -> int:
        return self.__timeout if self.__timeout else self._inherited_val('timeout')
    @property
    def request_priority(self) -> int:
        return self.__request_priority if self.__request_priority else self._inherited_val('request_priority')
    @property
    def use_server_cache(self) -> bool:
        return self.__use_server_cache if self.__use_server_cache is not None else self._inherited_val(
            'use_server_cache', False)
    @property
    def market_behaviour(self) -> str:
        return self.__market_behaviour if self.__market_behaviour else self._inherited_val(
            'market_behaviour', default='ContraintsBased')
    @property
    def pricing_date(self) -> dt.date:
        """Pricing date"""
        if self.__pricing_date is not None:
            return self.__pricing_date
        default_pricing_date = business_day_offset(today(self.market_data_location), 0, roll='preceding')
        return self._inherited_val('pricing_date', default=default_pricing_date)
    @property
    def use_cache(self) -> bool:
        """Cache results"""
        return self.__use_cache if self.__use_cache else self._inherited_val('use_cache', default=False)
    @property
    def visible_to_gs(self) -> Optional[bool]:
        """Request contents visible to GS"""
        return self.__visible_to_gs if self.__visible_to_gs else self._inherited_val('visible_to_gs')
    @property
    def set_parameters_only(self) -> bool:
        return self.__set_parameters_only
    def clone(self, **kwargs):
        clone_kwargs = {k: getattr(self, k, None) for k in signature(self.__init__).parameters.keys()}
        clone_kwargs.update(kwargs)
        return self.__class__(**clone_kwargs)
    def _calc(self, instrument: InstrumentBase, risk_key: RiskKey) -> PricingFuture:
        pending = self.active_context.__pending
        from gs_quant.instrument import DummyInstrument
        if isinstance(instrument, DummyInstrument):
            return PricingFuture(StringWithInfo(value=instrument.dummy_result, risk_key=risk_key))
        future = pending.get((risk_key, instrument))
        if future is None:
            future = PricingFuture()
            cached_result = PricingCache.get(risk_key, instrument) if self.use_cache else None
            if cached_result is not None:
                future.set_result(cached_result)
            else:
                pending[(risk_key, instrument)] = future
        return future
    def calc(self, instrument: InstrumentBase, risk_measure: RiskMeasure) -> PricingFuture:
        """
        Calculate the risk measure for the instrument. Do not use directly, use via instruments
        :param instrument: The instrument
        :param risk_measure: The measure we wish to calculate
        :return: A PricingFuture whose result will be the calculation result
        **Examples**
        >>> from gs_quant.instrument import IRSwap
        >>> from gs_quant.risk import IRDelta
        >>>
        >>> swap = IRSwap('Pay', '10y', 'USD', fixed_rate=0.01)
        >>> delta = swap.calc(IRDelta)
        """
        return self._calc(instrument, self.__risk_key(risk_measure, instrument.provider)) 
class PositionContext(ContextBaseWithDefault):
    """
    A context for controlling portfolio position behaviour
    """
    def __init__(self,
                 position_date: Optional[dt.date] = None):
        """
        The methods on this class should not be called directly. Instead, use the methods on the portfolios,
        as per the examples
        :param position_date: the date for pricing calculations. Default is today
        
        **Examples**
        To change the position date of the default context:
        >>> from gs_quant.markets import PositionContext
        >>> import datetime
        >>>
        >>> PricingContext.current = PositionContext(datetime.date(2021, 1, 2))
        For a pricing a portfolio with positions held on a specific date:
        >>> from gs_quant.markets.portfolio import Portfolio
        >>> portfolio = Portfolio.get(portfolio_id='MQPORTFOLIOID')
        >>>
        >>> with PositionContext():
        >>>     portfolio.price()
        >>>
        For an asynchronous request:
        >>> with PositionContext(), PricingContext(is_async=True):
        >>>     price_f = portfolio.price()
        >>>
        >>> while not price_f.done:
        >>> ...
        """
        super().__init__()
        if position_date:
            if position_date > dt.date.today():
                raise ValueError("The PositionContext does not support a position_date in the future")
        self.__position_date = position_date if position_date \
            else business_day_offset(dt.date.today(), 0, roll='preceding')
    @property
    def position_date(self):
        return self.__position_date
    @classmethod
    def default_value(cls) -> object:
        return PositionContext()
    def clone(self, **kwargs):
        clone_kwargs = {k: getattr(self, k, None) for k in signature(self.__init__).parameters.keys()}
        clone_kwargs.update(kwargs)
        return self.__class__(**clone_kwargs)