TensorFlow v2.9。我正在使用设备上训练。从python导出一个模块,然后从C调用具体的API进行预测和训练。我也可以使用GPU。
SavedModel无法利用高级API(例如。
predict
或
fit
)。否则,它将失败,并出现以下错误。
RuntimeError:检测到对的调用
Model.predict
内部
tf.function
。
模型预测
是管理自己的高级端点
tf.函数
。请将呼叫转移到
模型预测
在所有围护结构之外
tf.函数
s.请注意,您可以调用
Model
直接打开
Tensor
s在a内部
tf.函数
比如:
model(x)
。
因此,我的模型被实现为一个自定义模型。预测使用
型号(x)
如这里所解释的:
https://www.tensorflow.org/lite/examples/on_device_training/overview
我不知道为什么,但是
型号(x)
似乎无法正确处理批处理。
例如,批量大小为
3
和
型号(x)
接受
(3, 4, 15, 15)
作为输入,其中一个输出是
(3, 1, 225)
。
如下所示,所有三个矢量
(1,225)
在输出张量中完全相同。
[
[
[-20.500122, -20.500196, -16.388021, -20.500189, -13.888604, -20.500208, -20.500103, -13.725816, -16.14115, -15.523373, -16.094854, -15.536175, -13.494872, -20.500164, -16.729692, -17.314562, -9.923043, -20.500137, -13.227316, -19.462494, -8.832517, -11.005514, -16.657751, -20.500229, -19.104895, -17.969429, -16.826006, -18.479736, -11.35681, -20.50018, -17.686893, -15.8137665, -20.500158, -20.498934, -11.30343, -12.114782, -6.9864135, -16.129002, -11.758956, -13.793568, -10.100338, -18.394066, -7.8771715, -18.867481, -13.54011, -20.500141, -18.142273, -13.827344, -12.14585, -8.751808, -7.360826, -7.8197165, -8.190978, -7.9918194, -7.1475286, -10.866553, -13.463445, -12.561472, -17.644833, -20.499897, -15.04738, -15.1495285, -15.757288, -10.316235, -6.4681287, -6.771983, -6.2083254, -5.169312, -5.9851274, -7.3863406, -5.7047515, -11.461843, -19.462492, -20.499823, -16.014748, -19.572166, -10.054104, -9.654353, -6.9895654, -6.523039, -3.4712281, -4.010914, -3.058044, -5.203539, -4.562346, -7.3472414, -8.2306795, -14.15948, -16.442978, -15.1097, -20.499994, -16.006512, -13.285485, -9.599341, -5.576161, -5.10128, -2.1091957, -2.6103199, -2.3030841, -4.3452697, -5.1566353, -6.7773423, -13.5079155, -18.91643, -20.49996, -20.50012, -20.500032, -15.034921, -7.0785294, -6.62519, -2.6741242, -3.3764887, -3.2719333, -3.4223785, -3.1113718, -6.607987, -6.7852387, -9.567825, -17.231964, -18.361439, -15.199417, -20.500113, -8.907006, -8.894981, -4.4610567, -5.3974047, -3.1986039, -3.308056, -2.5260184, -4.416704, -5.5637026, -8.839353, -7.404949, -18.09958, -20.499996, -20.500063, -12.94954, -17.1081, -7.8807735, -6.0368576, -4.0000243, -4.983799, -3.7624922, -3.9401622, -5.351621, -7.3347793, -6.7273192, -16.521574, -10.310918, -18.213472, -18.239689, -20.49987, -13.403644, -10.768933, -6.169673, -6.226465, -4.851883, -3.5755277, -5.7955694, -7.59566, -6.5219584, -15.287647, -9.992104, -20.49974, -11.737182, -20.500032, -13.7056465, -11.700055, -11.151376, -12.240701, -6.9801717, -9.907572, -9.89772, -7.7714005, -7.599248, -14.2966175, -10.805019, -14.946489, -15.138906, -20.49991, -16.84454, -20.500303, -15.745817, -11.974067, -14.362624, -13.677492, -6.8857694, -10.488706, -9.6858, -15.690493, -13.776093, -17.350763, -13.82417, -20.500122, -16.799477, -11.256063, -16.112524, -20.500021, -16.107948, -11.349038, -12.018146, -20.500145, -15.021783, -20.500141, -14.088732, -19.462494, -16.841585, -17.49845, -15.664743, -18.375904, -20.500162, -17.897068, -20.50004, -13.704247, -15.333616, -20.500124, -14.740182, -12.495611, -20.500069, -20.50013, -17.074047, -13.579008, -16.136011, -20.500244, -11.993184]
],
[
[-20.500122, -20.500196, -16.388021, -20.500189, -13.888604, -20.500208, -20.500103, -13.725816, -16.14115, -15.523373, -16.094854, -15.536175, -13.494872, -20.500164, -16.729692, -17.314562, -9.923043, -20.500137, -13.227316, -19.462494, -8.832517, -11.005514, -16.657751, -20.500229, -19.104895, -17.969429, -16.826006, -18.479736, -11.35681, -20.50018, -17.686893, -15.8137665, -20.500158, -20.498934, -11.30343, -12.114782, -6.9864135, -16.129002, -11.758956, -13.793568, -10.100338, -18.394066, -7.8771715, -18.867481, -13.54011, -20.500141, -18.142273, -13.827344, -12.14585, -8.751808, -7.360826, -7.8197165, -8.190978, -7.9918194, -7.1475286, -10.866553, -13.463445, -12.561472, -17.644833, -20.499897, -15.04738, -15.1495285, -15.757288, -10.316235, -6.4681287, -6.771983, -6.2083254, -5.169312, -5.9851274, -7.3863406, -5.7047515, -11.461843, -19.462492, -20.499823, -16.014748, -19.572166, -10.054104, -9.654353, -6.9895654, -6.523039, -3.4712281, -4.010914, -3.058044, -5.203539, -4.562346, -7.3472414, -8.2306795, -14.15948, -16.442978, -15.1097, -20.499994, -16.006512, -13.285485, -9.599341, -5.576161, -5.10128, -2.1091957, -2.6103199, -2.3030841, -4.3452697, -5.1566353, -6.7773423, -13.5079155, -18.91643, -20.49996, -20.50012, -20.500032, -15.034921, -7.0785294, -6.62519, -2.6741242, -3.3764887, -3.2719333, -3.4223785, -3.1113718, -6.607987, -6.7852387, -9.567825, -17.231964, -18.361439, -15.199417, -20.500113, -8.907006, -8.894981, -4.4610567, -5.3974047, -3.1986039, -3.308056, -2.5260184, -4.416704, -5.5637026, -8.839353, -7.404949, -18.09958, -20.499996, -20.500063, -12.94954, -17.1081, -7.8807735, -6.0368576, -4.0000243, -4.983799, -3.7624922, -3.9401622, -5.351621, -7.3347793, -6.7273192, -16.521574, -10.310918, -18.213472, -18.239689, -20.49987, -13.403644, -10.768933, -6.169673, -6.226465, -4.851883, -3.5755277, -5.7955694, -7.59566, -6.5219584, -15.287647, -9.992104, -20.49974, -11.737182, -20.500032, -13.7056465, -11.700055, -11.151376, -12.240701, -6.9801717, -9.907572, -9.89772, -7.7714005, -7.599248, -14.2966175, -10.805019, -14.946489, -15.138906, -20.49991, -16.84454, -20.500303, -15.745817, -11.974067, -14.362624, -13.677492, -6.8857694, -10.488706, -9.6858, -15.690493, -13.776093, -17.350763, -13.82417, -20.500122, -16.799477, -11.256063, -16.112524, -20.500021, -16.107948, -11.349038, -12.018146, -20.500145, -15.021783, -20.500141, -14.088732, -19.462494, -16.841585, -17.49845, -15.664743, -18.375904, -20.500162, -17.897068, -20.50004, -13.704247, -15.333616, -20.500124, -14.740182, -12.495611, -20.500069, -20.50013, -17.074047, -13.579008, -16.136011, -20.500244, -11.993184]
],
[
[-20.500122, -20.500196, -16.388021, -20.500189, -13.888604, -20.500208, -20.500103, -13.725816, -16.14115, -15.523373, -16.094854, -15.536175, -13.494872, -20.500164, -16.729692, -17.314562, -9.923043, -20.500137, -13.227316, -19.462494, -8.832517, -11.005514, -16.657751, -20.500229, -19.104895, -17.969429, -16.826006, -18.479736, -11.35681, -20.50018, -17.686893, -15.8137665, -20.500158, -20.498934, -11.30343, -12.114782, -6.9864135, -16.129002, -11.758956, -13.793568, -10.100338, -18.394066, -7.8771715, -18.867481, -13.54011, -20.500141, -18.142273, -13.827344, -12.14585, -8.751808, -7.360826, -7.8197165, -8.190978, -7.9918194, -7.1475286, -10.866553, -13.463445, -12.561472, -17.644833, -20.499897, -15.04738, -15.1495285, -15.757288, -10.316235, -6.4681287, -6.771983, -6.2083254, -5.169312, -5.9851274, -7.3863406, -5.7047515, -11.461843, -19.462492, -20.499823, -16.014748, -19.572166, -10.054104, -9.654353, -6.9895654, -6.523039, -3.4712281, -4.010914, -3.058044, -5.203539, -4.562346, -7.3472414, -8.2306795, -14.15948, -16.442978, -15.1097, -20.499994, -16.006512, -13.285485, -9.599341, -5.576161, -5.10128, -2.1091957, -2.6103199, -2.3030841, -4.3452697, -5.1566353, -6.7773423, -13.5079155, -18.91643, -20.49996, -20.50012, -20.500032, -15.034921, -7.0785294, -6.62519, -2.6741242, -3.3764887, -3.2719333, -3.4223785, -3.1113718, -6.607987, -6.7852387, -9.567825, -17.231964, -18.361439, -15.199417, -20.500113, -8.907006, -8.894981, -4.4610567, -5.3974047, -3.1986039, -3.308056, -2.5260184, -4.416704, -5.5637026, -8.839353, -7.404949, -18.09958, -20.499996, -20.500063, -12.94954, -17.1081, -7.8807735, -6.0368576, -4.0000243, -4.983799, -3.7624922, -3.9401622, -5.351621, -7.3347793, -6.7273192, -16.521574, -10.310918, -18.213472, -18.239689, -20.49987, -13.403644, -10.768933, -6.169673, -6.226465, -4.851883, -3.5755277, -5.7955694, -7.59566, -6.5219584, -15.287647, -9.992104, -20.49974, -11.737182, -20.500032, -13.7056465, -11.700055, -11.151376, -12.240701, -6.9801717, -9.907572, -9.89772, -7.7714005, -7.599248, -14.2966175, -10.805019, -14.946489, -15.138906, -20.49991, -16.84454, -20.500303, -15.745817, -11.974067, -14.362624, -13.677492, -6.8857694, -10.488706, -9.6858, -15.690493, -13.776093, -17.350763, -13.82417, -20.500122, -16.799477, -11.256063, -16.112524, -20.500021, -16.107948, -11.349038, -12.018146, -20.500145, -15.021783, -20.500141, -14.088732, -19.462494, -16.841585, -17.49845, -15.664743, -18.375904, -20.500162, -17.897068, -20.50004, -13.704247, -15.333616, -20.500124, -14.740182, -12.495611, -20.500069, -20.50013, -17.074047, -13.579008, -16.136011, -20.500244, -11.993184]
]
]
但是
(4, 15, 15)
子张量在输入张量中是不同的。
[[[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
[0 0 0 0 0 1 1 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 1 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 1 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 0 1 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 1 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]]
[[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
[0 0 0 0 0 0 0 1 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 1 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
[[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]]]
[[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
[0 0 0 0 0 0 1 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
[0 0 0 0 0 0 0 1 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]]]
请问这里出了什么问题?完整的源代码如下。
def create_model(board_width, board_height):
class RenjuModel(tf.Module):
def __init__(self):
l2_penalty_beta = 1e-4
# Define the tensorflow neural network
# 1. Input:
self.inputs = tf.keras.Input( shape=(4, board_height, board_width), dtype=tf.dtypes.float32, name="input")
self.transposed_inputs = tf.keras.layers.Lambda( lambda x: tf.transpose(x, [0, 2, 3, 1]) )(self.inputs)
# 2. Common Networks Layers
self.conv1 = tf.keras.layers.Conv2D( name="conv1",
filters=32,
kernel_size=(3, 3),
padding="same",
data_format="channels_last",
activation=tf.keras.activations.relu,
kernel_regularizer=tf.keras.regularizers.L2(l2_penalty_beta)
)(self.transposed_inputs)
self.conv2 = tf.keras.layers.Conv2D( name="conv2",
filters=64,
kernel_size=(3, 3),
padding="same",
data_format="channels_last",
activation=tf.keras.activations.relu,
kernel_regularizer=tf.keras.regularizers.L2(l2_penalty_beta)
)(self.conv1)
self.conv3 = tf.keras.layers.Conv2D( name="conv3",
filters=128,
kernel_size=(3, 3),
padding="same",
data_format="channels_last",
activation=tf.keras.activations.relu,
kernel_regularizer=tf.keras.regularizers.L2(l2_penalty_beta)
)(self.conv2)
# 3-1 Action Networks
self.action_conv = tf.keras.layers.Conv2D( name="action_conv",
filters=4,
kernel_size=(1, 1),
padding="same",
data_format="channels_last",
activation=tf.keras.activations.relu,
kernel_regularizer=tf.keras.regularizers.L2(l2_penalty_beta)
)(self.conv3)
# flatten tensor
self.action_conv_flat = tf.keras.layers.Reshape( (-1, 4 * board_height * board_width), name="action_conv_flat"
)(self.action_conv)
# 3-2 Full connected layer, the output is the log probability of moves
# on each slot on the board
self.action_fc = tf.keras.layers.Dense( board_height * board_width,
activation=tf.nn.log_softmax,
name="action_fc",
kernel_regularizer=tf.keras.regularizers.L2(l2_penalty_beta)
)(self.action_conv_flat)
# 4 Evaluation Networks
self.evaluation_conv = tf.keras.layers.Conv2D( name="evaluation_conv",
filters=2,
kernel_size=(1, 1),
padding="same",
data_format="channels_last",
activation=tf.keras.activations.relu,
kernel_regularizer=tf.keras.regularizers.L2(l2_penalty_beta)
)(self.conv3)
self.evaluation_conv_flat = tf.keras.layers.Reshape( (-1, 2 * board_height * board_width),
name="evaluation_conv_flat"
)(self.evaluation_conv)
self.evaluation_fc1 = tf.keras.layers.Dense( 64,
activation=tf.keras.activations.relu,
name="evaluation_fc1",
kernel_regularizer=tf.keras.regularizers.L2(l2_penalty_beta)
)(self.evaluation_conv_flat)
self.evaluation_fc2 = tf.keras.layers.Dense( 1,
activation=tf.keras.activations.tanh,
name="evaluation_fc2",
kernel_regularizer=tf.keras.regularizers.L2(l2_penalty_beta)
)(self.evaluation_fc1)
self.model = tf.keras.Model(inputs=self.inputs, outputs=[self.action_fc, self.evaluation_fc2], name="renju_model")
self.model.summary()
self.lr = tf.Variable(0.002, trainable=False, dtype=tf.dtypes.float32)
self.model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate = self.lr),
loss=[self.action_loss, tf.keras.losses.MeanSquaredError()],
metrics=['accuracy'])
@tf.function(input_signature=[ tf.TensorSpec([None, 1, board_height * board_width], tf.float32),
tf.TensorSpec([None, 1, board_height * board_width], tf.float32)
])
def action_loss(self, labels, predictions):
tf.print(labels, summarize=-1)
tf.print(predictions, summarize=-1)
# labels are probabilities; predictions are logits
return tf.negative(tf.reduce_mean(
tf.reduce_sum(tf.multiply(labels, predictions), 2)))
@tf.function(input_signature=[
tf.TensorSpec([None, 4, board_height, board_width], tf.float32),
])
def predict(self, state_batch):
if tf.shape(state_batch)[0] > 1:
tf.print(state_batch, summarize=-1)
x = self.model(state_batch)
if tf.shape(state_batch)[0] > 1:
tf.print(x, summarize=-1)
return x
@tf.function(input_signature=[tf.TensorSpec(shape=[None, 4, board_height, board_width], dtype=tf.float32),
tf.TensorSpec(shape=[None, 1, board_height * board_width], dtype=tf.float32),
tf.TensorSpec(shape=[], dtype=tf.float32),
tf.TensorSpec(shape=[1], dtype=tf.float32) ])
def train(self, state_batch, mcts_probs, winner_batch, lr):
self.lr.assign(tf.gather(lr, 0))
with tf.GradientTape() as tape:
predictions = self.model(state_batch, training=True) # Forward pass
# the loss function is configured in `compile()`
loss = self.model.compiled_loss([mcts_probs, winner_batch], predictions, regularization_losses=self.model.losses)
gradients = tape.gradient(loss, self.model.trainable_variables)
self.model.optimizer.apply_gradients(
zip(gradients, self.model.trainable_variables))
entropy = tf.negative(tf.reduce_mean(
tf.reduce_sum(tf.exp(predictions[0][0]) * predictions[0][0], 1)))
return (loss, entropy)
@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
def save(self, checkpoint_path):
tensor_names = [weight.name for weight in self.model.weights]
tensors_to_save = [weight.read_value() for weight in self.model.weights]
tf.raw_ops.Save(
filename=checkpoint_path, tensor_names=tensor_names,
data=tensors_to_save, name='save')
return checkpoint_path
@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
def restore(self, checkpoint_path):
restored_tensors = {}
for var in self.model.weights:
restored = tf.raw_ops.Restore( file_pattern=checkpoint_path, tensor_name=var.name, dt=var.dtype, name='restore')
var.assign(restored)
restored_tensors[var.name] = restored
return checkpoint_path
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
def random_choose_with_dirichlet_noice(self, probs):
concentration = 0.3*tf.ones(tf.size(probs))
dist = tfp.distributions.Dirichlet(concentration)
p = 0.75*probs + 0.25*dist.sample(1)[0]
samples = tf.random.categorical(tf.math.log([p]), 1)
return samples[0] # selected index
return RenjuModel()
model = create_model( 15, 15)
#Saving the model, explictly adding the concrete functions as signatures
model.model.save('renju_15x15_model',
save_format='tf',
signatures={
'predict': model.predict.get_concrete_function(),
'train' : model.train.get_concrete_function(),
'save' : model.save.get_concrete_function(),
'restore' : model.restore.get_concrete_function(),
'random_choose_with_dirichlet_noice' : model.random_choose_with_dirichlet_noice.get_concrete_function()
})