Projekt

Obecné

Profil

Stáhnout (16.3 KB) Statistiky
| Větev: | Tag: | Revize:
1
from sqlite3 import Connection, Error
2
from typing import Dict, List
3

    
4
from src.exceptions.database_exception import DatabaseException
5
from injector import inject
6
from src.constants import *
7
from src.model.certificate import Certificate
8

    
9

    
10
class CertificateRepository:
11

    
12
    @inject
13
    def __init__(self, connection: Connection):
14
        """
15
        Constructor of the CertificateRepository object
16

    
17
        :param connection: Instance of the Connection object
18
        """
19
        self.connection = connection
20
        self.cursor = connection.cursor()
21

    
22
    def create(self, certificate: Certificate):
23
        """
24
        Creates a certificate.
25
        For root certificate (CA) the parent certificate id is modified to the same id (id == parent_id).
26

    
27
        :param certificate: Instance of the Certificate object
28

    
29
        :return: the result of whether the creation was successful
30
        """
31

    
32
        try:
33
            sql = (f"INSERT INTO {TAB_CERTIFICATES} "
34
                   f"({COL_COMMON_NAME},"
35
                   f"{COL_VALID_FROM},"
36
                   f"{COL_VALID_TO},"
37
                   f"{COL_PEM_DATA},"
38
                   f"{COL_PRIVATE_KEY_ID},"
39
                   f"{COL_TYPE_ID},"
40
                   f"{COL_PARENT_ID})"
41
                   f"VALUES(?,?,?,?,?,?,?)")
42
            values = [certificate.common_name,
43
                      certificate.valid_from,
44
                      certificate.valid_to,
45
                      certificate.pem_data,
46
                      certificate.private_key_id,
47
                      certificate.type_id,
48
                      certificate.parent_id]
49
            self.cursor.execute(sql, values)
50
            self.connection.commit()
51

    
52
            last_id: int = self.cursor.lastrowid
53

    
54
            # TODO assure that this is correct
55
            if certificate.type_id == ROOT_CA_ID:
56
                certificate.parent_id = last_id
57
                self.update(last_id, certificate)
58
            else:
59
                for usage_id, usage_value in certificate.usages.items():
60
                    if usage_value:
61
                        sql = (f"INSERT INTO {TAB_CERTIFICATE_USAGES} "
62
                               f"({COL_CERTIFICATE_ID},"
63
                               f"{COL_USAGE_TYPE_ID}) "
64
                               f"VALUES (?,?)")
65
                        values = [last_id, usage_id]
66
                        self.cursor.execute(sql, values)
67
                        self.connection.commit()
68
        except Error as e:
69
            raise DatabaseException(e)
70

    
71
        return last_id
72

    
73
    def read(self, certificate_id: int):
74
        """
75
        Reads (selects) a certificate.
76

    
77
        :param certificate_id: ID of specific certificate
78

    
79
        :return: instance of the Certificate object
80
        """
81

    
82
        try:
83
            sql = (f"SELECT * FROM {TAB_CERTIFICATES} "
84
                   f"WHERE {COL_ID} = ?")
85
            values = [certificate_id]
86
            self.cursor.execute(sql, values)
87
            certificate_row = self.cursor.fetchone()
88

    
89
            if certificate_row is None:
90
                return None
91

    
92
            sql = (f"SELECT * FROM {TAB_CERTIFICATE_USAGES} "
93
                   f"WHERE {COL_CERTIFICATE_ID} = ?")
94
            self.cursor.execute(sql, values)
95
            usage_rows = self.cursor.fetchall()
96

    
97
            usage_dict: Dict[int, bool] = {}
98
            for usage_row in usage_rows:
99
                usage_dict[usage_row[2]] = True
100

    
101
            certificate: Certificate = Certificate(certificate_row[0],
102
                                                   certificate_row[1],
103
                                                   certificate_row[2],
104
                                                   certificate_row[3],
105
                                                   certificate_row[4],
106
                                                   certificate_row[7],
107
                                                   certificate_row[8],
108
                                                   certificate_row[9],
109
                                                   usage_dict,
110
                                                   certificate_row[5],
111
                                                   certificate_row[6])
