BUG: Use calendar to disambiguate symbol lookups in fetcher.

This commit is contained in:
Scott Sanderson 2018-10-02 11:38:45 -04:00
parent 1bcb3140ae
commit f2d49d72fe
3 changed files with 44 additions and 8 deletions

View file

@ -61,48 +61,55 @@ class FetcherTestCase(WithResponses,
'start_date': pd.Timestamp('2006-01-01', tz='UTC'), 'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
'end_date': pd.Timestamp('2007-01-01', tz='UTC'), 'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
'symbol': 'AAPL', 'symbol': 'AAPL',
'asset_type': 'equity',
'exchange': 'nasdaq' 'exchange': 'nasdaq'
}, },
3766: { 3766: {
'start_date': pd.Timestamp('2006-01-01', tz='UTC'), 'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
'end_date': pd.Timestamp('2007-01-01', tz='UTC'), 'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
'symbol': 'IBM', 'symbol': 'IBM',
'asset_type': 'equity',
'exchange': 'nasdaq' 'exchange': 'nasdaq'
}, },
5061: { 5061: {
'start_date': pd.Timestamp('2006-01-01', tz='UTC'), 'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
'end_date': pd.Timestamp('2007-01-01', tz='UTC'), 'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
'symbol': 'MSFT', 'symbol': 'MSFT',
'asset_type': 'equity',
'exchange': 'nasdaq' 'exchange': 'nasdaq'
}, },
14848: { 14848: {
'start_date': pd.Timestamp('2006-01-01', tz='UTC'), 'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
'end_date': pd.Timestamp('2007-01-01', tz='UTC'), 'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
'symbol': 'YHOO', 'symbol': 'YHOO',
'asset_type': 'equity',
'exchange': 'nasdaq' 'exchange': 'nasdaq'
}, },
25317: { 25317: {
'start_date': pd.Timestamp('2006-01-01', tz='UTC'), 'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
'end_date': pd.Timestamp('2007-01-01', tz='UTC'), 'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
'symbol': 'DELL', 'symbol': 'DELL',
'asset_type': 'equity',
'exchange': 'nasdaq' 'exchange': 'nasdaq'
}, },
13: { 13: {
'start_date': pd.Timestamp('2006-01-01', tz='UTC'), 'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
'end_date': pd.Timestamp('2010-01-01', tz='UTC'), 'end_date': pd.Timestamp('2010-01-01', tz='UTC'),
'symbol': 'NFLX', 'symbol': 'NFLX',
'asset_type': 'equity',
'exchange': 'nasdaq' 'exchange': 'nasdaq'
},
9999999: {
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
'symbol': 'AAPL',
'exchange': 'non_us_exchange'
} }
}, },
orient='index', orient='index',
) )
@classmethod
def make_exchanges_info(cls, *args, **kwargs):
return pd.DataFrame.from_records([
{'exchange': 'nasdaq', 'country_code': 'US'},
{'exchange': 'non_us_exchange', 'country_code': 'CA'},
])
def run_algo(self, code, sim_params=None): def run_algo(self, code, sim_params=None):
if sim_params is None: if sim_params is None:
sim_params = self.sim_params sim_params = self.sim_params
@ -603,7 +610,7 @@ def initialize(context):
date_column = 'Settlement Date', date_column = 'Settlement Date',
date_format = '%m/%d/%y') date_format = '%m/%d/%y')
context.nflx = symbol('NFLX') context.nflx = symbol('NFLX')
context.aapl = symbol('AAPL') context.aapl = symbol('AAPL', country_code='US')
def handle_data(context, data): def handle_data(context, data):
assert np.isnan(data.current(context.nflx, 'invalid_column')) assert np.isnan(data.current(context.nflx, 'invalid_column'))

View file

