Source code for gs_quant.base

"""
Copyright 2019 Goldman Sachs.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied.  See the License for the
specific language governing permissions and limitations
under the License.
"""
import builtins
import copy
import datetime as dt
import logging
from abc import ABC, ABCMeta, abstractmethod
from collections import namedtuple
from dataclasses import Field, InitVar, MISSING, dataclass, field, fields, replace
from enum import EnumMeta
from functools import update_wrapper
from typing import Iterable, Mapping, Optional, Union, Tuple

import numpy as np
from dataclasses_json import config, global_config
from dataclasses_json.core import _decode_generic, _is_supported_generic
from inflection import camelize, underscore

from gs_quant.context_base import ContextBase, ContextMeta
from gs_quant.json_convertors import encode_date_or_str, decode_date_or_str, decode_optional_date, encode_datetime, \
    decode_datetime, decode_float_or_str, decode_instrument, encode_dictable, decode_quote_report, decode_quote_reports, \
    decode_custom_comment, decode_custom_comments, decode_hedge_type, decode_hedge_types

_logger = logging.getLogger(__name__)

__builtins = set(dir(builtins))
__getattribute__ = object.__getattribute__
__setattr__ = object.__setattr__

_rename_cache = {}

def exclude_none(o):
    return o is None


def exlude_always(_o):
    return True


def is_iterable(o, t):
    return isinstance(o, Iterable) and all(isinstance(it, t) for it in o)


def is_instance_or_iterable(o, t):
    return isinstance(o, t) or is_iterable(o, t)


def _get_underscore(arg):
    if arg not in _rename_cache:
        _rename_cache[arg] = underscore(arg)

    return _rename_cache[arg]

def handle_camel_case_args(cls):
    init = cls.__init__

    def wrapper(self, *args, **kwargs):
        normalised_kwargs = {}

        for arg, value in kwargs.items():
            if not arg.isupper():
                snake_case_arg = _get_underscore(arg)
                if snake_case_arg != arg and snake_case_arg in kwargs:
                    raise ValueError('{} and {} both specified'.format(arg, snake_case_arg))

                arg = snake_case_arg

            arg = cls._field_mappings().get(arg, arg)
            normalised_kwargs[arg] = value

        return init(self, *args, **normalised_kwargs)

    cls.__init__ = update_wrapper(wrapper=wrapper, wrapped=init)

    return cls


field_metadata = config(exclude=exclude_none)
name_metadata = config(exclude=exlude_always)


class RiskKey(namedtuple('RiskKey', ('provider', 'date', 'market', 'params', 'scenario', 'risk_measure'))):

    @property
    def ex_measure(self):
        return RiskKey(self.provider, self.date, self.market, self.params, self.scenario, None)

    @property
    def fields(self):
        return self._fields


class EnumBase:

    @classmethod
    def _missing_(cls: EnumMeta, key):
        if not isinstance(key, str):
            key = str(key)
        return next((m for m in cls.__members__.values() if m.value.lower() == key.lower()), None)

    def __reduce_ex__(self, protocol):
        return self.__class__, (self.value,)

    def __lt__(self: EnumMeta, other):
        return self.value < other.value

    def __repr__(self):
        return str(self)

    def __str__(self):
        return self.value


class HashableDict(dict):

    @staticmethod
    def hashables(in_dict) -> Tuple:
        hashables = []
        for it in in_dict.items():
            if isinstance(it[1], dict):
                hashables.append((it[0], HashableDict.hashables(it[1])))
            else:
                hashables.append(it)
        return tuple(hashables)

    def __hash__(self):
        return hash(HashableDict.hashables(self))


