diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h index 083b22ed016..e6493224816 100644 --- a/torch/csrc/jit/pybind_utils.h +++ b/torch/csrc/jit/pybind_utils.h @@ -321,14 +321,14 @@ inline IValue createGenericList(py::handle obj, const TypePtr& elem_type) { } inline IValue createGenericDict( - py::handle obj, + py::dict obj, const TypePtr& key_type, const TypePtr& value_type) { c10::impl::GenericDict elems(key_type, value_type); elems.reserve(py::len(obj)); - for (auto key : obj) { + for (auto entry : obj) { elems.insert( - toIValue(key, key_type), toIValue(obj[key], value_type)); + toIValue(entry.first, key_type), toIValue(entry.second, value_type)); } return IValue(std::move(elems)); } @@ -445,7 +445,7 @@ inline IValue toIValue( case TypeKind::DictType: { const auto& dict_type = type->expect(); return createGenericDict( - obj, dict_type->getKeyType(), dict_type->getValueType()); + py::cast(obj), dict_type->getKeyType(), dict_type->getValueType()); } case TypeKind::OptionalType: { // check if it's a none obj since optional accepts NoneType