【譯】使用 Python 編寫虛擬機解釋器

jopen 9年前發布 | 12K 次閱讀 Python

原文地址: Making a simple VM interpreter in Python

更新:根據大家的評論我對代碼做了輕微的改動。感謝 robin-gvx、 bs4h 和 Dagur,具體代碼見 這里

Stack Machine 本身并沒有任何的寄存器,它將所需要處理的值全部放入堆棧中而后進行處理。Stack Machine 雖然簡單但是卻十分強大,這也是為神馬 Python,Java,PostScript,Forth 和其他語言都選擇它作為自己的虛擬機的原因。

首先,我們先來談談堆棧。我們需要一個指令指針棧用于保存返回地址。這樣當我們調用了一個子例程(比如調用一個函數)的時候我們就能夠返回到我們開始調用的地方了。我們可以使用自修改代碼( self-modifying code )來做這件事,恰如 Donald Knuth 發起的 MIX 所做的那樣。但是如果這么做的話你不得不自己維護堆棧從而保證遞歸能正常工作。在這篇文章中,我并不會真正的實現子例程調用,但是要實現它其實并不難(可以考慮把實現它當成練習)。

有了堆棧之后你會省很多事兒。舉個例子來說,考慮這樣一個表達式 (2+3)*4 。在 Stack Machine 上與這個表達式等價的代碼為 2 3 + 4 * 。首先,將 2 和 3 推入堆棧中,接下來的是操作符 + ,此時讓堆棧彈出這兩個數值,再把它兩加合之后的結果重新入棧。然后將 4 入堆,而后讓堆棧彈出兩個數值,再把他們相乘之后的結果重新入棧。多么簡單啊!

讓我們開始寫一個簡單的堆棧類吧。讓這個類繼承 collections.deque :

from collections import deque

class Stack(deque):
push = deque.append

def top(self):
    return self[-1]

現在我們有了 push 、 pop 和 top 這三個方法。 top 方法用于查看棧頂元素。

接下來,我們實現虛擬機這個類。在虛擬機中我們需要兩個堆棧以及一些內存空間來存儲程序本身(譯者注:這里的程序請結合下文理解)。得益于 Pyhton 的動態類型我們可以往 list 中放入任何類型。唯一的問題是我們無法區分出哪些是字符串哪些是內置函數。正確的做法是只將真正的 Python 函數放入 list 中。我可能會在將來實現這一點。

我們同時還需要一個指令指針指向程序中下一個要執行的代碼。

class Machine:
def __init__(self, code):
    self.data_stack = Stack()
    self.return_addr_stack = Stack()
    self.instruction_pointer = 0
    self.code = code

這時候我們增加一些方便使用的函數省得以后多敲鍵盤。

def pop(self):
    return self.data_stack.pop()
def push(self, value):
    self.data_stack.push(value)
def top(self):
    return self.data_stack.top()

然后我們增加一個 dispatch 函數來完成每一個操作碼做的事兒(我們并不是真正的使用操作碼,只是動態展開它,你懂的)。首先,增加一個解釋器所必須的循環:

def run(self):
    while self.instruction_pointer < len(self.code):
        opcode = self.code[self.instruction_pointer]
        self.instruction_pointer += 1
        self.dispatch(opcode)

誠如您所見的,這貨只好好的做一件事兒,即獲取下一條指令,讓指令指針執自增,然后根據操作碼分別處理。 dispatch 函數的代碼稍微長了一點。

def dispatch(self, op):
dispatch_map = {
    "%":        self.mod,
    "*":        self.mul,
    "+":        self.plus,
    "-":        self.minus,
    "/":        self.div,
    "==":      self.eq,
    "cast_int": self.cast_int,
    "cast_str": self.cast_str,
    "drop":  self.drop,
    "dup":    self.dup,
    "if":      self.if_stmt,
    "jmp":    self.jmp,
    "over":  self.over,
    "print":    self.print_,
    "println":  self.println,
    "read":  self.read,
    "stack":    self.dump_stack,
    "swap":  self.swap,
}
if op in dispatch_map:
    dispatch_map[op]()
elif isinstance(op, int):
    # push numbers on the data stack
    self.push(op)
elif isinstance(op, str) and op[0]==op[-1]=='"':
    # push quoted strings on the data stack
    self.push(op[1:-1])
else:
    raise RuntimeError("Unknown opcode: '%s'" % op)

