mirror of
https://github.com/saymrwulf/zipline.git
synced 2026-05-14 20:58:10 +00:00
BUG: Use calendar to disambiguate symbol lookups in fetcher.
This commit is contained in:
parent
1bcb3140ae
commit
f2d49d72fe
3 changed files with 44 additions and 8 deletions
|
|
@ -61,48 +61,55 @@ class FetcherTestCase(WithResponses,
|
||||||
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
|
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
|
||||||
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
|
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
|
||||||
'symbol': 'AAPL',
|
'symbol': 'AAPL',
|
||||||
'asset_type': 'equity',
|
|
||||||
'exchange': 'nasdaq'
|
'exchange': 'nasdaq'
|
||||||
},
|
},
|
||||||
3766: {
|
3766: {
|
||||||
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
|
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
|
||||||
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
|
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
|
||||||
'symbol': 'IBM',
|
'symbol': 'IBM',
|
||||||
'asset_type': 'equity',
|
|
||||||
'exchange': 'nasdaq'
|
'exchange': 'nasdaq'
|
||||||
},
|
},
|
||||||
5061: {
|
5061: {
|
||||||
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
|
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
|
||||||
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
|
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
|
||||||
'symbol': 'MSFT',
|
'symbol': 'MSFT',
|
||||||
'asset_type': 'equity',
|
|
||||||
'exchange': 'nasdaq'
|
'exchange': 'nasdaq'
|
||||||
},
|
},
|
||||||
14848: {
|
14848: {
|
||||||
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
|
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
|
||||||
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
|
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
|
||||||
'symbol': 'YHOO',
|
'symbol': 'YHOO',
|
||||||
'asset_type': 'equity',
|
|
||||||
'exchange': 'nasdaq'
|
'exchange': 'nasdaq'
|
||||||
},
|
},
|
||||||
25317: {
|
25317: {
|
||||||
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
|
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
|
||||||
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
|
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
|
||||||
'symbol': 'DELL',
|
'symbol': 'DELL',
|
||||||
'asset_type': 'equity',
|
|
||||||
'exchange': 'nasdaq'
|
'exchange': 'nasdaq'
|
||||||
},
|
},
|
||||||
13: {
|
13: {
|
||||||
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
|
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
|
||||||
'end_date': pd.Timestamp('2010-01-01', tz='UTC'),
|
'end_date': pd.Timestamp('2010-01-01', tz='UTC'),
|
||||||
'symbol': 'NFLX',
|
'symbol': 'NFLX',
|
||||||
'asset_type': 'equity',
|
|
||||||
'exchange': 'nasdaq'
|
'exchange': 'nasdaq'
|
||||||
|
},
|
||||||
|
9999999: {
|
||||||
|
'start_date': pd.Timestamp('2006-01-01', tz='UTC'),
|
||||||
|
'end_date': pd.Timestamp('2007-01-01', tz='UTC'),
|
||||||
|
'symbol': 'AAPL',
|
||||||
|
'exchange': 'non_us_exchange'
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
orient='index',
|
orient='index',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_exchanges_info(cls, *args, **kwargs):
|
||||||
|
return pd.DataFrame.from_records([
|
||||||
|
{'exchange': 'nasdaq', 'country_code': 'US'},
|
||||||
|
{'exchange': 'non_us_exchange', 'country_code': 'CA'},
|
||||||
|
])
|
||||||
|
|
||||||
def run_algo(self, code, sim_params=None):
|
def run_algo(self, code, sim_params=None):
|
||||||
if sim_params is None:
|
if sim_params is None:
|
||||||
sim_params = self.sim_params
|
sim_params = self.sim_params
|
||||||
|
|
@ -603,7 +610,7 @@ def initialize(context):
|
||||||
date_column = 'Settlement Date',
|
date_column = 'Settlement Date',
|
||||||
date_format = '%m/%d/%y')
|
date_format = '%m/%d/%y')
|
||||||
context.nflx = symbol('NFLX')
|
context.nflx = symbol('NFLX')
|
||||||
context.aapl = symbol('AAPL')
|
context.aapl = symbol('AAPL', country_code='US')
|
||||||
|
|
||||||
def handle_data(context, data):
|
def handle_data(context, data):
|
||||||
assert np.isnan(data.current(context.nflx, 'invalid_column'))
|
assert np.isnan(data.current(context.nflx, 'invalid_column'))
|
||||||
|
|
|
||||||
|
|
@ -793,6 +793,7 @@ class TradingAlgorithm(object):
|
||||||
mask=True,
|
mask=True,
|
||||||
symbol_column=None,
|
symbol_column=None,
|
||||||
special_params_checker=None,
|
special_params_checker=None,
|
||||||
|
country_code=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Fetch a csv from a remote url and register the data so that it is
|
"""Fetch a csv from a remote url and register the data so that it is
|
||||||
queryable from the ``data`` object.
|
queryable from the ``data`` object.
|
||||||
|
|
@ -828,6 +829,8 @@ class TradingAlgorithm(object):
|
||||||
argument is the name of the column in the preprocessed dataframe
|
argument is the name of the column in the preprocessed dataframe
|
||||||
containing the symbols. This will be used along with the date
|
containing the symbols. This will be used along with the date
|
||||||
information to map the sids in the asset finder.
|
information to map the sids in the asset finder.
|
||||||
|
country_code : str, optional
|
||||||
|
Country code to use to disambiguate symbol lookups.
|
||||||
**kwargs
|
**kwargs
|
||||||
Forwarded to :func:`pandas.read_csv`.
|
Forwarded to :func:`pandas.read_csv`.
|
||||||
|
|
||||||
|
|
@ -836,6 +839,10 @@ class TradingAlgorithm(object):
|
||||||
csv_data_source : zipline.sources.requests_csv.PandasRequestsCSV
|
csv_data_source : zipline.sources.requests_csv.PandasRequestsCSV
|
||||||
A requests source that will pull data from the url specified.
|
A requests source that will pull data from the url specified.
|
||||||
"""
|
"""
|
||||||
|
if country_code is None:
|
||||||
|
country_code = self.default_fetch_csv_country_code(
|
||||||
|
self.trading_calendar,
|
||||||
|
)
|
||||||
|
|
||||||
# Show all the logs every time fetcher is used.
|
# Show all the logs every time fetcher is used.
|
||||||
csv_data_source = PandasRequestsCSV(
|
csv_data_source = PandasRequestsCSV(
|
||||||
|
|
@ -853,6 +860,7 @@ class TradingAlgorithm(object):
|
||||||
mask,
|
mask,
|
||||||
symbol_column,
|
symbol_column,
|
||||||
data_frequency=self.data_frequency,
|
data_frequency=self.data_frequency,
|
||||||
|
country_code=country_code,
|
||||||
special_params_checker=special_params_checker,
|
special_params_checker=special_params_checker,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
@ -2367,6 +2375,12 @@ class TradingAlgorithm(object):
|
||||||
"""
|
"""
|
||||||
return _DEFAULT_DOMAINS.get(calendar.name, domain.GENERIC)
|
return _DEFAULT_DOMAINS.get(calendar.name, domain.GENERIC)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def default_fetch_csv_country_code(calendar):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
return _DEFAULT_FETCH_CSV_COUNTRY_CODES.get(calendar.name)
|
||||||
|
|
||||||
##################
|
##################
|
||||||
# End Pipeline API
|
# End Pipeline API
|
||||||
##################
|
##################
|
||||||
|
|
@ -2384,3 +2398,9 @@ class TradingAlgorithm(object):
|
||||||
|
|
||||||
# Map from calendar name to default domain for that calendar.
|
# Map from calendar name to default domain for that calendar.
|
||||||
_DEFAULT_DOMAINS = {d.calendar_name: d for d in domain.BUILT_IN_DOMAINS}
|
_DEFAULT_DOMAINS = {d.calendar_name: d for d in domain.BUILT_IN_DOMAINS}
|
||||||
|
# Map from calendar name to default country code for that calendar.
|
||||||
|
_DEFAULT_FETCH_CSV_COUNTRY_CODES = {
|
||||||
|
d.calendar_name: d.country_code for d in domain.BUILT_IN_DOMAINS
|
||||||
|
}
|
||||||
|
# Include us_futures, which doesn't have a pipeline domain.
|
||||||
|
_DEFAULT_FETCH_CSV_COUNTRY_CODES['us_futures'] = 'US'
|
||||||
|
|
|
||||||
|
|
@ -157,6 +157,7 @@ class PandasCSV(with_metaclass(ABCMeta, object)):
|
||||||
mask,
|
mask,
|
||||||
symbol_column,
|
symbol_column,
|
||||||
data_frequency,
|
data_frequency,
|
||||||
|
country_code,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
self.start_date = start_date
|
self.start_date = start_date
|
||||||
|
|
@ -167,6 +168,7 @@ class PandasCSV(with_metaclass(ABCMeta, object)):
|
||||||
self.mask = mask
|
self.mask = mask
|
||||||
self.symbol_column = symbol_column or "symbol"
|
self.symbol_column = symbol_column or "symbol"
|
||||||
self.data_frequency = data_frequency
|
self.data_frequency = data_frequency
|
||||||
|
self.country_code = country_code
|
||||||
|
|
||||||
invalid_kwargs = set(kwargs) - ALLOWED_READ_CSV_KWARGS
|
invalid_kwargs = set(kwargs) - ALLOWED_READ_CSV_KWARGS
|
||||||
if invalid_kwargs:
|
if invalid_kwargs:
|
||||||
|
|
@ -272,7 +274,11 @@ class PandasCSV(with_metaclass(ABCMeta, object)):
|
||||||
return numpy.nan
|
return numpy.nan
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return self.finder.lookup_symbol(uppered, as_of_date=None)
|
return self.finder.lookup_symbol(
|
||||||
|
uppered,
|
||||||
|
as_of_date=None,
|
||||||
|
country_code=self.country_code,
|
||||||
|
)
|
||||||
except MultipleSymbolsFound:
|
except MultipleSymbolsFound:
|
||||||
# Fill conflicted entries with zeros to mark that they need to be
|
# Fill conflicted entries with zeros to mark that they need to be
|
||||||
# resolved by date.
|
# resolved by date.
|
||||||
|
|
@ -342,6 +348,7 @@ class PandasCSV(with_metaclass(ABCMeta, object)):
|
||||||
# Replacing tzinfo here is necessary because of the
|
# Replacing tzinfo here is necessary because of the
|
||||||
# timezone metadata bug described below.
|
# timezone metadata bug described below.
|
||||||
row['dt'].replace(tzinfo=pytz.utc),
|
row['dt'].replace(tzinfo=pytz.utc),
|
||||||
|
country_code=self.country_code,
|
||||||
|
|
||||||
# It's possible that no asset comes back here if our
|
# It's possible that no asset comes back here if our
|
||||||
# lookup date is from before any asset held the
|
# lookup date is from before any asset held the
|
||||||
|
|
@ -470,6 +477,7 @@ class PandasRequestsCSV(PandasCSV):
|
||||||
mask,
|
mask,
|
||||||
symbol_column,
|
symbol_column,
|
||||||
data_frequency,
|
data_frequency,
|
||||||
|
country_code,
|
||||||
special_params_checker=None,
|
special_params_checker=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
|
|
@ -503,6 +511,7 @@ class PandasRequestsCSV(PandasCSV):
|
||||||
mask,
|
mask,
|
||||||
symbol_column,
|
symbol_column,
|
||||||
data_frequency,
|
data_frequency,
|
||||||
|
country_code=country_code,
|
||||||
**remaining_kwargs
|
**remaining_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue