aboutsummaryrefslogtreecommitdiff
path: root/lib/mbedtls-2.27.0/scripts/mbedtls_dev/psa_storage.py
blob: 45f0380e7b819e41d6d7d848448cdf007f8db012 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
"""Knowledge about the PSA key store as implemented in Mbed TLS.
"""

# Copyright The Mbed TLS Contributors
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import struct
from typing import Dict, List, Optional, Set, Union
import unittest

from mbedtls_dev import c_build_helper


class Expr:
    """Representation of a C expression with a known or knowable numerical value."""

    def __init__(self, content: Union[int, str]):
        if isinstance(content, int):
            digits = 8 if content > 0xffff else 4
            self.string = '{0:#0{1}x}'.format(content, digits + 2)
            self.value_if_known = content #type: Optional[int]
        else:
            self.string = content
            self.unknown_values.add(self.normalize(content))
            self.value_if_known = None

    value_cache = {} #type: Dict[str, int]
    """Cache of known values of expressions."""

    unknown_values = set() #type: Set[str]
    """Expressions whose values are not present in `value_cache` yet."""

    def update_cache(self) -> None:
        """Update `value_cache` for expressions registered in `unknown_values`."""
        expressions = sorted(self.unknown_values)
        values = c_build_helper.get_c_expression_values(
            'unsigned long', '%lu',
            expressions,
            header="""
            #include <psa/crypto.h>
            """,
            include_path=['include']) #type: List[str]
        for e, v in zip(expressions, values):
            self.value_cache[e] = int(v, 0)
        self.unknown_values.clear()

    @staticmethod
    def normalize(string: str) -> str:
        """Put the given C expression in a canonical form.

        This function is only intended to give correct results for the
        relatively simple kind of C expression typically used with this
        module.
        """
        return re.sub(r'\s+', r'', string)

    def value(self) -> int:
        """Return the numerical value of the expression."""
        if self.value_if_known is None:
            if re.match(r'([0-9]+|0x[0-9a-f]+)\Z', self.string, re.I):
                return int(self.string, 0)
            normalized = self.normalize(self.string)
            if normalized not in self.value_cache:
                self.update_cache()
            self.value_if_known = self.value_cache[normalized]
        return self.value_if_known

Exprable = Union[str, int, Expr]
"""Something that can be converted to a C expression with a known numerical value."""

def as_expr(thing: Exprable) -> Expr:
    """Return an `Expr` object for `thing`.

    If `thing` is already an `Expr` object, return it. Otherwise build a new
    `Expr` object from `thing`. `thing` can be an integer or a string that
    contains a C expression.
    """
    if isinstance(thing, Expr):
        return thing
    else:
        return Expr(thing)


class Key:
    """Representation of a PSA crypto key object and its storage encoding.
    """

    LATEST_VERSION = 0
    """The latest version of the storage format."""

    def __init__(self, *,
                 version: Optional[int] = None,
                 id: Optional[int] = None, #pylint: disable=redefined-builtin
                 lifetime: Exprable = 'PSA_KEY_LIFETIME_PERSISTENT',
                 type: Exprable, #pylint: disable=redefined-builtin
                 bits: int,
                 usage: Exprable, alg: Exprable, alg2: Exprable,
                 material: bytes #pylint: disable=used-before-assignment
                ) -> None:
        self.version = self.LATEST_VERSION if version is None else version
        self.id = id #pylint: disable=invalid-name #type: Optional[int]
        self.lifetime = as_expr(lifetime) #type: Expr
        self.type = as_expr(type) #type: Expr
        self.bits = bits #type: int
        self.usage = as_expr(usage) #type: Expr
        self.alg = as_expr(alg) #type: Expr
        self.alg2 = as_expr(alg2) #type: Expr
        self.material = material #type: bytes

    MAGIC = b'PSA\000KEY\000'

    @staticmethod
    def pack(
            fmt: str,
            *args: Union[int, Expr]
    ) -> bytes: #pylint: disable=used-before-assignment
        """Pack the given arguments into a byte string according to the given format.

        This function is similar to `struct.pack`, but with the following differences:
        * All integer values are encoded with standard sizes and in
          little-endian representation. `fmt` must not include an endianness
          prefix.
        * Arguments can be `Expr` objects instead of integers.
        * Only integer-valued elements are supported.
        """
        return struct.pack('<' + fmt, # little-endian, standard sizes
                           *[arg.value() if isinstance(arg, Expr) else arg
                             for arg in args])

    def bytes(self) -> bytes:
        """Return the representation of the key in storage as a byte array.

        This is the content of the PSA storage file. When PSA storage is
        implemented over stdio files, this does not include any wrapping made
        by the PSA-storage-over-stdio-file implementation.
        """
        header = self.MAGIC + self.pack('L', self.version)
        if self.version == 0:
            attributes = self.pack('LHHLLL',
                                   self.lifetime, self.type, self.bits,
                                   self.usage, self.alg, self.alg2)
            material = self.pack('L', len(self.material)) + self.material
        else:
            raise NotImplementedError
        return header + attributes + material

    def hex(self) -> str:
        """Return the representation of the key as a hexadecimal string.

        This is the hexadecimal representation of `self.bytes`.
        """
        return self.bytes().hex()

    def location_value(self) -> int:
        """The numerical value of the location encoded in the key's lifetime."""
        return self.lifetime.value() >> 8


class TestKey(unittest.TestCase):
    # pylint: disable=line-too-long
    """A few smoke tests for the functionality of the `Key` class."""

    def test_numerical(self):
        key = Key(version=0,
                  id=1, lifetime=0x00000001,
                  type=0x2400, bits=128,
                  usage=0x00000300, alg=0x05500200, alg2=0x04c01000,
                  material=b'@ABCDEFGHIJKLMNO')
        expected_hex = '505341004b45590000000000010000000024800000030000000250050010c00410000000404142434445464748494a4b4c4d4e4f'
        self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
        self.assertEqual(key.hex(), expected_hex)

    def test_names(self):
        length = 0xfff8 // 8 # PSA_MAX_KEY_BITS in bytes
        key = Key(version=0,
                  id=1, lifetime='PSA_KEY_LIFETIME_PERSISTENT',
                  type='PSA_KEY_TYPE_RAW_DATA', bits=length*8,
                  usage=0, alg=0, alg2=0,
                  material=b'\x00' * length)
        expected_hex = '505341004b45590000000000010000000110f8ff000000000000000000000000ff1f0000' + '00' * length
        self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
        self.assertEqual(key.hex(), expected_hex)

    def test_defaults(self):
        key = Key(type=0x1001, bits=8,
                  usage=0, alg=0, alg2=0,
                  material=b'\x2a')
        expected_hex = '505341004b455900000000000100000001100800000000000000000000000000010000002a'
        self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
        self.assertEqual(key.hex(), expected_hex)