112
        except Error as e:
113
            raise DatabaseException(e)
114

    
115
        return certificate
116

    
117
    def read_all(self, filter_type: int = None):
118
        """
119
        Reads (selects) all certificates (with type).
120

    
121
        :param filter_type: ID of certificate type from CertificateTypes table
122

    
123
        :return: list of certificates
124
        """
125

    
126
        try:
127
            sql_extension = ""
128
            values = []
129
            if filter_type is not None:
130
                sql_extension = (f" WHERE {COL_TYPE_ID} = ("
131
                                 f"SELECT {COL_ID} FROM {TAB_CERTIFICATE_TYPES} WHERE {COL_ID} = ?)")
132
                values = [filter_type]
133

    
134
            sql = f"SELECT * FROM {TAB_CERTIFICATES}{sql_extension}"
135
            self.cursor.execute(sql, values)
136
            certificate_rows = self.cursor.fetchall()
137

    
138
            certificates: List[Certificate] = []
139
            for certificate_row in certificate_rows:
140
                sql = (f"SELECT * FROM {TAB_CERTIFICATE_USAGES} "
141
                       f"WHERE {COL_CERTIFICATE_ID} = ?")
142
                values = [certificate_row[0]]
143
                self.cursor.execute(sql, values)
144
                usage_rows = self.cursor.fetchall()
145

    
146
                usage_dict: Dict[int, bool] = {}
147
                for usage_row in usage_rows:
148
                    usage_dict[usage_row[2]] = True
149

    
150
                certificates.append(Certificate(certificate_row[0],
151
                                                certificate_row[1],
152
                                                certificate_row[2],
153
                                                certificate_row[3],
154
                                                certificate_row[4],
155
                                                certificate_row[7],
156
                                                certificate_row[8],
157
                                                certificate_row[9],
158
                                                usage_dict,
159
                                                certificate_row[5],
160
                                                certificate_row[6]))
161
        except Error as e:
162
            raise DatabaseException(e)
163

    
164
        return certificates
165

    
166
    def update(self, certificate_id: int, certificate: Certificate):
167
        """
168
        Updates a certificate.
169
        If the parameter of certificate (Certificate object) is not to be changed,
170
        the same value must be specified.
171

    
172
        :param certificate_id: ID of specific certificate
173
        :param certificate: Instance of the Certificate object
174

    
175
        :return: the result of whether the updation was successful
176
        """
177

    
178
        try:
179
            sql = (f"UPDATE {TAB_CERTIFICATES} "
180
                   f"SET {COL_COMMON_NAME} = ?, "
181
                   f"{COL_VALID_FROM} = ?, "
182
                   f"{COL_VALID_TO} = ?, "
183
                   f"{COL_PEM_DATA} = ?, "
184
                   f"{COL_PRIVATE_KEY_ID} = ?, "
185
                   f"{COL_TYPE_ID} = ?, "
186
                   f"{COL_PARENT_ID} = ? "
187
                   f"WHERE {COL_ID} = ?")
188
            values = [certificate.common_name,
189
                      certificate.valid_from,
190
                      certificate.valid_to,
191
                      certificate.pem_data,
192
                      certificate.private_key_id,
193
                      certificate.type_id,
194
                      certificate.parent_id,
195
                      certificate_id]
196
            self.cursor.execute(sql, values)
197
            self.connection.commit()
198

    
199
            sql = (f"DELETE FROM {TAB_CERTIFICATE_USAGES} "
200
                   f"WHERE {COL_CERTIFICATE_ID} = ?")
201
            values = [certificate_id]
202
            self.cursor.execute(sql, values)
203
            self.connection.commit()
204

    
205
            # iterate over usage pairs
206
            for usage_id, usage_value in certificate.usages.items():
207
                if usage_value:
208
                    sql = (f"INSERT INTO {TAB_CERTIFICATE_USAGES} "
209
                           f"({COL_CERTIFICATE_ID},"
210
                           f"{COL_USAGE_TYPE_ID}) "
211
                           f"VALUES (?,?)")
