Commit f6a6bf7c authored by Aurelien Aptel's avatar Aurelien Aptel Committed by Steve French

cifs: switch servers depending on binding state

Currently a lot of the code to initialize a connection & session uses
the cifs_ses as input. But depending on if we are opening a new session
or a new channel we need to use different server pointers.

Add a "binding" flag in cifs_ses and a helper function that returns
the server ptr a session should use (only in the sess establishment
code path).
Signed-off-by: default avatarAurelien Aptel <aaptel@suse.com>
Signed-off-by: default avatarSteve French <stfrench@microsoft.com>
parent f780bd3f
...@@ -98,7 +98,7 @@ struct key_type cifs_spnego_key_type = { ...@@ -98,7 +98,7 @@ struct key_type cifs_spnego_key_type = {
struct key * struct key *
cifs_get_spnego_key(struct cifs_ses *sesInfo) cifs_get_spnego_key(struct cifs_ses *sesInfo)
{ {
struct TCP_Server_Info *server = sesInfo->server; struct TCP_Server_Info *server = cifs_ses_server(sesInfo);
struct sockaddr_in *sa = (struct sockaddr_in *) &server->dstaddr; struct sockaddr_in *sa = (struct sockaddr_in *) &server->dstaddr;
struct sockaddr_in6 *sa6 = (struct sockaddr_in6 *) &server->dstaddr; struct sockaddr_in6 *sa6 = (struct sockaddr_in6 *) &server->dstaddr;
char *description, *dp; char *description, *dp;
......
...@@ -994,6 +994,7 @@ struct cifs_ses { ...@@ -994,6 +994,7 @@ struct cifs_ses {
bool sign; /* is signing required? */ bool sign; /* is signing required? */
bool need_reconnect:1; /* connection reset, uid now invalid */ bool need_reconnect:1; /* connection reset, uid now invalid */
bool domainAuto:1; bool domainAuto:1;
bool binding:1; /* are we binding the session? */
__u16 session_flags; __u16 session_flags;
__u8 smb3signingkey[SMB3_SIGN_KEY_SIZE]; __u8 smb3signingkey[SMB3_SIGN_KEY_SIZE];
__u8 smb3encryptionkey[SMB3_SIGN_KEY_SIZE]; __u8 smb3encryptionkey[SMB3_SIGN_KEY_SIZE];
...@@ -1021,6 +1022,15 @@ struct cifs_ses { ...@@ -1021,6 +1022,15 @@ struct cifs_ses {
atomic_t chan_seq; /* round robin state */ atomic_t chan_seq; /* round robin state */
}; };
static inline
struct TCP_Server_Info *cifs_ses_server(struct cifs_ses *ses)
{
if (ses->binding)
return ses->chans[ses->chan_count].server;
else
return ses->server;
}
static inline bool static inline bool
cap_unix(struct cifs_ses *ses) cap_unix(struct cifs_ses *ses)
{ {
......
...@@ -5185,7 +5185,7 @@ int ...@@ -5185,7 +5185,7 @@ int
cifs_negotiate_protocol(const unsigned int xid, struct cifs_ses *ses) cifs_negotiate_protocol(const unsigned int xid, struct cifs_ses *ses)
{ {
int rc = 0; int rc = 0;
struct TCP_Server_Info *server = ses->server; struct TCP_Server_Info *server = cifs_ses_server(ses);
if (!server->ops->need_neg || !server->ops->negotiate) if (!server->ops->need_neg || !server->ops->negotiate)
return -ENOSYS; return -ENOSYS;
...@@ -5212,7 +5212,7 @@ cifs_setup_session(const unsigned int xid, struct cifs_ses *ses, ...@@ -5212,7 +5212,7 @@ cifs_setup_session(const unsigned int xid, struct cifs_ses *ses,
struct nls_table *nls_info) struct nls_table *nls_info)
{ {
int rc = -ENOSYS; int rc = -ENOSYS;
struct TCP_Server_Info *server = ses->server; struct TCP_Server_Info *server = cifs_ses_server(ses);
ses->capabilities = server->capabilities; ses->capabilities = server->capabilities;
if (linuxExtEnabled == 0) if (linuxExtEnabled == 0)
......
...@@ -342,6 +342,7 @@ int decode_ntlmssp_challenge(char *bcc_ptr, int blob_len, ...@@ -342,6 +342,7 @@ int decode_ntlmssp_challenge(char *bcc_ptr, int blob_len,
void build_ntlmssp_negotiate_blob(unsigned char *pbuffer, void build_ntlmssp_negotiate_blob(unsigned char *pbuffer,
struct cifs_ses *ses) struct cifs_ses *ses)
{ {
struct TCP_Server_Info *server = cifs_ses_server(ses);
NEGOTIATE_MESSAGE *sec_blob = (NEGOTIATE_MESSAGE *)pbuffer; NEGOTIATE_MESSAGE *sec_blob = (NEGOTIATE_MESSAGE *)pbuffer;
__u32 flags; __u32 flags;
...@@ -354,9 +355,9 @@ void build_ntlmssp_negotiate_blob(unsigned char *pbuffer, ...@@ -354,9 +355,9 @@ void build_ntlmssp_negotiate_blob(unsigned char *pbuffer,
NTLMSSP_NEGOTIATE_128 | NTLMSSP_NEGOTIATE_UNICODE | NTLMSSP_NEGOTIATE_128 | NTLMSSP_NEGOTIATE_UNICODE |
NTLMSSP_NEGOTIATE_NTLM | NTLMSSP_NEGOTIATE_EXTENDED_SEC | NTLMSSP_NEGOTIATE_NTLM | NTLMSSP_NEGOTIATE_EXTENDED_SEC |
NTLMSSP_NEGOTIATE_SEAL; NTLMSSP_NEGOTIATE_SEAL;
if (ses->server->sign) if (server->sign)
flags |= NTLMSSP_NEGOTIATE_SIGN; flags |= NTLMSSP_NEGOTIATE_SIGN;
if (!ses->server->session_estab || ses->ntlmssp->sesskey_per_smbsess) if (!server->session_estab || ses->ntlmssp->sesskey_per_smbsess)
flags |= NTLMSSP_NEGOTIATE_KEY_XCH; flags |= NTLMSSP_NEGOTIATE_KEY_XCH;
sec_blob->NegotiateFlags = cpu_to_le32(flags); sec_blob->NegotiateFlags = cpu_to_le32(flags);
......
...@@ -310,7 +310,7 @@ smb2_negotiate(const unsigned int xid, struct cifs_ses *ses) ...@@ -310,7 +310,7 @@ smb2_negotiate(const unsigned int xid, struct cifs_ses *ses)
{ {
int rc; int rc;
ses->server->CurrentMid = 0; cifs_ses_server(ses)->CurrentMid = 0;
rc = SMB2_negotiate(xid, ses); rc = SMB2_negotiate(xid, ses);
/* BB we probably don't need to retry with modern servers */ /* BB we probably don't need to retry with modern servers */
if (rc == -EAGAIN) if (rc == -EAGAIN)
......
...@@ -791,7 +791,7 @@ SMB2_negotiate(const unsigned int xid, struct cifs_ses *ses) ...@@ -791,7 +791,7 @@ SMB2_negotiate(const unsigned int xid, struct cifs_ses *ses)
struct kvec rsp_iov; struct kvec rsp_iov;
int rc = 0; int rc = 0;
int resp_buftype; int resp_buftype;
struct TCP_Server_Info *server = ses->server; struct TCP_Server_Info *server = cifs_ses_server(ses);
int blob_offset, blob_length; int blob_offset, blob_length;
char *security_blob; char *security_blob;
int flags = CIFS_NEG_OP; int flags = CIFS_NEG_OP;
...@@ -813,7 +813,7 @@ SMB2_negotiate(const unsigned int xid, struct cifs_ses *ses) ...@@ -813,7 +813,7 @@ SMB2_negotiate(const unsigned int xid, struct cifs_ses *ses)
memset(server->preauth_sha_hash, 0, SMB2_PREAUTH_HASH_SIZE); memset(server->preauth_sha_hash, 0, SMB2_PREAUTH_HASH_SIZE);
memset(ses->preauth_sha_hash, 0, SMB2_PREAUTH_HASH_SIZE); memset(ses->preauth_sha_hash, 0, SMB2_PREAUTH_HASH_SIZE);
if (strcmp(ses->server->vals->version_string, if (strcmp(server->vals->version_string,
SMB3ANY_VERSION_STRING) == 0) { SMB3ANY_VERSION_STRING) == 0) {
req->Dialects[0] = cpu_to_le16(SMB30_PROT_ID); req->Dialects[0] = cpu_to_le16(SMB30_PROT_ID);
req->Dialects[1] = cpu_to_le16(SMB302_PROT_ID); req->Dialects[1] = cpu_to_le16(SMB302_PROT_ID);
...@@ -829,7 +829,7 @@ SMB2_negotiate(const unsigned int xid, struct cifs_ses *ses) ...@@ -829,7 +829,7 @@ SMB2_negotiate(const unsigned int xid, struct cifs_ses *ses)
total_len += 8; total_len += 8;
} else { } else {
/* otherwise send specific dialect */ /* otherwise send specific dialect */
req->Dialects[0] = cpu_to_le16(ses->server->vals->protocol_id); req->Dialects[0] = cpu_to_le16(server->vals->protocol_id);
req->DialectCount = cpu_to_le16(1); req->DialectCount = cpu_to_le16(1);
total_len += 2; total_len += 2;
} }
...@@ -1171,7 +1171,7 @@ SMB2_sess_alloc_buffer(struct SMB2_sess_data *sess_data) ...@@ -1171,7 +1171,7 @@ SMB2_sess_alloc_buffer(struct SMB2_sess_data *sess_data)
int rc; int rc;
struct cifs_ses *ses = sess_data->ses; struct cifs_ses *ses = sess_data->ses;
struct smb2_sess_setup_req *req; struct smb2_sess_setup_req *req;
struct TCP_Server_Info *server = ses->server; struct TCP_Server_Info *server = cifs_ses_server(ses);
unsigned int total_len; unsigned int total_len;
rc = smb2_plain_req_init(SMB2_SESSION_SETUP, NULL, (void **) &req, rc = smb2_plain_req_init(SMB2_SESSION_SETUP, NULL, (void **) &req,
...@@ -1258,22 +1258,23 @@ SMB2_sess_establish_session(struct SMB2_sess_data *sess_data) ...@@ -1258,22 +1258,23 @@ SMB2_sess_establish_session(struct SMB2_sess_data *sess_data)
{ {
int rc = 0; int rc = 0;
struct cifs_ses *ses = sess_data->ses; struct cifs_ses *ses = sess_data->ses;
struct TCP_Server_Info *server = cifs_ses_server(ses);
mutex_lock(&ses->server->srv_mutex); mutex_lock(&server->srv_mutex);
if (ses->server->ops->generate_signingkey) { if (server->ops->generate_signingkey) {
rc = ses->server->ops->generate_signingkey(ses); rc = server->ops->generate_signingkey(ses);
if (rc) { if (rc) {
cifs_dbg(FYI, cifs_dbg(FYI,
"SMB3 session key generation failed\n"); "SMB3 session key generation failed\n");
mutex_unlock(&ses->server->srv_mutex); mutex_unlock(&server->srv_mutex);
return rc; return rc;
} }
} }
if (!ses->server->session_estab) { if (!server->session_estab) {
ses->server->sequence_number = 0x2; server->sequence_number = 0x2;
ses->server->session_estab = true; server->session_estab = true;
} }
mutex_unlock(&ses->server->srv_mutex); mutex_unlock(&server->srv_mutex);
cifs_dbg(FYI, "SMB2/3 session established successfully\n"); cifs_dbg(FYI, "SMB2/3 session established successfully\n");
spin_lock(&GlobalMid_Lock); spin_lock(&GlobalMid_Lock);
...@@ -1509,7 +1510,7 @@ SMB2_select_sec(struct cifs_ses *ses, struct SMB2_sess_data *sess_data) ...@@ -1509,7 +1510,7 @@ SMB2_select_sec(struct cifs_ses *ses, struct SMB2_sess_data *sess_data)
{ {
int type; int type;
type = smb2_select_sectype(ses->server, ses->sectype); type = smb2_select_sectype(cifs_ses_server(ses), ses->sectype);
cifs_dbg(FYI, "sess setup type %d\n", type); cifs_dbg(FYI, "sess setup type %d\n", type);
if (type == Unspecified) { if (type == Unspecified) {
cifs_dbg(VFS, cifs_dbg(VFS,
...@@ -1537,7 +1538,7 @@ SMB2_sess_setup(const unsigned int xid, struct cifs_ses *ses, ...@@ -1537,7 +1538,7 @@ SMB2_sess_setup(const unsigned int xid, struct cifs_ses *ses,
const struct nls_table *nls_cp) const struct nls_table *nls_cp)
{ {
int rc = 0; int rc = 0;
struct TCP_Server_Info *server = ses->server; struct TCP_Server_Info *server = cifs_ses_server(ses);
struct SMB2_sess_data *sess_data; struct SMB2_sess_data *sess_data;
cifs_dbg(FYI, "Session Setup\n"); cifs_dbg(FYI, "Session Setup\n");
...@@ -1563,7 +1564,7 @@ SMB2_sess_setup(const unsigned int xid, struct cifs_ses *ses, ...@@ -1563,7 +1564,7 @@ SMB2_sess_setup(const unsigned int xid, struct cifs_ses *ses,
/* /*
* Initialize the session hash with the server one. * Initialize the session hash with the server one.
*/ */
memcpy(ses->preauth_sha_hash, ses->server->preauth_sha_hash, memcpy(ses->preauth_sha_hash, server->preauth_sha_hash,
SMB2_PREAUTH_HASH_SIZE); SMB2_PREAUTH_HASH_SIZE);
while (sess_data->func) while (sess_data->func)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment