diff --git a/src/transformers/tools/base.py b/src/transformers/tools/base.py index b97bc4a43..add64e373 100644 --- a/src/transformers/tools/base.py +++ b/src/transformers/tools/base.py @@ -260,6 +260,24 @@ class Tool: tool_class = custom_tool["tool_class"] tool_class = get_class_from_dynamic_module(tool_class, repo_id, use_auth_token=token, **hub_kwargs) + if len(tool_class.name) == 0: + tool_class.name = custom_tool["name"] + if tool_class.name != custom_tool["name"]: + logger.warn( + f"{tool_class.__name__} implements a different name in its configuration and class. Using the tool " + "configuration name." + ) + tool_class.name = custom_tool["name"] + + if len(tool_class.description) == 0: + tool_class.description = custom_tool["description"] + if tool_class.description != custom_tool["description"]: + logger.warn( + f"{tool_class.__name__} implements a different description in its configuration and class. Using the " + "tool configuration description." + ) + tool_class.description = custom_tool["description"] + if remote: return RemoteTool(model_repo_id, token=token, tool_class=tool_class) return tool_class(model_repo_id, token=token, **kwargs)