mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Closes https://github.com/caffe2/caffe2/pull/1260 Differential Revision: D5906739 Pulled By: Yangqing fbshipit-source-id: e482ba9ba60b5337d9165f28f7ec68d4518a0902
69 lines
2.4 KiB
Python
69 lines
2.4 KiB
Python
# Copyright (c) 2016-present, Facebook, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
##############################################################################
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import copy
|
|
from caffe2.proto import caffe2_pb2
|
|
from caffe2.python import core
|
|
|
|
|
|
def rewrite_init_net_simple(net):
|
|
for op in net.op:
|
|
op.device_option.device_type = caffe2_pb2.MKLDNN
|
|
|
|
|
|
def rewrite_run_net_simple(net):
|
|
# Simple rewrite for now - assume entire graph can be executed
|
|
# with MKL, so just insert copy ops for external_input[0] and
|
|
# external_output[0]
|
|
def mkl_tmp(name):
|
|
return "{}__MKL__".format(name)
|
|
|
|
input_blob = net.external_input[0]
|
|
(output_blob,) = net.external_output
|
|
if input_blob != net.op[0].input[0]:
|
|
raise Exception(
|
|
"Input blob: {} is not consumed by first op: {}".format(
|
|
input_blob, net.op[0]))
|
|
if output_blob not in net.op[-1].output:
|
|
raise Exception(
|
|
"Output blob: {} is not produced by last op: {}".format(
|
|
output_blob, net.op[-1].output[0]))
|
|
|
|
# Modify input/outputs to point to copied MKL blobs.
|
|
|
|
copy_input_op = core.CreateOperator(
|
|
"CopyCPUToMKL", input_blob, mkl_tmp(input_blob))
|
|
net.op[0].input[0] = mkl_tmp(input_blob)
|
|
copy_output_op = core.CreateOperator(
|
|
"CopyMKLToCPU", mkl_tmp(output_blob), output_blob)
|
|
net.op[-1].output[0] = mkl_tmp(output_blob)
|
|
ops = [copy_input_op] + net.op[:] + [copy_output_op]
|
|
del net.op[:]
|
|
net.op.extend(ops)
|
|
for op in net.op:
|
|
op.device_option.device_type = caffe2_pb2.MKLDNN
|
|
|
|
|
|
def rewrite_model_helper_simple(model):
|
|
model = copy.deepcopy(model)
|
|
# All parameter initialization should run on MKL
|
|
rewrite_init_net_simple(model.param_init_net.Proto())
|
|
rewrite_run_net_simple(model.net.Proto())
|
|
return model
|