mirror of
https://git.oceanpay.cc/danial/kami_jd_ck.git
synced 2025-12-18 22:11:07 +00:00
204 lines
6.6 KiB
Python
204 lines
6.6 KiB
Python
import base64
|
|
import os
|
|
import platform
|
|
|
|
import onnxruntime
|
|
import numpy as np
|
|
import cv2
|
|
|
|
from logger import get_logger
|
|
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
def imgtest(img, res_list, centre_xy):
|
|
for inx, n in enumerate(res_list):
|
|
x1, y1, x2, y2 = n
|
|
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|
cv2.putText(img, str(inx + 1), (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
|
|
|
for n in centre_xy:
|
|
cv2.rectangle(img, (n[0], n[1]), (n[0], n[1]), (0, 255, 0), 10)
|
|
return img
|
|
|
|
|
|
class siamese:
|
|
|
|
def __init__(self):
|
|
current_os = platform.system()
|
|
if current_os == "Linux":
|
|
utils_path = r"/app/utils"
|
|
else:
|
|
utils_path = "./utils"
|
|
test_path = os.path.join(utils_path, 'test.onnx')
|
|
logger.info(f"推理模型路径:{test_path}")
|
|
self.siamese_model = onnxruntime.InferenceSession(test_path, providers=["CPUExecutionProvider"])
|
|
|
|
def sigmoid(self, x):
|
|
return 1 / (1 + np.exp(-x))
|
|
|
|
def dispose_img(self, img):
|
|
img = cv2.resize(img, (60, 60))
|
|
image = np.array(img).astype(np.float32) / 255.0
|
|
return np.expand_dims(np.transpose(image, (2, 0, 1)), 0)
|
|
|
|
def predict_siamese(self, image_1, image_2):
|
|
photo_1 = self.dispose_img(image_1)
|
|
photo_2 = self.dispose_img(image_2)
|
|
out = self.siamese_model.run(None, {"x1": photo_1, "x2": photo_2})
|
|
out = out[0]
|
|
out = self.sigmoid(out)
|
|
return int(out[0][0] * 100)
|
|
|
|
|
|
class yolo:
|
|
def __init__(self):
|
|
self.w = 288
|
|
self.h = 192
|
|
current_os = platform.system()
|
|
if current_os == "Linux":
|
|
utils_path = r"/app/utils"
|
|
else:
|
|
utils_path = "./utils"
|
|
yolo_path = os.path.join(utils_path, 'yolo.onnx')
|
|
logger.info(f"识别模型路径:{yolo_path}")
|
|
self.yolo_model = onnxruntime.InferenceSession(yolo_path, providers=["CPUExecutionProvider"])
|
|
|
|
def nms(self, dets, thresh):
|
|
x1 = dets[:, 0]
|
|
y1 = dets[:, 1]
|
|
x2 = dets[:, 2]
|
|
y2 = dets[:, 3]
|
|
areas = (y2 - y1 + 1) * (x2 - x1 + 1)
|
|
scores = dets[:, 4]
|
|
keep = []
|
|
index = scores.argsort()[::-1]
|
|
|
|
while index.size > 0:
|
|
i = index[0]
|
|
keep.append(i)
|
|
x11 = np.maximum(x1[i], x1[index[1:]])
|
|
y11 = np.maximum(y1[i], y1[index[1:]])
|
|
x22 = np.minimum(x2[i], x2[index[1:]])
|
|
y22 = np.minimum(y2[i], y2[index[1:]])
|
|
|
|
w = np.maximum(0, x22 - x11 + 1)
|
|
h = np.maximum(0, y22 - y11 + 1)
|
|
|
|
overlaps = w * h
|
|
ious = overlaps / (areas[i] + areas[index[1:]] - overlaps)
|
|
idx = np.where(ious <= thresh)[0]
|
|
index = index[idx + 1]
|
|
return keep
|
|
|
|
def yolo_to_xy(self, x):
|
|
y = np.copy(x)
|
|
y[:, 0] = x[:, 0] - x[:, 2] / 2
|
|
y[:, 1] = x[:, 1] - x[:, 3] / 2
|
|
y[:, 2] = x[:, 0] + x[:, 2] / 2
|
|
y[:, 3] = x[:, 1] + x[:, 3] / 2
|
|
return y
|
|
|
|
def filter_box(self, org_box, conf_thres, iou_thres):
|
|
org_box = np.squeeze(org_box)
|
|
conf = org_box[..., 4] > conf_thres
|
|
box = org_box[conf == True]
|
|
cls_cinf = box[..., 5:]
|
|
cls = []
|
|
for i in range(len(cls_cinf)):
|
|
cls.append(int(np.argmax(cls_cinf[i])))
|
|
all_cls = list(set(cls))
|
|
output = []
|
|
for i in range(len(all_cls)):
|
|
curr_cls = all_cls[i]
|
|
curr_cls_box = []
|
|
for j in range(len(cls)):
|
|
if cls[j] == curr_cls:
|
|
box[j][5] = curr_cls
|
|
curr_cls_box.append(box[j][:6])
|
|
curr_cls_box = np.array(curr_cls_box)
|
|
curr_cls_box = self.yolo_to_xy(curr_cls_box)
|
|
curr_out_box = self.nms(curr_cls_box, iou_thres)
|
|
for k in curr_out_box:
|
|
output.append(curr_cls_box[k])
|
|
output = np.array(output)
|
|
return output
|
|
|
|
def extract_coordinate(self, box_data, shape):
|
|
boxes = box_data[..., :4].astype(np.int32)
|
|
scores = box_data[..., 4]
|
|
classes = box_data[..., 5].astype(np.int32)
|
|
data = []
|
|
for box, score, cl in zip(boxes, scores, classes):
|
|
width_scale = shape[1] / self.w
|
|
height_scale = shape[0] / self.h
|
|
x1 = max(int(box[0] * width_scale), 0)
|
|
y1 = max(int(box[1] * height_scale), 0)
|
|
x2 = max(int(box[2] * width_scale), 0)
|
|
y2 = max(int(box[3] * height_scale), 0)
|
|
data.append([x1, y1, x2, y2])
|
|
return data
|
|
|
|
def dispose_img(self, img):
|
|
shape = img.shape
|
|
img = cv2.resize(img, (self.w, self.h))
|
|
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR2RGB和HWC2CHW
|
|
img = img.astype(dtype=np.float32)
|
|
img /= 255.0
|
|
img = np.expand_dims(img, axis=0)
|
|
return img, shape
|
|
|
|
def discern(self, img):
|
|
img, shape = self.dispose_img(img)
|
|
outputs = self.yolo_model.run(None, {"images": img})
|
|
outbox = self.filter_box(outputs, 0.5, 0.5)
|
|
return self.extract_coordinate(outbox, shape)
|
|
|
|
|
|
class JdIcon:
|
|
def main(self, bg, icon):
|
|
yolo_model = yolo()
|
|
siamese_model = siamese()
|
|
|
|
image_data = base64.b64decode(bg)
|
|
np_data = np.frombuffer(image_data, np.uint8)
|
|
bg = cv2.imdecode(np_data, cv2.IMREAD_UNCHANGED)
|
|
|
|
image_data = base64.b64decode(icon)
|
|
np_data = np.frombuffer(image_data, np.uint8)
|
|
icon = cv2.imdecode(np_data, cv2.IMREAD_UNCHANGED)
|
|
# 切割要点击的图片
|
|
x1, y1, x2, y2 = 35, 0, 75, 36
|
|
icon = icon[y1:y2, x1:x2]
|
|
icon = cv2.resize(icon, (60, 60))
|
|
# cv2.imwrite(f'icon.png', icon)
|
|
xy_list = yolo_model.discern(bg)
|
|
# char_list = self.slice_icon_img(bg)
|
|
# print(xy_list)
|
|
txt_list = []
|
|
for n in xy_list:
|
|
x1, y1, x2, y2 = n
|
|
txt_list.append({'xy': n, 'img': bg[y1:y2, x1:x2]})
|
|
|
|
data = {'image_box': [], 'centre_xy': [], 'dataset': []}
|
|
dd = [0, 0, 0]
|
|
for inx, txt in enumerate(txt_list):
|
|
res = siamese_model.predict_siamese(txt['img'], icon)
|
|
if res > dd[1]:
|
|
dd[0] = inx
|
|
dd[1] = res
|
|
dd[2] = icon
|
|
|
|
# data['dataset'].append([copy.deepcopy(txt_list[dd[0]]['img']), dd[2]])
|
|
data['image_box'].append(txt_list[dd[0]]['xy'])
|
|
txt_list.pop(dd[0])
|
|
|
|
for n in data['image_box']:
|
|
x1, y1, x2, y2 = n
|
|
data['centre_xy'].append([int((x1 + x2) / 2), int((y1 + y2) / 2)])
|
|
# bg = imgtest(bg, data['image_box'], data['centre_xy'])
|
|
# cv2.imwrite('test.jpg', bg)
|
|
|
|
return data
|