diff --git a/aten/src/ATen/native/TensorProperties.cpp b/aten/src/ATen/native/TensorProperties.cpp index 65a79865a65..050442aea89 100644 --- a/aten/src/ATen/native/TensorProperties.cpp +++ b/aten/src/ATen/native/TensorProperties.cpp @@ -57,7 +57,7 @@ bool cudnn_is_acceptable(const Tensor& self) { Tensor detach(const Tensor& self) { #ifndef USE_STATIC_DISPATCH // this just exists to give us a hook in VariableType and an entry in Declarations.yaml - AT_ERROR("detach is not implemented for Tensor"); + //AT_ERROR("detach is not implemented for Tensor"); #endif // this is no-op for USE_STATIC_DISPATCH mode return self; @@ -66,7 +66,7 @@ Tensor detach(const Tensor& self) { Tensor & detach_(Tensor & self) { #ifndef USE_STATIC_DISPATCH // this just exists to give us a hook in VariableType and an entry in Declarations.yaml - AT_ERROR("detach_ is not implemented for Tensor"); + //AT_ERROR("detach_ is not implemented for Tensor"); #endif // this is no-op for USE_STATIC_DISPATCH mode return self; diff --git a/c10/core/impl/LocalTensorTypeSet.cpp b/c10/core/impl/LocalTensorTypeSet.cpp index f5c328c98d1..e144cace040 100644 --- a/c10/core/impl/LocalTensorTypeSet.cpp +++ b/c10/core/impl/LocalTensorTypeSet.cpp @@ -5,6 +5,8 @@ namespace c10 { namespace impl { +C10_DEFINE_bool(disable_variable_dispatch, false, "This flag forcibly disables the Variable code paths from executing, which currently breaks profiling in the process."); + namespace { /// In the CAFFE2_FB_LIMITED_MOBILE_CAPABILITY build setting, @@ -23,6 +25,12 @@ static PODLocalTensorTypeSet raw_local_tensor_type_set; } // anonymous namespace LocalTensorTypeSet tls_local_tensor_type_set() { + // Hack until variable performance is fixed + if (FLAGS_disable_variable_dispatch) { + raw_local_tensor_type_set.set_excluded( + raw_local_tensor_type_set.excluded().add( + TensorTypeId::VariableTensorId)); + } return raw_local_tensor_type_set; } diff --git a/c10/core/impl/LocalTensorTypeSet.h b/c10/core/impl/LocalTensorTypeSet.h index 728cc8afe3f..2cf5d3993d9 100644 --- a/c10/core/impl/LocalTensorTypeSet.h +++ b/c10/core/impl/LocalTensorTypeSet.h @@ -1,6 +1,7 @@ #pragma once #include +#include // TLS management for TensorTypeSet (the "local" TensorTypeSet(s)) // @@ -22,6 +23,8 @@ namespace c10 { namespace impl { +C10_DECLARE_bool(disable_variable_dispatch); + // POD version of LocalTensorTypeSet. Declared here just so that // we can put it in the guards. struct C10_API PODLocalTensorTypeSet {