212
                    values = [certificate_id, usage_id]
213
                    self.cursor.execute(sql, values)
214
                    self.connection.commit()
215
        except Error as e:
216
            raise DatabaseException(e)
217

    
218
        return self.cursor.rowcount > 0
219

    
220
    def delete(self, certificate_id: int):
221
        """
222
        Deletes a certificate
223

    
224
        :param certificate_id: ID of specific certificate
225

    
226
        :return: the result of whether the deletion was successful
227
        """
228

    
229
        try:
230
            sql = (f"DELETE FROM {TAB_CERTIFICATES} "
231
                   f"WHERE {COL_ID} = ?")
232
            values = [certificate_id]
233
            self.cursor.execute(sql, values)
234
            self.connection.commit()
235
        except Error as e:
236
            raise DatabaseException(e)
237

    
238
        return self.cursor.rowcount > 0
239

    
240
    def set_certificate_revoked(
241
            self, certificate_id: int, revocation_date: str, revocation_reason: str = REV_REASON_UNSPECIFIED):
242
        """
243
        Revoke a certificate
244

    
245
        :param certificate_id: ID of specific certificate
246
        :param revocation_date: Date, when the certificate is revoked
247
        :param revocation_reason: Reason of the revocation
248

    
249
        :return:
250
            the result of whether the revocation was successful OR
251
            sqlite3.Error if an exception is thrown
252
        """
253

    
254
        try:
255
            if revocation_date != "" and revocation_reason == "":
256
                revocation_reason = REV_REASON_UNSPECIFIED
257
            elif revocation_date == "":
258
                return False
259

    
260
            sql = (f"UPDATE {TAB_CERTIFICATES} "
261
                   f"SET {COL_REVOCATION_DATE} = ?, "
262
                   f"{COL_REVOCATION_REASON} = ? "
263
                   f"WHERE {COL_ID} = ? AND ({COL_REVOCATION_DATE} IS NULL OR {COL_REVOCATION_DATE} = '')")
264
            values = [revocation_date,
265
                      revocation_reason,
266
                      certificate_id]
267
            self.cursor.execute(sql, values)
268
            self.connection.commit()
269
        except Error as e:
270
            raise DatabaseException(e)
271

    
272
        return self.cursor.rowcount > 0
273

    
274
    def clear_certificate_revocation(self, certificate_id: int):
275
        """
276
        Clear revocation of a certificate
277

    
278
        :param certificate_id: ID of specific certificate
279

    
280
        :return:
281
            the result of whether the clear revocation was successful OR
282
            sqlite3.Error if an exception is thrown
283
        """
284

    
285
        try:
286
            sql = (f"UPDATE {TAB_CERTIFICATES} "
287
                   f"SET {COL_REVOCATION_DATE} = '', "
288
                   f"{COL_REVOCATION_REASON} = '' "
289
                   f"WHERE {COL_ID} = ?")
290
            values = [certificate_id]
291
            self.cursor.execute(sql, values)
292
            self.connection.commit()
293
        except Error as e:
294
            raise DatabaseException(e)
295

    
296
        return self.cursor.rowcount > 0
297

    
298
    def get_all_revoked_by(self, certificate_id: int):
299
        """
300
        Get list of the revoked certificates that are direct descendants of the certificate with the ID
301

    
302
        :param certificate_id: ID of specific certificate
303

    
304
        :return:
305
            list of the certificates OR
306
            None if the list is empty OR
307
            sqlite3.Error if an exception is thrown
308
        """
309

    
310
        try:
311
            sql = (f"SELECT * FROM {TAB_CERTIFICATES} "
312
                   f"WHERE {COL_PARENT_ID} = ? AND {COL_REVOCATION_DATE} IS NOT NULL AND {COL_REVOCATION_DATE} != ''")
313
            values = [certificate_id]
314
            self.cursor.execute(sql, values)
315
            certificate_rows = self.cursor.fetchall()
316

    
317
            certificates: List[Certificate] = []
318
            for certificate_row in certificate_rows:
319
                sql = (f"SELECT * FROM {TAB_CERTIFICATE_USAGES} "
320
                       f"WHERE {COL_CERTIFICATE_ID} = ?")
