From f2d49d72feef15a1a62d50f0004bbf2d00d55d2a Mon Sep 17 00:00:00 2001 From: Scott Sanderson Date: Tue, 2 Oct 2018 11:38:45 -0400 Subject: [PATCH] BUG: Use calendar to disambiguate symbol lookups in fetcher. --- tests/test_fetcher.py | 21 ++++++++++++++------- zipline/algorithm.py | 20 ++++++++++++++++++++ zipline/sources/requests_csv.py | 11 ++++++++++- 3 files changed, 44 insertions(+), 8 deletions(-) diff --git a/tests/test_fetcher.py b/tests/test_fetcher.py index fa8054de..bc9347fa 100644 --- a/tests/test_fetcher.py +++ b/tests/test_fetcher.py @@ -61,48 +61,55 @@ class FetcherTestCase(WithResponses, 'start_date': pd.Timestamp('2006-01-01', tz='UTC'), 'end_date': pd.Timestamp('2007-01-01', tz='UTC'), 'symbol': 'AAPL', - 'asset_type': 'equity', 'exchange': 'nasdaq' }, 3766: { 'start_date': pd.Timestamp('2006-01-01', tz='UTC'), 'end_date': pd.Timestamp('2007-01-01', tz='UTC'), 'symbol': 'IBM', - 'asset_type': 'equity', 'exchange': 'nasdaq' }, 5061: { 'start_date': pd.Timestamp('2006-01-01', tz='UTC'), 'end_date': pd.Timestamp('2007-01-01', tz='UTC'), 'symbol': 'MSFT', - 'asset_type': 'equity', 'exchange': 'nasdaq' }, 14848: { 'start_date': pd.Timestamp('2006-01-01', tz='UTC'), 'end_date': pd.Timestamp('2007-01-01', tz='UTC'), 'symbol': 'YHOO', - 'asset_type': 'equity', 'exchange': 'nasdaq' }, 25317: { 'start_date': pd.Timestamp('2006-01-01', tz='UTC'), 'end_date': pd.Timestamp('2007-01-01', tz='UTC'), 'symbol': 'DELL', - 'asset_type': 'equity', 'exchange': 'nasdaq' }, 13: { 'start_date': pd.Timestamp('2006-01-01', tz='UTC'), 'end_date': pd.Timestamp('2010-01-01', tz='UTC'), 'symbol': 'NFLX', - 'asset_type': 'equity', 'exchange': 'nasdaq' + }, + 9999999: { + 'start_date': pd.Timestamp('2006-01-01', tz='UTC'), + 'end_date': pd.Timestamp('2007-01-01', tz='UTC'), + 'symbol': 'AAPL', + 'exchange': 'non_us_exchange' } }, orient='index', ) + @classmethod + def make_exchanges_info(cls, *args, **kwargs): + return pd.DataFrame.from_records([ + {'exchange': 'nasdaq', 'country_code': 'US'}, + {'exchange': 'non_us_exchange', 'country_code': 'CA'}, + ]) + def run_algo(self, code, sim_params=None): if sim_params is None: sim_params = self.sim_params @@ -603,7 +610,7 @@ def initialize(context): date_column = 'Settlement Date', date_format = '%m/%d/%y') context.nflx = symbol('NFLX') - context.aapl = symbol('AAPL') + context.aapl = symbol('AAPL', country_code='US') def handle_data(context, data): assert np.isnan(data.current(context.nflx, 'invalid_column')) diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 3c0d8913..4b69a5bc 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -793,6 +793,7 @@ class TradingAlgorithm(object): mask=True, symbol_column=None, special_params_checker=None, + country_code=None, **kwargs): """Fetch a csv from a remote url and register the data so that it is queryable from the ``data`` object. @@ -828,6 +829,8 @@ class TradingAlgorithm(object): argument is the name of the column in the preprocessed dataframe containing the symbols. This will be used along with the date information to map the sids in the asset finder. + country_code : str, optional + Country code to use to disambiguate symbol lookups. **kwargs Forwarded to :func:`pandas.read_csv`. @@ -836,6 +839,10 @@ class TradingAlgorithm(object): csv_data_source : zipline.sources.requests_csv.PandasRequestsCSV A requests source that will pull data from the url specified. """ + if country_code is None: + country_code = self.default_fetch_csv_country_code( + self.trading_calendar, + ) # Show all the logs every time fetcher is used. csv_data_source = PandasRequestsCSV( @@ -853,6 +860,7 @@ class TradingAlgorithm(object): mask, symbol_column, data_frequency=self.data_frequency, + country_code=country_code, special_params_checker=special_params_checker, **kwargs ) @@ -2367,6 +2375,12 @@ class TradingAlgorithm(object): """ return _DEFAULT_DOMAINS.get(calendar.name, domain.GENERIC) + @staticmethod + def default_fetch_csv_country_code(calendar): + """ + """ + return _DEFAULT_FETCH_CSV_COUNTRY_CODES.get(calendar.name) + ################## # End Pipeline API ################## @@ -2384,3 +2398,9 @@ class TradingAlgorithm(object): # Map from calendar name to default domain for that calendar. _DEFAULT_DOMAINS = {d.calendar_name: d for d in domain.BUILT_IN_DOMAINS} +# Map from calendar name to default country code for that calendar. +_DEFAULT_FETCH_CSV_COUNTRY_CODES = { + d.calendar_name: d.country_code for d in domain.BUILT_IN_DOMAINS +} +# Include us_futures, which doesn't have a pipeline domain. +_DEFAULT_FETCH_CSV_COUNTRY_CODES['us_futures'] = 'US' diff --git a/zipline/sources/requests_csv.py b/zipline/sources/requests_csv.py index 633ad253..1ea36f69 100644 --- a/zipline/sources/requests_csv.py +++ b/zipline/sources/requests_csv.py @@ -157,6 +157,7 @@ class PandasCSV(with_metaclass(ABCMeta, object)): mask, symbol_column, data_frequency, + country_code, **kwargs): self.start_date = start_date @@ -167,6 +168,7 @@ class PandasCSV(with_metaclass(ABCMeta, object)): self.mask = mask self.symbol_column = symbol_column or "symbol" self.data_frequency = data_frequency + self.country_code = country_code invalid_kwargs = set(kwargs) - ALLOWED_READ_CSV_KWARGS if invalid_kwargs: @@ -272,7 +274,11 @@ class PandasCSV(with_metaclass(ABCMeta, object)): return numpy.nan try: - return self.finder.lookup_symbol(uppered, as_of_date=None) + return self.finder.lookup_symbol( + uppered, + as_of_date=None, + country_code=self.country_code, + ) except MultipleSymbolsFound: # Fill conflicted entries with zeros to mark that they need to be # resolved by date. @@ -342,6 +348,7 @@ class PandasCSV(with_metaclass(ABCMeta, object)): # Replacing tzinfo here is necessary because of the # timezone metadata bug described below. row['dt'].replace(tzinfo=pytz.utc), + country_code=self.country_code, # It's possible that no asset comes back here if our # lookup date is from before any asset held the @@ -470,6 +477,7 @@ class PandasRequestsCSV(PandasCSV): mask, symbol_column, data_frequency, + country_code, special_params_checker=None, **kwargs): @@ -503,6 +511,7 @@ class PandasRequestsCSV(PandasCSV): mask, symbol_column, data_frequency, + country_code=country_code, **remaining_kwargs )