目錄
- 核心概念
- direct_session
- direct_session.h
- direct_session.cc
1. 核心概念
讀過之前文章的讀者應該還記得,session是一個執行代理。我們把計算圖和輸入交給session,由它來調度執行器,執行計算產生結果。TF給我們提供了一個最簡單的執行器direction_session。按照當前的理解,我們覺得direction_session的實現應該是非常簡單而直接的,畢竟執行器的複雜結構我們在executor那篇已經見到了。但實際上,問題的痛點在於,有時候我們只是希望以計算圖中某些節點為輸入,某些節點為輸出,來執行圖中的一小部分計算,而不需要執行整張圖,另外一個方面,這種對圖部分執行的任務,在同一張圖上可能同時存在多個。為了應對這種情況,direct_session就衍生出了很多輔助資料。
2. direct_session2.1 direct_session.h
DirectSession類提供了豐富的資料和介面,以下為了表達簡潔,我們略去了部分函數的形參:
class DirectSession : public Session { public: DirectionSession(const SessionOptions& options, const Device* device_mgr, DirectSessionFactory* factory); Status Create(const GraphDef& graph) override; Status Extend(const GraphDef& graph) override; Status Run(...) override;//運行圖 Status PRunSetup(...);//部分運行圖準備 Status PRun(...);//部分運行圖 Status Reset(const std::vector<string>& containers);//清空device_mgr中的containers,如果containers本身就是空的,那麼清空預設容器 Status ListDevice(...) override; Status Close() overrides; Status LocalDeviceManager(const DeviceMgr** output) overrides; void ExportCostModels(...); private: Status MaybeInitializeExecutionState(...);//給定graph之後,如果執行器狀態沒有初始化,則初始化基礎的執行器狀態 Status GetOrCreateExecutors(...);//對於一組給定的輸入和輸出,在一個給定的執行器集合中檢索,是否存在合適的執行器,如果沒有,則創造一個 Status CreateGraphs(...);//給定graph_def_和裝置,以及輸入和輸出,創造多張圖,這些新建立的圖共用一個公用的函數庫flib_def Status ExtendLocked(const GraphDef& graph);//Extend的內部執行類 Status ResourceHandleToInputTensor(...); Status SendPRunInputs(...);//將更多的輸入提供給執行器,啟動後續的執行 Status RecvPRunOutputs(...);//從執行器中擷取更多的輸出,它會等待直到輸出張量計算完成 Status CheckFetch(...);//檢查需求的輸出能否根據給定的輸入計算出來 Status WaitForNotification(...); Status CheckNotClosed(); const SessionOptions options_; //裝置相關的結構 const std::unique_ptr<const DeviceMgr> device_mgr_; std::vector<Device*> devices_; DeviceSet device_set_; string session_handle_; bool graph_created_ GUARDED_BY(graph_def_lock_) = false; mutex graph_def_lock_; GraphDef graph_def_ GUARDED_BY(graph_def_lock_); std::vector<std::pair<thread::ThreadPool*, bool>> thread_pools_;//被用來執行op的線程池,用一個布爾值來標誌,是否擁有這個線程池 Status init_error_; bool sync_on_finish_ = true;//如果為真,阻塞線程直到裝置已經完成了某個步驟內的所有隊列中的操作 void SchedClosure(thread::ThreadPool* pool, std::function<void()> c);//線上程池中調度c mutex executor_lock_;//保護執行器 std::unordered_map<string, std::shared_ptr<ExecutorsAndkeys>> executor_ GUARDED_BY(executor_lock_);//由簽名映射到它的執行器,簽名包括了部分執行圖的輸入和輸出,由這兩個就能唯一確定一個部分執行圖 std::unordered_map<string, std::shared_ptr<RunState>> partial_runs_ GUARDED_BY(executor_lock_);//從簽名到部分執行狀態,每一個部分執行都會有一個專門儲存其狀態的結構 SessionState session_state_;//儲存了所有當前在會話中正在存活的張量 DirectSessionFactory* const factory_; CancellationManager* cancellation_manager_; std::unordered_map<string, string> stateful_placements_ GUARDED_BY(graph_def_lock_);//對於有狀態的節點(比如params和queue),儲存節點名稱到節點所在裝置的映射,一旦這些節點被放置在了某個裝置上,是不允許再移動的 std::unique_ptr<SimpleGraphExecutionState> execution_state_ GUARDED_BY(graph_def_lock_);//放置整張圖時使用 std::unique_ptr<FunctionLibraryDefinition> flib_def_;//在任何的重寫或最佳化之前的函數庫,特別是,CreateGraphs函數會修改函數庫 mutex closed_lock_; bool closed_ GUARDED_BY(closed_lock_) = false;//如果會話已經被關閉,則為true //為這個會話產生唯一的名字 std::atomic<int64> edge_name_counter_ = {0}; std::atomic<int64> handle_name_counter_ = {0}; static std::atomic_int_fast64_t step_id_counter_;//為所有的會話產生唯一的step id const int64 operation_timeout_in_ms_ = 0;//全域對阻塞操作的逾時閾值 CostModelManager cost_model_manager_;//為當前會話中執行的圖管理所有的損失模型}
可見,DirectSession裡面的很多內容都是為部分執行準備的。由於計算圖僅是一個計算的規劃,我們可以通過為同一張圖選取不同的輸入和輸出,來執行不同的計算。而不同的計算需要不同的執行器,也需要不同的儲存結構來儲存各個計算的目前狀態。為此,TF專門給出了幾個結構體,首先我們來看一下對不同計算執行器的封裝:
//為每一個partition準備的執行器和函數執行階段程式庫struct PerPartionExecutorAndLib { Graph* graph = nullptr; std::unique_ptr<FunctionLibraryRuntime> flib; std::unique_ptr<Executor> executor;};//為每一次計算提供的資料結構struct ExecutorsAndKeys { std::atomic_int_fast64_t step_count; std::unique_ptr<Graph> graph; NameNodeMap name_to_node; std::unique_ptr<FunctionLibraryDefinition> flib_def; std::vector<PerPartitionExecutorsAndLib> items; std::unordered_map<string, size_t> input_name_to_index; std::unordered_map<string, string> input_name_to_rendezvous_key; std::unordered_map<string, size_t> output_name_to_index; std::unordered_map<string, string> output_name_to_rendezvous_key; DataTypeVector input_types; DataTypeVector output_types;};
對於一張計算圖來說,我們的每一次計算的執行,不論是完整圖的計算還是部分圖的計算,都有可能是跨裝置的,因此都需要先做節點放置,把圖的節點分割到不同的裝置上,每一個裝置上放置了一個圖的partition,每個partition有對應的運行時函數庫和執行器。而對於每一種計算來說,我們需要一個vector把不同partition的資訊儲存起來。
另外,剛才提到我們還需要為每一次計算提供儲存目前狀態的結構,下面就來看一下:
//對於每一個partition內的執行,會話儲存了一個RunStatestruct RunState { mutex mu_; Status status GUARDED_BY(mu_); IntraProcessRendezvous* rendez = nullptr; std::unique_ptr<StepStatsCollector> collector; Notification executors_done; std::unordered_map<string, bool> pending_inputs;//如果已經提供了輸入,則為true std::unordered_map<string, bool> pending_outputs;//如果已經獲得了輸出,則為true TensorStore tensor_store; ScopedStepContainer step-container; //...};struct RunStateArgs { RunStateArgs(const DebugOption& options) : debug_options(options) {} bool is_partial_run = false; string handle; std::unique_ptr<Graph> graph; const DebugOptions& debug_options;};
其中,RunState為每一個partition的執行提供了狀態儲存的功能,而RunStateArgs則為前者提供了用於調試的參數和配置。
2.2 direct_session.cc
在源檔案裡,給出了DirectSessionFactory的定義,它提供了對於DirectSession進行產生和管理的功能,簡要摘錄如下:
class DirectSessionFactory : public SessionFactory { public: Session* NewSession(const SessionOptions& options) override; Status Reset(...) override; void Deregister(const DirectSession* session); private: mutex session_lock_; std::vector<DirectSession*> session_ GUARDED_BY(sessions_lock_);//用於儲存產生的DirectSession};
另外,還提供了一個對於直接工廠註冊的類:
class DirectSessionRegistrar { public: DirectSessionRegistrar() { SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory()); }};static DirectSessionRegistrar registrar;
下面,我們會按照順序對DirectSession內重要的函數,進行拆解,由於部分函數細節比較多,除了核心代碼之外,我們僅給出功能解釋:
DirectSession::DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr, DirectSessionFactory* const factory){ //根據options準備線程池 //根據device_mgr準備device_和device_set_和每個裝置的op_segment()}Status DirectSession::Run(...){ //提取對於當前會話的本次啟動並執行輸入的名稱 //檢查對於所需的輸入輸出,是否已經存在現成的執行器 //構造一個調用幀(call frame),方便會話與執行器之間傳遞輸入和輸出 //建立一個運行時狀態的結構(RunState) //開始並存執行,核心代碼如下 for(const auto& item : executors_and_keys->items){ item.executor->RunAsync(args, barrier->Get()); } //擷取輸出 //儲存本次運行中我們希望儲存的輸出張量 //建立並返回損失模型(cost model) //如果RunOptions中有相關配置,輸出分割後的圖}Status DirectSession::GetOrCreateExecutors(...){ //快速尋找路徑 //慢尋找路徑,對輸入和輸出做排序,使得相同輸入和輸出集合會得到相同的簽名 //如果未找到,則建立這個執行器並緩衝 //構建執行圖,核心代碼如下 CreateGraphs(options, &graphs, &ek->flib_def, run_state_args, &ek->input_types, &ek->output_types)); //為各子圖準備運行時環境}Status DirectSession::CreateGraphs(...){ //前期預先處理 //圖分割演算法,核心代碼如下 Partition(popts, &client_graph->graph, &partitions); //檢查分割結果的有效性 //圖最佳化遍曆,核心代碼如下 OptimizationPassRegistry::Global()->RunGrouping(OptimizationPassRegistry::POST_PARTITIONING, optimization_options); //允許裝置重寫它擁有的子圖}
可見,具體的執行過程是在Run函數內部,調用executor->RunAsync函數來實現的,在具體執行之前,我們還需要通過GetOrCreateExecutors函數獲得執行器,在這個函數內部,我們通過CreateGraphs函數對原圖進行了分割,並利用圖最佳化遍曆演算法對圖進行了最佳化。