TensorFlow架構與設計:會話生命周期

uvtn3309as2 7年前發布 | 40K 次閱讀 TensorFlow C/C++

TensorFlow的系統結構以C API為界,將整個系統分為「前端」和「后端」兩個子系統:

  • 前端系統:提供編程模型,負責構造計算圖;
  • 后端系統:提供運行時環境,負責執行計算圖。

系統架構

前端系統主要扮演Client的角色,主要負責計算圖的構造,并管理Session生命周期過程。

前端系統是一個支持多語言的編程環境,并提供統一的編程模型支撐用戶構造計算圖。Client通過Session,連接TensorFlow后端的「運行時」,啟動計算圖的執行過程。

后端系統是TensorFlow的運行時系統,主要負責計算圖的執行過程,包括計算圖的剪枝,設備分配,子圖計算等過程。

本文首先以Session創建為例,揭示前端Python與后端C/C++系統實現的通道,闡述TensorFlow多語言編程的奧秘。隨后,以Python前端,C API橋梁,C++后端為生命線,闡述Session的生命周期過程。

Swig: 幕后英雄

前端多語言編程環境與后端C/C++實現系統的通道歸功于Swig的包裝器。TensorFlow使用Bazel的構建工具,在編譯之前啟動Swig的代碼生成過程,通過tf_session.i自動生成了兩個適配(Wrapper)文件:

  • pywrap_tensorflow.py: 負責對接上層Python調用;
  • pywrap_tensorflow.cpp: 負責對接下層C實現。

此外,pywrap_tensorflow.py模塊首次被加載時,自動地加載_pywrap_tensorflow.so的動態鏈接庫。從而實現了pywrap_tensorflow.py到pywrap_tensorflow.cpp的函數調用關系。

在pywrap_tensorflow.cpp的實現中,靜態注冊了一個函數符號表。在運行時,按照Python的函數名稱,匹配找到對應的C函數實現,最終轉調到c_api.c的具體實現。

Swig代碼生成器

編程接口:Python

當Client要啟動計算圖的執行過程時,先創建了一個Session實例,進而調用父類BaseSession的構造函數。

# tensorflow/python/client/session.py
class Session(BaseSession):
  def __init__(self, target='', graph=None, config=None):
    super(Session, self).__init__(target, graph, config=config)
    # ignoring others

在BaseSession的構造函數中,將調用pywrap_tensorflow模塊中的函數。其中,pywrap_tensorflow模塊自動由Swig生成。

# tensorflow/python/client/session.py
from tensorflow.python import pywrap_tensorflow as tf_session

class BaseSession(SessionInterface):
  def __init__(self, target='', graph=None, config=None):
    self._session = None
    opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
    try:
      with errors.raise_exception_on_not_ok_status() as status:
        self._session = tf_session.TF_NewDeprecatedSession(opts, status)
    finally:
      tf_session.TF_DeleteSessionOptions(opts)
    # ignoring others

生成代碼:Swig

pywrap_tensorflow.py

在pywrap_tensorflow模塊中,通過_pywrap_tensorflow將在_pywrap_tensorflow.so中調用對應的C++函數實現。

# tensorflow/bazel-bin/tensorflow/python/pywrap_tensorflow.py
def TF_NewDeprecatedSession(arg1, status):
    return _pywrap_tensorflow.TF_NewDeprecatedSession(arg1, status)
pywrap_tensorflow.cpp

在pywrap_tensorflow.cpp的具體實現中,它靜態注冊了函數調用的符號表,實現Python的函數名稱到C++實現函數的具體映射。

# tensorflow/bazel-bin/tensorflow/python/pywrap_tensorflow.cpp
static PyMethodDef SwigMethods[] = {
    ...
     {"TF_NewDeprecatedSession", _wrap_TF_NewDeprecatedSession, METH_VARARGS, NULL},
}

PyObject *_wrap_TF_NewDeprecatedSession(
  PyObject *self, PyObject *args) {
  TF_SessionOptions* arg1 = ... 
  TF_Status* arg2 = ...

  TF_DeprecatedSession* result = TF_NewDeprecatedSession(arg1, arg2);
  // ignoring others implements
}

最終,自動生成的pywrap_tensorflow.cpp僅僅負責函數調用的轉發,最終將調用底層C系統向上提供的API接口。

C API:橋梁

c_api.h是TensorFlow的后端執行系統面向前端開放的公共API接口之一,自此將進入TensorFlow后端系統的浩瀚天空。

// tensorflow/c/c_api.c
TF_DeprecatedSession* TF_NewDeprecatedSession(
  const TF_SessionOptions*, TF_Status* status) {
  Session* session;
  status->status = NewSession(opt->options, &session);
  if (status->status.ok()) {
    return new TF_DeprecatedSession({session});
  } else {
    return NULL;
  }
}

后端系統:C++

NewSession將根據前端傳遞的Session.target,使用SessionFactory多態創建不同類型的Session(C++)對象。

Status NewSession(const SessionOptions& options, Session** out_session) {
  SessionFactory* factory;
  Status s = SessionFactory::GetFactory(options, &factory);
  if (!s.ok()) {
    *out_session = nullptr;
    LOG(ERROR) << s;
    return s;
  }
  *out_session = factory->NewSession(options);
  if (!*out_session) {
    return errors::Internal("Failed to create session.");
  }
  return Status::OK();
}

會話生命周期

下文以前端Python,橋梁C API,后端C++為生命線,理順三者之間的調用關系,闡述Session的生命周期過程。

在Python前端,Session的生命周期主要體現在:

Session._extend_graph(graph)

  • 創建Session(target)
  • 迭代執行Session.run(fetches, feed_dict)
  • Session.TF_Run(feeds, fetches, targets)
  • 關閉Session
  • 銷毀Session
sess = Session(target)
for _ in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
sess.close()

相應地,C++后端,Session的生命周期主要體現在:

  • 根據target多態創建Session
  • Session.Create(graph):有且僅有一次
  • Session.Extend(graph):零次或多次
  • 迭代執行Session.Run(inputs, outputs, targets)
  • 關閉Session.Close
  • 銷毀Session對象
// create/load graph ...
tensorflow::GraphDef graph;

// local runtime, target is ""
tensorflow::SessionOptions options;

// create Session
std::unique_ptr<tensorflow::Session> 
sess(tensorflow::NewSession(options));

// create graph at initialization.
tensorflow::Status s = sess->Create(graph);
if (!s.ok()) { ... }

// run step
std::vector<tensorflow::Tensor> outputs;
s = session->Run(
  {},               // inputs is empty 
  {"output:0"},     // outputs names
  {"update_state"}, // target names
  &outputs);        // output tensors
if (!s.ok()) { ... }

// close
session->Close();

創建會話

上文介紹了Session創建的詳細過程,從Python前端為起點,通過Swig自動生成的Python-C++的包裝器為媒介,實現了Python到TensorFlow的C API的調用。

其中,C API是前端系統與后端系統的分水嶺。后端C++系統根據前端傳遞的Session.target,使用SessionFactory多態創建Session(C++)對象。

創建會話

從嚴格的角色意義上劃分,GrpcSession依然扮演了Client的角色。它使用target,通過RPC協議與Master建立通信連接,因此,GrpcSession同時扮演了RPC Client的角色。

Session多態創建

創建/擴展圖

隨后,Python前端將調用Session.run接口,將構造好的計算圖,以GraphDef的形式發送給C++后端。

其中,前端每次調用Session.run接口時,都會試圖將新增節點的計算圖發送給后端系統,以便后端系統將新增節點的計算圖Extend到原來的計算圖中。特殊地,在首次調用Session.run時,將發送整個計算圖給后端系統。

后端系統首次調用Session.Extend時,轉調(或等價)Session.Create;以后,后端系統每次調用Session.Extend時將真正執行Extend的語義,將新增的計算圖的節點追加至原來的計算圖中。

隨后,后端將啟動計算圖執行的準備工作。

創建/擴展圖

迭代運行

接著,Python前端Session.run實現將Feed, Fetch列表準備好,傳遞給后端系統。后端系統調用Session.Run接口。

后端系統的一次Session.Run執行常常被稱為一次Step,Step的執行過程是TensorFlow運行時的核心。

每次Step,計算圖將正向計算網絡的輸出,反向傳遞梯度,并完成一次訓練參數的更新。首先,后端系統根據Feed, Fetch,對計算圖(常稱為Full Graph)進行剪枝,得到一個最小依賴的計算子圖(常稱為Client Graph)。

然后,運行時啟動設備分配算法,如果節點之間的邊橫跨設備,則將該邊分裂,插入相應的Send與Recv節點,實現跨設備節點的通信機制。

隨后,將分裂出來的子圖片段(常稱為Partition Graph)注冊到相應的設備上,并在本地設備上啟動子圖片段的執行過程。

Run Step

關閉會話

當計算圖執行完畢后,需要關閉Session,以便釋放后端的系統資源,包括隊列,IO等。會話關閉流程較為簡單,如下圖所示。

關閉會話

銷毀會話

最后,會話關閉之后,Python前端系統啟動GC,當Session.del被調用后,啟動后臺C++的Session對象銷毀過程。

銷毀會話

 

來自:http://www.iteye.com/news/32241

 

 本文由用戶 uvtn3309as2 自行上傳分享,僅供網友學習交流。所有權歸原作者,若您的權利被侵害,請聯系管理員。
 轉載本站原創文章,請注明出處,并保留原始鏈接、圖片水印。
 本站是一個以用戶分享為主的開源技術平臺,歡迎各類分享!