Development of the ocr part of AOI
Samo Penic
2018-11-17 0d97e9b2d738682ed0aa6349b43a9719e0ca0aa9
commit | author | age
e555c0 1 from pyzbar.pyzbar import decode
0d97e9 2 from .sid_process import getSID
e555c0 3 import cv2
SP 4 import numpy as np
0d97e9 5
SP 6 import pkg_resources
7
8 markerfile = '/template.png'  # always use slash
9 markerfilename = pkg_resources.resource_filename(__name__, markerfile)
10
e555c0 11
SP 12
511c2e 13 class Paper:
762a5e 14     def __init__(self, filename=None, sid_classifier=None, settings=None):
511c2e 15         self.filename = filename
SP 16         self.invalid = None
17         self.QRData = None
e0996e 18         self.settings = {"answer_threshold": 0.25} if settings is None else settings
511c2e 19         self.errors = []
SP 20         self.warnings = []
e0996e 21         self.sid = None
762a5e 22         self.sid_classifier = sid_classifier
511c2e 23         if filename is not None:
SP 24             self.loadImage(filename)
25             self.runOcr()
e555c0 26
511c2e 27     def loadImage(self, filename, rgbchannel=0):
SP 28         self.img = cv2.imread(filename, rgbchannel)
29         if self.img is None:
30             self.errors.append("File could not be loaded!")
31             self.invalid = True
32             return
33         self.imgHeight, self.imgWidth = self.img.shape[0:2]
e555c0 34
0d97e9 35     def saveImage(self, filename="/tmp/debug_image.png"):
511c2e 36         cv2.imwrite(filename, self.img)
e555c0 37
511c2e 38     def runOcr(self):
SP 39         if self.invalid == True:
40             return
41         self.decodeQRandRotate()
42         self.imgTreshold()
43         skewAngle = 0
44         #         try:
45         #             skewAngle=self.getSkewAngle()
46         #         except:
47         #             self.errors.append("Could not determine skew angle!")
48         #         self.rotateAngle(skewAngle)
e555c0 49
511c2e 50         self.generateAnswerMatrix()
e555c0 51
511c2e 52         self.saveImage()
e555c0 53
511c2e 54     def decodeQRandRotate(self):
SP 55         if self.invalid == True:
56             return
57         blur = cv2.blur(self.img, (3, 3))
58         d = decode(blur)
59         self.img = blur
60         if len(d) == 0:
61             self.errors.append("QR code could not be found!")
62             self.data = None
63             self.invalid = True
64             return
e0996e 65         if(len(d)>1): #if there are multiple codes, get first ean or qr code available.
SP 66             for dd in d:
67                 if(dd.type=="EAN13" or dd.type=="QR"):
68                     d[0]=dd
69                     break
511c2e 70         self.QRDecode = d
SP 71         self.QRData = d[0].data
72         xpos = d[0].rect.left
73         ypos = d[0].rect.top
74         # check if image is rotated wrongly
82ec6d 75         if xpos > self.imgHeight / 2.0 and ypos > self.imgWidth / 2.0:
511c2e 76             self.rotateAngle(180)
e555c0 77
511c2e 78     def rotateAngle(self, angle=0):
e0996e 79         # rot_mat = cv2.getRotationMatrix2D(
82ec6d 80         #    (self.imgHeight / 2, self.imgWidth / 2), angle, 1.0
e0996e 81         # )
511c2e 82         rot_mat = cv2.getRotationMatrix2D(
e0996e 83             (self.imgWidth / 2, self.imgHeight / 2), angle, 1.0
511c2e 84         )
SP 85         result = cv2.warpAffine(
86             self.img,
87             rot_mat,
82ec6d 88             (self.imgWidth, self.imgHeight),
511c2e 89             flags=cv2.INTER_CUBIC,
SP 90             borderMode=cv2.BORDER_CONSTANT,
91             borderValue=(255, 255, 255),
92         )
e555c0 93
511c2e 94         self.img = result
SP 95         self.imgHeight, self.imgWidth = self.img.shape[0:2]
e555c0 96
511c2e 97         # todo, make better tresholding
e555c0 98
511c2e 99     def imgTreshold(self):
SP 100         (self.thresh, self.bwimg) = cv2.threshold(
101             self.img, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU
102         )
e555c0 103
511c2e 104     def getSkewAngle(self):
SP 105         neg = 255 - self.bwimg  # get negative image
0d97e9 106         cv2.imwrite("/tmp/debug_1.png", neg)
e555c0 107
511c2e 108         angle_counter = 0  # number of angles
SP 109         angle = 0.0  # collects sum of angles
110         cimg = cv2.cvtColor(self.img, cv2.COLOR_GRAY2BGR)
e555c0 111
511c2e 112         # get all the Hough lines
SP 113         for line in cv2.HoughLinesP(neg, 1, np.pi / 180, 325):
114             x1, y1, x2, y2 = line[0]
115             cv2.line(cimg, (x1, y1), (x2, y2), (0, 0, 255), 2)
116             # calculate the angle (in radians)
117             this_angle = np.arctan2(y2 - y1, x2 - x1)
118             if this_angle and abs(this_angle) <= 10:
119                 # filtered zero degree and outliers
120                 angle += this_angle
121                 angle_counter += 1
e555c0 122
511c2e 123                 # the skew is calculated of the mean of the total angles, #try block helps with division by zero.
SP 124         try:
125             skew = np.rad2deg(
126                 angle / angle_counter
127             )  # the 1.2 factor is just experimental....
128         except:
129             skew = 0
e555c0 130
0d97e9 131         cv2.imwrite("/tmp/debug_2.png", cimg)
511c2e 132         return skew
e555c0 133
511c2e 134     def locateUpMarkers(self, threshold=0.85, height=200):
0d97e9 135         template = cv2.imread(markerfilename, 0)
511c2e 136         w, h = template.shape[::-1]
SP 137         crop_img = self.img[0:height, :]
138         res = cv2.matchTemplate(crop_img, template, cv2.TM_CCOEFF_NORMED)
139         loc = np.where(res >= threshold)
140         cimg = cv2.cvtColor(crop_img, cv2.COLOR_GRAY2BGR)
141         # remove false matching of the squares in qr code
142         loc_filtered_x = []
143         loc_filtered_y = []
144         if len(loc[0]) == 0:
145             min_y = -1
146         else:
147             min_y = np.min(loc[0])
148             for pt in zip(*loc[::-1]):
149                 if pt[1] < min_y + 20:
150                     loc_filtered_y.append(pt[1])
151                     loc_filtered_x.append(pt[0])
152                     # order by x coordinate
153             loc_filtered_x, loc_filtered_y = zip(
154                 *sorted(zip(loc_filtered_x, loc_filtered_y))
155             )
02e0f7 156             # loc=[loc_filtered_y,loc_filtered_x]
SP 157             # remove duplicates
511c2e 158             a = np.diff(loc_filtered_x) > 40
SP 159             a = np.append(a, True)
160             loc_filtered_x = np.array(loc_filtered_x)
161             loc_filtered_y = np.array(loc_filtered_y)
162             loc = [loc_filtered_y[a], loc_filtered_x[a]]
163             for pt in zip(*loc[::-1]):
164                 cv2.rectangle(cimg, pt, (pt[0] + w, pt[1] + h), (0, 255, 255), 2)
e555c0 165
0d97e9 166         cv2.imwrite("/tmp/debug_3.png", cimg)
e555c0 167
511c2e 168         self.xMarkerLocations = loc
SP 169         return loc
e555c0 170
511c2e 171     def locateRightMarkers(self, threshold=0.85, width=200):
0d97e9 172         template = cv2.imread(markerfilename, 0)
511c2e 173         w, h = template.shape[::-1]
SP 174         crop_img = self.img[:, -width:]
175         res = cv2.matchTemplate(crop_img, template, cv2.TM_CCOEFF_NORMED)
176         loc = np.where(res >= threshold)
177         cimg = cv2.cvtColor(crop_img, cv2.COLOR_GRAY2BGR)
178         # remove false matching of the squares in qr code
179         loc_filtered_x = []
180         loc_filtered_y = []
181         if len(loc[1]) == 0:
182             min_x = -1
183         else:
184             max_x = np.max(loc[1])
185             for pt in zip(*loc[::-1]):
186                 if pt[1] > max_x - 20:
187                     loc_filtered_y.append(pt[1])
188                     loc_filtered_x.append(pt[0])
189                     # order by y coordinate
190             loc_filtered_y, loc_filtered_x = zip(
191                 *sorted(zip(loc_filtered_y, loc_filtered_x))
192             )
193             # loc=[loc_filtered_y,loc_filtered_x]
194             # remove duplicates
195             a = np.diff(loc_filtered_y) > 40
196             a = np.append(a, True)
197             loc_filtered_x = np.array(loc_filtered_x)
198             loc_filtered_y = np.array(loc_filtered_y)
199             loc = [loc_filtered_y[a], loc_filtered_x[a]]
200             for pt in zip(*loc[::-1]):
201                 cv2.rectangle(cimg, pt, (pt[0] + w, pt[1] + h), (0, 255, 255), 2)
e555c0 202
0d97e9 203         cv2.imwrite("/tmp/debug_4.png", cimg)
e555c0 204
511c2e 205         self.yMarkerLocations = [loc[0], loc[1] + self.imgWidth - width]
SP 206         return self.yMarkerLocations
e555c0 207
511c2e 208     def generateAnswerMatrix(self):
SP 209         self.locateUpMarkers()
210         self.locateRightMarkers()
e555c0 211
511c2e 212         roixoff = 10
SP 213         roiyoff = 5
214         roiwidth = 50
215         roiheight = roiwidth
216         totpx = roiwidth * roiheight
e555c0 217
511c2e 218         self.answerMatrix = []
SP 219         for y in self.yMarkerLocations[0]:
220             oneline = []
221             for x in self.xMarkerLocations[1]:
222                 roi = self.bwimg[
223                     y - roiyoff : y + int(roiheight - roiyoff),
224                     x - roixoff : x + int(roiwidth - roixoff),
225                 ]
226                 # cv2.imwrite('ans_x'+str(x)+'_y_'+str(y)+'.png',roi)
227                 black = totpx - cv2.countNonZero(roi)
228                 oneline.append(black / totpx)
229             self.answerMatrix.append(oneline)
9efc18 230
SP 231     def get_enhanced_sid(self):
02e0f7 232         if self.sid_classifier is None:
SP 233             return "x"
762a5e 234         if self.settings is not None:
e0996e 235             sid_mask = self.settings.get("sid_mask", None)
SP 236         es, err, warn = getSID(
02e0f7 237             self.img[
d5c694 238                 int(0.04 * self.imgHeight) : int(0.095 * self.imgHeight),
02e0f7 239                 int(0.7 * self.imgWidth) : int(0.99 * self.imgWidth),
SP 240             ],
241             self.sid_classifier,
e0996e 242             sid_mask,
02e0f7 243         )
5cb7c1 244         [self.errors.append(e) for e in err]
SP 245         [self.warnings.append(w) for w in warn]
02e0f7 246         return es
0436f6 247
SP 248     def get_code_data(self):
cf921b 249         if self.QRData is None:
SP 250             self.errors.append("Could not read QR or EAN code! Not an exam?")
e0996e 251             retval = {
SP 252                 "exam_id": None,
253                 "page_no": None,
254                 "paper_id": None,
255                 "faculty_id": None,
256                 "sid": None,
0436f6 257             }
e0996e 258             return retval
SP 259         qrdata = bytes.decode(self.QRData, "utf8")
260         if self.QRDecode[0].type == "EAN13":
261             return {
262                 "exam_id": int(qrdata[0:7]),
263                 "page_no": int(qrdata[7]),
264                 "paper_id": int(qrdata[-5:-1]),
265                 "faculty_id": None,
266                 "sid": None,
267             }
268         else:
269             data = qrdata.split(",")
270             retval = {
271                 "exam_id": int(data[1]),
272                 "page_no": int(data[3]),
273                 "paper_id": int(data[2]),
274                 "faculty_id": int(data[0]),
275             }
276             if len(data) > 4:
277                 retval["sid"] = data[4]
0436f6 278
SP 279             return retval
280
281     def get_paper_ocr_data(self):
e0996e 282         data = self.get_code_data()
SP 283         data["qr"] = self.QRData
284         data["errors"] = self.errors
285         data["warnings"] = self.warnings
286         data["up_position"] = (
287             list(self.xMarkerLocations[1] / self.imgWidth),
288             list(self.yMarkerLocations[1] / self.imgHeight),
289         )
290         data["right_position"] = (
291             list(self.xMarkerLocations[1] / self.imgWidth),
292             list(self.yMarkerLocations[1] / self.imgHeight),
293         )
294         data["ans_matrix"] = (
295             (np.array(self.answerMatrix) > self.settings["answer_threshold"]) * 1
296         ).tolist()
297         if data["sid"] is None and data["page_no"] == 0:
298             data["sid"] = self.get_enhanced_sid()
0436f6 299         return data