BUG: Use calendar to disambiguate symbol lookups in fetcher.

This commit is contained in:
Scott Sanderson 2018-10-02 11:38:45 -04:00
parent 1bcb3140ae
commit f2d49d72fe
3 changed files with 44 additions and 8 deletions

View file

@ -61,48 +61,55 @@ class FetcherTestCase(WithResponses,
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
'symbol': 'AAPL',
'asset_type': 'equity',
'exchange': 'nasdaq'
},
3766: {
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
'symbol': 'IBM',
'asset_type': 'equity',
'exchange': 'nasdaq'
},
5061: {
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
'symbol': 'MSFT',
'asset_type': 'equity',
'exchange': 'nasdaq'
},
14848: {
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
'symbol': 'YHOO',
'asset_type': 'equity',
'exchange': 'nasdaq'
},
25317: {
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
'symbol': 'DELL',
'asset_type': 'equity',
'exchange': 'nasdaq'
},
13: {
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
'end_date': pd.Timestamp('2010-01-01', tz='UTC'),
'symbol': 'NFLX',
'asset_type': 'equity',
'exchange': 'nasdaq'
},
9999999: {
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
'symbol': 'AAPL',
'exchange': 'non_us_exchange'
}
},
orient='index',
)
@classmethod
def make_exchanges_info(cls, *args, **kwargs):
return pd.DataFrame.from_records([
{'exchange': 'nasdaq', 'country_code': 'US'},
{'exchange': 'non_us_exchange', 'country_code': 'CA'},
])
def run_algo(self, code, sim_params=None):
if sim_params is None:
sim_params = self.sim_params
@ -603,7 +610,7 @@ def initialize(context):
date_column = 'Settlement Date',
date_format = '%m/%d/%y')
context.nflx = symbol('NFLX')
context.aapl = symbol('AAPL')
context.aapl = symbol('AAPL', country_code='US')
def handle_data(context, data):
assert np.isnan(data.current(context.nflx, 'invalid_column'))

View file

@ -793,6 +793,7 @@ class TradingAlgorithm(object):
mask=True,
symbol_column=None,
special_params_checker=None,
country_code=None,
**kwargs):
"""Fetch a csv from a remote url and register the data so that it is
queryable from the ``data`` object.
@ -828,6 +829,8 @@ class TradingAlgorithm(object):
argument is the name of the column in the preprocessed dataframe
containing the symbols. This will be used along with the date
information to map the sids in the asset finder.
country_code : str, optional
Country code to use to disambiguate symbol lookups.
**kwargs
Forwarded to :func:`pandas.read_csv`.
@ -836,6 +839,10 @@ class TradingAlgorithm(object):
csv_data_source : zipline.sources.requests_csv.PandasRequestsCSV
A requests source that will pull data from the url specified.
"""
if country_code is None:
country_code = self.default_fetch_csv_country_code(
self.trading_calendar,
)
# Show all the logs every time fetcher is used.
csv_data_source = PandasRequestsCSV(
@ -853,6 +860,7 @@ class TradingAlgorithm(object):
mask,
symbol_column,
data_frequency=self.data_frequency,
country_code=country_code,
special_params_checker=special_params_checker,
**kwargs
)
@ -2367,6 +2375,12 @@ class TradingAlgorithm(object):
"""
return _DEFAULT_DOMAINS.get(calendar.name, domain.GENERIC)
@staticmethod
def default_fetch_csv_country_code(calendar):
"""
"""
return _DEFAULT_FETCH_CSV_COUNTRY_CODES.get(calendar.name)
##################
# End Pipeline API
##################
@ -2384,3 +2398,9 @@ class TradingAlgorithm(object):
# Map from calendar name to default domain for that calendar.
_DEFAULT_DOMAINS = {d.calendar_name: d for d in domain.BUILT_IN_DOMAINS}
# Map from calendar name to default country code for that calendar.
_DEFAULT_FETCH_CSV_COUNTRY_CODES = {
d.calendar_name: d.country_code for d in domain.BUILT_IN_DOMAINS
}
# Include us_futures, which doesn't have a pipeline domain.
_DEFAULT_FETCH_CSV_COUNTRY_CODES['us_futures'] = 'US'

View file

@ -157,6 +157,7 @@ class PandasCSV(with_metaclass(ABCMeta, object)):
mask,
symbol_column,
data_frequency,
country_code,
**kwargs):
self.start_date = start_date
@ -167,6 +168,7 @@ class PandasCSV(with_metaclass(ABCMeta, object)):
self.mask = mask
self.symbol_column = symbol_column or "symbol"
self.data_frequency = data_frequency
self.country_code = country_code
invalid_kwargs = set(kwargs) - ALLOWED_READ_CSV_KWARGS
if invalid_kwargs:
@ -272,7 +274,11 @@ class PandasCSV(with_metaclass(ABCMeta, object)):
return numpy.nan
try:
return self.finder.lookup_symbol(uppered, as_of_date=None)
return self.finder.lookup_symbol(
uppered,
as_of_date=None,
country_code=self.country_code,
)
except MultipleSymbolsFound:
# Fill conflicted entries with zeros to mark that they need to be
# resolved by date.
@ -342,6 +348,7 @@ class PandasCSV(with_metaclass(ABCMeta, object)):
# Replacing tzinfo here is necessary because of the
# timezone metadata bug described below.
row['dt'].replace(tzinfo=pytz.utc),
country_code=self.country_code,
# It's possible that no asset comes back here if our
# lookup date is from before any asset held the
@ -470,6 +477,7 @@ class PandasRequestsCSV(PandasCSV):
mask,
symbol_column,
data_frequency,
country_code,
special_params_checker=None,
**kwargs):
@ -503,6 +511,7 @@ class PandasRequestsCSV(PandasCSV):
mask,
symbol_column,
data_frequency,
country_code=country_code,
**remaining_kwargs
)