Source code for gs_quant.analytics.datagrid.datagrid

"""
Copyright 2019 Goldman Sachs.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied.  See the License for the
specific language governing permissions and limitations
under the License.
"""
import asyncio
import datetime as dt
import json
import logging
import webbrowser
from collections import defaultdict
from dataclasses import asdict
from numbers import Number
from typing import List, Dict, Optional, Tuple, Union, Set

import numpy as np
from pandas import DataFrame, Series, concat

from gs_quant.analytics.common import DATAGRID_HELP_MSG
from gs_quant.analytics.common.helpers import resolve_entities, get_entity_rdate_key, get_entity_rdate_key_from_rdate, \
    get_rdate_cache_key
from gs_quant.analytics.core.processor import DataQueryInfo, MeasureQueryInfo
from gs_quant.analytics.core.processor_result import ProcessorResult
from gs_quant.analytics.core.query_helpers import aggregate_queries, fetch_query, build_query_string, valid_dimensions
from gs_quant.analytics.datagrid.data_cell import DataCell
from gs_quant.analytics.datagrid.data_column import DataColumn, ColumnFormat, MultiColumnGroup
from gs_quant.analytics.datagrid.data_row import DataRow, DimensionsOverride, ProcessorOverride, Override, \
    ValueOverride, RowSeparator
from gs_quant.analytics.datagrid.serializers import row_from_dict
from gs_quant.analytics.datagrid.utils import DataGridSort, SortOrder, SortType, DataGridFilter, FilterOperation, \
    FilterCondition, get_utc_now
from gs_quant.analytics.processors import CoordinateProcessor, EntityProcessor
from gs_quant.datetime.relative_date import RelativeDate
from gs_quant.entities.entitlements import Entitlements
from gs_quant.entities.entity import Entity
from gs_quant.errors import MqValueError
from gs_quant.session import GsSession, OAuth2Session
from gs_quant.target.common import Entitlements as Entitlements_

_logger = logging.getLogger(__name__)

API = '/data/grids'
DATAGRID_HEADERS: Dict[str, str] = {'Content-Type': 'application/json;charset=utf-8'}