class DictBase(HashableDict):

    _PROPERTIES = set()

    def __init__(self, *args, **kwargs):
        if self._PROPERTIES:
            invalid_arg = next((k for k in kwargs.keys() if k not in self._PROPERTIES), None)
            if invalid_arg is not None:
                raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{invalid_arg}'")

        super().__init__(*args, **{camelize(k, uppercase_first_letter=False): v for k, v in kwargs.items()
                                   if v is not None})

    def __getitem__(self, item):
        return super().__getitem__(camelize(item, uppercase_first_letter=False))

    def __setitem__(self, key, value):
        if value is not None:
            return super().__setitem__(camelize(key, uppercase_first_letter=False), value)

    def __getattr__(self, item):
        if self._PROPERTIES:
            if _get_underscore(item) in self._PROPERTIES:
                return self.get(item)
        elif item in self:
            return self[item]

        raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{item}'")

    def __setattr__(self, key, value):
        if key in dir(self):
            return super().__setattr__(key, value)
        elif self._PROPERTIES and _get_underscore(key) not in self._PROPERTIES:
            raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{key}'")

        self[key] = value

    @classmethod
    def properties(cls) -> set:
        return cls._PROPERTIES


class Base(ABC):
    """The base class for all generated classes"""

    __fields_by_name = None
    __field_mappings = None

    def __getattr__(self, item):
        fields_by_name = __getattribute__(self, '_fields_by_name')()

        if item.startswith('_') or item in fields_by_name:
            return __getattribute__(self, item)

        # Handle setting via camelCase names (legacy behaviour) and field mappings from disallowed names
        snake_case_item = _get_underscore(item)
        field_mappings = __getattribute__(self, '_field_mappings')()
        snake_case_item = field_mappings.get(snake_case_item, snake_case_item)

        try:
            return __getattribute__(self, snake_case_item)
        except AttributeError:
            return __getattribute__(self, item)

    def __setattr__(self, key, value):
        # Handle setting via camelCase names (legacy behaviour)
        snake_case_key = _get_underscore(key)
        snake_case_key = self._field_mappings().get(snake_case_key, snake_case_key)
        fld = self._fields_by_name().get(snake_case_key)

        if fld:
            if not fld.init:
                raise ValueError(f'{key} cannot be set')

            key = snake_case_key
            value = self.__coerce_value(fld.type, value)

        __setattr__(self, key, value)

    def __repr__(self):
        if self.name is not None:
            return f'{self.name} ({self.__class__.__name__})'

        return super().__repr__()

    @classmethod
    def __coerce_value(cls, typ: type, value):
        if isinstance(value, np.generic):
            # Handle numpy types
            return value.item()
        elif hasattr(value, 'tolist'):
            # tolist converts scalar or array to native python type if not already native.
            return value()
        elif typ in (DictBase, Optional[DictBase]) and isinstance(value, Base):
            return value.to_dict()
        if _is_supported_generic(typ):
            return _decode_generic(typ, value, False)
        else:
            return value

    @classmethod
    def _fields_by_name(cls) -> Mapping[str, Field]:
        if cls is Base:
            return {}

        if cls.__fields_by_name is None:
            cls.__fields_by_name = {f.name: f for f in fields(cls)}

        return cls.__fields_by_name

    @classmethod
    def _field_mappings(cls) -> Mapping[str, str]:
        if cls is Base:
            return {}

        if cls.__field_mappings is None:
            field_mappings = {}
            for fld in fields(cls):
                config_fn = fld.metadata.get('dataclasses_json', {}).get('letter_case')
                if config_fn:
                    mapped_name = config_fn('field_name')
                    if mapped_name:
                        field_mappings[mapped_name] = fld.name

            cls.__field_mappings = field_mappings
        return cls.__field_mappings

    def clone(self, **kwargs):
        """
            Clone this object, overriding specified values

            :param kwargs: property names and values, e.g. swap.clone(fixed_rate=0.01)

            **Examples**

            To change the market data location of the default context:

            >>> from gs_quant.instrument import IRCap
            >>> cap = IRCap('5y', 'GBP')
            >>>
            >>> new_cap = cap.clone(cap_rate=0.01)
        """
        return replace(self, **kwargs)

    @classmethod
    def properties(cls) -> set:
        """The public property names of this class"""
        return set(f[:-1] if f[-1] == '_' else f for f in cls._fields_by_name().keys())

    @classmethod
    def properties_init(cls) -> set:
        """The public property names of this class"""
        return set(f[:-1] if f[-1] == '_' else f for f, v in cls._fields_by_name().items() if v.init)

    def as_dict(self, as_camel_case: bool = False) -> dict:
        """Dictionary of the public, non-null properties and values"""

        # to_dict() converts all the values to JSON type, does camel case and name mappings
        # asdict() does not convert values or case of the keys or do name mappings

        ret = {}
        field_mappings = {v: k for k, v in self._field_mappings().items()}

        for key in self.__fields_by_name.keys():
            value = __getattribute__(self, key)
            key = field_mappings.get(key, key)

            if value is not None:
                if as_camel_case:
                    key = camelize(key, uppercase_first_letter=False)

                ret[key] = value

        return ret

    @classmethod
    def default_instance(cls):
        """
        Construct a default instance of this type
        """
        required = {f.name: None if f.default == MISSING else f.default for f in fields(cls) if f.init}
        return cls(**required)

    def from_instance(self, instance):
        """
        Copy the values from an existing instance of the same type to our self
        :param instance: from which to copy:
        :return:
        """
        if not isinstance(instance, type(self)):
            raise ValueError('Can only use from_instance with an object of the same type')

        for fld in fields(self.__class__):
            if fld.init:
                __setattr__(self, fld.name, __getattribute__(instance, fld.name))


[docs]@dataclass class Priceable(Base): def resolve(self, in_place: bool = True): """ Resolve non-supplied properties of an instrument **Examples** >>> from gs_quant.instrument import IRSwap >>> >>> swap = IRSwap('Pay', '10y', 'USD') >>> rate = swap.fixedRate rate is None >>> swap.resolve() >>> rate = swap.fixedRate rates is now the solved fixed rate """ raise NotImplementedError def dollar_price(self): """ Present value in USD :return: a float or a future, depending on whether the current PricingContext is async, or has been entered **Examples** >>> from gs_quant.instrument import IRCap >>> >>> cap = IRCap('1y', 'EUR') >>> price = cap.dollar_price() price is the present value in USD (a float) >>> cap_usd = IRCap('1y', 'USD') >>> cap_eur = IRCap('1y', 'EUR') >>> >>> from gs_quant.markets import PricingContext >>> >>> with PricingContext(): >>> price_usd_f = cap_usd.dollar_price() >>> price_eur_f = cap_eur.dollar_price() >>> >>> price_usd = price_usd_f.result() >>> price_eur = price_eur_f.result() price_usd_f and price_eur_f are futures, price_usd and price_eur are floats """ raise NotImplementedError def price(self): """ Present value in local currency. Note that this is not yet supported on all instruments ***Examples** >>> from gs_quant.instrument import IRSwap >>> >>> swap = IRSwap('Pay', '10y', 'EUR') >>> price = swap.price() price is the present value in EUR (a float) """ raise NotImplementedError def calc(self, risk_measure, fn=None): """ Calculate the value of the risk_measure :param risk_measure: the risk measure to compute, e.g. IRDelta (from gs_quant.risk) :param fn: a function for post-processing results :return: a float or dataframe, depending on whether the value is scalar or structured, or a future thereof (depending on how PricingContext is being used) **Examples** >>> from gs_quant.instrument import IRCap >>> from gs_quant.risk import IRDelta >>> >>> cap = IRCap('1y', 'USD') >>> delta = cap.calc(IRDelta) delta is a dataframe >>> from gs_quant.instrument import EqOption >>> from gs_quant.risk import EqDelta >>> >>> option = EqOption('.SPX', '3m', 'ATMF', 'Call', 'European') >>> delta = option.calc(EqDelta) delta is a float >>> from gs_quant.markets import PricingContext >>> >>> cap_usd = IRCap('1y', 'USD') >>> cap_eur = IRCap('1y', 'EUR') >>> with PricingContext(): >>> usd_delta_f = cap_usd.calc(IRDelta) >>> eur_delta_f = cap_eur.calc(IRDelta) >>> >>> usd_delta = usd_delta_f.result() >>> eur_delta = eur_delta_f.result() usd_delta_f and eur_delta_f are futures, usd_delta and eur_delta are dataframes """ raise NotImplementedError
class __ScenarioMeta(ABCMeta, ContextMeta): pass @dataclass class Scenario(Base, ContextBase, ABC, metaclass=__ScenarioMeta): def __lt__(self, other): if self.__repr__ != other.__repr__: return self.name < other.name return False def __repr__(self): if self.name: return self.name else: params = self.as_dict() sorted_keys = sorted(params.keys(), key=lambda x: x.lower()) params = ', '.join( [f'{k}:{params[k].__repr__ if isinstance(params[k], Base) else params[k]}' for k in sorted_keys]) return self.scenario_type + '(' + params + ')' @dataclass class RiskMeasureParameter(Base, ABC): pass @dataclass class InstrumentBase(Base, ABC): quantity_: InitVar[float] = field(default=1, init=False) @property @abstractmethod def provider(self): ... @property def instrument_quantity(self) -> float: return self.quantity_ @property def resolution_key(self) -> Optional[RiskKey]: try: return self.__resolution_key except AttributeError: return None @property def unresolved(self): try: return self.__unresolved except AttributeError: return None @property def metadata(self): try: return self.__metadata except AttributeError: return None @metadata.setter def metadata(self, value): self.__metadata = value def from_instance(self, instance): self.__resolution_key = None super().from_instance(instance) self.__unresolved = instance.__unresolved self.__resolution_key = instance.__resolution_key def resolved(self, values: dict, resolution_key: RiskKey): all_values = self.as_dict(True) all_values.update(values) new_instrument = self.from_dict(all_values) new_instrument.name = self.name new_instrument.__unresolved = copy.copy(self) new_instrument.__resolution_key = resolution_key return new_instrument def clone(self, **kwargs): new_instrument = super().clone(**kwargs) new_instrument.__unresolved = self.unresolved new_instrument.metadata = self.metadata new_instrument.__resolution_key = self.resolution_key return new_instrument @dataclass class Market(ABC): def __hash__(self): return hash(self.market or self.location) def __eq__(self, other): return (self.market or self.location) == (other.market or other.location) def __lt__(self, other): return repr(self) < repr(other) @property @abstractmethod def market(self): ... @property @abstractmethod def location(self): ... def to_dict(self): return self.market.to_dict() class Sentinel: def __init__(self, name: str): self.__name = name def __eq__(self, other): return self.__name == other.__name @dataclass class QuoteReport(Base, ABC): pass @dataclass class CustomComments(Base, ABC): pass def get_enum_value(enum_type: EnumMeta, value: Union[EnumBase, str]): if value in (None,): return None if isinstance(value, enum_type): return value try: enum_value = enum_type(value) except ValueError: _logger.warning('Setting value to {}, which is not a valid entry in {}'.format(value, enum_type)) enum_value = value return enum_value # Yes, I know this is a little evil ... global_config.encoders[dt.date] = dt.date.isoformat global_config.encoders[Optional[dt.date]] = encode_date_or_str global_config.decoders[dt.date] = decode_optional_date global_config.decoders[Optional[dt.date]] = decode_optional_date global_config.encoders[Union[dt.date, str]] = encode_date_or_str global_config.encoders[Optional[Union[dt.date, str]]] = encode_date_or_str global_config.decoders[Union[dt.date, str]] = decode_date_or_str global_config.decoders[Optional[Union[dt.date, str]]] = decode_date_or_str global_config.encoders[dt.datetime] = encode_datetime global_config.encoders[Optional[dt.datetime]] = encode_datetime global_config.decoders[dt.datetime] = decode_datetime global_config.decoders[Optional[dt.datetime]] = decode_datetime global_config.decoders[Union[float, str]] = decode_float_or_str global_config.decoders[Optional[Union[float, str]]] = decode_float_or_str global_config.decoders[InstrumentBase] = decode_instrument global_config.decoders[Optional[InstrumentBase]] = decode_instrument global_config.decoders[QuoteReport] = decode_quote_report global_config.decoders[Optional[Tuple[QuoteReport, ...]]] = decode_quote_reports global_config.decoders[CustomComments] = decode_custom_comment global_config.decoders[Optional[Tuple[CustomComments, ...]]] = decode_custom_comments global_config.encoders[Market] = encode_dictable global_config.encoders[Optional[Market]] = encode_dictable