mirror of
https://github.com/saymrwulf/zipline.git
synced 2026-05-14 20:58:10 +00:00
BUG: Use calendar to disambiguate symbol lookups in fetcher.
This commit is contained in:
parent
1bcb3140ae
commit
f2d49d72fe
3 changed files with 44 additions and 8 deletions
|
|
@ -61,48 +61,55 @@ class FetcherTestCase(WithResponses,
|
|||
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
|
||||
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
|
||||
'symbol': 'AAPL',
|
||||
'asset_type': 'equity',
|
||||
'exchange': 'nasdaq'
|
||||
},
|
||||
3766: {
|
||||
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
|
||||
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
|
||||
'symbol': 'IBM',
|
||||
'asset_type': 'equity',
|
||||
'exchange': 'nasdaq'
|
||||
},
|
||||
5061: {
|
||||
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
|
||||
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
|
||||
'symbol': 'MSFT',
|
||||
'asset_type': 'equity',
|
||||
'exchange': 'nasdaq'
|
||||
},
|
||||
14848: {
|
||||
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
|
||||
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
|
||||
'symbol': 'YHOO',
|
||||
'asset_type': 'equity',
|
||||
'exchange': 'nasdaq'
|
||||
},
|
||||
25317: {
|
||||
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
|
||||
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
|
||||
'symbol': 'DELL',
|
||||
'asset_type': 'equity',
|
||||
'exchange': 'nasdaq'
|
||||
},
|
||||
13: {
|
||||
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
|
||||
'end_date': pd.Timestamp('2010-01-01', tz='UTC'),
|
||||
'symbol': 'NFLX',
|
||||
'asset_type': 'equity',
|
||||
'exchange': 'nasdaq'
|
||||
},
|
||||
9999999: {
|
||||
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
|
||||
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
|
||||
'symbol': 'AAPL',
|
||||
'exchange': 'non_us_exchange'
|
||||
}
|
||||
},
|
||||
orient='index',
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def make_exchanges_info(cls, *args, **kwargs):
|
||||
return pd.DataFrame.from_records([
|
||||
{'exchange': 'nasdaq', 'country_code': 'US'},
|
||||
{'exchange': 'non_us_exchange', 'country_code': 'CA'},
|
||||
])
|
||||
|
||||
def run_algo(self, code, sim_params=None):
|
||||
if sim_params is None:
|
||||
sim_params = self.sim_params
|
||||
|
|
@ -603,7 +610,7 @@ def initialize(context):
|
|||
date_column = 'Settlement Date',
|
||||
date_format = '%m/%d/%y')
|
||||
context.nflx = symbol('NFLX')
|
||||
context.aapl = symbol('AAPL')
|
||||
context.aapl = symbol('AAPL', country_code='US')
|
||||
|
||||
def handle_data(context, data):
|
||||
assert np.isnan(data.current(context.nflx, 'invalid_column'))
|
||||
|
|
|
|||
|
|
@ -793,6 +793,7 @@ class TradingAlgorithm(object):
|
|||
mask=True,
|
||||
symbol_column=None,
|
||||
special_params_checker=None,
|
||||
country_code=None,
|
||||
**kwargs):
|
||||
"""Fetch a csv from a remote url and register the data so that it is
|
||||
queryable from the ``data`` object.
|
||||
|
|
@ -828,6 +829,8 @@ class TradingAlgorithm(object):
|
|||
argument is the name of the column in the preprocessed dataframe
|
||||
containing the symbols. This will be used along with the date
|
||||
information to map the sids in the asset finder.
|
||||
country_code : str, optional
|
||||
Country code to use to disambiguate symbol lookups.
|
||||
**kwargs
|
||||
Forwarded to :func:`pandas.read_csv`.
|
||||
|
||||
|
|
@ -836,6 +839,10 @@ class TradingAlgorithm(object):
|
|||
csv_data_source : zipline.sources.requests_csv.PandasRequestsCSV
|
||||
A requests source that will pull data from the url specified.
|
||||
"""
|
||||
if country_code is None:
|
||||
country_code = self.default_fetch_csv_country_code(
|
||||
self.trading_calendar,
|
||||
)
|
||||
|
||||
# Show all the logs every time fetcher is used.
|
||||
csv_data_source = PandasRequestsCSV(
|
||||
|
|
@ -853,6 +860,7 @@ class TradingAlgorithm(object):
|
|||
mask,
|
||||
symbol_column,
|
||||
data_frequency=self.data_frequency,
|
||||
country_code=country_code,
|
||||
special_params_checker=special_params_checker,
|
||||
**kwargs
|
||||
)
|
||||
|
|
@ -2367,6 +2375,12 @@ class TradingAlgorithm(object):
|
|||
"""
|
||||
return _DEFAULT_DOMAINS.get(calendar.name, domain.GENERIC)
|
||||
|
||||
@staticmethod
|
||||
def default_fetch_csv_country_code(calendar):
|
||||
"""
|
||||
"""
|
||||
return _DEFAULT_FETCH_CSV_COUNTRY_CODES.get(calendar.name)
|
||||
|
||||
##################
|
||||
# End Pipeline API
|
||||
##################
|
||||
|
|
@ -2384,3 +2398,9 @@ class TradingAlgorithm(object):
|
|||
|
||||
# Map from calendar name to default domain for that calendar.
|
||||
_DEFAULT_DOMAINS = {d.calendar_name: d for d in domain.BUILT_IN_DOMAINS}
|
||||
# Map from calendar name to default country code for that calendar.
|
||||
_DEFAULT_FETCH_CSV_COUNTRY_CODES = {
|
||||
d.calendar_name: d.country_code for d in domain.BUILT_IN_DOMAINS
|
||||
}
|
||||
# Include us_futures, which doesn't have a pipeline domain.
|
||||
_DEFAULT_FETCH_CSV_COUNTRY_CODES['us_futures'] = 'US'
|
||||
|
|
|
|||
|
|
@ -157,6 +157,7 @@ class PandasCSV(with_metaclass(ABCMeta, object)):
|
|||
mask,
|
||||
symbol_column,
|
||||
data_frequency,
|
||||
country_code,
|
||||
**kwargs):
|
||||
|
||||
self.start_date = start_date
|
||||
|
|
@ -167,6 +168,7 @@ class PandasCSV(with_metaclass(ABCMeta, object)):
|
|||
self.mask = mask
|
||||
self.symbol_column = symbol_column or "symbol"
|
||||
self.data_frequency = data_frequency
|
||||
self.country_code = country_code
|
||||
|
||||
invalid_kwargs = set(kwargs) - ALLOWED_READ_CSV_KWARGS
|
||||
if invalid_kwargs:
|
||||
|
|
@ -272,7 +274,11 @@ class PandasCSV(with_metaclass(ABCMeta, object)):
|
|||
return numpy.nan
|
||||
|
||||
try:
|
||||
return self.finder.lookup_symbol(uppered, as_of_date=None)
|
||||
return self.finder.lookup_symbol(
|
||||
uppered,
|
||||
as_of_date=None,
|
||||
country_code=self.country_code,
|
||||
)
|
||||
except MultipleSymbolsFound:
|
||||
# Fill conflicted entries with zeros to mark that they need to be
|
||||
# resolved by date.
|
||||
|
|
@ -342,6 +348,7 @@ class PandasCSV(with_metaclass(ABCMeta, object)):
|
|||
# Replacing tzinfo here is necessary because of the
|
||||
# timezone metadata bug described below.
|
||||
row['dt'].replace(tzinfo=pytz.utc),
|
||||
country_code=self.country_code,
|
||||
|
||||
# It's possible that no asset comes back here if our
|
||||
# lookup date is from before any asset held the
|
||||
|
|
@ -470,6 +477,7 @@ class PandasRequestsCSV(PandasCSV):
|
|||
mask,
|
||||
symbol_column,
|
||||
data_frequency,
|
||||
country_code,
|
||||
special_params_checker=None,
|
||||
**kwargs):
|
||||
|
||||
|
|
@ -503,6 +511,7 @@ class PandasRequestsCSV(PandasCSV):
|
|||
mask,
|
||||
symbol_column,
|
||||
data_frequency,
|
||||
country_code=country_code,
|
||||
**remaining_kwargs
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue