Commit d0d859bb authored by Farhan Ali's avatar Farhan Ali Committed by Herbert Xu

crypto: virtio - Register an algo only if it's supported

Register a crypto algo with the Linux crypto layer only if
the algorithm is supported by the backend virtio-crypto
device.

Also route crypto requests to a virtio-crypto
device, only if it can support the requested service and
algorithm.
Signed-off-by: default avatarFarhan Ali <alifm@linux.ibm.com>
Acked-by: default avatarGonglei <arei.gonglei@huawei.com>
Acked-by: default avatarChristian Borntraeger <borntraeger@de.ibm.com>
Signed-off-by: default avatarHerbert Xu <herbert@gondor.apana.org.au>
parent b551bac1
...@@ -49,12 +49,18 @@ struct virtio_crypto_sym_request { ...@@ -49,12 +49,18 @@ struct virtio_crypto_sym_request {
bool encrypt; bool encrypt;
}; };
struct virtio_crypto_algo {
uint32_t algonum;
uint32_t service;
unsigned int active_devs;
struct crypto_alg algo;
};
/* /*
* The algs_lock protects the below global virtio_crypto_active_devs * The algs_lock protects the below global virtio_crypto_active_devs
* and crypto algorithms registion. * and crypto algorithms registion.
*/ */
static DEFINE_MUTEX(algs_lock); static DEFINE_MUTEX(algs_lock);
static unsigned int virtio_crypto_active_devs;
static void virtio_crypto_ablkcipher_finalize_req( static void virtio_crypto_ablkcipher_finalize_req(
struct virtio_crypto_sym_request *vc_sym_req, struct virtio_crypto_sym_request *vc_sym_req,
struct ablkcipher_request *req, struct ablkcipher_request *req,
...@@ -312,15 +318,21 @@ static int virtio_crypto_ablkcipher_setkey(struct crypto_ablkcipher *tfm, ...@@ -312,15 +318,21 @@ static int virtio_crypto_ablkcipher_setkey(struct crypto_ablkcipher *tfm,
unsigned int keylen) unsigned int keylen)
{ {
struct virtio_crypto_ablkcipher_ctx *ctx = crypto_ablkcipher_ctx(tfm); struct virtio_crypto_ablkcipher_ctx *ctx = crypto_ablkcipher_ctx(tfm);
uint32_t alg;
int ret; int ret;
ret = virtio_crypto_alg_validate_key(keylen, &alg);
if (ret)
return ret;
if (!ctx->vcrypto) { if (!ctx->vcrypto) {
/* New key */ /* New key */
int node = virtio_crypto_get_current_node(); int node = virtio_crypto_get_current_node();
struct virtio_crypto *vcrypto = struct virtio_crypto *vcrypto =
virtcrypto_get_dev_node(node); virtcrypto_get_dev_node(node,
VIRTIO_CRYPTO_SERVICE_CIPHER, alg);
if (!vcrypto) { if (!vcrypto) {
pr_err("virtio_crypto: Could not find a virtio device in the system\n"); pr_err("virtio_crypto: Could not find a virtio device in the system or unsupported algo\n");
return -ENODEV; return -ENODEV;
} }
...@@ -571,57 +583,85 @@ static void virtio_crypto_ablkcipher_finalize_req( ...@@ -571,57 +583,85 @@ static void virtio_crypto_ablkcipher_finalize_req(
virtcrypto_clear_request(&vc_sym_req->base); virtcrypto_clear_request(&vc_sym_req->base);
} }
static struct crypto_alg virtio_crypto_algs[] = { { static struct virtio_crypto_algo virtio_crypto_algs[] = { {
.cra_name = "cbc(aes)", .algonum = VIRTIO_CRYPTO_CIPHER_AES_CBC,
.cra_driver_name = "virtio_crypto_aes_cbc", .service = VIRTIO_CRYPTO_SERVICE_CIPHER,
.cra_priority = 150, .algo = {
.cra_flags = CRYPTO_ALG_TYPE_ABLKCIPHER | CRYPTO_ALG_ASYNC, .cra_name = "cbc(aes)",
.cra_blocksize = AES_BLOCK_SIZE, .cra_driver_name = "virtio_crypto_aes_cbc",
.cra_ctxsize = sizeof(struct virtio_crypto_ablkcipher_ctx), .cra_priority = 150,
.cra_alignmask = 0, .cra_flags = CRYPTO_ALG_TYPE_ABLKCIPHER | CRYPTO_ALG_ASYNC,
.cra_module = THIS_MODULE, .cra_blocksize = AES_BLOCK_SIZE,
.cra_type = &crypto_ablkcipher_type, .cra_ctxsize = sizeof(struct virtio_crypto_ablkcipher_ctx),
.cra_init = virtio_crypto_ablkcipher_init, .cra_alignmask = 0,
.cra_exit = virtio_crypto_ablkcipher_exit, .cra_module = THIS_MODULE,
.cra_u = { .cra_type = &crypto_ablkcipher_type,
.ablkcipher = { .cra_init = virtio_crypto_ablkcipher_init,
.setkey = virtio_crypto_ablkcipher_setkey, .cra_exit = virtio_crypto_ablkcipher_exit,
.decrypt = virtio_crypto_ablkcipher_decrypt, .cra_u = {
.encrypt = virtio_crypto_ablkcipher_encrypt, .ablkcipher = {
.min_keysize = AES_MIN_KEY_SIZE, .setkey = virtio_crypto_ablkcipher_setkey,
.max_keysize = AES_MAX_KEY_SIZE, .decrypt = virtio_crypto_ablkcipher_decrypt,
.ivsize = AES_BLOCK_SIZE, .encrypt = virtio_crypto_ablkcipher_encrypt,
.min_keysize = AES_MIN_KEY_SIZE,
.max_keysize = AES_MAX_KEY_SIZE,
.ivsize = AES_BLOCK_SIZE,
},
}, },
}, },
} }; } };
int virtio_crypto_algs_register(void) int virtio_crypto_algs_register(struct virtio_crypto *vcrypto)
{ {
int ret = 0; int ret = 0;
int i = 0;
mutex_lock(&algs_lock); mutex_lock(&algs_lock);
if (++virtio_crypto_active_devs != 1)
goto unlock;
ret = crypto_register_algs(virtio_crypto_algs, for (i = 0; i < ARRAY_SIZE(virtio_crypto_algs); i++) {
ARRAY_SIZE(virtio_crypto_algs));
if (ret) uint32_t service = virtio_crypto_algs[i].service;
virtio_crypto_active_devs--; uint32_t algonum = virtio_crypto_algs[i].algonum;
if (!virtcrypto_algo_is_supported(vcrypto, service, algonum))
continue;
if (virtio_crypto_algs[i].active_devs == 0) {
ret = crypto_register_alg(&virtio_crypto_algs[i].algo);
if (ret)
goto unlock;
}
virtio_crypto_algs[i].active_devs++;
dev_info(&vcrypto->vdev->dev, "Registered algo %s\n",
virtio_crypto_algs[i].algo.cra_name);
}
unlock: unlock:
mutex_unlock(&algs_lock); mutex_unlock(&algs_lock);
return ret; return ret;
} }
void virtio_crypto_algs_unregister(void) void virtio_crypto_algs_unregister(struct virtio_crypto *vcrypto)
{ {
int i = 0;
mutex_lock(&algs_lock); mutex_lock(&algs_lock);
if (--virtio_crypto_active_devs != 0)
goto unlock;
crypto_unregister_algs(virtio_crypto_algs, for (i = 0; i < ARRAY_SIZE(virtio_crypto_algs); i++) {
ARRAY_SIZE(virtio_crypto_algs));
uint32_t service = virtio_crypto_algs[i].service;
uint32_t algonum = virtio_crypto_algs[i].algonum;
if (virtio_crypto_algs[i].active_devs == 0 ||
!virtcrypto_algo_is_supported(vcrypto, service, algonum))
continue;
if (virtio_crypto_algs[i].active_devs == 1)
crypto_unregister_alg(&virtio_crypto_algs[i].algo);
virtio_crypto_algs[i].active_devs--;
}
unlock:
mutex_unlock(&algs_lock); mutex_unlock(&algs_lock);
} }
...@@ -116,7 +116,12 @@ int virtcrypto_dev_in_use(struct virtio_crypto *vcrypto_dev); ...@@ -116,7 +116,12 @@ int virtcrypto_dev_in_use(struct virtio_crypto *vcrypto_dev);
int virtcrypto_dev_get(struct virtio_crypto *vcrypto_dev); int virtcrypto_dev_get(struct virtio_crypto *vcrypto_dev);
void virtcrypto_dev_put(struct virtio_crypto *vcrypto_dev); void virtcrypto_dev_put(struct virtio_crypto *vcrypto_dev);
int virtcrypto_dev_started(struct virtio_crypto *vcrypto_dev); int virtcrypto_dev_started(struct virtio_crypto *vcrypto_dev);
struct virtio_crypto *virtcrypto_get_dev_node(int node); bool virtcrypto_algo_is_supported(struct virtio_crypto *vcrypto_dev,
uint32_t service,
uint32_t algo);
struct virtio_crypto *virtcrypto_get_dev_node(int node,
uint32_t service,
uint32_t algo);
int virtcrypto_dev_start(struct virtio_crypto *vcrypto); int virtcrypto_dev_start(struct virtio_crypto *vcrypto);
void virtcrypto_dev_stop(struct virtio_crypto *vcrypto); void virtcrypto_dev_stop(struct virtio_crypto *vcrypto);
int virtio_crypto_ablkcipher_crypt_req( int virtio_crypto_ablkcipher_crypt_req(
...@@ -136,7 +141,7 @@ static inline int virtio_crypto_get_current_node(void) ...@@ -136,7 +141,7 @@ static inline int virtio_crypto_get_current_node(void)
return node; return node;
} }
int virtio_crypto_algs_register(void); int virtio_crypto_algs_register(struct virtio_crypto *vcrypto);
void virtio_crypto_algs_unregister(void); void virtio_crypto_algs_unregister(struct virtio_crypto *vcrypto);
#endif /* _VIRTIO_CRYPTO_COMMON_H */ #endif /* _VIRTIO_CRYPTO_COMMON_H */
...@@ -181,14 +181,20 @@ int virtcrypto_dev_started(struct virtio_crypto *vcrypto_dev) ...@@ -181,14 +181,20 @@ int virtcrypto_dev_started(struct virtio_crypto *vcrypto_dev)
/* /*
* virtcrypto_get_dev_node() - Get vcrypto_dev on the node. * virtcrypto_get_dev_node() - Get vcrypto_dev on the node.
* @node: Node id the driver works. * @node: Node id the driver works.
* @service: Crypto service that needs to be supported by the
* dev
* @algo: The algorithm number that needs to be supported by the
* dev
* *
* Function returns the virtio crypto device used fewest on the node. * Function returns the virtio crypto device used fewest on the node,
* and supports the given crypto service and algorithm.
* *
* To be used by virtio crypto device specific drivers. * To be used by virtio crypto device specific drivers.
* *
* Return: pointer to vcrypto_dev or NULL if not found. * Return: pointer to vcrypto_dev or NULL if not found.
*/ */
struct virtio_crypto *virtcrypto_get_dev_node(int node) struct virtio_crypto *virtcrypto_get_dev_node(int node, uint32_t service,
uint32_t algo)
{ {
struct virtio_crypto *vcrypto_dev = NULL, *tmp_dev; struct virtio_crypto *vcrypto_dev = NULL, *tmp_dev;
unsigned long best = ~0; unsigned long best = ~0;
...@@ -199,7 +205,8 @@ struct virtio_crypto *virtcrypto_get_dev_node(int node) ...@@ -199,7 +205,8 @@ struct virtio_crypto *virtcrypto_get_dev_node(int node)
if ((node == dev_to_node(&tmp_dev->vdev->dev) || if ((node == dev_to_node(&tmp_dev->vdev->dev) ||
dev_to_node(&tmp_dev->vdev->dev) < 0) && dev_to_node(&tmp_dev->vdev->dev) < 0) &&
virtcrypto_dev_started(tmp_dev)) { virtcrypto_dev_started(tmp_dev) &&
virtcrypto_algo_is_supported(tmp_dev, service, algo)) {
ctr = atomic_read(&tmp_dev->ref_count); ctr = atomic_read(&tmp_dev->ref_count);
if (best > ctr) { if (best > ctr) {
vcrypto_dev = tmp_dev; vcrypto_dev = tmp_dev;
...@@ -214,7 +221,9 @@ struct virtio_crypto *virtcrypto_get_dev_node(int node) ...@@ -214,7 +221,9 @@ struct virtio_crypto *virtcrypto_get_dev_node(int node)
/* Get any started device */ /* Get any started device */
list_for_each_entry(tmp_dev, list_for_each_entry(tmp_dev,
virtcrypto_devmgr_get_head(), list) { virtcrypto_devmgr_get_head(), list) {
if (virtcrypto_dev_started(tmp_dev)) { if (virtcrypto_dev_started(tmp_dev) &&
virtcrypto_algo_is_supported(tmp_dev,
service, algo)) {
vcrypto_dev = tmp_dev; vcrypto_dev = tmp_dev;
break; break;
} }
...@@ -240,7 +249,7 @@ struct virtio_crypto *virtcrypto_get_dev_node(int node) ...@@ -240,7 +249,7 @@ struct virtio_crypto *virtcrypto_get_dev_node(int node)
*/ */
int virtcrypto_dev_start(struct virtio_crypto *vcrypto) int virtcrypto_dev_start(struct virtio_crypto *vcrypto)
{ {
if (virtio_crypto_algs_register()) { if (virtio_crypto_algs_register(vcrypto)) {
pr_err("virtio_crypto: Failed to register crypto algs\n"); pr_err("virtio_crypto: Failed to register crypto algs\n");
return -EFAULT; return -EFAULT;
} }
...@@ -260,5 +269,65 @@ int virtcrypto_dev_start(struct virtio_crypto *vcrypto) ...@@ -260,5 +269,65 @@ int virtcrypto_dev_start(struct virtio_crypto *vcrypto)
*/ */
void virtcrypto_dev_stop(struct virtio_crypto *vcrypto) void virtcrypto_dev_stop(struct virtio_crypto *vcrypto)
{ {
virtio_crypto_algs_unregister(); virtio_crypto_algs_unregister(vcrypto);
}
/*
* vcrypto_algo_is_supported()
* @vcrypto: Pointer to virtio crypto device.
* @service: The bit number for service validate.
* See VIRTIO_CRYPTO_SERVICE_*
* @algo : The bit number for the algorithm to validate.
*
*
* Validate if the virtio crypto device supports a service and
* algo.
*
* Return true if device supports a service and algo.
*/
bool virtcrypto_algo_is_supported(struct virtio_crypto *vcrypto,
uint32_t service,
uint32_t algo)
{
uint32_t service_mask = 1u << service;
uint32_t algo_mask = 0;
bool low = true;
if (algo > 31) {
algo -= 32;
low = false;
}
if (!(vcrypto->crypto_services & service_mask))
return false;
switch (service) {
case VIRTIO_CRYPTO_SERVICE_CIPHER:
if (low)
algo_mask = vcrypto->cipher_algo_l;
else
algo_mask = vcrypto->cipher_algo_h;
break;
case VIRTIO_CRYPTO_SERVICE_HASH:
algo_mask = vcrypto->hash_algo;
break;
case VIRTIO_CRYPTO_SERVICE_MAC:
if (low)
algo_mask = vcrypto->mac_algo_l;
else
algo_mask = vcrypto->mac_algo_h;
break;
case VIRTIO_CRYPTO_SERVICE_AEAD:
algo_mask = vcrypto->aead_algo;
break;
}
if (!(algo_mask & (1u << algo)))
return false;
return true;
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment