329 lines
9 KiB
Python
329 lines
9 KiB
Python
|
#stolen from https://gist.github.com/qinjian623/6aa777037534c1c1dccbb66f832e93b8
|
||
|
|
||
|
import onnx
|
||
|
import struct
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import warnings
|
||
|
from collections import OrderedDict
|
||
|
|
||
|
data_type_tab = {
|
||
|
1: ['f', 4],
|
||
|
2: ['B', 1],
|
||
|
3: ['b', 1],
|
||
|
4: ['H', 2],
|
||
|
5: ['h', 2],
|
||
|
6: ['i', 4],
|
||
|
7: ['q', 8],
|
||
|
10: ['e', 2],
|
||
|
11: ['d', 8],
|
||
|
12: ['I', 4],
|
||
|
13: ['Q', 8]
|
||
|
}
|
||
|
|
||
|
|
||
|
def empty(x):
|
||
|
return x
|
||
|
|
||
|
def _slim422(l4):
|
||
|
assert len(l4) == 4
|
||
|
|
||
|
p0, p1 = l4[::2]
|
||
|
if l4[0] == 0: # TODO bad code
|
||
|
p0 = l4[2] // 2
|
||
|
if l4[2] == 1:
|
||
|
p0 = 1
|
||
|
if l4[1] == 0: # TODO bad code
|
||
|
p1 = l4[3] // 2
|
||
|
if l4[3] == 1:
|
||
|
p1 = 1
|
||
|
return p0, p1
|
||
|
|
||
|
|
||
|
def _check_attr(attrs, map):
|
||
|
for attr in attrs:
|
||
|
if attr.name not in map:
|
||
|
warnings.warn("Missing {} in parser's attr_map.".format(attr.name))
|
||
|
|
||
|
|
||
|
def unpack_weights(initializer):
|
||
|
ret = {}
|
||
|
for i in initializer:
|
||
|
name = i.name
|
||
|
dtype = i.data_type
|
||
|
shape = list(i.dims)
|
||
|
if dtype not in data_type_tab:
|
||
|
warnings("This data type {} is not supported yet.".format(dtype))
|
||
|
fmt, size = data_type_tab[dtype]
|
||
|
if len(i.raw_data) == 0:
|
||
|
if dtype == 1:
|
||
|
data_list = i.float_data
|
||
|
elif dtype == 7:
|
||
|
data_list = i.int64_data
|
||
|
else:
|
||
|
warnings.warn("No-raw-data type {} not supported yet.".format(dtype))
|
||
|
else:
|
||
|
data_list = struct.unpack('<' + fmt * (len(i.raw_data) // size), i.raw_data)
|
||
|
t = torch.tensor(data_list)
|
||
|
if len(shape) != 0:
|
||
|
t = t.view(*shape)
|
||
|
ret[name] = t
|
||
|
return ret
|
||
|
|
||
|
def rebuild_conv(node, weights):
|
||
|
rebuild_conv.conv_attr_map = {
|
||
|
"pads": "padding",
|
||
|
"strides": "stride",
|
||
|
"kernel_shape": "kernel_size",
|
||
|
"group": "groups",
|
||
|
"dilations": "dilation"
|
||
|
}
|
||
|
assert len(node.output) == 1
|
||
|
with_bias = False
|
||
|
if len(node.input) == 3:
|
||
|
with_bias = True
|
||
|
bias_name = node.input[2]
|
||
|
bias = weights[bias_name]
|
||
|
|
||
|
weight_name = node.input[1]
|
||
|
weight = weights[weight_name]
|
||
|
in_channels = weight.shape[1]
|
||
|
out_channels = weight.shape[0]
|
||
|
kwargs = {}
|
||
|
for att in node.attribute:
|
||
|
kwargs[rebuild_conv.conv_attr_map[att.name]] = list(att.ints) if att.name != 'group' else att.i
|
||
|
if 'padding' in kwargs:
|
||
|
kwargs["padding"] = _slim422(kwargs["padding"])
|
||
|
groups = 1 if 'groups' not in kwargs else kwargs['groups']
|
||
|
in_channels *= groups
|
||
|
conv = nn.Conv2d(in_channels, out_channels, **kwargs, bias=with_bias)
|
||
|
conv.weight.data = weight
|
||
|
if with_bias:
|
||
|
conv.bias.data = bias
|
||
|
return conv, node.input[:1], node.output
|
||
|
|
||
|
|
||
|
def rebuild_dropout(node, weights):
|
||
|
ratio = node.attribute[0].f
|
||
|
return nn.Dropout2d(p=ratio), node.input, node.output
|
||
|
|
||
|
def rebuild_relu(node, weights, nodes):
|
||
|
return nn.ReLU(), node.input, node.output
|
||
|
|
||
|
|
||
|
def rebuild_maxpool(node, weights):
|
||
|
rebuild_maxpool.mp_attr_map = {
|
||
|
"pads": "padding",
|
||
|
"strides": "stride",
|
||
|
"kernel_shape": "kernel_size",
|
||
|
}
|
||
|
kwargs = {}
|
||
|
for att in node.attribute:
|
||
|
kwargs[rebuild_maxpool.mp_attr_map[att.name]] = list(att.ints)
|
||
|
if 'padding' in kwargs:
|
||
|
kwargs["padding"] = _slim422(kwargs["padding"])
|
||
|
mp = nn.MaxPool2d(**kwargs)
|
||
|
return mp, node.input, node.output
|
||
|
|
||
|
|
||
|
def rebuild_add(node, weights, graph):
|
||
|
def add(a, b):
|
||
|
return a + b
|
||
|
return add, node.input, node.output
|
||
|
|
||
|
def rebuild_transpose(node, weights, graph):
|
||
|
perm = node.attribute[0].ints
|
||
|
def transpose(x):
|
||
|
x = x.permute(*perm)
|
||
|
return x
|
||
|
return transpose, node.input, node.output
|
||
|
|
||
|
|
||
|
def rebuild_flatten(node, weights):
|
||
|
if len(node.attribute) == 0:
|
||
|
d = 1
|
||
|
else:
|
||
|
d = node.attribute[0].i
|
||
|
|
||
|
def flatten(x):
|
||
|
o_shape = []
|
||
|
for i in range(d):
|
||
|
o_shape.append(x.shape[i])
|
||
|
o_shape.append(-1)
|
||
|
return x.view(*o_shape)
|
||
|
return flatten, node.input, node.output
|
||
|
|
||
|
|
||
|
def rebuild_gemm(node, weights):
|
||
|
weight = weights[node.input[1]]
|
||
|
bias = weights[node.input[2]]
|
||
|
in_feats = weight.shape[1]
|
||
|
out_feats = weight.shape[0]
|
||
|
linear = nn.Linear(in_features=in_feats, out_features=out_feats)
|
||
|
linear.weight.data = weight
|
||
|
linear.bias.data = bias
|
||
|
return linear, node.input[:1], node.output
|
||
|
|
||
|
#TODO: rebuild matmul, find a way to include the transpose weights
|
||
|
def rebuild_matmul(node, weights, nodes):
|
||
|
weight = weights[node.input[1]]
|
||
|
for next_node in nodes:
|
||
|
if next_node.input[0] == node.output[0]:
|
||
|
bias = weights[next_node.input[1]]
|
||
|
in_feats = weight.shape[0]
|
||
|
out_feats = weight.shape[1]
|
||
|
linear = nn.Linear(in_features=in_feats,
|
||
|
out_features=out_feats)
|
||
|
linear.weight.data = weight.transpose(1,0)
|
||
|
linear.bias.data = bias
|
||
|
return linear, node.input[:1], node.output
|
||
|
|
||
|
def rebuild_pad(node, weights):
|
||
|
mode = node.attribute[0].s
|
||
|
pads = list(node.attribute[1].ints)
|
||
|
value = node.attribute[2].f
|
||
|
assert mode == b'constant' # TODO constant only
|
||
|
assert sum(pads[:4]) == 0 # TODO pad2d only
|
||
|
pad = nn.ConstantPad2d(pads[4:], value)
|
||
|
return pad, node.input, node.output
|
||
|
|
||
|
|
||
|
def rebuild_constant(node, weights, nodes):
|
||
|
raw_data = node.attribute[0].t.raw_data
|
||
|
data_type = node.attribute[0].t.data_type
|
||
|
fmt, size = data_type_tab[data_type]
|
||
|
data = struct.unpack('<' + fmt * (len(raw_data) // size), raw_data)
|
||
|
if len(data) == 1:
|
||
|
data = data[0]
|
||
|
|
||
|
def constant():
|
||
|
return torch.tensor(data)
|
||
|
return constant, [], node.output
|
||
|
|
||
|
|
||
|
def rebuild_sum(node, weights):
|
||
|
def sum(*inputs):
|
||
|
ret = inputs[0]
|
||
|
for i in inputs[1:]:
|
||
|
ret += i
|
||
|
return ret
|
||
|
return sum, node.input, node.output
|
||
|
|
||
|
|
||
|
|
||
|
def rebuild_shape(node, weights):
|
||
|
def shape(x):
|
||
|
return torch.tensor(list(x.shape))
|
||
|
return shape, node.input, node.output
|
||
|
|
||
|
|
||
|
def rebuild_gather(node, weights):
|
||
|
axis = node.attribute[0].i
|
||
|
|
||
|
def gather(x, idx):
|
||
|
return torch.gather(x, axis, idx)
|
||
|
return gather, node.input, node.output
|
||
|
|
||
|
|
||
|
def _nd_unsqueeze(x, dims):
|
||
|
dims = sorted(dims)
|
||
|
for d in dims:
|
||
|
x = torch.unsqueeze(x, dim=d)
|
||
|
return x
|
||
|
|
||
|
|
||
|
def rebuild_unsqueeze(node, weights):
|
||
|
axes = node.attribute[0].ints
|
||
|
|
||
|
def unsqueeze(x):
|
||
|
return _nd_unsqueeze(x, axes)
|
||
|
|
||
|
return unsqueeze, node.input, node.output
|
||
|
|
||
|
|
||
|
def rebuild_mul(node, weights):
|
||
|
def mul(a, b):
|
||
|
return a * b
|
||
|
return mul, node.input, node.output
|
||
|
|
||
|
|
||
|
def rebuild_softmax(node, weights):
|
||
|
def f_softmax(x):
|
||
|
return x.softmax(dim=1, dtype=torch.double).float()
|
||
|
return f_softmax, node.input, node.output
|
||
|
|
||
|
|
||
|
def rebuild_reshape(node, weights, nodes):
|
||
|
def reshape(x, s):
|
||
|
data_shape = x.shape
|
||
|
onnx_shape = s.tolist()
|
||
|
pt_shape = []
|
||
|
for idx, d in enumerate(onnx_shape):
|
||
|
if d == 0:
|
||
|
pt_shape.append(data_shape[idx])
|
||
|
else:
|
||
|
pt_shape.append(d)
|
||
|
return torch.reshape(x, pt_shape)
|
||
|
return reshape, node.input, node.output
|
||
|
|
||
|
|
||
|
def rebuild_op(node, weights, nodes):
|
||
|
op_type = node.op_type
|
||
|
return globals()['rebuild_'+op_type.lower()](node, weights, nodes)
|
||
|
|
||
|
|
||
|
def construct_pytorch_nodes(graph, weights):
|
||
|
ret = []
|
||
|
for single_node in graph.node:
|
||
|
ret.append(rebuild_op(single_node, weights, graph.node))
|
||
|
return ret
|
||
|
|
||
|
|
||
|
def resolve_deps(name, deps, inter_tensors):
|
||
|
if name in inter_tensors:
|
||
|
return
|
||
|
else:
|
||
|
op, deps_names = deps[name]
|
||
|
args = []
|
||
|
for deps_name in deps_names:
|
||
|
resolve_deps(deps_name, deps, inter_tensors)
|
||
|
args.append(inter_tensors[deps_name])
|
||
|
result = op(*args)
|
||
|
inter_tensors[name] = result
|
||
|
|
||
|
|
||
|
class DependencyModule(nn.Module):
|
||
|
#TODO: add a layers and convert everything to a dict
|
||
|
def __init__(self, onnx_model, input_name=None):
|
||
|
super(DependencyModule, self).__init__()
|
||
|
self.deps = {}
|
||
|
self.inter_tensors = dict()
|
||
|
self.weights = unpack_weights(onnx_model.graph.initializer)
|
||
|
nodes = construct_pytorch_nodes(onnx_model.graph, self.weights)
|
||
|
#TODO: ModuleDict with propernames
|
||
|
layers = OrderedDict({})
|
||
|
for idx, (node, inputs, outputs) in enumerate(nodes):
|
||
|
if isinstance(node, nn.Module):
|
||
|
layers['layer'+str(idx)] = node
|
||
|
for output_name in outputs:
|
||
|
self.deps[output_name] = (node, inputs)
|
||
|
|
||
|
self.input_name = onnx_model.graph.input[0].name # TODO only you
|
||
|
self.output_name = onnx_model.graph.output[0].name # TODO only you
|
||
|
if input_name is not None:
|
||
|
self.input_name = input_name
|
||
|
self.layers = nn.ModuleDict(layers)
|
||
|
|
||
|
|
||
|
def forward(self, input):
|
||
|
self.inter_tensors = self.weights.copy()
|
||
|
self.inter_tensors[self.input_name] = input
|
||
|
resolve_deps(self.output_name, self.deps, self.inter_tensors)
|
||
|
return self.inter_tensors[self.output_name]
|
||
|
|
||
|
|
||
|
def convert(onnx_path):
|
||
|
onnx_model = onnx.load(onnx_path)
|
||
|
reconstruct_model = DependencyModule(onnx_model)
|
||
|
return reconstruct_model
|