Commit 67ed868d authored by Linus Torvalds's avatar Linus Torvalds

Merge tag '5.17-rc-ksmbd-server-fixes' of git://git.samba.org/ksmbd

Pull ksmbd server fixes from Steve French:

 - authentication fix

 - RDMA (smbdirect) fixes (including fix for a memory corruption, and
   some performance improvements)

 - multiple improvements for multichannel

 - misc fixes, including crediting (flow control) improvements

 - cleanup fixes, including some kernel doc fixes

* tag '5.17-rc-ksmbd-server-fixes' of git://git.samba.org/ksmbd: (23 commits)
  ksmbd: fix guest connection failure with nautilus
  ksmbd: uninitialized variable in create_socket()
  ksmbd: smbd: fix missing client's memory region invalidation
  ksmbd: add smb-direct shutdown
  ksmbd: smbd: change the default maximum read/write, receive size
  ksmbd: smbd: create MR pool
  ksmbd: add reserved room in ipc request/response
  ksmbd: smbd: call rdma_accept() under CM handler
  ksmbd: limits exceeding the maximum allowable outstanding requests
  ksmbd: move credit charge deduction under processing request
  ksmbd: add support for smb2 max credit parameter
  ksmbd: set 445 port to smbdirect port by default
  ksmbd: register ksmbd ib client with ib_register_client()
  ksmbd: Fix smb2_get_name() kernel-doc comment
  ksmbd: Delete an invalid argument description in smb2_populate_readdir_entry()
  ksmbd: Fix smb2_set_info_file() kernel-doc comment
  ksmbd: Fix buffer_check_err() kernel-doc comment
  ksmbd: fix multi session connection failure
  ksmbd: set both ipv4 and ipv6 in FSCTL_QUERY_NETWORK_INTERFACE_INFO
  ksmbd: set RSS capable in FSCTL_QUERY_NETWORK_INTERFACE_INFO
  ...