321
                values = [certificate_row[0]]
322
                self.cursor.execute(sql, values)
323
                usage_rows = self.cursor.fetchall()
324

    
325
                usage_dict: Dict[int, bool] = {}
326
                for usage_row in usage_rows:
327
                    usage_dict[usage_row[2]] = True
328

    
329
                certificates.append(Certificate(certificate_row[0],
330
                                                certificate_row[1],
331
                                                certificate_row[2],
332
                                                certificate_row[3],
333
                                                certificate_row[4],
334
                                                certificate_row[7],
335
                                                certificate_row[8],
336
                                                certificate_row[9],
337
                                                usage_dict,
338
                                                certificate_row[5],
339
                                                certificate_row[6]))
340
        except Error as e:
341
            raise DatabaseException(e)
342

    
343
        return certificates
344

    
345
    def get_all_issued_by(self, certificate_id: int):
346
        """
347
        Get list of the certificates that are direct descendants of the certificate with the ID
348

    
349
        :param certificate_id: ID of specific certificate
350

    
351
        :return:
352
            list of the certificates OR
353
            None if the list is empty OR
354
            sqlite3.Error if an exception is thrown
355
        """
356

    
357
        try:
358
            sql = (f"SELECT * FROM {TAB_CERTIFICATES} "
359
                   f"WHERE {COL_PARENT_ID} = ? AND {COL_ID} != ?")
360
            values = [certificate_id, certificate_id]
361
            self.cursor.execute(sql, values)
362
            certificate_rows = self.cursor.fetchall()
363

    
364
            certificates: List[Certificate] = []
365
            for certificate_row in certificate_rows:
366
                sql = (f"SELECT * FROM {TAB_CERTIFICATE_USAGES} "
367
                       f"WHERE {COL_CERTIFICATE_ID} = ?")
368
                values = [certificate_row[0]]
369
                self.cursor.execute(sql, values)
370
                usage_rows = self.cursor.fetchall()
371

    
372
                usage_dict: Dict[int, bool] = {}
373
                for usage_row in usage_rows:
374
                    usage_dict[usage_row[2]] = True
375

    
376
                certificates.append(Certificate(certificate_row[0],
377
                                                certificate_row[1],
378
                                                certificate_row[2],
379
                                                certificate_row[3],
380
                                                certificate_row[4],
381
                                                certificate_row[7],
382
                                                certificate_row[8],
383
                                                certificate_row[9],
384
                                                usage_dict,
385
                                                certificate_row[5],
386
                                                certificate_row[6]))
387
        except Error as e:
388
            raise DatabaseException(e)
389

    
390
        return certificates
391

    
392
    def get_all_descendants_of(self, certificate_id: int):
393
        """
394
        Get a list of all certificates C such that the certificate identified by "certificate_id" belongs to a trust chain
395
        between C and its root certificate authority (i.e. is an ancestor of C).
396
        :param certificate_id: target certificate ID
397
        :return: list of all descendants
398
        """
399
        def dfs(children_of, this, collection: list):
400
            for child in children_of(this.certificate_id):
401
                dfs(children_of, child, collection)
402
            collection.append(this)
403

    
404
        subtree_root = self.read(certificate_id)
405
        if subtree_root is None:
406
            return None
407

    
408
        all_certs = []
409
        dfs(self.get_all_issued_by, subtree_root, all_certs)
410
        return all_certs
411

    
412
    def get_next_id(self) -> int:
413
        """
414
        Get identifier of the next certificate that will be inserted into the database
415
        :return: identifier of the next certificate that will be added into the database
416
        """
417
        # get next IDs of all tables
418
        self.cursor.execute("SELECT * FROM SQLITE_SEQUENCE")
419
        results = self.cursor.fetchall()
420

    
421
        # search for next ID in Certificates table and return it
422
        for result in results:
423
            if result[0] == TAB_CERTIFICATES:
424
                return result[1] + 1  # current last id + 1
425
        # if certificates table is not present in the query results, return 1
426
        return 1
(2-2/3)