pytorch/torch/csrc/jit/fuser/codegen.h
Natalia Gimelshein ed47b85d3b Allow fusion of float function arguments (#18087)
Summary:
so that functions like `def fn(x, p:float)` can be fused. Fixes #9940 and #11186. Fuses only float (not integer) arguments to simplify assembling arguments for fusion launch.
CPU fusion is disabled in CI and this won't be tested, but I tested it locally.
cc t-vi, apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18087

Differential Revision: D14581206

Pulled By: wanchaol

fbshipit-source-id: ccb0cf79b1751706f9b2cdf1715115eae5a39fb6
2019-03-22 13:52:33 -07:00

29 lines
811 B
C++

#pragma once
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/fuser/arg_spec.h>
#include <torch/csrc/jit/fuser/partition_desc.h>
#include <torch/csrc/jit/fuser/tensor_desc.h>
#include <torch/csrc/jit/ir.h>
#include <iostream>
#include <string>
#include <tuple>
#include <vector>
namespace torch {
namespace jit {
namespace fuser {
// Creates a CPU or CUDA kernel for the given graph.
// Returns the C++ or CUDA string implementing the kernel.
TORCH_API std::string generateKernel(
const std::string& name,
const Graph& graph,
const std::vector<std::pair<const Value*, const c10::optional<TensorDesc>>>& inputs,
const std::vector<std::pair<const Value*, const TensorDesc>>& outputs,
const bool use_cuda);
} // namespace fuser
} // namespace jit
} // namespace torch