基本上,這段代碼只是根據操作碼查找是都有對應的處理函數,例如 * 對應 self.mul , drop 對應 self.drop , dup 對應 self.dup 。順便說一句,你在這里看到的這段代碼其實本質上就是簡單版的 Forth 。而且, Forth 語言還是值得您看看的。

總之捏,它一但發現操作碼是 * 的話就直接調用 self.mul 并執行它。就像這樣:

def mul(self):
    self.push(self.pop() * self.pop())

其他的函數也是類似這樣的。如果我們在 dispatch_map 中查找不到相應操作函數,我們首先檢查他是不是數字類型,如果是的話直接入棧;如果是被引號括起來的字符串的話也是同樣處理--直接入棧。

截止現在,恭喜你,一個虛擬機就完成了。

讓我們定義更多的操作,然后使用我們剛完成的虛擬機和 p-code 語言來寫程序。

# Allow to use "print" as a name for our own method:
from __future__ import print_function
# ...
def plus(self):
    self.push(self.pop() + self.pop())
def minus(self):
    last = self.pop()
    self.push(self.pop() - last)
def mul(self):
    self.push(self.pop() * self.pop())
def div(self):
    last = self.pop()
    self.push(self.pop() / last)
def print(self):
    sys.stdout.write(str(self.pop()))
    sys.stdout.flush()
def println(self):
    sys.stdout.write("%s\n" % self.pop())
    sys.stdout.flush()

讓我們用我們的虛擬機寫個與 print((2+3)*4) 等同效果的例子。

Machine([2, 3, "+", 4, "*", "println"]).run()你可以試著運行它。

現在引入一個新的操作 jump , 即 go-to 操作

def jmp(self):
    addr = self.pop()
    if isinstance(addr, int) and 0 <= addr < len(self.code):
        self.instruction_pointer = addr
    else:
        raise RuntimeError("JMP address must be a valid integer.")

它只改變指令指針的值。我們再看看分支跳轉是怎么做的。

def if_stmt(self):
    false_clause = self.pop()
    true_clause = self.pop()
    test = self.pop()
    self.push(true_clause if test else false_clause)

這同樣也是很直白的。如果你想要添加一個條件跳轉,你只要簡單的執行 test-value true-value false-value IF JMP 就可以了.(分支處理是很常見的操作,許多虛擬機都提供類似 JNE 這樣的操作。 JNE 是 jump if not equal 的縮寫)。

下面的程序要求使用者輸入兩個數字,然后打印出他們的和和乘積。

Machine([
'"Enter a number: "', "print", "read", "cast_int",
'"Enter another number: "', "print", "read", "cast_int",
"over", "over",
'"Their sum is: "', "print", "+", "println",
'"Their product is: "', "print", "*", "println"
]).run()

over 、 read 和 cast_int 這三個操作是長這樣滴:

def cast_int(self):
    self.push(int(self.pop()))
def over(self):
    b = self.pop()
    a = self.pop()
    self.push(a)
    self.push(b)
    self.push(a)
def read(self):
    self.push(raw_input())

以下這一段程序要求使用者輸入一個數字,然后打印出這個數字是奇數還是偶數。

Machine([
'"Enter a number: "', "print", "read", "cast_int",
'"The number "', "print", "dup", "print", '" is "', "print",
2, "%", 0, "==", '"even."', '"odd."', "if", "println",
0, "jmp" # loop forever!
]).run()

這里有個小練習給你去實現:增加 call 和 return 這兩個操作碼。 call 操作碼將會做如下事情 :將當前地址推入返回堆棧中,然后調用 self.jmp() 。 return 操作碼將會做如下事情:返回堆棧彈棧,將彈棧出來元素的值賦予指令指針(這個值可以讓你跳轉回去或者從 call 調用中返回)。當你完成這兩個命令,那么你的虛擬機就可以調用子例程了。

一個簡單的解析器

創造一個模仿上述程序的小型語言。我們將把它編譯成我們的機器碼。

 import tokenize
 from StringIO import StringIO

# ...

def parse(text):
tokens =   tokenize.generate_tokens(StringIO(text).readline)
for toknum, tokval, _, _, _ in tokens:
    if toknum == tokenize.NUMBER:
        yield int(tokval)
    elif toknum in [tokenize.OP, tokenize.STRING, tokenize.NAME]:
        yield tokval
    elif toknum == tokenize.ENDMARKER:
        break
    else:
        raise RuntimeError("Unknown token %s: '%s'" %
                (tokenize.tok_name[toknum], tokval))

一個簡單的優化:常量折疊