@ -793,6 +793,7 @@ class TradingAlgorithm(object):
mask=True, mask=True,
symbol_column=None, symbol_column=None,
special_params_checker=None, special_params_checker=None,
country_code=None,
**kwargs): **kwargs):
"""Fetch a csv from a remote url and register the data so that it is """Fetch a csv from a remote url and register the data so that it is
queryable from the ``data`` object. queryable from the ``data`` object.
@ -828,6 +829,8 @@ class TradingAlgorithm(object):
argument is the name of the column in the preprocessed dataframe argument is the name of the column in the preprocessed dataframe
containing the symbols. This will be used along with the date containing the symbols. This will be used along with the date
information to map the sids in the asset finder. information to map the sids in the asset finder.
country_code : str, optional
Country code to use to disambiguate symbol lookups.
**kwargs **kwargs
Forwarded to :func:`pandas.read_csv`. Forwarded to :func:`pandas.read_csv`.
@ -836,6 +839,10 @@ class TradingAlgorithm(object):
csv_data_source : zipline.sources.requests_csv.PandasRequestsCSV csv_data_source : zipline.sources.requests_csv.PandasRequestsCSV
A requests source that will pull data from the url specified. A requests source that will pull data from the url specified.
""" """
if country_code is None:
country_code = self.default_fetch_csv_country_code(
self.trading_calendar,
)
# Show all the logs every time fetcher is used. # Show all the logs every time fetcher is used.
csv_data_source = PandasRequestsCSV( csv_data_source = PandasRequestsCSV(
@ -853,6 +860,7 @@ class TradingAlgorithm(object):
mask, mask,
symbol_column, symbol_column,
data_frequency=self.data_frequency, data_frequency=self.data_frequency,
country_code=country_code,
special_params_checker=special_params_checker, special_params_checker=special_params_checker,
**kwargs **kwargs
) )
@ -2367,6 +2375,12 @@ class TradingAlgorithm(object):
""" """
return _DEFAULT_DOMAINS.get(calendar.name, domain.GENERIC) return _DEFAULT_DOMAINS.get(calendar.name, domain.GENERIC)
@staticmethod
def default_fetch_csv_country_code(calendar):
"""
"""
return _DEFAULT_FETCH_CSV_COUNTRY_CODES.get(calendar.name)
################## ##################
# End Pipeline API # End Pipeline API
################## ##################
@ -2384,3 +2398,9 @@ class TradingAlgorithm(object):
# Map from calendar name to default domain for that calendar. # Map from calendar name to default domain for that calendar.
_DEFAULT_DOMAINS = {d.calendar_name: d for d in domain.BUILT_IN_DOMAINS} _DEFAULT_DOMAINS = {d.calendar_name: d for d in domain.BUILT_IN_DOMAINS}
# Map from calendar name to default country code for that calendar.
_DEFAULT_FETCH_CSV_COUNTRY_CODES = {
d.calendar_name: d.country_code for d in domain.BUILT_IN_DOMAINS
}
# Include us_futures, which doesn't have a pipeline domain.
_DEFAULT_FETCH_CSV_COUNTRY_CODES['us_futures'] = 'US'

View file

@ -157,6 +157,7 @@ class PandasCSV(with_metaclass(ABCMeta, object)):
mask, mask,
symbol_column, symbol_column,
data_frequency, data_frequency,
country_code,
**kwargs): **kwargs):
self.start_date = start_date self.start_date = start_date
@ -167,6 +168,7 @@ class PandasCSV(with_metaclass(ABCMeta, object)):
self.mask = mask self.mask = mask
self.symbol_column = symbol_column or "symbol" self.symbol_column = symbol_column or "symbol"
self.data_frequency = data_frequency self.data_frequency = data_frequency
self.country_code = country_code
invalid_kwargs = set(kwargs) - ALLOWED_READ_CSV_KWARGS invalid_kwargs = set(kwargs) - ALLOWED_READ_CSV_KWARGS
if invalid_kwargs: if invalid_kwargs:
@ -272,7 +274,11 @@ class PandasCSV(with_metaclass(ABCMeta, object)):
return numpy.nan return numpy.nan
try: try:
return self.finder.lookup_symbol(uppered, as_of_date=None) return self.finder.lookup_symbol(
uppered,
as_of_date=None,
country_code=self.country_code,
)
except MultipleSymbolsFound: except MultipleSymbolsFound:
# Fill conflicted entries with zeros to mark that they need to be # Fill conflicted entries with zeros to mark that they need to be
# resolved by date. # resolved by date.
@ -342,6 +348,7 @@ class PandasCSV(with_metaclass(ABCMeta, object)):
# Replacing tzinfo here is necessary because of the # Replacing tzinfo here is necessary because of the
# timezone metadata bug described below. # timezone metadata bug described below.
row['dt'].replace(tzinfo=pytz.utc), row['dt'].replace(tzinfo=pytz.utc),
country_code=self.country_code,
# It's possible that no asset comes back here if our # It's possible that no asset comes back here if our
# lookup date is from before any asset held the # lookup date is from before any asset held the
@ -470,6 +477,7 @@ class PandasRequestsCSV(PandasCSV):
mask, mask,
symbol_column, symbol_column,
data_frequency, data_frequency,
country_code,
special_params_checker=None, special_params_checker=None,
**kwargs): **kwargs):
@ -503,6 +511,7 @@ class PandasRequestsCSV(PandasCSV):
mask, mask,
symbol_column, symbol_column,
data_frequency, data_frequency,
country_code=country_code,
**remaining_kwargs **remaining_kwargs
) )