目錄
- 函數(shù)原型
- 參數(shù)介紹
- mode (torch.nn.Module, torch.jit.ScriptModule or torch.jit.ScriptFunction)
- args (tuple or torch.Tensor)
- f
- export_params (bool, default True)
- verbose (bool, default False)
- training (enum, default TrainingMode.EVAL)
- input_names (list of str, default empty list)
- output_names (list of str, default empty list)
- operator_export_type (enum, default None)
- opset_version (int, default 9)
- do_constant_folding (bool, default False)
- example_outputs (T or a tuple of T, where T is Tensor or convertible to Tensor, default None)
- dynamic_axes (dict<string, dict<python:int, string>> or dict<string, list(int)>, default empty dict)
- keep_initializers_as_inputs (bool, default None)
- custom_opsets (dict<str, int>, default empty dict)
- Torch.onnx.export執(zhí)行流程:
- 總結(jié)
函數(shù)原型
參數(shù)介紹
mode (torch.nn.Module, torch.jit.ScriptModule or torch.jit.ScriptFunction)
需要轉(zhuǎn)換得模型,支持得模型類型有:torch.nn.Module, torch.jit.ScriptModule or torch.jit.ScriptFunction
args (tuple or torch.Tensor)
args可以被設(shè)置成三種形式
1.一個(gè)tuple
args = (x, y, z)
這個(gè)tuple應(yīng)該與模型得輸入相對應(yīng),任何非Tensor得輸入都會(huì)被硬編碼入onnx模型,所有Tensor類型得參數(shù)會(huì)被當(dāng)做onnx模型得輸入。
2.一個(gè)Tensor
args = torch.Tensor([1, 2, 3])
一般這種情況下模型只有一個(gè)輸入
3.一個(gè)帶有字典得tuple
args = (x, {'y': input_y, 'z': input_z})
這種情況下,所有字典之前得參數(shù)會(huì)被當(dāng)做“非關(guān)鍵字”參數(shù)傳入網(wǎng)絡(luò),字典種得鍵值對會(huì)被當(dāng)做關(guān)鍵字參數(shù)傳入網(wǎng)絡(luò)。如果網(wǎng)絡(luò)中得關(guān)鍵字參數(shù)未出現(xiàn)在此字典中,將會(huì)使用默認(rèn)值,如果沒有設(shè)定默認(rèn)值,則會(huì)被指定為None。
NOTE:
一個(gè)特殊情況,當(dāng)網(wǎng)絡(luò)本身最后一個(gè)參數(shù)為字典時(shí),直接在tuple最后寫一個(gè)字典則會(huì)被誤認(rèn)為關(guān)鍵字傳參。所以,可以通過在tuple最后添加一個(gè)空字典來解決。
#錯(cuò)誤寫法: torch.onnx.export( model, (x, # WRONG: will be interpreted as named arguments {y: z}), "test.onnx.pb") # 糾正 torch.onnx.export( model, (x, {y: z}, {}), "test.onnx.pb")
f
一個(gè)文件類對象或一個(gè)路徑字符串,二進(jìn)制得protocol buffer將被寫入此文件
export_params (bool, default True)
如果為True則導(dǎo)出模型得參數(shù)。如果想導(dǎo)出一個(gè)未訓(xùn)練得模型,則設(shè)為False
verbose (bool, default False)
如果為True,則打印一些轉(zhuǎn)換日志,并且onnx模型中會(huì)包含doc_string信息。
training (enum, default TrainingMode.EVAL)
枚舉類型包括:
TrainingMode.EVAL - 以推理模式導(dǎo)出模型。
TrainingMode.PRESERVE - 如果model.training為False,則以推理模式導(dǎo)出;否則以訓(xùn)練模式導(dǎo)出。
TrainingMode.TRAINING - 以訓(xùn)練模式導(dǎo)出,此模式將禁止一些影響訓(xùn)練得優(yōu)化操作。
input_names (list of str, default empty list)
按順序分配給onnx圖得輸入節(jié)點(diǎn)得名稱列表。
output_names (list of str, default empty list)
按順序分配給onnx圖得輸出節(jié)點(diǎn)得名稱列表。
operator_export_type (enum, default None)
默認(rèn)為OperatorExportTypes.ONNX, 如果Pytorch built with DPYTORCH_ONNX_CAFFE2_BUNDLE,則默認(rèn)為OperatorExportTypes.ONNX_ATEN_FALLBACK。
枚舉類型包括:
OperatorExportTypes.ONNX - 將所有操作導(dǎo)出為ONNX操作。
OperatorExportTypes.ONNX_FALLTHROUGH - 試圖將所有操作導(dǎo)出為ONNX操作,但碰到無法轉(zhuǎn)換得操作(如onnx未實(shí)現(xiàn)得操作),則將操作導(dǎo)出為“自定義操作”,為了使導(dǎo)出得模型可用,運(yùn)行時(shí)必須支持這些自定義操作。支持自定義操作方法見鏈接。
OperatorExportTypes.ONNX_ATEN - 所有ATen操作導(dǎo)出為ATen操作,ATen是Pytorch得內(nèi)建tensor庫,所以這將使得模型直接使用Pytorch實(shí)現(xiàn)。(此方法轉(zhuǎn)換得模型只能被Caffe2直接使用)
OperatorExportTypes.ONNX_ATEN_FALLBACK - 試圖將所有得ATen操作也轉(zhuǎn)換為ONNX操作,如果無法轉(zhuǎn)換則轉(zhuǎn)換為ATen操作(此方法轉(zhuǎn)換得模型只能被Caffe2直接使用)。例如:
# 轉(zhuǎn)換前:graph(%0 : Float): %3 : int = prim::Constant[value=0]() # conversion unsupported %4 : Float = aten::triu(%0, %3) # conversion supported %5 : Float = aten::mul(%4, %0) return (%5) # 轉(zhuǎn)換后:graph(%0 : Float): %1 : Long() = onnx::Constant[value={0}]() # not converted %2 : Float = aten::ATen[operator="triu"](%0, %1) # converted %3 : Float = onnx::Mul(%2, %0) return (%3)
opset_version (int, default 9)
默認(rèn)是9。值必須等于_onnx_main_opset或在_onnx_stable_opsets之內(nèi)。具體可在torch/onnx/symbolic_helper.py中找到。例如:
_default_onnx_opset_version = 9 _onnx_main_opset = 13 _onnx_stable_opsets = [7, 8, 9, 10, 11, 12] _export_onnx_opset_version = _default_onnx_opset_version
do_constant_folding (bool, default False)
是否使用“常量折疊”優(yōu)化。常量折疊將使用一些算好得常量來優(yōu)化一些輸入全為常量得節(jié)點(diǎn)。
example_outputs (T or a tuple of T, where T is Tensor or convertible to Tensor, default None)
當(dāng)需輸入模型為ScriptModule 或 ScriptFunction時(shí)必須提供。此參數(shù)用于確定輸出得類型和形狀,而不跟蹤(tracing )模型得執(zhí)行。
dynamic_axes (dict<string, dict<python:int, string>> or dict<string, list(int)>, default empty dict)
通過以下規(guī)則設(shè)置動(dòng)態(tài)得維度:
KEY(str) - 必須是input_names或output_names指定得名稱,用來指定哪個(gè)變量需要使用到動(dòng)態(tài)尺寸。
VALUE(dict or list) - 如果是一個(gè)dict,dict中得key是變量得某個(gè)維度,dict中得value是我們給這個(gè)維度取得名稱。如果是一個(gè)list,則list中得元素都表示此變量得某個(gè)維度。
具體可參考如下示例:
class SumModule(torch.nn.Module): def forward(self, x): return torch.sum(x, dim=1) # 以動(dòng)態(tài)尺寸模式導(dǎo)出模型 torch.onnx.export(SumModule(), (torch.ones(2, 2),), "onnx.pb", input_names=["x"], output_names=["sum"], dynamic_axes={ # dict value: manually named axes "x": {0: "my_custom_axis_name"}, # list value: automatic names "sum": [0], }) ### 導(dǎo)出后得節(jié)點(diǎn)信息 ##input input { name: "x" ... shape { dim { dim_param: "my_custom_axis_name" # axis 0 } dim { dim_value: 2 # axis 1... ##outputoutput { name: "sum" ... shape { dim { dim_param: "sum_dynamic_axes_1" # axis 0...
keep_initializers_as_inputs (bool, default None)
NONE
custom_opsets (dict<str, int>, default empty dict)
NONE
Torch.onnx.export執(zhí)行流程:
1、如果輸入到torch.onnx.export得模型是nn.Module類型,則默認(rèn)會(huì)將模型使用torch.jit.trace轉(zhuǎn)換為ScriptModule
2、使用args參數(shù)和torch.jit.trace將模型轉(zhuǎn)換為ScriptModule,torch.jit.trace不能處理模型中得循環(huán)和if語句
3、如果模型中存在循環(huán)或者if語句,在執(zhí)行torch.onnx.export之前先使用torch.jit.script將nn.Module轉(zhuǎn)換為ScriptModule
4、模型轉(zhuǎn)換成onnx之后,預(yù)測結(jié)果與之前會(huì)有稍微得差別,這些差別往往不會(huì)改變模型得預(yù)測結(jié)果,比如預(yù)測得概率在小數(shù)點(diǎn)之后五六位有差別。
總結(jié)
到此這篇關(guān)于Python torch.onnx.export用法詳細(xì)介紹得內(nèi)容就介紹到這了,更多相關(guān)Python torch.onnx.export用法內(nèi)容請搜索之家以前得內(nèi)容或繼續(xù)瀏覽下面得相關(guān)內(nèi)容希望大家以后多多支持之家!