常量折疊( Constant folding )是窺孔優化( peephole optimization )的一個例子,也即是說再在編譯期間可以針對某些明顯的代碼片段做些預計算的工作。比如,對于涉及到常量的數學表達式例如 2 3 + 就可以很輕松的實現這種優化。

def constant_fold(code):
"""Constant-folds simple mathematical expressions like 2 3 + to 5."""
while True:
    # Find two consecutive numbers and an arithmetic operator
    for i, (a, b, op) in enumerate(zip(code, code[1:], code[2:])):
    if isinstance(a, int) and isinstance(b, int) \
        and op in {"+", "-", "*", "/"}:
        m = Machine((a, b, op))
        m.run()
        code[i:i+3] = [m.top()]
        print("Constant-folded %s%s%s to %s" % (a,op,b,m.top()))
        break
    else:
    break
return code

采用常量折疊遇到唯一問題就是我們不得不更新跳轉地址,但在很多情況這是很難辦到的(例如: test cast_int jmp )。針對這個問題有很多解決方法,其中一個簡單的方法就是只允許跳轉到程序中的命名標簽上,然后在優化之后解析出他們真正的地址。

如果你實現了 Forth words ,也即函數,你可以做更多的優化,比如刪除可能永遠不會被用到的程序代碼( dead code elimination

REPL

我們可以創造一個簡單的 PERL,就像這樣

def repl():
print('Hit CTRL+D or type "exit" to quit.')
while True:
    try:
        source = raw_input("> ")
        code = list(parse(source))
        code = constant_fold(code)
        Machine(code).run()
    except (RuntimeError, IndexError) as e:
        print("IndexError: %s" % e)
    except KeyboardInterrupt:
        print("\nKeyboardInterrupt")

用一些簡單的程序來測試我們的 REPL

> 2 3 + 4 * println
Constant-folded 2+3 to 5
Constant-folded 5*4 to 20
20
> 12 dup * println
144
> "Hello, world!" dup println println
Hello, world!
Hello, world!
你可以看到,常量折疊看起來運轉正常。在第一個例子中,它把整個程序優化成這樣 20 println。

下一步

當你添加完 call 和 return 之后,你便可以讓使用者定義自己的函數了。在 Forth 中函數被稱為 words,他們以冒號開頭緊接著是名字然后以分號結束。例如,一個整數平方的 word 是長這樣滴

: square dup * ;

實際上,你可以試試把這一段放在程序中,比如 Gforth

$ gforth
Gforth 0.7.3, Copyright (C) 1995-2008 Free Software Foundation, Inc.
Gforth comes with ABSOLUTELY NO WARRANTY; for details type `license'
Type `bye' to exit
: square dup * ;  ok
12 square . 144  ok

你可以在解析器中通過發現 : 來支持這一點。一旦你發現一個冒號,你必須記錄下它的名字及其地址(比如:在程序中的位置)然后把他們插入到符號表( symbol table )中。簡單起見,你甚至可以把整個函數的代碼(包括分號)放在字典中,譬如:

symbol_table = {
"square": ["dup", "*"]
# ...
    }

當你完成了解析的工作,你可以 連接 你的程序:遍歷整個主程序并且在符號表中尋找自定義函數的地方。一旦你找到一個并且它沒有在主程序的后面出現,那么你可以把它附加到主程序的后面。然后用 <address> call 替換掉 square ,這里的 <address> 是函數插入的地址。

為了保證程序能正常執行,你應該考慮剔除 jmp 操作。否則的話,你不得不解析它們。它確實能執行,但是你得按照用戶編寫程序的順序保存它們。舉例來說,你想在子例程之間移動,你要格外小心。你可能需要添加 exit 函數用于停止程序(可能需要告訴操作系統返回值),這樣主程序就不會繼續執行以至于跑到子例程中。

實際上,一個好的程序空間布局很有可能把主程序當成一個名為 main 的子例程。或者由你決定搞成什么樣子。

如您所見,這一切都是很有趣的,而且通過這一過程你也學會了很多關于代碼生成、鏈接、程序空間布局相關的知識。

更多能做的事兒

你可以使用 Python 字節碼生成庫來嘗試將虛擬機代碼為原生的 Python 字節碼。或者用 Java 實現運行在 JVM 上面,這樣你就可以自由使用 JITing

同樣的,你也可以嘗試下 register machine 。你可以嘗試用棧幀( stack frames )實現調用棧( call stack ),并基于此建立調用會話。

最后,如果你不喜歡類似 Forth 這樣的語言,你可以創造運行于這個虛擬機之上的自定義語言。譬如,你可以把類似 (2+3)*4 這樣的中綴表達式轉化成 2 3 + 4 * 然后生成代碼。你也可以允許 C 風格的代碼塊 { ... } 這樣的話,語句 if ( test ) { ... } else { ... } 將會被翻譯成

<true/false test>
<address of true block>
<address of false block>
if
jmp

<true block>
<address of end of entire if-statement> jmp

<false block>
<address of end of entire if-statement> jmp

例子,

Address  Code
-------  ----
 0       2 3 >
 3       7        # Address of true-block
 4       11       # Address of false-block
 5       if
 6       jmp      # Conditional jump based on test

# True-block

7     "Two is greater than three."  
8       println
9       15       # Continue main program
10       jmp

# False-block ("else { ... }")
11       "Two is less than three."
12       println
13       15       # Continue main program
14       jmp

# If-statement finished, main program continues here
15       ...

對了,你還需要添加比較操作符 != < <= > >= 。

我已經在我的 C++ stack machine 實現了這些東東,你可以參考下。

我已經把這里呈現出來的代碼搞成了個項目 Crianza ,它使用了更多的優化和實驗性質的模型來吧程序編譯成 Python 字節碼。

祝好運!

完整的代碼

下面是全部的代碼,兼容 Python 2 和 Python 3

你可以通過 這里 得到它。

#!/usr/bin/env python
# coding: utf-8
"""
A simple VM interpreter.
Code from the post at http://csl.name/post/vm/
This version should work on both Python 2 and 3.
"""
from __future__ import print_function
from collections import deque
from io import StringIO
import sys
import tokenize
def get_input(*args, **kw):
"""Read a string from standard input."""
if sys.version[0] == "2":
    return raw_input(*args, **kw)
else:
    return input(*args, **kw)
class Stack(deque):
push = deque.append
def top(self):
    return self[-1]
class Machine:
def __init__(self, code):
    self.data_stack = Stack()
    self.return_stack = Stack()
    self.instruction_pointer = 0
    self.code = code
def pop(self):
    return self.data_stack.pop()
def push(self, value):
    self.data_stack.push(value)
def top(self):
    return self.data_stack.top()
def run(self):
    while self.instruction_pointer < len(self.code):
        opcode = self.code[self.instruction_pointer]
        self.instruction_pointer += 1
        self.dispatch(opcode)
def dispatch(self, op):
    dispatch_map = {
        "%":        self.mod,
        "*":        self.mul,
        "+":        self.plus,
        "-":        self.minus,
        "/":        self.div,
        "==":      self.eq,
        "cast_int": self.cast_int,
        "cast_str": self.cast_str,
        "drop":  self.drop,
        "dup":    self.dup,
        "exit":  self.exit,
        "if":      self.if_stmt,
        "jmp":    self.jmp,
        "over":  self.over,
        "print":    self.print,
        "println":  self.println,
        "read":  self.read,
        "stack":    self.dump_stack,
        "swap":  self.swap,
    }
    if op in dispatch_map:
        dispatch_map[op]()
    elif isinstance(op, int):
        self.push(op) # push numbers on stack
    elif isinstance(op, str) and op[0]==op[-1]=='"':
        self.push(op[1:-1]) # push quoted strings on stack
    else:
        raise RuntimeError("Unknown opcode: '%s'" % op)
# OPERATIONS FOLLOW:
def plus(self):
    self.push(self.pop() + self.pop())
def exit(self):
    sys.exit(0)
def minus(self):
    last = self.pop()
    self.push(self.pop() - last)
def mul(self):
    self.push(self.pop() * self.pop())
def div(self):
    last = self.pop()
    self.push(self.pop() / last)
def mod(self):
    last = self.pop()
    self.push(self.pop() % last)
def dup(self):
    self.push(self.top())
def over(self):
    b = self.pop()
    a = self.pop()
    self.push(a)
    self.push(b)
    self.push(a)
def drop(self):
    self.pop()
def swap(self):
    b = self.pop()
    a = self.pop()
    self.push(b)
    self.push(a)
def print(self):
    sys.stdout.write(str(self.pop()))
    sys.stdout.flush()
def println(self):
    sys.stdout.write("%s\n" % self.pop())
    sys.stdout.flush()
def read(self):
    self.push(get_input())
def cast_int(self):
    self.push(int(self.pop()))
def cast_str(self):
    self.push(str(self.pop()))
def eq(self):
    self.push(self.pop() == self.pop())
def if_stmt(self):
    false_clause = self.pop()
    true_clause = self.pop()
    test = self.pop()
    self.push(true_clause if test else false_clause)
def jmp(self):
    addr = self.pop()
    if isinstance(addr, int) and 0 <= addr < len(self.code):
        self.instruction_pointer = addr
    else:
        raise RuntimeError("JMP address must be a valid integer.")
def dump_stack(self):
    print("Data stack (top first):")
    for v in reversed(self.data_stack):
        print(" - type %s, value '%s'" % (type(v), v))
def parse(text):
# Note that the tokenizer module is intended for parsing Python source
# code, so if you're going to expand on the parser, you may have to use
# another tokenizer.
if sys.version[0] == "2":
    stream = StringIO(unicode(text))
else:
    stream = StringIO(text)
tokens = tokenize.generate_tokens(stream.readline)
for toknum, tokval, _, _, _ in tokens:
    if toknum == tokenize.NUMBER:
        yield int(tokval)
    elif toknum in [tokenize.OP, tokenize.STRING, tokenize.NAME]:
        yield tokval
    elif toknum == tokenize.ENDMARKER:
        break
    else:
        raise RuntimeError("Unknown token %s: '%s'" %
                (tokenize.tok_name[toknum], tokval))
def constant_fold(code):
"""Constant-folds simple mathematical expressions like 2 3 + to 5."""
while True:
    # Find two consecutive numbers and an arithmetic operator
    for i, (a, b, op) in enumerate(zip(code, code[1:], code[2:])):
        if isinstance(a, int) and isinstance(b, int) \
                and op in {"+", "-", "*", "/"}:
            m = Machine((a, b, op))
            m.run()
            code[i:i+3] = [m.top()]
            print("Constant-folded %s%s%s to %s" %  (a,op,b,m.top()))
            break
        else:
            break
        return code
def repl():
print('Hit CTRL+D or type "exit" to quit.')
while True:
    try:
        source = get_input("> ")
        code = list(parse(source))
        code = constant_fold(code)
        Machine(code).run()
    except (RuntimeError, IndexError) as e:
        print("IndexError: %s" % e)
    except KeyboardInterrupt:
        print("\nKeyboardInterrupt")
def test(code = [2, 3, "+", 5, "*", "println"]):
print("Code before optimization: %s" % str(code))
optimized = constant_fold(code)
print("Code after optimization: %s" % str(optimized))
print("Stack after running original program:")
a = Machine(code)
a.run()
a.dump_stack()
print("Stack after running optimized program:")
b = Machine(optimized)
b.run()
b.dump_stack()
result = a.data_stack == b.data_stack
print("Result: %s" % ("OK" if result else "FAIL"))
return result
def examples():
print("** Program 1: Runs the code for `print((2+3)*4)`")
Machine([2, 3, "+", 4, "*", "println"]).run()
print("\n** Program 2: Ask for numbers, computes sum and product.")
Machine([
    '"Enter a number: "', "print", "read", "cast_int",
    '"Enter another number: "', "print", "read", "cast_int",
    "over", "over",
    '"Their sum is: "', "print", "+", "println",
    '"Their product is: "', "print", "*", "println"
]).run()
print("\n** Program 3: Shows branching and looping (use CTRL+D to exit).")
Machine([
    '"Enter a number: "', "print", "read", "cast_int",
    '"The number "', "print", "dup", "print", '" is "', "print",
    2, "%", 0, "==", '"even."', '"odd."', "if", "println",
    0, "jmp" # loop forever!
]).run()
if __name__ == "__main__":
try:
    if len(sys.argv) > 1:
        cmd = sys.argv[1]
        if cmd == "repl":
            repl()
        elif cmd == "test":
            test()
            examples()
        else:
            print("Commands: repl, test")
    else:
        repl()
except EOFError:
    print("")
    

本文系 OneAPM 工程師編譯整理。OneAPM是中國基礎軟件領域的新興領軍企業,能幫助企業用戶和開發者輕松實現:緩慢的程序代碼和SQL語句的實時抓取。想閱讀更多技術文章,請訪問OneAPM 官方技術博客

 本文由用戶 jopen 自行上傳分享,僅供網友學習交流。所有權歸原作者,若您的權利被侵害,請聯系管理員。
 轉載本站原創文章,請注明出處,并保留原始鏈接、圖片水印。
 本站是一個以用戶分享為主的開源技術平臺,歡迎各類分享!