parents c5a0b6e4 ac090d9c
...@@ -21,101 +21,11 @@ ...@@ -21,101 +21,11 @@
#include "ksmbd_spnego_negtokeninit.asn1.h" #include "ksmbd_spnego_negtokeninit.asn1.h"
#include "ksmbd_spnego_negtokentarg.asn1.h" #include "ksmbd_spnego_negtokentarg.asn1.h"
#define SPNEGO_OID_LEN 7
#define NTLMSSP_OID_LEN 10 #define NTLMSSP_OID_LEN 10
#define KRB5_OID_LEN 7
#define KRB5U2U_OID_LEN 8
#define MSKRB5_OID_LEN 7
static unsigned long SPNEGO_OID[7] = { 1, 3, 6, 1, 5, 5, 2 };
static unsigned long NTLMSSP_OID[10] = { 1, 3, 6, 1, 4, 1, 311, 2, 2, 10 };
static unsigned long KRB5_OID[7] = { 1, 2, 840, 113554, 1, 2, 2 };
static unsigned long KRB5U2U_OID[8] = { 1, 2, 840, 113554, 1, 2, 2, 3 };
static unsigned long MSKRB5_OID[7] = { 1, 2, 840, 48018, 1, 2, 2 };
static char NTLMSSP_OID_STR[NTLMSSP_OID_LEN] = { 0x2b, 0x06, 0x01, 0x04, 0x01, static char NTLMSSP_OID_STR[NTLMSSP_OID_LEN] = { 0x2b, 0x06, 0x01, 0x04, 0x01,
0x82, 0x37, 0x02, 0x02, 0x0a }; 0x82, 0x37, 0x02, 0x02, 0x0a };
static bool
asn1_subid_decode(const unsigned char **begin, const unsigned char *end,
unsigned long *subid)
{
const unsigned char *ptr = *begin;
unsigned char ch;
*subid = 0;
do {
if (ptr >= end)
return false;
ch = *ptr++;
*subid <<= 7;
*subid |= ch & 0x7F;
} while ((ch & 0x80) == 0x80);
*begin = ptr;
return true;
}
static bool asn1_oid_decode(const unsigned char *value, size_t vlen,
unsigned long **oid, size_t *oidlen)
{
const unsigned char *iptr = value, *end = value + vlen;
unsigned long *optr;
unsigned long subid;
vlen += 1;
if (vlen < 2 || vlen > UINT_MAX / sizeof(unsigned long))
goto fail_nullify;
*oid = kmalloc(vlen * sizeof(unsigned long), GFP_KERNEL);
if (!*oid)
return false;
optr = *oid;
if (!asn1_subid_decode(&iptr, end, &subid))
goto fail;
if (subid < 40) {
optr[0] = 0;
optr[1] = subid;
} else if (subid < 80) {
optr[0] = 1;
optr[1] = subid - 40;
} else {
optr[0] = 2;
optr[1] = subid - 80;
}
*oidlen = 2;
optr += 2;
while (iptr < end) {
if (++(*oidlen) > vlen)
goto fail;
if (!asn1_subid_decode(&iptr, end, optr++))
goto fail;
}
return true;
fail:
kfree(*oid);
fail_nullify:
*oid = NULL;
return false;
}
static bool oid_eq(unsigned long *oid1, unsigned int oid1len,
unsigned long *oid2, unsigned int oid2len)
{
if (oid1len != oid2len)
return false;
return memcmp(oid1, oid2, oid1len) == 0;
}
int int
ksmbd_decode_negTokenInit(unsigned char *security_blob, int length, ksmbd_decode_negTokenInit(unsigned char *security_blob, int length,
struct ksmbd_conn *conn) struct ksmbd_conn *conn)
...@@ -252,26 +162,18 @@ int build_spnego_ntlmssp_auth_blob(unsigned char **pbuffer, u16 *buflen, ...@@ -252,26 +162,18 @@ int build_spnego_ntlmssp_auth_blob(unsigned char **pbuffer, u16 *buflen,
int ksmbd_gssapi_this_mech(void *context, size_t hdrlen, unsigned char tag, int ksmbd_gssapi_this_mech(void *context, size_t hdrlen, unsigned char tag,
const void *value, size_t vlen) const void *value, size_t vlen)
{ {
unsigned long *oid; enum OID oid;
size_t oidlen;
int err = 0;
if (!asn1_oid_decode(value, vlen, &oid, &oidlen)) { oid = look_up_OID(value, vlen);
err = -EBADMSG; if (oid != OID_spnego) {
goto out;
}
if (!oid_eq(oid, oidlen, SPNEGO_OID, SPNEGO_OID_LEN))
err = -EBADMSG;
kfree(oid);
out:
if (err) {
char buf[50]; char buf[50];
sprint_oid(value, vlen, buf, sizeof(buf)); sprint_oid(value, vlen, buf, sizeof(buf));
ksmbd_debug(AUTH, "Unexpected OID: %s\n", buf); ksmbd_debug(AUTH, "Unexpected OID: %s\n", buf);
return -EBADMSG;
} }
return err;
return 0;
} }
int ksmbd_neg_token_init_mech_type(void *context, size_t hdrlen, int ksmbd_neg_token_init_mech_type(void *context, size_t hdrlen,
...@@ -279,37 +181,31 @@ int ksmbd_neg_token_init_mech_type(void *context, size_t hdrlen, ...@@ -279,37 +181,31 @@ int ksmbd_neg_token_init_mech_type(void *context, size_t hdrlen,
size_t vlen) size_t vlen)
{ {
struct ksmbd_conn *conn = context; struct ksmbd_conn *conn = context;
unsigned long *oid; enum OID oid;
size_t oidlen;
int mech_type; int mech_type;
char buf[50];
if (!asn1_oid_decode(value, vlen, &oid, &oidlen)) oid = look_up_OID(value, vlen);
goto fail; if (oid == OID_ntlmssp) {
if (oid_eq(oid, oidlen, NTLMSSP_OID, NTLMSSP_OID_LEN))
mech_type = KSMBD_AUTH_NTLMSSP; mech_type = KSMBD_AUTH_NTLMSSP;
else if (oid_eq(oid, oidlen, MSKRB5_OID, MSKRB5_OID_LEN)) } else if (oid == OID_mskrb5) {
mech_type = KSMBD_AUTH_MSKRB5; mech_type = KSMBD_AUTH_MSKRB5;
else if (oid_eq(oid, oidlen, KRB5_OID, KRB5_OID_LEN)) } else if (oid == OID_krb5) {
mech_type = KSMBD_AUTH_KRB5; mech_type = KSMBD_AUTH_KRB5;
else if (oid_eq(oid, oidlen, KRB5U2U_OID, KRB5U2U_OID_LEN)) } else if (oid == OID_krb5u2u) {
mech_type = KSMBD_AUTH_KRB5U2U; mech_type = KSMBD_AUTH_KRB5U2U;
else } else {
goto fail; char buf[50];
sprint_oid(value, vlen, buf, sizeof(buf));
ksmbd_debug(AUTH, "Unexpected OID: %s\n", buf);
return -EBADMSG;
}
conn->auth_mechs |= mech_type; conn->auth_mechs |= mech_type;
if (conn->preferred_auth_mech == 0) if (conn->preferred_auth_mech == 0)
conn->preferred_auth_mech = mech_type; conn->preferred_auth_mech = mech_type;
kfree(oid);
return 0; return 0;
fail:
kfree(oid);
sprint_oid(value, vlen, buf, sizeof(buf));
ksmbd_debug(AUTH, "Unexpected OID: %s\n", buf);
return -EBADMSG;
} }
int ksmbd_neg_token_init_mech_token(void *context, size_t hdrlen, int ksmbd_neg_token_init_mech_token(void *context, size_t hdrlen,
......
...@@ -215,7 +215,7 @@ static int calc_ntlmv2_hash(struct ksmbd_session *sess, char *ntlmv2_hash, ...@@ -215,7 +215,7 @@ static int calc_ntlmv2_hash(struct ksmbd_session *sess, char *ntlmv2_hash,
* Return: 0 on success, error number on error * Return: 0 on success, error number on error
*/ */
int ksmbd_auth_ntlmv2(struct ksmbd_session *sess, struct ntlmv2_resp *ntlmv2, int ksmbd_auth_ntlmv2(struct ksmbd_session *sess, struct ntlmv2_resp *ntlmv2,
int blen, char *domain_name) int blen, char *domain_name, char *cryptkey)
{ {
char ntlmv2_hash[CIFS_ENCPWD_SIZE]; char ntlmv2_hash[CIFS_ENCPWD_SIZE];
char ntlmv2_rsp[CIFS_HMAC_MD5_HASH_SIZE]; char ntlmv2_rsp[CIFS_HMAC_MD5_HASH_SIZE];
...@@ -256,7 +256,7 @@ int ksmbd_auth_ntlmv2(struct ksmbd_session *sess, struct ntlmv2_resp *ntlmv2, ...@@ -256,7 +256,7 @@ int ksmbd_auth_ntlmv2(struct ksmbd_session *sess, struct ntlmv2_resp *ntlmv2,
goto out; goto out;
} }
memcpy(construct, sess->ntlmssp.cryptkey, CIFS_CRYPTO_KEY_SIZE); memcpy(construct, cryptkey, CIFS_CRYPTO_KEY_SIZE);
memcpy(construct + CIFS_CRYPTO_KEY_SIZE, &ntlmv2->blob_signature, blen); memcpy(construct + CIFS_CRYPTO_KEY_SIZE, &ntlmv2->blob_signature, blen);
rc = crypto_shash_update(CRYPTO_HMACMD5(ctx), construct, len); rc = crypto_shash_update(CRYPTO_HMACMD5(ctx), construct, len);
...@@ -295,7 +295,8 @@ int ksmbd_auth_ntlmv2(struct ksmbd_session *sess, struct ntlmv2_resp *ntlmv2, ...@@ -295,7 +295,8 @@ int ksmbd_auth_ntlmv2(struct ksmbd_session *sess, struct ntlmv2_resp *ntlmv2,
* Return: 0 on success, error number on error * Return: 0 on success, error number on error
*/ */
int ksmbd_decode_ntlmssp_auth_blob(struct authenticate_message *authblob, int ksmbd_decode_ntlmssp_auth_blob(struct authenticate_message *authblob,
int blob_len, struct ksmbd_session *sess) int blob_len, struct ksmbd_conn *conn,
struct ksmbd_session *sess)
{ {
char *domain_name; char *domain_name;
unsigned int nt_off, dn_off; unsigned int nt_off, dn_off;
...@@ -324,7 +325,7 @@ int ksmbd_decode_ntlmssp_auth_blob(struct authenticate_message *authblob, ...@@ -324,7 +325,7 @@ int ksmbd_decode_ntlmssp_auth_blob(struct authenticate_message *authblob,
/* TODO : use domain name that imported from configuration file */ /* TODO : use domain name that imported from configuration file */
domain_name = smb_strndup_from_utf16((const char *)authblob + dn_off, domain_name = smb_strndup_from_utf16((const char *)authblob + dn_off,
dn_len, true, sess->conn->local_nls); dn_len, true, conn->local_nls);
if (IS_ERR(domain_name)) if (IS_ERR(domain_name))
return PTR_ERR(domain_name); return PTR_ERR(domain_name);
...@@ -333,7 +334,7 @@ int ksmbd_decode_ntlmssp_auth_blob(struct authenticate_message *authblob, ...@@ -333,7 +334,7 @@ int ksmbd_decode_ntlmssp_auth_blob(struct authenticate_message *authblob,
domain_name); domain_name);
ret = ksmbd_auth_ntlmv2(sess, (struct ntlmv2_resp *)((char *)authblob + nt_off), ret = ksmbd_auth_ntlmv2(sess, (struct ntlmv2_resp *)((char *)authblob + nt_off),
nt_len - CIFS_ENCPWD_SIZE, nt_len - CIFS_ENCPWD_SIZE,
domain_name); domain_name, conn->ntlmssp.cryptkey);
kfree(domain_name); kfree(domain_name);
return ret; return ret;
} }
...@@ -347,7 +348,7 @@ int ksmbd_decode_ntlmssp_auth_blob(struct authenticate_message *authblob, ...@@ -347,7 +348,7 @@ int ksmbd_decode_ntlmssp_auth_blob(struct authenticate_message *authblob,
* *
*/ */
int ksmbd_decode_ntlmssp_neg_blob(struct negotiate_message *negblob, int ksmbd_decode_ntlmssp_neg_blob(struct negotiate_message *negblob,
int blob_len, struct ksmbd_session *sess) int blob_len, struct ksmbd_conn *conn)
{ {
if (blob_len < sizeof(struct negotiate_message)) { if (blob_len < sizeof(struct negotiate_message)) {
ksmbd_debug(AUTH, "negotiate blob len %d too small\n", ksmbd_debug(AUTH, "negotiate blob len %d too small\n",
...@@ -361,7 +362,7 @@ int ksmbd_decode_ntlmssp_neg_blob(struct negotiate_message *negblob, ...@@ -361,7 +362,7 @@ int ksmbd_decode_ntlmssp_neg_blob(struct negotiate_message *negblob,
return -EINVAL; return -EINVAL;
} }
sess->ntlmssp.client_flags = le32_to_cpu(negblob->NegotiateFlags); conn->ntlmssp.client_flags = le32_to_cpu(negblob->NegotiateFlags);
return 0; return 0;
} }
...@@ -375,14 +376,14 @@ int ksmbd_decode_ntlmssp_neg_blob(struct negotiate_message *negblob, ...@@ -375,14 +376,14 @@ int ksmbd_decode_ntlmssp_neg_blob(struct negotiate_message *negblob,
*/ */
unsigned int unsigned int
ksmbd_build_ntlmssp_challenge_blob(struct challenge_message *chgblob, ksmbd_build_ntlmssp_challenge_blob(struct challenge_message *chgblob,
struct ksmbd_session *sess) struct ksmbd_conn *conn)
{ {
struct target_info *tinfo; struct target_info *tinfo;
wchar_t *name; wchar_t *name;
__u8 *target_name; __u8 *target_name;
unsigned int flags, blob_off, blob_len, type, target_info_len = 0; unsigned int flags, blob_off, blob_len, type, target_info_len = 0;
int len, uni_len, conv_len; int len, uni_len, conv_len;
int cflags = sess->ntlmssp.client_flags; int cflags = conn->ntlmssp.client_flags;
memcpy(chgblob->Signature, NTLMSSP_SIGNATURE, 8); memcpy(chgblob->Signature, NTLMSSP_SIGNATURE, 8);
chgblob->MessageType = NtLmChallenge; chgblob->MessageType = NtLmChallenge;
...@@ -403,7 +404,7 @@ ksmbd_build_ntlmssp_challenge_blob(struct challenge_message *chgblob, ...@@ -403,7 +404,7 @@ ksmbd_build_ntlmssp_challenge_blob(struct challenge_message *chgblob,
if (cflags & NTLMSSP_REQUEST_TARGET) if (cflags & NTLMSSP_REQUEST_TARGET)
flags |= NTLMSSP_REQUEST_TARGET; flags |= NTLMSSP_REQUEST_TARGET;
if (sess->conn->use_spnego && if (conn->use_spnego &&
(cflags & NTLMSSP_NEGOTIATE_EXTENDED_SEC)) (cflags & NTLMSSP_NEGOTIATE_EXTENDED_SEC))
flags |= NTLMSSP_NEGOTIATE_EXTENDED_SEC; flags |= NTLMSSP_NEGOTIATE_EXTENDED_SEC;
...@@ -414,7 +415,7 @@ ksmbd_build_ntlmssp_challenge_blob(struct challenge_message *chgblob, ...@@ -414,7 +415,7 @@ ksmbd_build_ntlmssp_challenge_blob(struct challenge_message *chgblob,
return -ENOMEM; return -ENOMEM;
conv_len = smb_strtoUTF16((__le16 *)name, ksmbd_netbios_name(), len, conv_len = smb_strtoUTF16((__le16 *)name, ksmbd_netbios_name(), len,
sess->conn->local_nls); conn->local_nls);
if (conv_len < 0 || conv_len > len) { if (conv_len < 0 || conv_len > len) {
kfree(name); kfree(name);
return -EINVAL; return -EINVAL;
...@@ -430,8 +431,8 @@ ksmbd_build_ntlmssp_challenge_blob(struct challenge_message *chgblob, ...@@ -430,8 +431,8 @@ ksmbd_build_ntlmssp_challenge_blob(struct challenge_message *chgblob,
chgblob->TargetName.BufferOffset = cpu_to_le32(blob_off); chgblob->TargetName.BufferOffset = cpu_to_le32(blob_off);
/* Initialize random conn challenge */ /* Initialize random conn challenge */
get_random_bytes(sess->ntlmssp.cryptkey, sizeof(__u64)); get_random_bytes(conn->ntlmssp.cryptkey, sizeof(__u64));
memcpy(chgblob->Challenge, sess->ntlmssp.cryptkey, memcpy(chgblob->Challenge, conn->ntlmssp.cryptkey,
CIFS_CRYPTO_KEY_SIZE); CIFS_CRYPTO_KEY_SIZE);
/* Add Target Information to security buffer */ /* Add Target Information to security buffer */
......
...@@ -38,16 +38,16 @@ struct kvec; ...@@ -38,16 +38,16 @@ struct kvec;
int ksmbd_crypt_message(struct ksmbd_conn *conn, struct kvec *iov, int ksmbd_crypt_message(struct ksmbd_conn *conn, struct kvec *iov,
unsigned int nvec, int enc); unsigned int nvec, int enc);
void ksmbd_copy_gss_neg_header(void *buf); void ksmbd_copy_gss_neg_header(void *buf);
int ksmbd_auth_ntlm(struct ksmbd_session *sess, char *pw_buf);
int ksmbd_auth_ntlmv2(struct ksmbd_session *sess, struct ntlmv2_resp *ntlmv2, int ksmbd_auth_ntlmv2(struct ksmbd_session *sess, struct ntlmv2_resp *ntlmv2,
int blen, char *domain_name); int blen, char *domain_name, char *cryptkey);
int ksmbd_decode_ntlmssp_auth_blob(struct authenticate_message *authblob, int ksmbd_decode_ntlmssp_auth_blob(struct authenticate_message *authblob,
int blob_len, struct ksmbd_session *sess); int blob_len, struct ksmbd_conn *conn,
struct ksmbd_session *sess);
int ksmbd_decode_ntlmssp_neg_blob(struct negotiate_message *negblob, int ksmbd_decode_ntlmssp_neg_blob(struct negotiate_message *negblob,
int blob_len, struct ksmbd_session *sess); int blob_len, struct ksmbd_conn *conn);
unsigned int unsigned int
ksmbd_build_ntlmssp_challenge_blob(struct challenge_message *chgblob, ksmbd_build_ntlmssp_challenge_blob(struct challenge_message *chgblob,
struct ksmbd_session *sess); struct ksmbd_conn *conn);
int ksmbd_krb5_authenticate(struct ksmbd_session *sess, char *in_blob, int ksmbd_krb5_authenticate(struct ksmbd_session *sess, char *in_blob,
int in_len, char *out_blob, int *out_len); int in_len, char *out_blob, int *out_len);
int ksmbd_sign_smb2_pdu(struct ksmbd_conn *conn, char *key, struct kvec *iov, int ksmbd_sign_smb2_pdu(struct ksmbd_conn *conn, char *key, struct kvec *iov,
......
...@@ -62,6 +62,7 @@ struct ksmbd_conn *ksmbd_conn_alloc(void) ...@@ -62,6 +62,7 @@ struct ksmbd_conn *ksmbd_conn_alloc(void)
atomic_set(&conn->req_running, 0); atomic_set(&conn->req_running, 0);
atomic_set(&conn->r_count, 0); atomic_set(&conn->r_count, 0);
conn->total_credits = 1; conn->total_credits = 1;
conn->outstanding_credits = 1;
init_waitqueue_head(&conn->req_running_q); init_waitqueue_head(&conn->req_running_q);
INIT_LIST_HEAD(&conn->conns_list); INIT_LIST_HEAD(&conn->conns_list);
...@@ -386,17 +387,24 @@ int ksmbd_conn_transport_init(void) ...@@ -386,17 +387,24 @@ int ksmbd_conn_transport_init(void)
static void stop_sessions(void) static void stop_sessions(void)
{ {
struct ksmbd_conn *conn; struct ksmbd_conn *conn;
struct ksmbd_transport *t;
again: again:
read_lock(&conn_list_lock); read_lock(&conn_list_lock);
list_for_each_entry(conn, &conn_list, conns_list) { list_for_each_entry(conn, &conn_list, conns_list) {
struct task_struct *task; struct task_struct *task;
task = conn->transport->handler; t = conn->transport;
task = t->handler;
if (task) if (task)
ksmbd_debug(CONN, "Stop session handler %s/%d\n", ksmbd_debug(CONN, "Stop session handler %s/%d\n",
task->comm, task_pid_nr(task)); task->comm, task_pid_nr(task));
conn->status = KSMBD_SESS_EXITING; conn->status = KSMBD_SESS_EXITING;
if (t->ops->shutdown) {
read_unlock(&conn_list_lock);
t->ops->shutdown(t);
read_lock(&conn_list_lock);
}
} }
read_unlock(&conn_list_lock); read_unlock(&conn_list_lock);
......
...@@ -61,8 +61,8 @@ struct ksmbd_conn { ...@@ -61,8 +61,8 @@ struct ksmbd_conn {
atomic_t req_running; atomic_t req_running;
/* References which are made for this Server object*/ /* References which are made for this Server object*/
atomic_t r_count; atomic_t r_count;
unsigned short total_credits; unsigned int total_credits;
unsigned short max_credits; unsigned int outstanding_credits;
spinlock_t credits_lock; spinlock_t credits_lock;
wait_queue_head_t req_running_q; wait_queue_head_t req_running_q;
/* Lock to protect requests list*/ /* Lock to protect requests list*/
...@@ -72,12 +72,7 @@ struct ksmbd_conn { ...@@ -72,12 +72,7 @@ struct ksmbd_conn {
int connection_type; int connection_type;
struct ksmbd_stats stats; struct ksmbd_stats stats;
char ClientGUID[SMB2_CLIENT_GUID_SIZE]; char ClientGUID[SMB2_CLIENT_GUID_SIZE];
union { struct ntlmssp_auth ntlmssp;
/* pending trans request table */
struct trans_state *recent_trans;
/* Used by ntlmssp */
char *ntlmssp_cryptkey;
};
spinlock_t llist_lock; spinlock_t llist_lock;
struct list_head lock_list; struct list_head lock_list;
...@@ -122,6 +117,7 @@ struct ksmbd_conn_ops { ...@@ -122,6 +117,7 @@ struct ksmbd_conn_ops {
struct ksmbd_transport_ops { struct ksmbd_transport_ops {
int (*prepare)(struct ksmbd_transport *t); int (*prepare)(struct ksmbd_transport *t);
void (*disconnect)(struct ksmbd_transport *t); void (*disconnect)(struct ksmbd_transport *t);
void (*shutdown)(struct ksmbd_transport *t);
int (*read)(struct ksmbd_transport *t, char *buf, unsigned int size); int (*read)(struct ksmbd_transport *t, char *buf, unsigned int size);
int (*writev)(struct ksmbd_transport *t, struct kvec *iovs, int niov, int (*writev)(struct ksmbd_transport *t, struct kvec *iovs, int niov,
int size, bool need_invalidate_rkey, int size, bool need_invalidate_rkey,
......
...@@ -103,6 +103,8 @@ struct ksmbd_startup_request { ...@@ -103,6 +103,8 @@ struct ksmbd_startup_request {
* we set the SPARSE_FILES bit (0x40). * we set the SPARSE_FILES bit (0x40).
*/ */
__u32 sub_auth[3]; /* Subauth value for Security ID */ __u32 sub_auth[3]; /* Subauth value for Security ID */
__u32 smb2_max_credits; /* MAX credits */
__u32 reserved[128]; /* Reserved room */
__u32 ifc_list_sz; /* interfaces list size */ __u32 ifc_list_sz; /* interfaces list size */
__s8 ____payload[]; __s8 ____payload[];
}; };
...@@ -113,7 +115,7 @@ struct ksmbd_startup_request { ...@@ -113,7 +115,7 @@ struct ksmbd_startup_request {
* IPC request to shutdown ksmbd server. * IPC request to shutdown ksmbd server.
*/ */
struct ksmbd_shutdown_request { struct ksmbd_shutdown_request {
__s32 reserved; __s32 reserved[16];
}; };
/* /*
...@@ -122,6 +124,7 @@ struct ksmbd_shutdown_request { ...@@ -122,6 +124,7 @@ struct ksmbd_shutdown_request {
struct ksmbd_login_request { struct ksmbd_login_request {
__u32 handle; __u32 handle;
__s8 account[KSMBD_REQ_MAX_ACCOUNT_NAME_SZ]; /* user account name */ __s8 account[KSMBD_REQ_MAX_ACCOUNT_NAME_SZ]; /* user account name */
__u32 reserved[16]; /* Reserved room */
}; };
/* /*
...@@ -135,6 +138,7 @@ struct ksmbd_login_response { ...@@ -135,6 +138,7 @@ struct ksmbd_login_response {
__u16 status; __u16 status;
__u16 hash_sz; /* hash size */ __u16 hash_sz; /* hash size */
__s8 hash[KSMBD_REQ_MAX_HASH_SZ]; /* password hash */ __s8 hash[KSMBD_REQ_MAX_HASH_SZ]; /* password hash */
__u32 reserved[16]; /* Reserved room */
}; };
/* /*
...@@ -143,6 +147,7 @@ struct ksmbd_login_response { ...@@ -143,6 +147,7 @@ struct ksmbd_login_response {
struct ksmbd_share_config_request { struct ksmbd_share_config_request {
__u32 handle; __u32 handle;
__s8 share_name[KSMBD_REQ_MAX_SHARE_NAME]; /* share name */ __s8 share_name[KSMBD_REQ_MAX_SHARE_NAME]; /* share name */
__u32 reserved[16]; /* Reserved room */
}; };
/* /*
...@@ -157,6 +162,7 @@ struct ksmbd_share_config_response { ...@@ -157,6 +162,7 @@ struct ksmbd_share_config_response {
__u16 force_directory_mode; __u16 force_directory_mode;
__u16 force_uid; __u16 force_uid;
__u16 force_gid; __u16 force_gid;
__u32 reserved[128]; /* Reserved room */
__u32 veto_list_sz; __u32 veto_list_sz;
__s8 ____payload[]; __s8 ____payload[];
}; };
...@@ -187,6 +193,7 @@ struct ksmbd_tree_connect_request { ...@@ -187,6 +193,7 @@ struct ksmbd_tree_connect_request {
__s8 account[KSMBD_REQ_MAX_ACCOUNT_NAME_SZ]; __s8 account[KSMBD_REQ_MAX_ACCOUNT_NAME_SZ];
__s8 share[KSMBD_REQ_MAX_SHARE_NAME]; __s8 share[KSMBD_REQ_MAX_SHARE_NAME];
__s8 peer_addr[64]; __s8 peer_addr[64];
__u32 reserved[16]; /* Reserved room */
}; };
/* /*
...@@ -196,6 +203,7 @@ struct ksmbd_tree_connect_response { ...@@ -196,6 +203,7 @@ struct ksmbd_tree_connect_response {
__u32 handle; __u32 handle;
__u16 status; __u16 status;
__u16 connection_flags; __u16 connection_flags;
__u32 reserved[16]; /* Reserved room */
}; };
/* /*
...@@ -204,6 +212,7 @@ struct ksmbd_tree_connect_response { ...@@ -204,6 +212,7 @@ struct ksmbd_tree_connect_response {
struct ksmbd_tree_disconnect_request { struct ksmbd_tree_disconnect_request {
__u64 session_id; /* session id */ __u64 session_id; /* session id */
__u64 connect_id; /* tree connection id */ __u64 connect_id; /* tree connection id */
__u32 reserved[16]; /* Reserved room */
}; };
/* /*
...@@ -212,6 +221,7 @@ struct ksmbd_tree_disconnect_request { ...@@ -212,6 +221,7 @@ struct ksmbd_tree_disconnect_request {
struct ksmbd_logout_request { struct ksmbd_logout_request {
__s8 account[KSMBD_REQ_MAX_ACCOUNT_NAME_SZ]; /* user account name */ __s8 account[KSMBD_REQ_MAX_ACCOUNT_NAME_SZ]; /* user account name */
__u32 account_flags; __u32 account_flags;
__u32 reserved[16]; /* Reserved room */
}; };
/* /*
......
...@@ -67,3 +67,13 @@ int ksmbd_anonymous_user(struct ksmbd_user *user) ...@@ -67,3 +67,13 @@ int ksmbd_anonymous_user(struct ksmbd_user *user)
return 1; return 1;
return 0; return 0;
} }
bool ksmbd_compare_user(struct ksmbd_user *u1, struct ksmbd_user *u2)
{
if (strcmp(u1->name, u2->name))
return false;
if (memcmp(u1->passkey, u2->passkey, u1->passkey_sz))
return false;
return true;
}
...@@ -64,4 +64,5 @@ struct ksmbd_user *ksmbd_login_user(const char *account); ...@@ -64,4 +64,5 @@ struct ksmbd_user *ksmbd_login_user(const char *account);
struct ksmbd_user *ksmbd_alloc_user(struct ksmbd_login_response *resp); struct ksmbd_user *ksmbd_alloc_user(struct ksmbd_login_response *resp);
void ksmbd_free_user(struct ksmbd_user *user); void ksmbd_free_user(struct ksmbd_user *user);
int ksmbd_anonymous_user(struct ksmbd_user *user); int ksmbd_anonymous_user(struct ksmbd_user *user);
bool ksmbd_compare_user(struct ksmbd_user *u1, struct ksmbd_user *u2);
#endif /* __USER_CONFIG_MANAGEMENT_H__ */ #endif /* __USER_CONFIG_MANAGEMENT_H__ */
...@@ -45,7 +45,6 @@ struct ksmbd_session { ...@@ -45,7 +45,6 @@ struct ksmbd_session {
int state; int state;
__u8 *Preauth_HashValue; __u8 *Preauth_HashValue;
struct ntlmssp_auth ntlmssp;
char sess_key[CIFS_KEY_SIZE]; char sess_key[CIFS_KEY_SIZE];
struct hlist_node hlist; struct hlist_node hlist;
......
...@@ -289,7 +289,7 @@ static int smb2_validate_credit_charge(struct ksmbd_conn *conn, ...@@ -289,7 +289,7 @@ static int smb2_validate_credit_charge(struct ksmbd_conn *conn,
unsigned int req_len = 0, expect_resp_len = 0, calc_credit_num, max_len; unsigned int req_len = 0, expect_resp_len = 0, calc_credit_num, max_len;
unsigned short credit_charge = le16_to_cpu(hdr->CreditCharge); unsigned short credit_charge = le16_to_cpu(hdr->CreditCharge);
void *__hdr = hdr; void *__hdr = hdr;
int ret; int ret = 0;
switch (hdr->Command) { switch (hdr->Command) {
case SMB2_QUERY_INFO: case SMB2_QUERY_INFO:
...@@ -326,21 +326,27 @@ static int smb2_validate_credit_charge(struct ksmbd_conn *conn, ...@@ -326,21 +326,27 @@ static int smb2_validate_credit_charge(struct ksmbd_conn *conn,
ksmbd_debug(SMB, "Insufficient credit charge, given: %d, needed: %d\n", ksmbd_debug(SMB, "Insufficient credit charge, given: %d, needed: %d\n",
credit_charge, calc_credit_num); credit_charge, calc_credit_num);
return 1; return 1;
} else if (credit_charge > conn->max_credits) { } else if (credit_charge > conn->vals->max_credits) {
ksmbd_debug(SMB, "Too large credit charge: %d\n", credit_charge); ksmbd_debug(SMB, "Too large credit charge: %d\n", credit_charge);
return 1; return 1;
} }
spin_lock(&conn->credits_lock); spin_lock(&conn->credits_lock);
if (credit_charge <= conn->total_credits) { if (credit_charge > conn->total_credits) {
conn->total_credits -= credit_charge;
ret = 0;
} else {
ksmbd_debug(SMB, "Insufficient credits granted, given: %u, granted: %u\n", ksmbd_debug(SMB, "Insufficient credits granted, given: %u, granted: %u\n",
credit_charge, conn->total_credits); credit_charge, conn->total_credits);
ret = 1; ret = 1;
} }
if ((u64)conn->outstanding_credits + credit_charge > conn->vals->max_credits) {
ksmbd_debug(SMB, "Limits exceeding the maximum allowable outstanding requests, given : %u, pending : %u\n",
credit_charge, conn->outstanding_credits);
ret = 1;
} else
conn->outstanding_credits += credit_charge;
spin_unlock(&conn->credits_lock); spin_unlock(&conn->credits_lock);
return ret; return ret;
} }
......
...@@ -19,6 +19,7 @@ static struct smb_version_values smb21_server_values = { ...@@ -19,6 +19,7 @@ static struct smb_version_values smb21_server_values = {
.max_read_size = SMB21_DEFAULT_IOSIZE, .max_read_size = SMB21_DEFAULT_IOSIZE,
.max_write_size = SMB21_DEFAULT_IOSIZE, .max_write_size = SMB21_DEFAULT_IOSIZE,
.max_trans_size = SMB21_DEFAULT_IOSIZE, .max_trans_size = SMB21_DEFAULT_IOSIZE,
.max_credits = SMB2_MAX_CREDITS,
.large_lock_type = 0, .large_lock_type = 0,
.exclusive_lock_type = SMB2_LOCKFLAG_EXCLUSIVE, .exclusive_lock_type = SMB2_LOCKFLAG_EXCLUSIVE,
.shared_lock_type = SMB2_LOCKFLAG_SHARED, .shared_lock_type = SMB2_LOCKFLAG_SHARED,
...@@ -44,6 +45,7 @@ static struct smb_version_values smb30_server_values = { ...@@ -44,6 +45,7 @@ static struct smb_version_values smb30_server_values = {
.max_read_size = SMB3_DEFAULT_IOSIZE, .max_read_size = SMB3_DEFAULT_IOSIZE,
.max_write_size = SMB3_DEFAULT_IOSIZE, .max_write_size = SMB3_DEFAULT_IOSIZE,
.max_trans_size = SMB3_DEFAULT_TRANS_SIZE, .max_trans_size = SMB3_DEFAULT_TRANS_SIZE,
.max_credits = SMB2_MAX_CREDITS,
.large_lock_type = 0, .large_lock_type = 0,
.exclusive_lock_type = SMB2_LOCKFLAG_EXCLUSIVE, .exclusive_lock_type = SMB2_LOCKFLAG_EXCLUSIVE,
.shared_lock_type = SMB2_LOCKFLAG_SHARED, .shared_lock_type = SMB2_LOCKFLAG_SHARED,
...@@ -70,6 +72,7 @@ static struct smb_version_values smb302_server_values = { ...@@ -70,6 +72,7 @@ static struct smb_version_values smb302_server_values = {
.max_read_size = SMB3_DEFAULT_IOSIZE, .max_read_size = SMB3_DEFAULT_IOSIZE,
.max_write_size = SMB3_DEFAULT_IOSIZE, .max_write_size = SMB3_DEFAULT_IOSIZE,
.max_trans_size = SMB3_DEFAULT_TRANS_SIZE, .max_trans_size = SMB3_DEFAULT_TRANS_SIZE,
.max_credits = SMB2_MAX_CREDITS,
.large_lock_type = 0, .large_lock_type = 0,
.exclusive_lock_type = SMB2_LOCKFLAG_EXCLUSIVE, .exclusive_lock_type = SMB2_LOCKFLAG_EXCLUSIVE,
.shared_lock_type = SMB2_LOCKFLAG_SHARED, .shared_lock_type = SMB2_LOCKFLAG_SHARED,
...@@ -96,6 +99,7 @@ static struct smb_version_values smb311_server_values = { ...@@ -96,6 +99,7 @@ static struct smb_version_values smb311_server_values = {
.max_read_size = SMB3_DEFAULT_IOSIZE, .max_read_size = SMB3_DEFAULT_IOSIZE,
.max_write_size = SMB3_DEFAULT_IOSIZE, .max_write_size = SMB3_DEFAULT_IOSIZE,
.max_trans_size = SMB3_DEFAULT_TRANS_SIZE, .max_trans_size = SMB3_DEFAULT_TRANS_SIZE,
.max_credits = SMB2_MAX_CREDITS,
.large_lock_type = 0, .large_lock_type = 0,
.exclusive_lock_type = SMB2_LOCKFLAG_EXCLUSIVE, .exclusive_lock_type = SMB2_LOCKFLAG_EXCLUSIVE,
.shared_lock_type = SMB2_LOCKFLAG_SHARED, .shared_lock_type = SMB2_LOCKFLAG_SHARED,
...@@ -197,7 +201,6 @@ void init_smb2_1_server(struct ksmbd_conn *conn) ...@@ -197,7 +201,6 @@ void init_smb2_1_server(struct ksmbd_conn *conn)
conn->ops = &smb2_0_server_ops; conn->ops = &smb2_0_server_ops;
conn->cmds = smb2_0_server_cmds; conn->cmds = smb2_0_server_cmds;
conn->max_cmds = ARRAY_SIZE(smb2_0_server_cmds); conn->max_cmds = ARRAY_SIZE(smb2_0_server_cmds);
conn->max_credits = SMB2_MAX_CREDITS;
conn->signing_algorithm = SIGNING_ALG_HMAC_SHA256_LE; conn->signing_algorithm = SIGNING_ALG_HMAC_SHA256_LE;
if (server_conf.flags & KSMBD_GLOBAL_FLAG_SMB2_LEASES) if (server_conf.flags & KSMBD_GLOBAL_FLAG_SMB2_LEASES)
...@@ -215,7 +218,6 @@ void init_smb3_0_server(struct ksmbd_conn *conn) ...@@ -215,7 +218,6 @@ void init_smb3_0_server(struct ksmbd_conn *conn)
conn->ops = &smb3_0_server_ops; conn->ops = &smb3_0_server_ops;
conn->cmds = smb2_0_server_cmds; conn->cmds = smb2_0_server_cmds;
conn->max_cmds = ARRAY_SIZE(smb2_0_server_cmds); conn->max_cmds = ARRAY_SIZE(smb2_0_server_cmds);
conn->max_credits = SMB2_MAX_CREDITS;
conn->signing_algorithm = SIGNING_ALG_AES_CMAC_LE; conn->signing_algorithm = SIGNING_ALG_AES_CMAC_LE;
if (server_conf.flags & KSMBD_GLOBAL_FLAG_SMB2_LEASES) if (server_conf.flags & KSMBD_GLOBAL_FLAG_SMB2_LEASES)
...@@ -240,7 +242,6 @@ void init_smb3_02_server(struct ksmbd_conn *conn) ...@@ -240,7 +242,6 @@ void init_smb3_02_server(struct ksmbd_conn *conn)
conn->ops = &smb3_0_server_ops; conn->ops = &smb3_0_server_ops;
conn->cmds = smb2_0_server_cmds; conn->cmds = smb2_0_server_cmds;
conn->max_cmds = ARRAY_SIZE(smb2_0_server_cmds); conn->max_cmds = ARRAY_SIZE(smb2_0_server_cmds);
conn->max_credits = SMB2_MAX_CREDITS;
conn->signing_algorithm = SIGNING_ALG_AES_CMAC_LE; conn->signing_algorithm = SIGNING_ALG_AES_CMAC_LE;
if (server_conf.flags & KSMBD_GLOBAL_FLAG_SMB2_LEASES) if (server_conf.flags & KSMBD_GLOBAL_FLAG_SMB2_LEASES)
...@@ -265,7 +266,6 @@ int init_smb3_11_server(struct ksmbd_conn *conn) ...@@ -265,7 +266,6 @@ int init_smb3_11_server(struct ksmbd_conn *conn)
conn->ops = &smb3_11_server_ops; conn->ops = &smb3_11_server_ops;
conn->cmds = smb2_0_server_cmds; conn->cmds = smb2_0_server_cmds;
conn->max_cmds = ARRAY_SIZE(smb2_0_server_cmds); conn->max_cmds = ARRAY_SIZE(smb2_0_server_cmds);
conn->max_credits = SMB2_MAX_CREDITS;
conn->signing_algorithm = SIGNING_ALG_AES_CMAC_LE; conn->signing_algorithm = SIGNING_ALG_AES_CMAC_LE;
if (server_conf.flags & KSMBD_GLOBAL_FLAG_SMB2_LEASES) if (server_conf.flags & KSMBD_GLOBAL_FLAG_SMB2_LEASES)
...@@ -304,3 +304,11 @@ void init_smb2_max_trans_size(unsigned int sz) ...@@ -304,3 +304,11 @@ void init_smb2_max_trans_size(unsigned int sz)
smb302_server_values.max_trans_size = sz; smb302_server_values.max_trans_size = sz;
smb311_server_values.max_trans_size = sz; smb311_server_values.max_trans_size = sz;
} }
void init_smb2_max_credits(unsigned int sz)
{
smb21_server_values.max_credits = sz;
smb30_server_values.max_credits = sz;
smb302_server_values.max_credits = sz;
smb311_server_values.max_credits = sz;
}
...@@ -299,16 +299,15 @@ int smb2_set_rsp_credits(struct ksmbd_work *work) ...@@ -299,16 +299,15 @@ int smb2_set_rsp_credits(struct ksmbd_work *work)
struct smb2_hdr *req_hdr = ksmbd_req_buf_next(work); struct smb2_hdr *req_hdr = ksmbd_req_buf_next(work);
struct smb2_hdr *hdr = ksmbd_resp_buf_next(work); struct smb2_hdr *hdr = ksmbd_resp_buf_next(work);
struct ksmbd_conn *conn = work->conn; struct ksmbd_conn *conn = work->conn;
unsigned short credits_requested; unsigned short credits_requested, aux_max;
unsigned short credit_charge, credits_granted = 0; unsigned short credit_charge, credits_granted = 0;
unsigned short aux_max, aux_credits;
if (work->send_no_response) if (work->send_no_response)
return 0; return 0;
hdr->CreditCharge = req_hdr->CreditCharge; hdr->CreditCharge = req_hdr->CreditCharge;
if (conn->total_credits > conn->max_credits) { if (conn->total_credits > conn->vals->max_credits) {
hdr->CreditRequest = 0; hdr->CreditRequest = 0;
pr_err("Total credits overflow: %d\n", conn->total_credits); pr_err("Total credits overflow: %d\n", conn->total_credits);
return -EINVAL; return -EINVAL;
...@@ -316,6 +315,14 @@ int smb2_set_rsp_credits(struct ksmbd_work *work) ...@@ -316,6 +315,14 @@ int smb2_set_rsp_credits(struct ksmbd_work *work)
credit_charge = max_t(unsigned short, credit_charge = max_t(unsigned short,
le16_to_cpu(req_hdr->CreditCharge), 1); le16_to_cpu(req_hdr->CreditCharge), 1);
if (credit_charge > conn->total_credits) {
ksmbd_debug(SMB, "Insufficient credits granted, given: %u, granted: %u\n",
credit_charge, conn->total_credits);
return -EINVAL;
}
conn->total_credits -= credit_charge;
conn->outstanding_credits -= credit_charge;
credits_requested = max_t(unsigned short, credits_requested = max_t(unsigned short,
le16_to_cpu(req_hdr->CreditRequest), 1); le16_to_cpu(req_hdr->CreditRequest), 1);
...@@ -325,16 +332,14 @@ int smb2_set_rsp_credits(struct ksmbd_work *work) ...@@ -325,16 +332,14 @@ int smb2_set_rsp_credits(struct ksmbd_work *work)
* TODO: Need to adjuct CreditRequest value according to * TODO: Need to adjuct CreditRequest value according to
* current cpu load * current cpu load
*/ */
aux_credits = credits_requested - 1;
if (hdr->Command == SMB2_NEGOTIATE) if (hdr->Command == SMB2_NEGOTIATE)
aux_max = 0; aux_max = 1;
else else
aux_max = conn->max_credits - credit_charge; aux_max = conn->vals->max_credits - credit_charge;
aux_credits = min_t(unsigned short, aux_credits, aux_max); credits_granted = min_t(unsigned short, credits_requested, aux_max);
credits_granted = credit_charge + aux_credits;
if (conn->max_credits - conn->total_credits < credits_granted) if (conn->vals->max_credits - conn->total_credits < credits_granted)
credits_granted = conn->max_credits - credits_granted = conn->vals->max_credits -
conn->total_credits; conn->total_credits;
conn->total_credits += credits_granted; conn->total_credits += credits_granted;
...@@ -610,16 +615,14 @@ static void destroy_previous_session(struct ksmbd_user *user, u64 id) ...@@ -610,16 +615,14 @@ static void destroy_previous_session(struct ksmbd_user *user, u64 id)
/** /**
* smb2_get_name() - get filename string from on the wire smb format * smb2_get_name() - get filename string from on the wire smb format
* @share: ksmbd_share_config pointer
* @src: source buffer * @src: source buffer
* @maxlen: maxlen of source string * @maxlen: maxlen of source string
* @nls_table: nls_table pointer * @local_nls: nls_table pointer
* *
* Return: matching converted filename on success, otherwise error ptr * Return: matching converted filename on success, otherwise error ptr
*/ */
static char * static char *
smb2_get_name(struct ksmbd_share_config *share, const char *src, smb2_get_name(const char *src, const int maxlen, struct nls_table *local_nls)
const int maxlen, struct nls_table *local_nls)
{ {
char *name; char *name;
...@@ -1303,7 +1306,7 @@ static int ntlm_negotiate(struct ksmbd_work *work, ...@@ -1303,7 +1306,7 @@ static int ntlm_negotiate(struct ksmbd_work *work,
int sz, rc; int sz, rc;
ksmbd_debug(SMB, "negotiate phase\n"); ksmbd_debug(SMB, "negotiate phase\n");
rc = ksmbd_decode_ntlmssp_neg_blob(negblob, negblob_len, work->sess); rc = ksmbd_decode_ntlmssp_neg_blob(negblob, negblob_len, work->conn);
if (rc) if (rc)
return rc; return rc;
...@@ -1313,7 +1316,7 @@ static int ntlm_negotiate(struct ksmbd_work *work, ...@@ -1313,7 +1316,7 @@ static int ntlm_negotiate(struct ksmbd_work *work,
memset(chgblob, 0, sizeof(struct challenge_message)); memset(chgblob, 0, sizeof(struct challenge_message));
if (!work->conn->use_spnego) { if (!work->conn->use_spnego) {
sz = ksmbd_build_ntlmssp_challenge_blob(chgblob, work->sess); sz = ksmbd_build_ntlmssp_challenge_blob(chgblob, work->conn);
if (sz < 0) if (sz < 0)
return -ENOMEM; return -ENOMEM;
...@@ -1329,7 +1332,7 @@ static int ntlm_negotiate(struct ksmbd_work *work, ...@@ -1329,7 +1332,7 @@ static int ntlm_negotiate(struct ksmbd_work *work,
return -ENOMEM; return -ENOMEM;
chgblob = (struct challenge_message *)neg_blob; chgblob = (struct challenge_message *)neg_blob;
sz = ksmbd_build_ntlmssp_challenge_blob(chgblob, work->sess); sz = ksmbd_build_ntlmssp_challenge_blob(chgblob, work->conn);
if (sz < 0) { if (sz < 0) {
rc = -ENOMEM; rc = -ENOMEM;
goto out; goto out;
...@@ -1450,28 +1453,30 @@ static int ntlm_authenticate(struct ksmbd_work *work) ...@@ -1450,28 +1453,30 @@ static int ntlm_authenticate(struct ksmbd_work *work)
ksmbd_free_user(user); ksmbd_free_user(user);
return 0; return 0;
} }
ksmbd_free_user(sess->user);
}
sess->user = user; if (!ksmbd_compare_user(sess->user, user)) {
if (user_guest(sess->user)) { ksmbd_free_user(user);
if (conn->sign) {
ksmbd_debug(SMB, "Guest login not allowed when signing enabled\n");
return -EPERM; return -EPERM;
} }
ksmbd_free_user(user);
} else {
sess->user = user;
}
if (user_guest(sess->user)) {
rsp->SessionFlags = SMB2_SESSION_FLAG_IS_GUEST_LE; rsp->SessionFlags = SMB2_SESSION_FLAG_IS_GUEST_LE;
} else { } else {
struct authenticate_message *authblob; struct authenticate_message *authblob;
authblob = user_authblob(conn, req); authblob = user_authblob(conn, req);
sz = le16_to_cpu(req->SecurityBufferLength); sz = le16_to_cpu(req->SecurityBufferLength);
rc = ksmbd_decode_ntlmssp_auth_blob(authblob, sz, sess); rc = ksmbd_decode_ntlmssp_auth_blob(authblob, sz, conn, sess);
if (rc) { if (rc) {
set_user_flag(sess->user, KSMBD_USER_FLAG_BAD_PASSWORD); set_user_flag(sess->user, KSMBD_USER_FLAG_BAD_PASSWORD);
ksmbd_debug(SMB, "authentication failed\n"); ksmbd_debug(SMB, "authentication failed\n");
return -EPERM; return -EPERM;
} }
}
/* /*
* If session state is SMB2_SESSION_VALID, We can assume * If session state is SMB2_SESSION_VALID, We can assume
...@@ -1484,7 +1489,8 @@ static int ntlm_authenticate(struct ksmbd_work *work) ...@@ -1484,7 +1489,8 @@ static int ntlm_authenticate(struct ksmbd_work *work)
return 0; return 0;
} }
if ((conn->sign || server_conf.enforced_signing) || if ((rsp->SessionFlags != SMB2_SESSION_FLAG_IS_GUEST_LE &&
(conn->sign || server_conf.enforced_signing)) ||
(req->SecurityMode & SMB2_NEGOTIATE_SIGNING_REQUIRED)) (req->SecurityMode & SMB2_NEGOTIATE_SIGNING_REQUIRED))
sess->sign = true; sess->sign = true;
...@@ -1504,7 +1510,6 @@ static int ntlm_authenticate(struct ksmbd_work *work) ...@@ -1504,7 +1510,6 @@ static int ntlm_authenticate(struct ksmbd_work *work)
*/ */
sess->sign = false; sess->sign = false;
} }
}
binding_session: binding_session:
if (conn->dialect >= SMB30_PROT_ID) { if (conn->dialect >= SMB30_PROT_ID) {
...@@ -2057,9 +2062,6 @@ int smb2_session_logoff(struct ksmbd_work *work) ...@@ -2057,9 +2062,6 @@ int smb2_session_logoff(struct ksmbd_work *work)
ksmbd_debug(SMB, "request\n"); ksmbd_debug(SMB, "request\n");
/* Got a valid session, set connection state */
WARN_ON(sess->conn != conn);
/* setting CifsExiting here may race with start_tcp_sess */ /* setting CifsExiting here may race with start_tcp_sess */
ksmbd_conn_set_need_reconnect(work); ksmbd_conn_set_need_reconnect(work);
ksmbd_close_session_fds(work); ksmbd_close_session_fds(work);
...@@ -2530,8 +2532,7 @@ int smb2_open(struct ksmbd_work *work) ...@@ -2530,8 +2532,7 @@ int smb2_open(struct ksmbd_work *work)
goto err_out1; goto err_out1;
} }
name = smb2_get_name(share, name = smb2_get_name(req->Buffer,
req->Buffer,
le16_to_cpu(req->NameLength), le16_to_cpu(req->NameLength),
work->conn->local_nls); work->conn->local_nls);
if (IS_ERR(name)) { if (IS_ERR(name)) {
...@@ -3392,7 +3393,6 @@ static int dentry_name(struct ksmbd_dir_info *d_info, int info_level) ...@@ -3392,7 +3393,6 @@ static int dentry_name(struct ksmbd_dir_info *d_info, int info_level)
* @conn: connection instance * @conn: connection instance
* @info_level: smb information level * @info_level: smb information level
* @d_info: structure included variables for query dir * @d_info: structure included variables for query dir
* @user_ns: user namespace
* @ksmbd_kstat: ksmbd wrapper of dirent stat information * @ksmbd_kstat: ksmbd wrapper of dirent stat information
* *
* if directory has many entries, find first can't read it fully. * if directory has many entries, find first can't read it fully.
...@@ -4018,6 +4018,7 @@ int smb2_query_dir(struct ksmbd_work *work) ...@@ -4018,6 +4018,7 @@ int smb2_query_dir(struct ksmbd_work *work)
* buffer_check_err() - helper function to check buffer errors * buffer_check_err() - helper function to check buffer errors
* @reqOutputBufferLength: max buffer length expected in command response * @reqOutputBufferLength: max buffer length expected in command response
* @rsp: query info response buffer contains output buffer length * @rsp: query info response buffer contains output buffer length
* @rsp_org: base response buffer pointer in case of chained response
* @infoclass_size: query info class response buffer size * @infoclass_size: query info class response buffer size
* *
* Return: 0 on success, otherwise error * Return: 0 on success, otherwise error
...@@ -5398,8 +5399,7 @@ static int smb2_rename(struct ksmbd_work *work, ...@@ -5398,8 +5399,7 @@ static int smb2_rename(struct ksmbd_work *work,
goto out; goto out;
} }
new_name = smb2_get_name(share, new_name = smb2_get_name(file_info->FileName,
file_info->FileName,
le32_to_cpu(file_info->FileNameLength), le32_to_cpu(file_info->FileNameLength),
local_nls); local_nls);
if (IS_ERR(new_name)) { if (IS_ERR(new_name)) {
...@@ -5510,8 +5510,7 @@ static int smb2_create_link(struct ksmbd_work *work, ...@@ -5510,8 +5510,7 @@ static int smb2_create_link(struct ksmbd_work *work,
if (!pathname) if (!pathname)
return -ENOMEM; return -ENOMEM;
link_name = smb2_get_name(share, link_name = smb2_get_name(file_info->FileName,
file_info->FileName,
le32_to_cpu(file_info->FileNameLength), le32_to_cpu(file_info->FileNameLength),
local_nls); local_nls);
if (IS_ERR(link_name) || S_ISDIR(file_inode(filp)->i_mode)) { if (IS_ERR(link_name) || S_ISDIR(file_inode(filp)->i_mode)) {
...@@ -5849,7 +5848,7 @@ static int set_file_mode_info(struct ksmbd_file *fp, ...@@ -5849,7 +5848,7 @@ static int set_file_mode_info(struct ksmbd_file *fp,
* smb2_set_info_file() - handler for smb2 set info command * smb2_set_info_file() - handler for smb2 set info command
* @work: smb work containing set info command buffer * @work: smb work containing set info command buffer
* @fp: ksmbd_file pointer * @fp: ksmbd_file pointer
* @info_class: smb2 set info class * @req: request buffer pointer
* @share: ksmbd_share_config pointer * @share: ksmbd_share_config pointer
* *
* Return: 0 on success, otherwise error * Return: 0 on success, otherwise error
...@@ -6121,25 +6120,33 @@ static noinline int smb2_read_pipe(struct ksmbd_work *work) ...@@ -6121,25 +6120,33 @@ static noinline int smb2_read_pipe(struct ksmbd_work *work)
return err; return err;
} }
static ssize_t smb2_read_rdma_channel(struct ksmbd_work *work, static int smb2_set_remote_key_for_rdma(struct ksmbd_work *work,
struct smb2_read_req *req, void *data_buf, struct smb2_buffer_desc_v1 *desc,
size_t length) __le32 Channel,
__le16 ChannelInfoOffset,
__le16 ChannelInfoLength)
{ {
struct smb2_buffer_desc_v1 *desc =
(struct smb2_buffer_desc_v1 *)&req->Buffer[0];
int err;
if (work->conn->dialect == SMB30_PROT_ID && if (work->conn->dialect == SMB30_PROT_ID &&
req->Channel != SMB2_CHANNEL_RDMA_V1) Channel != SMB2_CHANNEL_RDMA_V1)
return -EINVAL; return -EINVAL;
if (req->ReadChannelInfoOffset == 0 || if (ChannelInfoOffset == 0 ||
le16_to_cpu(req->ReadChannelInfoLength) < sizeof(*desc)) le16_to_cpu(ChannelInfoLength) < sizeof(*desc))
return -EINVAL; return -EINVAL;
work->need_invalidate_rkey = work->need_invalidate_rkey =
(req->Channel == SMB2_CHANNEL_RDMA_V1_INVALIDATE); (Channel == SMB2_CHANNEL_RDMA_V1_INVALIDATE);
work->remote_key = le32_to_cpu(desc->token); work->remote_key = le32_to_cpu(desc->token);
return 0;
}
static ssize_t smb2_read_rdma_channel(struct ksmbd_work *work,
struct smb2_read_req *req, void *data_buf,
size_t length)
{
struct smb2_buffer_desc_v1 *desc =
(struct smb2_buffer_desc_v1 *)&req->Buffer[0];
int err;
err = ksmbd_conn_rdma_write(work->conn, data_buf, length, err = ksmbd_conn_rdma_write(work->conn, data_buf, length,
le32_to_cpu(desc->token), le32_to_cpu(desc->token),
...@@ -6162,7 +6169,7 @@ int smb2_read(struct ksmbd_work *work) ...@@ -6162,7 +6169,7 @@ int smb2_read(struct ksmbd_work *work)
struct ksmbd_conn *conn = work->conn; struct ksmbd_conn *conn = work->conn;
struct smb2_read_req *req; struct smb2_read_req *req;
struct smb2_read_rsp *rsp; struct smb2_read_rsp *rsp;
struct ksmbd_file *fp; struct ksmbd_file *fp = NULL;
loff_t offset; loff_t offset;
size_t length, mincount; size_t length, mincount;
ssize_t nbytes = 0, remain_bytes = 0; ssize_t nbytes = 0, remain_bytes = 0;
...@@ -6176,6 +6183,18 @@ int smb2_read(struct ksmbd_work *work) ...@@ -6176,6 +6183,18 @@ int smb2_read(struct ksmbd_work *work)
return smb2_read_pipe(work); return smb2_read_pipe(work);
} }
if (req->Channel == SMB2_CHANNEL_RDMA_V1_INVALIDATE ||
req->Channel == SMB2_CHANNEL_RDMA_V1) {
err = smb2_set_remote_key_for_rdma(work,
(struct smb2_buffer_desc_v1 *)
&req->Buffer[0],
req->Channel,
req->ReadChannelInfoOffset,
req->ReadChannelInfoLength);
if (err)
goto out;
}
fp = ksmbd_lookup_fd_slow(work, le64_to_cpu(req->VolatileFileId), fp = ksmbd_lookup_fd_slow(work, le64_to_cpu(req->VolatileFileId),
le64_to_cpu(req->PersistentFileId)); le64_to_cpu(req->PersistentFileId));
if (!fp) { if (!fp) {
...@@ -6361,21 +6380,6 @@ static ssize_t smb2_write_rdma_channel(struct ksmbd_work *work, ...@@ -6361,21 +6380,6 @@ static ssize_t smb2_write_rdma_channel(struct ksmbd_work *work,
desc = (struct smb2_buffer_desc_v1 *)&req->Buffer[0]; desc = (struct smb2_buffer_desc_v1 *)&req->Buffer[0];
if (work->conn->dialect == SMB30_PROT_ID &&
req->Channel != SMB2_CHANNEL_RDMA_V1)
return -EINVAL;
if (req->Length != 0 || req->DataOffset != 0)
return -EINVAL;
if (req->WriteChannelInfoOffset == 0 ||
le16_to_cpu(req->WriteChannelInfoLength) < sizeof(*desc))
return -EINVAL;
work->need_invalidate_rkey =
(req->Channel == SMB2_CHANNEL_RDMA_V1_INVALIDATE);
work->remote_key = le32_to_cpu(desc->token);
data_buf = kvmalloc(length, GFP_KERNEL | __GFP_ZERO); data_buf = kvmalloc(length, GFP_KERNEL | __GFP_ZERO);
if (!data_buf) if (!data_buf)
return -ENOMEM; return -ENOMEM;
...@@ -6422,6 +6426,20 @@ int smb2_write(struct ksmbd_work *work) ...@@ -6422,6 +6426,20 @@ int smb2_write(struct ksmbd_work *work)
return smb2_write_pipe(work); return smb2_write_pipe(work);
} }
if (req->Channel == SMB2_CHANNEL_RDMA_V1 ||
req->Channel == SMB2_CHANNEL_RDMA_V1_INVALIDATE) {
if (req->Length != 0 || req->DataOffset != 0)
return -EINVAL;
err = smb2_set_remote_key_for_rdma(work,
(struct smb2_buffer_desc_v1 *)
&req->Buffer[0],
req->Channel,
req->WriteChannelInfoOffset,
req->WriteChannelInfoLength);
if (err)
goto out;
}
if (!test_tree_conn_flag(work->tcon, KSMBD_TREE_CONN_FLAG_WRITABLE)) { if (!test_tree_conn_flag(work->tcon, KSMBD_TREE_CONN_FLAG_WRITABLE)) {
ksmbd_debug(SMB, "User does not have write permission\n"); ksmbd_debug(SMB, "User does not have write permission\n");
err = -EACCES; err = -EACCES;
...@@ -7243,15 +7261,10 @@ static int fsctl_query_iface_info_ioctl(struct ksmbd_conn *conn, ...@@ -7243,15 +7261,10 @@ static int fsctl_query_iface_info_ioctl(struct ksmbd_conn *conn,
struct sockaddr_storage_rsp *sockaddr_storage; struct sockaddr_storage_rsp *sockaddr_storage;
unsigned int flags; unsigned int flags;
unsigned long long speed; unsigned long long speed;
struct sockaddr_in6 *csin6 = (struct sockaddr_in6 *)&conn->peer_addr;
rtnl_lock(); rtnl_lock();
for_each_netdev(&init_net, netdev) { for_each_netdev(&init_net, netdev) {
if (out_buf_len < bool ipv4_set = false;
nbytes + sizeof(struct network_interface_info_ioctl_rsp)) {
rtnl_unlock();
return -ENOSPC;
}
if (netdev->type == ARPHRD_LOOPBACK) if (netdev->type == ARPHRD_LOOPBACK)
continue; continue;
...@@ -7259,12 +7272,20 @@ static int fsctl_query_iface_info_ioctl(struct ksmbd_conn *conn, ...@@ -7259,12 +7272,20 @@ static int fsctl_query_iface_info_ioctl(struct ksmbd_conn *conn,
flags = dev_get_flags(netdev); flags = dev_get_flags(netdev);
if (!(flags & IFF_RUNNING)) if (!(flags & IFF_RUNNING))
continue; continue;
ipv6_retry:
if (out_buf_len <
nbytes + sizeof(struct network_interface_info_ioctl_rsp)) {
rtnl_unlock();
return -ENOSPC;
}
nii_rsp = (struct network_interface_info_ioctl_rsp *) nii_rsp = (struct network_interface_info_ioctl_rsp *)
&rsp->Buffer[nbytes]; &rsp->Buffer[nbytes];
nii_rsp->IfIndex = cpu_to_le32(netdev->ifindex); nii_rsp->IfIndex = cpu_to_le32(netdev->ifindex);
nii_rsp->Capability = 0; nii_rsp->Capability = 0;
if (netdev->real_num_tx_queues > 1)
nii_rsp->Capability |= cpu_to_le32(RSS_CAPABLE);
if (ksmbd_rdma_capable_netdev(netdev)) if (ksmbd_rdma_capable_netdev(netdev))
nii_rsp->Capability |= cpu_to_le32(RDMA_CAPABLE); nii_rsp->Capability |= cpu_to_le32(RDMA_CAPABLE);
...@@ -7289,8 +7310,7 @@ static int fsctl_query_iface_info_ioctl(struct ksmbd_conn *conn, ...@@ -7289,8 +7310,7 @@ static int fsctl_query_iface_info_ioctl(struct ksmbd_conn *conn,
nii_rsp->SockAddr_Storage; nii_rsp->SockAddr_Storage;
memset(sockaddr_storage, 0, 128); memset(sockaddr_storage, 0, 128);
if (conn->peer_addr.ss_family == PF_INET || if (!ipv4_set) {
ipv6_addr_v4mapped(&csin6->sin6_addr)) {
struct in_device *idev; struct in_device *idev;
sockaddr_storage->Family = cpu_to_le16(INTERNETWORK); sockaddr_storage->Family = cpu_to_le16(INTERNETWORK);
...@@ -7301,6 +7321,9 @@ static int fsctl_query_iface_info_ioctl(struct ksmbd_conn *conn, ...@@ -7301,6 +7321,9 @@ static int fsctl_query_iface_info_ioctl(struct ksmbd_conn *conn,
continue; continue;
sockaddr_storage->addr4.IPv4address = sockaddr_storage->addr4.IPv4address =
idev_ipv4_address(idev); idev_ipv4_address(idev);
nbytes += sizeof(struct network_interface_info_ioctl_rsp);
ipv4_set = true;
goto ipv6_retry;
} else { } else {
struct inet6_dev *idev6; struct inet6_dev *idev6;
struct inet6_ifaddr *ifa; struct inet6_ifaddr *ifa;
...@@ -7322,10 +7345,9 @@ static int fsctl_query_iface_info_ioctl(struct ksmbd_conn *conn, ...@@ -7322,10 +7345,9 @@ static int fsctl_query_iface_info_ioctl(struct ksmbd_conn *conn,
break; break;
} }
sockaddr_storage->addr6.ScopeId = 0; sockaddr_storage->addr6.ScopeId = 0;
}
nbytes += sizeof(struct network_interface_info_ioctl_rsp); nbytes += sizeof(struct network_interface_info_ioctl_rsp);
} }
}
rtnl_unlock(); rtnl_unlock();
/* zero if this is last one */ /* zero if this is last one */
......
...@@ -980,6 +980,7 @@ int init_smb3_11_server(struct ksmbd_conn *conn); ...@@ -980,6 +980,7 @@ int init_smb3_11_server(struct ksmbd_conn *conn);
void init_smb2_max_read_size(unsigned int sz); void init_smb2_max_read_size(unsigned int sz);
void init_smb2_max_write_size(unsigned int sz); void init_smb2_max_write_size(unsigned int sz);
void init_smb2_max_trans_size(unsigned int sz); void init_smb2_max_trans_size(unsigned int sz);
void init_smb2_max_credits(unsigned int sz);
bool is_smb2_neg_cmd(struct ksmbd_work *work); bool is_smb2_neg_cmd(struct ksmbd_work *work);
bool is_smb2_rsp(struct ksmbd_work *work); bool is_smb2_rsp(struct ksmbd_work *work);
......
...@@ -365,6 +365,7 @@ struct smb_version_values { ...@@ -365,6 +365,7 @@ struct smb_version_values {
__u32 max_read_size; __u32 max_read_size;
__u32 max_write_size; __u32 max_write_size;
__u32 max_trans_size; __u32 max_trans_size;
__u32 max_credits;
__u32 large_lock_type; __u32 large_lock_type;
__u32 exclusive_lock_type; __u32 exclusive_lock_type;
__u32 shared_lock_type; __u32 shared_lock_type;
......
...@@ -301,6 +301,8 @@ static int ipc_server_config_on_startup(struct ksmbd_startup_request *req) ...@@ -301,6 +301,8 @@ static int ipc_server_config_on_startup(struct ksmbd_startup_request *req)
init_smb2_max_write_size(req->smb2_max_write); init_smb2_max_write_size(req->smb2_max_write);
if (req->smb2_max_trans) if (req->smb2_max_trans)
init_smb2_max_trans_size(req->smb2_max_trans); init_smb2_max_trans_size(req->smb2_max_trans);
if (req->smb2_max_credits)
init_smb2_max_credits(req->smb2_max_credits);
ret = ksmbd_set_netbios_name(req->netbios_name); ret = ksmbd_set_netbios_name(req->netbios_name);
ret |= ksmbd_set_server_string(req->server_string); ret |= ksmbd_set_server_string(req->server_string);
......
...@@ -34,7 +34,8 @@ ...@@ -34,7 +34,8 @@
#include "smbstatus.h" #include "smbstatus.h"
#include "transport_rdma.h" #include "transport_rdma.h"
#define SMB_DIRECT_PORT 5445 #define SMB_DIRECT_PORT_IWARP 5445
#define SMB_DIRECT_PORT_INFINIBAND 445
#define SMB_DIRECT_VERSION_LE cpu_to_le16(0x0100) #define SMB_DIRECT_VERSION_LE cpu_to_le16(0x0100)
...@@ -60,6 +61,10 @@ ...@@ -60,6 +61,10 @@
* as defined in [MS-SMBD] 3.1.1.1 * as defined in [MS-SMBD] 3.1.1.1
* Those may change after a SMB_DIRECT negotiation * Those may change after a SMB_DIRECT negotiation
*/ */
/* Set 445 port to SMB Direct port by default */
static int smb_direct_port = SMB_DIRECT_PORT_INFINIBAND;
/* The local peer's maximum number of credits to grant to the peer */ /* The local peer's maximum number of credits to grant to the peer */
static int smb_direct_receive_credit_max = 255; static int smb_direct_receive_credit_max = 255;
...@@ -75,10 +80,18 @@ static int smb_direct_max_fragmented_recv_size = 1024 * 1024; ...@@ -75,10 +80,18 @@ static int smb_direct_max_fragmented_recv_size = 1024 * 1024;
/* The maximum single-message size which can be received */ /* The maximum single-message size which can be received */
static int smb_direct_max_receive_size = 8192; static int smb_direct_max_receive_size = 8192;
static int smb_direct_max_read_write_size = 1024 * 1024; static int smb_direct_max_read_write_size = 1048512;
static int smb_direct_max_outstanding_rw_ops = 8; static int smb_direct_max_outstanding_rw_ops = 8;
static LIST_HEAD(smb_direct_device_list);
static DEFINE_RWLOCK(smb_direct_device_lock);
struct smb_direct_device {
struct ib_device *ib_dev;
struct list_head list;
};
static struct smb_direct_listener { static struct smb_direct_listener {
struct rdma_cm_id *cm_id; struct rdma_cm_id *cm_id;
} smb_direct_listener; } smb_direct_listener;
...@@ -415,6 +428,7 @@ static void free_transport(struct smb_direct_transport *t) ...@@ -415,6 +428,7 @@ static void free_transport(struct smb_direct_transport *t)
if (t->qp) { if (t->qp) {
ib_drain_qp(t->qp); ib_drain_qp(t->qp);
ib_mr_pool_destroy(t->qp, &t->qp->rdma_mrs);
ib_destroy_qp(t->qp); ib_destroy_qp(t->qp);
} }
...@@ -555,6 +569,7 @@ static void recv_done(struct ib_cq *cq, struct ib_wc *wc) ...@@ -555,6 +569,7 @@ static void recv_done(struct ib_cq *cq, struct ib_wc *wc)
} }
t->negotiation_requested = true; t->negotiation_requested = true;
t->full_packet_received = true; t->full_packet_received = true;
enqueue_reassembly(t, recvmsg, 0);
wake_up_interruptible(&t->wait_status); wake_up_interruptible(&t->wait_status);
break; break;
case SMB_DIRECT_MSG_DATA_TRANSFER: { case SMB_DIRECT_MSG_DATA_TRANSFER: {
...@@ -1438,6 +1453,15 @@ static void smb_direct_disconnect(struct ksmbd_transport *t) ...@@ -1438,6 +1453,15 @@ static void smb_direct_disconnect(struct ksmbd_transport *t)
free_transport(st); free_transport(st);
} }
static void smb_direct_shutdown(struct ksmbd_transport *t)
{
struct smb_direct_transport *st = smb_trans_direct_transfort(t);
ksmbd_debug(RDMA, "smb-direct shutdown cm_id=%p\n", st->cm_id);
smb_direct_disconnect_rdma_work(&st->disconnect_work);
}
static int smb_direct_cm_handler(struct rdma_cm_id *cm_id, static int smb_direct_cm_handler(struct rdma_cm_id *cm_id,
struct rdma_cm_event *event) struct rdma_cm_event *event)
{ {
...@@ -1581,19 +1605,13 @@ static int smb_direct_accept_client(struct smb_direct_transport *t) ...@@ -1581,19 +1605,13 @@ static int smb_direct_accept_client(struct smb_direct_transport *t)
pr_err("error at rdma_accept: %d\n", ret); pr_err("error at rdma_accept: %d\n", ret);
return ret; return ret;
} }
wait_event_interruptible(t->wait_status,
t->status != SMB_DIRECT_CS_NEW);
if (t->status != SMB_DIRECT_CS_CONNECTED)
return -ENOTCONN;
return 0; return 0;
} }
static int smb_direct_negotiate(struct smb_direct_transport *t) static int smb_direct_prepare_negotiation(struct smb_direct_transport *t)
{ {
int ret; int ret;
struct smb_direct_recvmsg *recvmsg; struct smb_direct_recvmsg *recvmsg;
struct smb_direct_negotiate_req *req;
recvmsg = get_free_recvmsg(t); recvmsg = get_free_recvmsg(t);
if (!recvmsg) if (!recvmsg)
...@@ -1603,43 +1621,19 @@ static int smb_direct_negotiate(struct smb_direct_transport *t) ...@@ -1603,43 +1621,19 @@ static int smb_direct_negotiate(struct smb_direct_transport *t)
ret = smb_direct_post_recv(t, recvmsg); ret = smb_direct_post_recv(t, recvmsg);
if (ret) { if (ret) {
pr_err("Can't post recv: %d\n", ret); pr_err("Can't post recv: %d\n", ret);
goto out; goto out_err;
} }
t->negotiation_requested = false; t->negotiation_requested = false;
ret = smb_direct_accept_client(t); ret = smb_direct_accept_client(t);
if (ret) { if (ret) {
pr_err("Can't accept client\n"); pr_err("Can't accept client\n");
goto out; goto out_err;
} }
smb_direct_post_recv_credits(&t->post_recv_credits_work.work); smb_direct_post_recv_credits(&t->post_recv_credits_work.work);
return 0;
ksmbd_debug(RDMA, "Waiting for SMB_DIRECT negotiate request\n"); out_err:
ret = wait_event_interruptible_timeout(t->wait_status,
t->negotiation_requested ||
t->status == SMB_DIRECT_CS_DISCONNECTED,
SMB_DIRECT_NEGOTIATE_TIMEOUT * HZ);
if (ret <= 0 || t->status == SMB_DIRECT_CS_DISCONNECTED) {
ret = ret < 0 ? ret : -ETIMEDOUT;
goto out;
}
ret = smb_direct_check_recvmsg(recvmsg);
if (ret == -ECONNABORTED)
goto out;
req = (struct smb_direct_negotiate_req *)recvmsg->packet;
t->max_recv_size = min_t(int, t->max_recv_size,
le32_to_cpu(req->preferred_send_size));
t->max_send_size = min_t(int, t->max_send_size,
le32_to_cpu(req->max_receive_size));
t->max_fragmented_send_size =
le32_to_cpu(req->max_fragmented_size);
ret = smb_direct_send_negotiate_response(t, ret);
out:
if (recvmsg)
put_recvmsg(t, recvmsg); put_recvmsg(t, recvmsg);
return ret; return ret;
} }
...@@ -1724,7 +1718,9 @@ static int smb_direct_init_params(struct smb_direct_transport *t, ...@@ -1724,7 +1718,9 @@ static int smb_direct_init_params(struct smb_direct_transport *t,
cap->max_send_sge = SMB_DIRECT_MAX_SEND_SGES; cap->max_send_sge = SMB_DIRECT_MAX_SEND_SGES;
cap->max_recv_sge = SMB_DIRECT_MAX_RECV_SGES; cap->max_recv_sge = SMB_DIRECT_MAX_RECV_SGES;
cap->max_inline_data = 0; cap->max_inline_data = 0;
cap->max_rdma_ctxs = 0; cap->max_rdma_ctxs =
rdma_rw_mr_factor(device, t->cm_id->port_num, max_pages) *
smb_direct_max_outstanding_rw_ops;
return 0; return 0;
} }
...@@ -1806,6 +1802,7 @@ static int smb_direct_create_qpair(struct smb_direct_transport *t, ...@@ -1806,6 +1802,7 @@ static int smb_direct_create_qpair(struct smb_direct_transport *t,
{ {
int ret; int ret;
struct ib_qp_init_attr qp_attr; struct ib_qp_init_attr qp_attr;
int pages_per_rw;
t->pd = ib_alloc_pd(t->cm_id->device, 0); t->pd = ib_alloc_pd(t->cm_id->device, 0);
if (IS_ERR(t->pd)) { if (IS_ERR(t->pd)) {
...@@ -1853,6 +1850,23 @@ static int smb_direct_create_qpair(struct smb_direct_transport *t, ...@@ -1853,6 +1850,23 @@ static int smb_direct_create_qpair(struct smb_direct_transport *t,
t->qp = t->cm_id->qp; t->qp = t->cm_id->qp;
t->cm_id->event_handler = smb_direct_cm_handler; t->cm_id->event_handler = smb_direct_cm_handler;
pages_per_rw = DIV_ROUND_UP(t->max_rdma_rw_size, PAGE_SIZE) + 1;
if (pages_per_rw > t->cm_id->device->attrs.max_sgl_rd) {
int pages_per_mr, mr_count;
pages_per_mr = min_t(int, pages_per_rw,
t->cm_id->device->attrs.max_fast_reg_page_list_len);
mr_count = DIV_ROUND_UP(pages_per_rw, pages_per_mr) *
atomic_read(&t->rw_avail_ops);
ret = ib_mr_pool_init(t->qp, &t->qp->rdma_mrs, mr_count,
IB_MR_TYPE_MEM_REG, pages_per_mr, 0);
if (ret) {
pr_err("failed to init mr pool count %d pages %d\n",
mr_count, pages_per_mr);
goto err;
}
}
return 0; return 0;
err: err:
if (t->qp) { if (t->qp) {
...@@ -1877,6 +1891,49 @@ static int smb_direct_create_qpair(struct smb_direct_transport *t, ...@@ -1877,6 +1891,49 @@ static int smb_direct_create_qpair(struct smb_direct_transport *t,
static int smb_direct_prepare(struct ksmbd_transport *t) static int smb_direct_prepare(struct ksmbd_transport *t)
{ {
struct smb_direct_transport *st = smb_trans_direct_transfort(t); struct smb_direct_transport *st = smb_trans_direct_transfort(t);
struct smb_direct_recvmsg *recvmsg;
struct smb_direct_negotiate_req *req;
int ret;
ksmbd_debug(RDMA, "Waiting for SMB_DIRECT negotiate request\n");
ret = wait_event_interruptible_timeout(st->wait_status,
st->negotiation_requested ||
st->status == SMB_DIRECT_CS_DISCONNECTED,
SMB_DIRECT_NEGOTIATE_TIMEOUT * HZ);
if (ret <= 0 || st->status == SMB_DIRECT_CS_DISCONNECTED)
return ret < 0 ? ret : -ETIMEDOUT;
recvmsg = get_first_reassembly(st);
if (!recvmsg)
return -ECONNABORTED;
ret = smb_direct_check_recvmsg(recvmsg);
if (ret == -ECONNABORTED)
goto out;
req = (struct smb_direct_negotiate_req *)recvmsg->packet;
st->max_recv_size = min_t(int, st->max_recv_size,
le32_to_cpu(req->preferred_send_size));
st->max_send_size = min_t(int, st->max_send_size,
le32_to_cpu(req->max_receive_size));
st->max_fragmented_send_size =
le32_to_cpu(req->max_fragmented_size);
st->max_fragmented_recv_size =
(st->recv_credit_max * st->max_recv_size) / 2;
ret = smb_direct_send_negotiate_response(st, ret);
out:
spin_lock_irq(&st->reassembly_queue_lock);
st->reassembly_queue_length--;
list_del(&recvmsg->list);
spin_unlock_irq(&st->reassembly_queue_lock);
put_recvmsg(st, recvmsg);
return ret;
}
static int smb_direct_connect(struct smb_direct_transport *st)
{
int ret; int ret;
struct ib_qp_cap qp_cap; struct ib_qp_cap qp_cap;
...@@ -1898,13 +1955,11 @@ static int smb_direct_prepare(struct ksmbd_transport *t) ...@@ -1898,13 +1955,11 @@ static int smb_direct_prepare(struct ksmbd_transport *t)
return ret; return ret;
} }
ret = smb_direct_negotiate(st); ret = smb_direct_prepare_negotiation(st);
if (ret) { if (ret) {
pr_err("Can't negotiate: %d\n", ret); pr_err("Can't negotiate: %d\n", ret);
return ret; return ret;
} }
st->status = SMB_DIRECT_CS_CONNECTED;
return 0; return 0;
} }
...@@ -1920,6 +1975,7 @@ static bool rdma_frwr_is_supported(struct ib_device_attr *attrs) ...@@ -1920,6 +1975,7 @@ static bool rdma_frwr_is_supported(struct ib_device_attr *attrs)
static int smb_direct_handle_connect_request(struct rdma_cm_id *new_cm_id) static int smb_direct_handle_connect_request(struct rdma_cm_id *new_cm_id)
{ {
struct smb_direct_transport *t; struct smb_direct_transport *t;
int ret;
if (!rdma_frwr_is_supported(&new_cm_id->device->attrs)) { if (!rdma_frwr_is_supported(&new_cm_id->device->attrs)) {
ksmbd_debug(RDMA, ksmbd_debug(RDMA,
...@@ -1932,18 +1988,23 @@ static int smb_direct_handle_connect_request(struct rdma_cm_id *new_cm_id) ...@@ -1932,18 +1988,23 @@ static int smb_direct_handle_connect_request(struct rdma_cm_id *new_cm_id)
if (!t) if (!t)
return -ENOMEM; return -ENOMEM;
ret = smb_direct_connect(t);
if (ret)
goto out_err;
KSMBD_TRANS(t)->handler = kthread_run(ksmbd_conn_handler_loop, KSMBD_TRANS(t)->handler = kthread_run(ksmbd_conn_handler_loop,
KSMBD_TRANS(t)->conn, "ksmbd:r%u", KSMBD_TRANS(t)->conn, "ksmbd:r%u",
SMB_DIRECT_PORT); smb_direct_port);
if (IS_ERR(KSMBD_TRANS(t)->handler)) { if (IS_ERR(KSMBD_TRANS(t)->handler)) {
int ret = PTR_ERR(KSMBD_TRANS(t)->handler); ret = PTR_ERR(KSMBD_TRANS(t)->handler);
pr_err("Can't start thread\n"); pr_err("Can't start thread\n");
free_transport(t); goto out_err;
return ret;
} }
return 0; return 0;
out_err:
free_transport(t);
return ret;
} }
static int smb_direct_listen_handler(struct rdma_cm_id *cm_id, static int smb_direct_listen_handler(struct rdma_cm_id *cm_id,
...@@ -2007,12 +2068,65 @@ static int smb_direct_listen(int port) ...@@ -2007,12 +2068,65 @@ static int smb_direct_listen(int port)
return ret; return ret;
} }
static int smb_direct_ib_client_add(struct ib_device *ib_dev)
{
struct smb_direct_device *smb_dev;
/* Set 5445 port if device type is iWARP(No IB) */
if (ib_dev->node_type != RDMA_NODE_IB_CA)
smb_direct_port = SMB_DIRECT_PORT_IWARP;
if (!ib_dev->ops.get_netdev ||
!rdma_frwr_is_supported(&ib_dev->attrs))
return 0;
smb_dev = kzalloc(sizeof(*smb_dev), GFP_KERNEL);
if (!smb_dev)
return -ENOMEM;
smb_dev->ib_dev = ib_dev;
write_lock(&smb_direct_device_lock);
list_add(&smb_dev->list, &smb_direct_device_list);
write_unlock(&smb_direct_device_lock);
ksmbd_debug(RDMA, "ib device added: name %s\n", ib_dev->name);
return 0;
}
static void smb_direct_ib_client_remove(struct ib_device *ib_dev,
void *client_data)
{
struct smb_direct_device *smb_dev, *tmp;
write_lock(&smb_direct_device_lock);
list_for_each_entry_safe(smb_dev, tmp, &smb_direct_device_list, list) {
if (smb_dev->ib_dev == ib_dev) {
list_del(&smb_dev->list);
kfree(smb_dev);
break;
}
}
write_unlock(&smb_direct_device_lock);
}
static struct ib_client smb_direct_ib_client = {
.name = "ksmbd_smb_direct_ib",
.add = smb_direct_ib_client_add,
.remove = smb_direct_ib_client_remove,
};
int ksmbd_rdma_init(void) int ksmbd_rdma_init(void)
{ {
int ret; int ret;
smb_direct_listener.cm_id = NULL; smb_direct_listener.cm_id = NULL;
ret = ib_register_client(&smb_direct_ib_client);
if (ret) {
pr_err("failed to ib_register_client\n");
return ret;
}
/* When a client is running out of send credits, the credits are /* When a client is running out of send credits, the credits are
* granted by the server's sending a packet using this queue. * granted by the server's sending a packet using this queue.
* This avoids the situation that a clients cannot send packets * This avoids the situation that a clients cannot send packets
...@@ -2023,7 +2137,7 @@ int ksmbd_rdma_init(void) ...@@ -2023,7 +2137,7 @@ int ksmbd_rdma_init(void)
if (!smb_direct_wq) if (!smb_direct_wq)
return -ENOMEM; return -ENOMEM;
ret = smb_direct_listen(SMB_DIRECT_PORT); ret = smb_direct_listen(smb_direct_port);
if (ret) { if (ret) {
destroy_workqueue(smb_direct_wq); destroy_workqueue(smb_direct_wq);
smb_direct_wq = NULL; smb_direct_wq = NULL;
...@@ -2036,36 +2150,67 @@ int ksmbd_rdma_init(void) ...@@ -2036,36 +2150,67 @@ int ksmbd_rdma_init(void)
return 0; return 0;
} }
int ksmbd_rdma_destroy(void) void ksmbd_rdma_destroy(void)
{ {
if (smb_direct_listener.cm_id) if (!smb_direct_listener.cm_id)
return;
ib_unregister_client(&smb_direct_ib_client);
rdma_destroy_id(smb_direct_listener.cm_id); rdma_destroy_id(smb_direct_listener.cm_id);
smb_direct_listener.cm_id = NULL; smb_direct_listener.cm_id = NULL;
if (smb_direct_wq) { if (smb_direct_wq) {
destroy_workqueue(smb_direct_wq); destroy_workqueue(smb_direct_wq);
smb_direct_wq = NULL; smb_direct_wq = NULL;
} }
return 0;
} }
bool ksmbd_rdma_capable_netdev(struct net_device *netdev) bool ksmbd_rdma_capable_netdev(struct net_device *netdev)
{ {
struct ib_device *ibdev; struct smb_direct_device *smb_dev;
int i;
bool rdma_capable = false; bool rdma_capable = false;
read_lock(&smb_direct_device_lock);
list_for_each_entry(smb_dev, &smb_direct_device_list, list) {
for (i = 0; i < smb_dev->ib_dev->phys_port_cnt; i++) {
struct net_device *ndev;
ndev = smb_dev->ib_dev->ops.get_netdev(smb_dev->ib_dev,
i + 1);
if (!ndev)
continue;
if (ndev == netdev) {
dev_put(ndev);
rdma_capable = true;
goto out;
}
dev_put(ndev);
}
}
out:
read_unlock(&smb_direct_device_lock);
if (rdma_capable == false) {
struct ib_device *ibdev;
ibdev = ib_device_get_by_netdev(netdev, RDMA_DRIVER_UNKNOWN); ibdev = ib_device_get_by_netdev(netdev, RDMA_DRIVER_UNKNOWN);
if (ibdev) { if (ibdev) {
if (rdma_frwr_is_supported(&ibdev->attrs)) if (rdma_frwr_is_supported(&ibdev->attrs))
rdma_capable = true; rdma_capable = true;
ib_device_put(ibdev); ib_device_put(ibdev);
} }
}
return rdma_capable; return rdma_capable;
} }
static struct ksmbd_transport_ops ksmbd_smb_direct_transport_ops = { static struct ksmbd_transport_ops ksmbd_smb_direct_transport_ops = {
.prepare = smb_direct_prepare, .prepare = smb_direct_prepare,
.disconnect = smb_direct_disconnect, .disconnect = smb_direct_disconnect,
.shutdown = smb_direct_shutdown,
.writev = smb_direct_writev, .writev = smb_direct_writev,
.read = smb_direct_read, .read = smb_direct_read,
.rdma_read = smb_direct_rdma_read, .rdma_read = smb_direct_rdma_read,
......
...@@ -7,8 +7,6 @@ ...@@ -7,8 +7,6 @@
#ifndef __KSMBD_TRANSPORT_RDMA_H__ #ifndef __KSMBD_TRANSPORT_RDMA_H__
#define __KSMBD_TRANSPORT_RDMA_H__ #define __KSMBD_TRANSPORT_RDMA_H__
#define SMB_DIRECT_PORT 5445
/* SMB DIRECT negotiation request packet [MS-SMBD] 2.2.1 */ /* SMB DIRECT negotiation request packet [MS-SMBD] 2.2.1 */
struct smb_direct_negotiate_req { struct smb_direct_negotiate_req {
__le16 min_version; __le16 min_version;
...@@ -52,7 +50,7 @@ struct smb_direct_data_transfer { ...@@ -52,7 +50,7 @@ struct smb_direct_data_transfer {
#ifdef CONFIG_SMB_SERVER_SMBDIRECT #ifdef CONFIG_SMB_SERVER_SMBDIRECT
int ksmbd_rdma_init(void); int ksmbd_rdma_init(void);
int ksmbd_rdma_destroy(void); void ksmbd_rdma_destroy(void);
bool ksmbd_rdma_capable_netdev(struct net_device *netdev); bool ksmbd_rdma_capable_netdev(struct net_device *netdev);
#else #else
static inline int ksmbd_rdma_init(void) { return 0; } static inline int ksmbd_rdma_init(void) { return 0; }
......
...@@ -404,7 +404,7 @@ static int create_socket(struct interface *iface) ...@@ -404,7 +404,7 @@ static int create_socket(struct interface *iface)
&ksmbd_socket); &ksmbd_socket);
if (ret) { if (ret) {
pr_err("Can't create socket for ipv4: %d\n", ret); pr_err("Can't create socket for ipv4: %d\n", ret);
goto out_error; goto out_clear;
} }
sin.sin_family = PF_INET; sin.sin_family = PF_INET;
...@@ -462,6 +462,7 @@ static int create_socket(struct interface *iface) ...@@ -462,6 +462,7 @@ static int create_socket(struct interface *iface)
out_error: out_error:
tcp_destroy_socket(ksmbd_socket); tcp_destroy_socket(ksmbd_socket);
out_clear:
iface->ksmbd_socket = NULL; iface->ksmbd_socket = NULL;
return ret; return ret;
} }
......
...@@ -96,16 +96,6 @@ struct ksmbd_file { ...@@ -96,16 +96,6 @@ struct ksmbd_file {
int durable_timeout; int durable_timeout;
/* for SMB1 */
int pid;
/* conflict lock fail count for SMB1 */
unsigned int cflock_cnt;
/* last lock failure start offset for SMB1 */
unsigned long long llock_fstart;
int dirent_offset;
/* if ls is happening on directory, below is valid*/ /* if ls is happening on directory, below is valid*/
struct ksmbd_readdir_data readdir_data; struct ksmbd_readdir_data readdir_data;
int dot_dotdot[2]; int dot_dotdot[2];
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment