標籤:
原文地址:[Making a simple VM interpreter in Python](https://csl.name/post/vm/)**更新:根據大家的評論我對代碼做了輕微的改動。感謝 robin-gvx、 bs4h 和 Dagur,具體代碼見[這裡](https://github.com/cslarsen/python-simple-vm)**Stack Machine 本身並沒有任何的寄存器,它將所需要處理的值全部放入堆棧中而後進行處理。Stack Machine 雖然簡單但是卻十分強大,這也是為神馬 Python,Java,PostScript,Forth 和其他語言都選擇它作為自己的虛擬機器的原因。首先,我們先來談談堆棧。我們需要一個指令指標棧用於儲存返回地址。這樣當我們調用了一個子常式(比如調用一個函數)的時候我們就能夠返回到我們開始調用的地方了。我們可以使用自修改代碼([self-modifying code](https://en.wikipedia.org/wiki/Self-modifying_code))來做這件事,恰如 Donald Knuth 發起的 [MIX](https://en.wikipedia.org/wiki/MIX) 所做的那樣。但是如果這麼做的話你不得不自己維護堆棧從而保證遞迴能正常工作。在這篇文章中,我並不會真正的實現子常式調用,但是要實現它其實並不難(可以考慮把實現它當成練習)。有了堆棧之後你會省很多事兒。舉個例子來說,考慮這樣一個運算式`(2+3)*4`。在 `Stack Machine` 上與這個運算式等價的代碼為 `2 3 + 4 *`。首先,將 `2` 和 `3` 推入堆棧中,接下來的是操作符 `+`,此時讓堆棧彈出這兩個數值,再把它兩加合之後的結果重新入棧。然後將 `4` 入堆,而後讓堆棧彈出兩個數值,再把他們相乘之後的結果重新入棧。多麼簡單啊!讓我們開始寫一個簡單的堆棧類吧。讓這個類繼承 `collections.deque`:from collections import dequeclass Stack(deque): push = deque.append def top(self): return self[-1]現在我們有了 `push`、`pop` 和 `top` 這三個方法。`top` 方法用於查看棧頂元素。接下來,我們實現虛擬機器這個類。在虛擬機器中我們需要兩個堆棧以及一些記憶體空間來儲存程式本身(譯者註:這裡的程式請結合下文理解)。得益於 Pyhton 的動態類型我們可以往 list 中放入任何類型。唯一的問題是我們無法區分出哪些是字串哪些是內建函數。正確的做法是只將真正的 Python 函數放入 list 中。我可能會在將來實現這一點。我們同時還需要一個指令指標指向程式中下一個要執行的代碼。class Machine: def __init__(self, code): self.data_stack = Stack() self.return_addr_stack = Stack() self.instruction_pointer = 0 self.code = code這時候我們增加一些方便使用的函數省得以後多敲鍵盤。 def pop(self): return self.data_stack.pop() def push(self, value): self.data_stack.push(value) def top(self): return self.data_stack.top()然後我們增加一個 `dispatch` 函數來完成每一個作業碼做的事兒(我們並不是真正的使用作業碼,只是動態展開它,你懂的)。首先,增加一個解譯器所必須的迴圈: def run(self): while self.instruction_pointer < len(self.code): opcode = self.code[self.instruction_pointer] self.instruction_pointer += 1 self.dispatch(opcode)誠如您所見的,這貨只好好的做一件事兒,即擷取下一條指令,讓指令指標執自增,然後根據作業碼分別處理。`dispatch` 函數的代碼稍微長了一點。 def dispatch(self, op): dispatch_map = { "%": self.mod, "*": self.mul, "+": self.plus, "-": self.minus, "/": self.div, "==": self.eq, "cast_int": self.cast_int, "cast_str": self.cast_str, "drop": self.drop, "dup": self.dup, "if": self.if_stmt, "jmp": self.jmp, "over": self.over, "print": self.print_, "println": self.println, "read": self.read, "stack": self.dump_stack, "swap": self.swap, } if op in dispatch_map: dispatch_map[op]() elif isinstance(op, int): # push numbers on the data stack self.push(op) elif isinstance(op, str) and op[0]==op[-1]==‘"‘: # push quoted strings on the data stack self.push(op[1:-1]) else: raise RuntimeError("Unknown opcode: ‘%s‘" % op)基本上,這段代碼只是根據作業碼尋找是都有對應的處理函數,例如 `*` 對應 `self.mul`,`drop` 對應 `self.drop`,`dup`對應 `self.dup`。順便說一句,你在這裡看到的這段代碼其實本質上就是簡單版的 `Forth`。而且,`Forth` 語言還是值得您看看的。總之捏,它一但發現作業碼是 `*` 的話就直接調用 `self.mul` 並執行它。就像這樣: def mul(self): self.push(self.pop() * self.pop())其他的函數也是類似這樣的。如果我們在 `dispatch_map` 中尋找不到相應操作函數,我們首先檢查他是不是數字類型,如果是的話直接入棧;如果是被引號括起來的字串的話也是同樣處理--直接入棧。截止現在,恭喜你,一個虛擬機器就完成了。讓我們定義更多的操作,然後使用我們剛完成的虛擬機器和 [p-code](https://en.wikipedia.org/wiki/P-code_machine) 語言來寫程式。 # Allow to use "print" as a name for our own method: from __future__ import print_function # ... def plus(self): self.push(self.pop() + self.pop()) def minus(self): last = self.pop() self.push(self.pop() - last) def mul(self): self.push(self.pop() * self.pop()) def div(self): last = self.pop() self.push(self.pop() / last) def print(self): sys.stdout.write(str(self.pop())) sys.stdout.flush() def println(self): sys.stdout.write("%s\n" % self.pop()) sys.stdout.flush()讓我們用我們的虛擬機器寫個與 `print((2+3)*4)` 等同效果的例子。Machine([2, 3, "+", 4, "*", "println"]).run()你可以試著運行它。現在引入一個新的操作 `jump`, 即 `go-to` 操作 def jmp(self): addr = self.pop() if isinstance(addr, int) and 0 <= addr < len(self.code): self.instruction_pointer = addr else: raise RuntimeError("JMP address must be a valid integer.")它只改變指令指標的值。我們再看看分支跳轉是怎麼做的。 def if_stmt(self): false_clause = self.pop() true_clause = self.pop() test = self.pop() self.push(true_clause if test else false_clause)這同樣也是很直白的。如果你想要添加一個條件跳轉,你只要簡單的執行 `test-value true-value false-value IF JMP` 就可以了.(分支處理是很常見的操作,許多虛擬機器都提供類似 `JNE` 這樣的操作。`JNE` 是 `jump if not equal` 的縮寫)。下面的程式要求使用者輸入兩個數字,然後列印出他們的和和乘積。Machine([ ‘"Enter a number: "‘, "print", "read", "cast_int", ‘"Enter another number: "‘, "print", "read", "cast_int", "over", "over", ‘"Their sum is: "‘, "print", "+", "println", ‘"Their product is: "‘, "print", "*", "println"]).run()`over`、`read` 和 `cast_int` 這三個操作是長這樣滴: def cast_int(self): self.push(int(self.pop())) def over(self): b = self.pop() a = self.pop() self.push(a) self.push(b) self.push(a) def read(self): self.push(raw_input())以下這一段程式要求使用者輸入一個數字,然後列印出這個數字是奇數還是偶數。Machine([ ‘"Enter a number: "‘, "print", "read", "cast_int", ‘"The number "‘, "print", "dup", "print", ‘" is "‘, "print", 2, "%", 0, "==", ‘"even."‘, ‘"odd."‘, "if", "println", 0, "jmp" # loop forever!]).run()這裡有個小練習給你去實現:增加 `call` 和 `return` 這兩個作業碼。`call` 作業碼將會做如下事情 :將當前地址推入返回堆棧中,然後調用 `self.jmp()`。`return` 作業碼將會做如下事情:返回堆棧彈棧,將彈棧出來元素的值賦予指令指標(這個值可以讓你跳回去或者從 `call` 調用中返回)。當你完成這兩個命令,那麼你的虛擬機器就可以調用子常式了。##一個簡單的解析器創造一個模仿上述程式的小型語言。我們將把它編譯成我們的機器碼。 import tokenize from StringIO import StringIO # ...def parse(text): tokens = tokenize.generate_tokens(StringIO(text).readline) for toknum, tokval, _, _, _ in tokens: if toknum == tokenize.NUMBER: yield int(tokval) elif toknum in [tokenize.OP, tokenize.STRING, tokenize.NAME]: yield tokval elif toknum == tokenize.ENDMARKER: break else: raise RuntimeError("Unknown token %s: ‘%s‘" % (tokenize.tok_name[toknum], tokval))## 一個簡單的最佳化:常量摺疊常量摺疊([Constant folding](https://en.wikipedia.org/wiki/Constant_folding))是窺孔最佳化([peephole optimization](https://en.wikipedia.org/wiki/Peephole_optimization))的一個例子,也即是說再在編譯期間可以針對某些明顯的程式碼片段做些預計算的工作。比如,對於涉及到常量的數學運算式例如 `2 3 +`就可以很輕鬆的實現這種最佳化。def constant_fold(code): """Constant-folds simple mathematical expressions like 2 3 + to 5.""" while True: # Find two consecutive numbers and an arithmetic operator for i, (a, b, op) in enumerate(zip(code, code[1:], code[2:])): if isinstance(a, int) and isinstance(b, int) and op in {"+", "-", "*", "/"}: m = Machine((a, b, op)) m.run() code[i:i+3] = [m.top()] print("Constant-folded %s%s%s to %s" % (a,op,b,m.top())) break else: break return code採用常量摺疊遇到唯一問題就是我們不得不更新跳轉地址,但在很多情況這是很難辦到的(例如:`test cast_int jmp`)。針對這個問題有很多解決方案,其中一個簡單的方法就是只允許跳轉到程式中的命名標籤上,然後在最佳化之後解析出他們真正的地址。如果你實現了 `Forth words`,也即函數,你可以做更多的最佳化,比如刪除可能永遠不會被用到的程式碼([dead code elimination](https://en.wikipedia.org/wiki/Dead_code_elimination))## REPL我們可以創造一個簡單的 PERL,就像這樣def repl(): print(‘Hit CTRL+D or type "exit" to quit.‘) while True: try: source = raw_input("> ") code = list(parse(source)) code = constant_fold(code) Machine(code).run() except (RuntimeError, IndexError) as e: print("IndexError: %s" % e) except KeyboardInterrupt: print("\nKeyboardInterrupt")用一些簡單的程式來測試我們的 REPL> 2 3 + 4 * printlnConstant-folded 2+3 to 5Constant-folded 5*4 to 2020> 12 dup * println144> "Hello, world!" dup println printlnHello, world!Hello, world!你可以看到,常量摺疊看起來運轉正常。在第一個例子中,它把整個程式最佳化成這樣 20 println。## 下一步當你添加完 `call` 和 `return` 之後,你便可以讓使用者定義自己的函數了。在[Forth](https://en.wikipedia.org/wiki/Forth_(programming_language)) 中函數被稱為 words,他們以冒號開頭緊接著是名字然後以分號結束。例如,一個整數平方的 word 是長這樣滴: square dup * ;實際上,你可以試試把這一段放在程式中,比如 Gforth$ gforthGforth 0.7.3, Copyright (C) 1995-2008 Free Software Foundation, Inc.Gforth comes with ABSOLUTELY NO WARRANTY; for details type `license‘Type `bye‘ to exit: square dup * ; ok12 square . 144 ok你可以在解析器中通過發現 `:` 來支援這一點。一旦你發現一個冒號,你必須記錄下它的名字及其地址(比如:在程式中的位置)然後把他們插入到符號表([symbol table](https://en.wikipedia.org/wiki/Symbol_table))中。簡單起見,你甚至可以把整個函數的代碼(包括分號)放在字典中,譬如:symbol_table = {"square": ["dup", "*"]# ...}當你完成瞭解析的工作,你可以[串連](https://en.wikipedia.org/wiki/Linker_(computing))你的程式:遍曆整個主程式並且在符號表中尋找自訂函數的地方。一旦你找到一個並且它沒有在主程式的後面出現,那麼你可以把它附加到主程式的後面。然後用 `<address> call` 替換掉 `square`,這裡的 `<address>` 是函數插入的地址。為了保證程式能正常執行,你應該考慮剔除 `jmp` 操作。否則的話,你不得不解析它們。它確實能執行,但是你得按照使用者編寫程式的順序儲存它們。舉例來說,你想在子常式之間移動,你要格外小心。你可能需要添加 `exit` 函數用於停止程式(可能需要告訴作業系統傳回值),這樣主程式就不會繼續執行以至於跑到子常式中。實際上,一個好的程式空間布局很有可能把主程式當成一個名為 `main` 的子常式。或者由你決定搞成什麼樣子。如您所見,這一切都是很有趣的,而且通過這一過程你也學會了很多關於代碼產生、連結、程式空間布局相關的知識。## 更多能做的事兒你可以使用 Python 位元組碼產生庫來嘗試將虛擬機器代碼為原生的 Python 位元組碼。或者用 Java 實現運行在 JVM 上面,這樣你就可以自由使用 [JITing](https://en.wikipedia.org/wiki/Just-in-time_compilation)。同樣的,你也可以嘗試下[register machine](https://en.wikipedia.org/wiki/Register_machine)。你可以嘗試用棧幀([stack frames](https://en.wikipedia.org/wiki/Call_stack#STACK-FRAME))實現調用棧([call stack](https://en.wikipedia.org/wiki/Call_stack)),並基於此建立調用會話。最後,如果你不喜歡類似 Forth 這樣的語言,你可以創造運行於這個虛擬機器之上的自訂語言。譬如,你可以把類似 `(2+3)*4` 這樣的中綴運算式轉化成 `2 3 + 4 *` 然後產生代碼。你也可以允許 C 風格的代碼塊 `{ ... }` 這樣的話,語句 `if ( test ) { ... } else { ... }` 將會被翻譯成<true/false test><address of true block><address of false block>ifjmp<true block><address of end of entire if-statement> jmp<false block><address of end of entire if-statement> jmp例子,Address Code------- ---- 0 2 3 > 3 7 # Address of true-block 4 11 # Address of false-block 5 if 6 jmp # Conditional jump based on test# True-block 7 "Two is greater than three."8 println9 15 # Continue main program10 jmp# False-block ("else { ... }")11 "Two is less than three."12 println13 15 # Continue main program14 jmp# If-statement finished, main program continues here15 ...對了,你還需要添加比較操作符 `!= < <= > >=`。我已經在我的 [C++ stack machine](https://github.com/cslarsen/stack-machine) 實現了這些東東,你可以參考下。我已經把這裡呈現出來的代碼搞成了個項目 [Crianza](https://github.com/cslarsen/crianza),它使用了更多的最佳化和實驗性質的模型來吧程式編譯成 Python 位元組碼。祝好運!##完整的代碼下面是全部的代碼,相容 Python 2 和 Python 3你可以通過 [這裡](https://github.com/cslarsen/python-simple-vm) 得到它。#!/usr/bin/env python# coding: utf-8"""A simple VM interpreter.Code from the post at http://csl.name/post/vm/This version should work on both Python 2 and 3."""from __future__ import print_functionfrom collections import dequefrom io import StringIOimport sysimport tokenizedef get_input(*args, **kw): """Read a string from standard input.""" if sys.version[0] == "2": return raw_input(*args, **kw) else: return input(*args, **kw)class Stack(deque): push = deque.append def top(self): return self[-1]class Machine: def __init__(self, code): self.data_stack = Stack() self.return_stack = Stack() self.instruction_pointer = 0 self.code = code def pop(self): return self.data_stack.pop() def push(self, value): self.data_stack.push(value) def top(self): return self.data_stack.top() def run(self): while self.instruction_pointer < len(self.code): opcode = self.code[self.instruction_pointer] self.instruction_pointer += 1 self.dispatch(opcode) def dispatch(self, op): dispatch_map = { "%": self.mod, "*": self.mul, "+": self.plus, "-": self.minus, "/": self.div, "==": self.eq, "cast_int": self.cast_int, "cast_str": self.cast_str, "drop": self.drop, "dup": self.dup, "exit": self.exit, "if": self.if_stmt, "jmp": self.jmp, "over": self.over, "print": self.print, "println": self.println, "read": self.read, "stack": self.dump_stack, "swap": self.swap, } if op in dispatch_map: dispatch_map[op]() elif isinstance(op, int): self.push(op) # push numbers on stack elif isinstance(op, str) and op[0]==op[-1]==‘"‘: self.push(op[1:-1]) # push quoted strings on stack else: raise RuntimeError("Unknown opcode: ‘%s‘" % op) # OPERATIONS FOLLOW: def plus(self): self.push(self.pop() + self.pop()) def exit(self): sys.exit(0) def minus(self): last = self.pop() self.push(self.pop() - last) def mul(self): self.push(self.pop() * self.pop()) def div(self): last = self.pop() self.push(self.pop() / last) def mod(self): last = self.pop() self.push(self.pop() % last) def dup(self): self.push(self.top()) def over(self): b = self.pop() a = self.pop() self.push(a) self.push(b) self.push(a) def drop(self): self.pop() def swap(self): b = self.pop() a = self.pop() self.push(b) self.push(a) def print(self): sys.stdout.write(str(self.pop())) sys.stdout.flush() def println(self): sys.stdout.write("%s\n" % self.pop()) sys.stdout.flush() def read(self): self.push(get_input()) def cast_int(self): self.push(int(self.pop())) def cast_str(self): self.push(str(self.pop())) def eq(self): self.push(self.pop() == self.pop()) def if_stmt(self): false_clause = self.pop() true_clause = self.pop() test = self.pop() self.push(true_clause if test else false_clause) def jmp(self): addr = self.pop() if isinstance(addr, int) and 0 <= addr < len(self.code): self.instruction_pointer = addr else: raise RuntimeError("JMP address must be a valid integer.") def dump_stack(self): print("Data stack (top first):") for v in reversed(self.data_stack): print(" - type %s, value ‘%s‘" % (type(v), v))def parse(text): # Note that the tokenizer module is intended for parsing Python source # code, so if you‘re going to expand on the parser, you may have to use # another tokenizer. if sys.version[0] == "2": stream = StringIO(unicode(text)) else: stream = StringIO(text) tokens = tokenize.generate_tokens(stream.readline) for toknum, tokval, _, _, _ in tokens: if toknum == tokenize.NUMBER: yield int(tokval) elif toknum in [tokenize.OP, tokenize.STRING, tokenize.NAME]: yield tokval elif toknum == tokenize.ENDMARKER: break else: raise RuntimeError("Unknown token %s: ‘%s‘" % (tokenize.tok_name[toknum], tokval))def constant_fold(code): """Constant-folds simple mathematical expressions like 2 3 + to 5.""" while True: # Find two consecutive numbers and an arithmetic operator for i, (a, b, op) in enumerate(zip(code, code[1:], code[2:])): if isinstance(a, int) and isinstance(b, int) and op in {"+", "-", "*", "/"}: m = Machine((a, b, op)) m.run() code[i:i+3] = [m.top()] print("Constant-folded %s%s%s to %s" % (a,op,b,m.top())) break else: break return codedef repl(): print(‘Hit CTRL+D or type "exit" to quit.‘) while True: try: source = get_input("> ") code = list(parse(source)) code = constant_fold(code) Machine(code).run() except (RuntimeError, IndexError) as e: print("IndexError: %s" % e) except KeyboardInterrupt: print("\nKeyboardInterrupt")def test(code = [2, 3, "+", 5, "*", "println"]): print("Code before optimization: %s" % str(code)) optimized = constant_fold(code) print("Code after optimization: %s" % str(optimized)) print("Stack after running original program:") a = Machine(code) a.run() a.dump_stack() print("Stack after running optimized program:") b = Machine(optimized) b.run() b.dump_stack() result = a.data_stack == b.data_stack print("Result: %s" % ("OK" if result else "FAIL")) return resultdef examples(): print("** Program 1: Runs the code for `print((2+3)*4)`") Machine([2, 3, "+", 4, "*", "println"]).run() print("\n** Program 2: Ask for numbers, computes sum and product.") Machine([ ‘"Enter a number: "‘, "print", "read", "cast_int", ‘"Enter another number: "‘, "print", "read", "cast_int", "over", "over", ‘"Their sum is: "‘, "print", "+", "println", ‘"Their product is: "‘, "print", "*", "println" ]).run() print("\n** Program 3: Shows branching and looping (use CTRL+D to exit).") Machine([ ‘"Enter a number: "‘, "print", "read", "cast_int", ‘"The number "‘, "print", "dup", "print", ‘" is "‘, "print", 2, "%", 0, "==", ‘"even."‘, ‘"odd."‘, "if", "println", 0, "jmp" # loop forever! ]).run()if __name__ == "__main__": try: if len(sys.argv) > 1: cmd = sys.argv[1] if cmd == "repl": repl() elif cmd == "test": test() examples() else: print("Commands: repl, test") else: repl() except EOFError: print("") **本文系[OneAPM](http://oneapm.com/index.html?utm_source=Common&utm_medium=Articles&utm_campaign=TechnicalArticles&from=matefijuno)工程師編譯整理。OneAPM是中國基礎軟體領域的新興領軍企業,能協助企業使用者和開發人員輕鬆實現:緩慢的程式碼和SQL語句的即時抓取。想閱讀更多技術文章,請訪問OneAPM[官方技術部落格](http://code.oneapm.com/?hmsr=media&hmmd=&hmpl=&hmkw=&hmci=)。**
【譯】使用 Python 編寫虛擬機器解譯器