[docs]class DataGrid: """ DataGrid is a object for fetching Marquee data and applying processors (functions). DataGrids can be persisted via the DataGrid API and utilized on the Marquee Markets platform. :param name: Name of the DataGrid :param rows: List of DataGrid rows for the grid :param columns: List of DataGrid columns for the grid :param id_: Unique identifier of the grid :param entitlements: Marquee entitlements of the grid for the Marquee Market's platform :param sorts: Optional list of DataGridSort. Use this if you want to sort your columns. :param filters: Optional list of DataGridFilter. Use this to filter column's data. :param multiColumnGroups: Optional list of MultiColumnGroup. Useful to group columns for heatmaps. **Usage** To create a DataGrid, we define two components, rows and columns: >>> from gs_quant.markets.securities import Asset, AssetIdentifier >>> from gs_quant.data.coordinate import DataMeasure, DataFrequency >>> from gs_quant.analytics.processors import LastProcessor >>> >>> GS = Asset.get("GS UN", AssetIdentifier.BLOOMBERG_ID) >>> AAPL = Asset.get("AAPL UW", AssetIdentifier.BLOOMBERG_ID) >>> rows = [ >>> DataRow(GS), >>> DataRow(AAPL) >>> ] >>> trade_price = DataCoordinate( >>> measure=DataMeasure.TRADE_PRICE, >>> frequency=DataFrequency.REAL_TIME, >>> ) >>> >>> col_0 = DataColumn(name="Name", processor=EntityProcessor(field="short_name")) >>> col_1 = DataColumn(name="Last", processor=LastProcessor(trade_price)) >>> columns = [ col_0, col_1 ] >>> >>> datagrid = DataGrid(name="Example DataGrid", rows=rows, columns=columns) >>> datagrid.initialize() >>> datagrid.poll() >>> print(datagrid.to_frame()) **Documentation** Full Documentation and examples can be found here: https://developer.gs.com/docs/gsquant/tutorials/Data/DataGrid/ """
[docs] def __init__(self, name: str, rows: List[Union[DataRow, RowSeparator]], columns: List[DataColumn], *, id_: str = None, entitlements: Union[Entitlements, Entitlements_] = None, polling_time: int = None, sorts: Optional[List[DataGridSort]] = None, filters: Optional[List[DataGridFilter]] = None, multiColumnGroups: Optional[List[MultiColumnGroup]] = None, **kwargs): self.id_ = id_ self.entitlements = entitlements self.name = name self.rows = rows self.columns = columns self.sorts = sorts or [] self.filters = filters or [] self.multiColumnGroups = multiColumnGroups self.polling_time = polling_time or 0 # store the graph, data queries to leaf processors and results self._primary_column_index: int = kwargs.get('primary_column_index', 0) self._cells: List[DataCell] = [] self._data_queries: List[Union[DataQueryInfo, MeasureQueryInfo]] = [] self._entity_cells: List[DataCell] = [] self._coord_processor_cells: List[DataCell] = [] self._value_cells: List[DataCell] = [] self.entity_map: Dict[str, Entity] = {} # RDate Mappings self.rdate_entity_map: Dict[str, Set[Tuple]] = defaultdict(set) self.rule_cache: Dict[str, dt.date] = {} self.results: List[List[DataCell]] = [] self.is_initialized: bool = False print(DATAGRID_HELP_MSG)
[docs] def get_id(self) -> Optional[str]: """Get the unique DataGrid identifier. Will only exists if the DataGrid has been persisted. """ return self.id_
[docs] def initialize(self) -> None: """ Initializes the DataGrid. Iterates over all rows and columns, preparing cell structures. Cells then contain a graph and data queries to leaf processors. Upon providing data to a leaf, the leaf processor is calculated and propagated up the graph to the cell level. """ all_queries: List[Union[DataQueryInfo, MeasureQueryInfo]] = [] entity_cells: List[DataCell] = [] current_row_group = None # Loop over rows, columns for row_index, row in enumerate(self.rows): if isinstance(row, RowSeparator): current_row_group = row.name continue entity: Entity = row.entity if isinstance(entity, Entity): self.entity_map[entity.get_marquee_id()] = entity else: self.entity_map[''] = entity cells: List[DataCell] = [] row_overrides = row.overrides for column_index, column in enumerate(self.columns): column_name = column.name column_processor = column.processor # Get all the data coordinate overrides and apply the processor override if it exists data_overrides, value_override, processor_override = _get_overrides(row_overrides, column_name) # Create the cell cell: DataCell = DataCell(column_name, column_processor, entity, data_overrides, column_index, row_index, current_row_group) if processor_override: # Check if there is a processor override and apply if so cell.processor = processor_override if value_override: cell.value = ProcessorResult(True, value_override.value) cell.updated_time = get_utc_now() elif isinstance(column_processor, EntityProcessor): # store these cells to fetch entity data during poll entity_cells.append(cell) elif isinstance(column_processor, CoordinateProcessor): # store these cells to fetch entity data during poll if len(data_overrides): # Get the last in the list if more than 1 override is given cell.processor.children['a'].set_dimensions(data_overrides[-1].dimensions) self._coord_processor_cells.append(cell) elif column_processor.measure_processor: all_queries.append(MeasureQueryInfo(attr='', entity=entity, processor=column_processor)) else: # append the required queries to the map cell.build_cell_graph(all_queries, self.rdate_entity_map) cells.append(cell) self._cells.extend(cells) self.results.append(cells) self._data_queries = all_queries self._entity_cells = entity_cells self.is_initialized = True
[docs] def poll(self) -> None: """ Poll the data queries required to process this grid. Set the results at the leaf processors """ self._resolve_rdates() self._resolve_queries() self._process_special_cells() self._fetch_queries()
[docs] def save(self) -> str: """ Saves the DataGrid. If the DataGrid has already been created, the DataGrid will be updated. If the DataGrid has not been created it will be added to the DataGrid service. :return: Unique identifier of the DataGrid """ datagrid_json = self.__as_json() if self.id_: response = GsSession.current._put(f'{API}/{self.id_}', datagrid_json, request_headers=DATAGRID_HEADERS) else: response = GsSession.current._post(f'{API}', datagrid_json, request_headers=DATAGRID_HEADERS) self.id_ = response['id'] return DataGrid.from_dict(response).id_
[docs] def create(self): """ Creates a new DataGrid even if the DataGrid already exists. If the DataGrid has already been persisted, the DataGrid id will be replaced with the newly persisted DataGrid. :return: New DataGrid unique identifier """ datagrid_json = self.__as_json() response = GsSession.current._post(f'{API}', datagrid_json, request_headers=DATAGRID_HEADERS) self.id_ = response['id'] return response['id']
[docs] def delete(self): """ Deletes the DataGrid if it has been persisted. :return: None """ if self.id_: GsSession.current._delete(f'{API}/{self.id_}', request_headers=DATAGRID_HEADERS) else: raise MqValueError('DataGrid has not been persisted.')
[docs] def open(self): """ Opens the DataGrid in the default browser. :return: None """ if self.id_ is None: raise MqValueError('DataGrid must be created or saved before opening.') domain = GsSession.current.domain.replace(".web", "") if domain == 'https://api.gs.com': domain = 'https://marquee.gs.com' url = f'{domain}/s/markets/grids/{self.id_}' webbrowser.open(url)
@property def polling_time(self): return self.__polling_time @polling_time.setter def polling_time(self, value): if value is None: self.__polling_time = 0 elif value != 0 and value < 5000: raise MqValueError('polling_time must be >= than 10000ms.') self.__polling_time = value def _process_special_cells(self) -> None: """ Processes Coordinate and Entity cells :return: None """ # fetch entity cells for cell in self._entity_cells: try: cell.value = cell.processor.process(cell.entity) except Exception as e: cell.value = f'Error Calculating processor {cell.processor.__class__.__name__} ' \ f'for entity: {cell.entity.get_marquee_id()} due to {e}' cell.updated_time = get_utc_now() for cell in self._coord_processor_cells: try: cell.value = cell.processor.process() except Exception as e: cell.value = f'Error Calculating processor {cell.processor.__class__.__name__} ' \ f'for entity: {cell.entity.get_marquee_id()} due to {e}' cell.updated_time = get_utc_now() def _resolve_rdates(self, rule_cache: Dict = None): # TODO: Thread this... rule_cache = rule_cache or {} # Default to no calendar for rdate for external and oauth calendar = [] if not GsSession.current.is_internal() and isinstance(GsSession.current, OAuth2Session) else None for entity_id, rules in self.rdate_entity_map.items(): entity = self.entity_map.get(entity_id) currencies = None exchanges = None if isinstance(entity, Entity): entity_dict = entity.get_entity() currency = entity_dict.get("currency") exchange = entity_dict.get("exchange") currencies = [currency] if currency else None exchanges = [exchange] if exchange else None for rule_base_date_tuple in rules: rule, base_date = rule_base_date_tuple[0], rule_base_date_tuple[1] cache_key = get_rdate_cache_key(rule_base_date_tuple[0], rule_base_date_tuple[1], currencies, exchanges) date_value = rule_cache.get(cache_key) if date_value is None: if base_date: base_date = dt.datetime.strptime(base_date, "%Y-%m-%d").date() date_value = RelativeDate(rule, base_date).apply_rule(currencies=currencies, exchanges=exchanges, holiday_calendar=calendar) rule_cache[cache_key] = date_value self.rule_cache[get_entity_rdate_key(entity_id, rule, base_date)] = date_value def _resolve_queries(self, availability_cache: Dict = None) -> None: """ Resolves the dataset_id for each data query This is used to query data thereafter """ availability_cache = availability_cache or {} for query in self._data_queries: entity = query.entity if isinstance(entity, str) or isinstance(query, MeasureQueryInfo): # If we were unable to fetch entity (404/403) or if we're processing a measure processor continue query = query.query coord = query.coordinate entity_dimension = entity.data_dimension entity_id = entity.get_marquee_id() query_start = query.start query_end = query.end if isinstance(query_start, RelativeDate): key = get_entity_rdate_key_from_rdate(entity_id, query_start) query.start = self.rule_cache[key] if isinstance(query_end, RelativeDate): key = get_entity_rdate_key_from_rdate(entity_id, query_end) query.end = self.rule_cache[key] if entity_dimension not in coord.dimensions: if coord.dataset_id: # don't need to fetch the data set if user supplied it coord.set_dimensions({entity_dimension: entity.get_marquee_id()}) query.coordinate = coord else: # Need to resolve the dataset from availability entity_id = entity.get_marquee_id() try: raw_availability = availability_cache.get(entity_id) if raw_availability is None: raw_availability: Dict = GsSession.current._get(f'/data/measures/{entity_id}/availability') availability_cache[entity.get_marquee_id()] = raw_availability query.coordinate = entity.get_data_coordinate(measure=coord.measure, dimensions=coord.dimensions, frequency=coord.frequency, availability=raw_availability) except Exception as e: _logger.info( f'Could not get DataCoordinate with {coord} for entity {entity_id} due to {e}') def _fetch_queries(self): query_aggregations = aggregate_queries(self._data_queries) for dataset_id, query_map in query_aggregations.items(): for query in query_map.values(): df = fetch_query(query) for query_dimensions, query_infos in query['queries'].items(): if valid_dimensions(query_dimensions, df): queried_df = df.query(build_query_string(query_dimensions)) for query_info in query_infos: measure = query_info.query.coordinate.measure query_info.data = queried_df[measure if isinstance(measure, str) else measure.value] else: for query_info in query_infos: query_info.data = Series(dtype=float) for query_info in self._data_queries: if isinstance(query_info, MeasureQueryInfo): asyncio.get_event_loop().run_until_complete( query_info.processor.calculate(query_info.attr, ProcessorResult(True, None), self.rule_cache, query_info=query_info)) elif query_info.data is None or len(query_info.data) == 0: asyncio.get_event_loop().run_until_complete( query_info.processor.calculate(query_info.attr, ProcessorResult(False, f'No data found for ' f'Coordinate {query_info.query.coordinate}'), self.rule_cache)) else: asyncio.get_event_loop().run_until_complete( query_info.processor.calculate(query_info.attr, ProcessorResult(True, query_info.data), self.rule_cache))
[docs] @staticmethod def aggregate_queries(query_infos): mappings = defaultdict(dict) for query_info in query_infos: query = query_info.query coordinate = query.coordinate dataset_mappings = mappings[coordinate.dataset_id] query_key = query.get_range_string() dataset_mappings.setdefault(query_key, { 'parameters': {}, 'queries': {} }) queries = dataset_mappings[query_key]['queries'] queries[coordinate.get_dimensions()] = query_info parameters = dataset_mappings[query_key]['parameters'] for dimension, value in coordinate.dimensions.items(): parameters.setdefault(dimension, set()) parameters[dimension].add(value)
def _post_process(self) -> DataFrame: columns = self.columns results = defaultdict(list) for row in self.results: if len(row): results['rowGroup'].append(row[0].row_group or '') for column in row: column_value = column.value if column_value.success is True: column_data = column_value.data if isinstance(column_data, Number): format_: ColumnFormat = columns[column.column_index].format_ results[column.name].append(round(column_data, format_.precision)) else: results[column.name].append(column_data) else: results[column.name].append(np.NaN) df = DataFrame.from_dict(results) row_groups = list(df['rowGroup'].unique()) sub_dfs = [] for row_group in row_groups: sub_df = self.__handle_filters(df[df['rowGroup'] == row_group]) sub_df = self.__handle_sorts(sub_df) sub_dfs.append(sub_df) df = concat(sub_dfs) df.set_index(['rowGroup', df.index], inplace=True) df.rename_axis(index=['', ''], inplace=True) return df def __handle_sorts(self, df): """ Handles sorting of the dataframe :param df: incoming dataframe to be sorted :return: dataframe with sorting applied if any """ for sort in self.sorts: ascending = True if sort.order == SortOrder.ASCENDING else False if sort.sortType == SortType.ABSOLUTE_VALUE: df = df.reindex(df[sort.columnName].abs().sort_values(ascending=ascending, na_position='last').index) else: df = df.sort_values(by=sort.columnName, ascending=ascending, na_position='last') return df def __handle_filters(self, df) -> DataFrame: """ Handles filtering the dataframe :param df: incoming dataframe to be filtered :return: dataframe with filters applied if any """ if not len(df): return df starting_df = df.copy() running_df = df for filter_ in self.filters: filter_value = filter_.value if filter_value is None: continue filter_condition = filter_.condition if filter_condition == FilterCondition.OR: df = starting_df else: df = running_df column_name = filter_.columnName operation = filter_.operation if operation == FilterOperation.TOP: df = df.sort_values(by=column_name, ascending=False, na_position='last').head(filter_value) elif operation == FilterOperation.BOTTOM: df = df.sort_values(by=column_name, ascending=True, na_position='last').head(filter_value) elif operation == FilterOperation.ABSOLUTE_TOP: df = df.reindex(df[column_name].abs().sort_values(ascending=False, na_position='last').index).head( filter_value) elif operation == FilterOperation.ABSOLUTE_BOTTOM: df = df.reindex(df[column_name].abs().sort_values(ascending=True, na_position='last').index).head( filter_value) elif operation == FilterOperation.EQUALS: if not isinstance(filter_value, list): filter_value = [filter_value] # Special case to handle different types of floats if isinstance(filter_value[0], str): df = df.loc[df[column_name].isin(filter_value)] else: # Add a tolerance for the special case to handle different types of floats df = df[np.isclose(df[column_name].values[:, None], filter_value, atol=1e-10).any(axis=1)] elif operation == FilterOperation.NOT_EQUALS: if not isinstance(filter_value, list): filter_value = [filter_value] if isinstance(filter_value[0], str): df = df.loc[~df[column_name].isin(filter_value)] else: # Add a tolerance for the special case to handle different types of float df = df[~np.isclose(df[column_name].values[:, None], filter_value, atol=1e-10).any(axis=1)] elif operation == FilterOperation.GREATER_THAN: df = df[df[column_name] > filter_value] elif operation == FilterOperation.LESS_THAN: df = df[df[column_name] < filter_value] elif operation == FilterOperation.LESS_THAN_EQUALS: df = df[df[column_name] <= filter_value] elif operation == FilterOperation.GREATER_THAN_EQUALS: df = df[df[column_name] >= filter_value] else: raise MqValueError(f'Invalid Filter operation Type: {operation}') if filter_.condition == FilterCondition.OR: # Need to merge the results running_df = running_df.merge(df, how='outer') else: running_df = df return running_df
[docs] def to_frame(self) -> DataFrame: """ Returns the results of the DataGrid data fetching and applied processors. :return: DataFrame of results """ if not self.is_initialized: _logger.info("Grid has not been initialized. Ensure to run DataGrid.initialize()") return DataFrame() return self._post_process()
[docs] @classmethod def from_dict(cls, obj, reference_list: Optional[List] = None): id_ = obj.get('id', None) name = obj.get('name', '') parameters = obj.get('parameters', {}) entitlements = Entitlements_.from_dict(obj.get('entitlements', {})) # If a reference list is given, then the entities will be resolved by the caller if reference_list is not None: should_resolve_entities = False else: should_resolve_entities = True reference_list = [] rows = [row_from_dict(row, reference_list) for row in parameters.get('rows', [])] columns = [DataColumn.from_dict(column, reference_list) for column in parameters.get('columns', [])] sorts = [DataGridSort.from_dict(sort) for sort in parameters.get('sorts', [])] filters = [DataGridFilter.from_dict(filter_) for filter_ in parameters.get('filters', [])] multi_column_groups = [MultiColumnGroup.from_dict(group) for group in parameters.get('multiColumnGroups', [])] if should_resolve_entities: resolve_entities(reference_list) return DataGrid(name=name, rows=rows, columns=columns, id_=id_, entitlements=entitlements, primary_column_index=parameters.get('primaryColumnIndex', 0), polling_time=parameters.get('pollingTime', 0), multiColumnGroups=multi_column_groups, sorts=sorts, filters=filters)
[docs] def as_dict(self): datagrid = { 'name': self.name, 'parameters': { 'rows': [row.as_dict() for row in self.rows], 'columns': [column.as_dict() for column in self.columns], 'primaryColumnIndex': self._primary_column_index, 'pollingTime': self.polling_time or 0 } } if self.entitlements: if isinstance(self.entitlements, Entitlements_): datagrid['entitlements'] = self.entitlements.as_dict() elif isinstance(self.entitlements, Entitlements): datagrid['entitlements'] = self.entitlements.to_dict() else: datagrid['entitlements'] = self.entitlements if len(self.sorts): datagrid['parameters']['sorts'] = [asdict(sort) for sort in self.sorts] if len(self.filters): datagrid['parameters']['filters'] = [asdict(filter_) for filter_ in self.filters] if self.multiColumnGroups: datagrid['parameters']['multiColumnGroups'] = [group.asdict() for group in self.multiColumnGroups] return datagrid
[docs] def set_primary_column_index(self, index: int): """ Sets the primary column index which affects which row will expand to fill any additional horizontal space. :param index: index of the column to make primary :return: None """ self._primary_column_index = index
[docs] def set_sorts(self, sorts: List[DataGridSort]): """ Set the sorts parameter of the grid response :param sorts: value of grid sorts :return: None """ self.sorts = sorts
[docs] def add_sort(self, sort: DataGridSort, index: int = None): """ Add a sort to the grid response :param sort: DataGridSort :param index: index of the sort object to be added, defaults to end of sorts list :return: None """ if index: self.sorts.insert(index, sort) else: self.sorts.append(sort)
[docs] def set_filters(self, filters: List[DataGridFilter]): """ Set the filters parameter of the grid response :param filters: value of grid sorts :return: None """ self.filters = filters
[docs] def add_filter(self, filter_: DataGridFilter, index: int = None): """ Add a filter to the grid response :param filter_: DataGridFilter :param index: index of the sort object to be added, defaults to end of filters list :return: None """ if index: self.filters.insert(index, filter_) else: self.filters.append(filter_)
def __as_json(self) -> str: return json.dumps(self.as_dict())
def _get_overrides(row_overrides: List[Override], column_name: str) -> \ Tuple[List[DimensionsOverride], Optional[ValueOverride], Optional[ProcessorOverride]]: if not row_overrides: return [], None, None dimensions_overrides, value_override, processor_override = [], None, None for override in row_overrides: if column_name in override.column_names: if isinstance(override, DimensionsOverride): dimensions_overrides.append(override) elif isinstance(override, ValueOverride): value_override = override elif isinstance(override, ProcessorOverride): processor_override = override.processor return dimensions_overrides, value_override, processor_override