下面是我使用的代码:
global output_res
output_res = ""
def recurse(left, right, threshold, features, node, depth):
spacer = spacer_base * depth
if (threshold[node] != -2):
"""print(spacer + "if ( " + features[node] + " <= " + \
str(threshold[node]) + " ) {")"""
output_res += spacer + "if ( " + features[node] + " <= " + \
str(threshold[node]) + " ) {"
if left[node] != -1:
recurse (left, right, threshold, features, left[node], depth+1)
"""print(spacer + "}\n" + spacer +"else {")"""
output_res += spacer + "}\n" + spacer +"else {"
if right[node] != -1:
recurse (left, right, threshold, features, right[node], depth+1)
"""print(spacer + "}")"""
output_res += spacer + "}"
else:
target = value[node]
for i, v in zip(np.nonzero(target)[1], target[np.nonzero(target)]):
target_name = target_names[i]
target_count = int(v)
"""print(spacer + "return " + str(target_name) + " ( " + \
str(target_count) + " examples )")"""
output_res += spacer + "return " + str(target_name) + " ( " + \
str(target_count) + " examples )"
return output_res
recurse(left, right, threshold, features, 0, 0)
正如你所见
recurse()
是一个递归函数,我的目标是检索
output_res
,并在主函数中使用它,使用此代码时出现以下错误:
赋值前引用的局部变量“output_res”
使现代化
我找到了我想要的确切解决方案:
temp_list = []
def recurse(temp_list, left, right, threshold, features, node, depth):
spacer = spacer_base * depth
if (threshold[node] != -2):
temp_list.append(spacer + "if ( " + features[node] + " <= " + \
str(threshold[node]) + " ) {")
if left[node] != -1:
recurse (temp_list, left, right, threshold, features, left[node], depth+1)
temp_list.append(spacer + "}\n" + spacer +"else {")
if right[node] != -1:
recurse (temp_list, left, right, threshold, features, right[node], depth+1)
temp_list.append(spacer + "}")
else:
target = value[node]
for i, v in zip(np.nonzero(target)[1], target[np.nonzero(target)]):
target_name = target_names[i]
target_count = int(v)
temp_list.append(spacer + "return " + str(target_name) + " ( " + \
str(target_count) + " examples )")
recurse(temp_list, left, right, threshold, features, 0, 0)
return '\n'.join(temp_list)