diff --git a/include/net/tls.h b/include/net/tls.h index e2d84b791e5caa2bf9ccb6f667d66ff4b180820f..4457bc67f2e66d4b9fab5d4749e3ad15ec4282e9 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -169,7 +169,11 @@ enum { }; enum tls_context_flags { - TLS_RX_SYNC_RUNNING = 0, + /* tls_device_down was called after the netdev went down, device state + * was released, and kTLS works in software, even though rx_conf is + * still TLS_HW (needed for transition). + */ + TLS_RX_DEV_DEGRADED = 0, /* tls_dev_del was called for the RX side, device state was released, * but tls_ctx->netdev might still be kept, because TX-side driver * resources might not be released yet. Used to prevent the second @@ -203,7 +207,7 @@ struct tls_context { union tls_crypto_context crypto_recv; struct list_head list; - struct net_device *netdev; + struct net_device __rcu *netdev; refcount_t refcount; void *priv_ctx_tx; @@ -237,6 +241,16 @@ struct tls_context { void (*unhash)(struct sock *sk); }; +struct tls_context_wrapper { + struct tls_context ctx; + struct sock *sk; +}; + +static inline struct tls_context_wrapper *tls_ctx_wrapper(const struct tls_context *ctx) +{ + return (struct tls_context_wrapper *)ctx; +} + struct tls_offload_context_rx { /* sw must be the first member of tls_offload_context_rx */ struct tls_sw_context_rx sw; @@ -336,6 +350,9 @@ static inline bool tls_is_pending_open_record(struct tls_context *tls_ctx) struct sk_buff * tls_validate_xmit_skb(struct sock *sk, struct net_device *dev, struct sk_buff *skb); +struct sk_buff * +tls_validate_xmit_skb_sw(struct sock *sk, struct net_device *dev, + struct sk_buff *skb); static inline bool tls_is_sk_tx_device_offloaded(struct sock *sk) { diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index 228e3ce48d437291d59ee5c90a1cbaff9ebe17e8..82d5573cd054be0faa6f55af9a5d8934182d454f 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -48,6 +48,7 @@ static void tls_device_gc_task(struct work_struct *work); static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task); static LIST_HEAD(tls_device_gc_list); static LIST_HEAD(tls_device_list); +static LIST_HEAD(tls_device_down_list); static DEFINE_SPINLOCK(tls_device_lock); static void tls_device_free_ctx(struct tls_context *ctx) @@ -67,6 +68,7 @@ static void tls_device_free_ctx(struct tls_context *ctx) static void tls_device_gc_task(struct work_struct *work) { struct tls_context *ctx, *tmp; + struct net_device *netdev; unsigned long flags; LIST_HEAD(gc_list); @@ -75,7 +77,11 @@ static void tls_device_gc_task(struct work_struct *work) spin_unlock_irqrestore(&tls_device_lock, flags); list_for_each_entry_safe(ctx, tmp, &gc_list, list) { - struct net_device *netdev = ctx->netdev; + /* Safe, because this is the destroy flow, refcount is 0, so + * tls_device_down can't store this field in parallel. + */ + netdev = rcu_dereference_protected(ctx->netdev, + !refcount_read(&ctx->refcount)); if (netdev && ctx->tx_conf == TLS_HW) { netdev->tlsdev_ops->tls_dev_del(netdev, ctx, @@ -95,7 +101,7 @@ static void tls_device_attach(struct tls_context *ctx, struct sock *sk, if (sk->sk_destruct != tls_device_sk_destruct) { refcount_set(&ctx->refcount, 1); dev_hold(netdev); - ctx->netdev = netdev; + RCU_INIT_POINTER(ctx->netdev, netdev); spin_lock_irq(&tls_device_lock); list_add_tail(&ctx->list, &tls_device_list); spin_unlock_irq(&tls_device_lock); @@ -571,12 +577,11 @@ static void tls_device_resync_rx(struct tls_context *tls_ctx, { struct net_device *netdev; - if (WARN_ON(test_and_set_bit(TLS_RX_SYNC_RUNNING, &tls_ctx->flags))) - return; - netdev = READ_ONCE(tls_ctx->netdev); + rcu_read_lock(); + netdev = rcu_dereference(tls_ctx->netdev); if (netdev) netdev->tlsdev_ops->tls_dev_resync_rx(netdev, sk, seq, rcd_sn); - clear_bit_unlock(TLS_RX_SYNC_RUNNING, &tls_ctx->flags); + rcu_read_unlock(); } void handle_device_resync(struct sock *sk, u32 seq, u64 rcd_sn) @@ -589,6 +594,8 @@ void handle_device_resync(struct sock *sk, u32 seq, u64 rcd_sn) if (tls_ctx->rx_conf != TLS_HW) return; + if (unlikely(test_bit(TLS_RX_DEV_DEGRADED, &tls_ctx->flags))) + return; rx_ctx = tls_offload_ctx_rx(tls_ctx); resync_req = atomic64_read(&rx_ctx->resync_req); @@ -699,7 +706,18 @@ int tls_device_decrypted(struct sock *sk, struct sk_buff *skb) ctx->sw.decrypted |= is_decrypted; - /* Return immedeatly if the record is either entirely plaintext or + if (unlikely(test_bit(TLS_RX_DEV_DEGRADED, &tls_ctx->flags))) { + if (likely(is_encrypted || is_decrypted)) + return 0; + + /* After tls_device_down disables the offload, the next SKB will + * likely have initial fragments decrypted, and final ones not + * decrypted. We need to reencrypt that single SKB. + */ + return tls_device_reencrypt(sk, skb); + } + + /* Return immediately if the record is either entirely plaintext or * entirely ciphertext. Otherwise handle reencrypt partially decrypted * record. */ @@ -945,7 +963,8 @@ void tls_device_offload_cleanup_rx(struct sock *sk) struct net_device *netdev; down_read(&device_offload_lock); - netdev = tls_ctx->netdev; + netdev = rcu_dereference_protected(tls_ctx->netdev, + lockdep_is_held(&device_offload_lock)); if (!netdev) goto out; @@ -954,7 +973,7 @@ void tls_device_offload_cleanup_rx(struct sock *sk) if (tls_ctx->tx_conf != TLS_HW) { dev_put(netdev); - tls_ctx->netdev = NULL; + rcu_assign_pointer(tls_ctx->netdev, NULL); } else { set_bit(TLS_RX_DEV_CLOSED, &tls_ctx->flags); } @@ -965,6 +984,7 @@ void tls_device_offload_cleanup_rx(struct sock *sk) static int tls_device_down(struct net_device *netdev) { + struct tls_context_wrapper *ctx_wrapper; struct tls_context *ctx, *tmp; unsigned long flags; LIST_HEAD(list); @@ -974,7 +994,11 @@ static int tls_device_down(struct net_device *netdev) spin_lock_irqsave(&tls_device_lock, flags); list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) { - if (ctx->netdev != netdev || + struct net_device *ctx_netdev = + rcu_dereference_protected(ctx->netdev, + lockdep_is_held(&device_offload_lock)); + + if (ctx_netdev != netdev || !refcount_inc_not_zero(&ctx->refcount)) continue; @@ -983,6 +1007,27 @@ static int tls_device_down(struct net_device *netdev) spin_unlock_irqrestore(&tls_device_lock, flags); list_for_each_entry_safe(ctx, tmp, &list, list) { + /* Stop offloaded TX and switch to the fallback. + * tls_is_sk_tx_device_offloaded will return false. + */ + ctx_wrapper = tls_ctx_wrapper(ctx); + WRITE_ONCE(ctx_wrapper->sk->sk_validate_xmit_skb, tls_validate_xmit_skb_sw); + + /* Stop the RX and TX resync. + * tls_dev_resync must not be called after tls_dev_del. + */ + rcu_assign_pointer(ctx->netdev, NULL); + + /* Start skipping the RX resync logic completely. */ + set_bit(TLS_RX_DEV_DEGRADED, &ctx->flags); + + /* Sync with inflight packets. After this point: + * TX: no non-encrypted packets will be passed to the driver. + * RX: resync requests from the driver will be ignored. + */ + synchronize_net(); + + /* Release the offload context on the driver side. */ if (ctx->tx_conf == TLS_HW) netdev->tlsdev_ops->tls_dev_del(netdev, ctx, TLS_OFFLOAD_CTX_DIR_TX); @@ -990,15 +1035,29 @@ static int tls_device_down(struct net_device *netdev) !test_bit(TLS_RX_DEV_CLOSED, &ctx->flags)) netdev->tlsdev_ops->tls_dev_del(netdev, ctx, TLS_OFFLOAD_CTX_DIR_RX); - WRITE_ONCE(ctx->netdev, NULL); - smp_mb__before_atomic(); /* pairs with test_and_set_bit() */ - while (test_bit(TLS_RX_SYNC_RUNNING, &ctx->flags)) - usleep_range(10, 200); + dev_put(netdev); - list_del_init(&ctx->list); - if (refcount_dec_and_test(&ctx->refcount)) + /* Move the context to a separate list for two reasons: + * 1. When the context is deallocated, list_del is called. + * 2. It's no longer an offloaded context, so we don't want to + * run offload-specific code on this context. + */ + spin_lock_irqsave(&tls_device_lock, flags); + list_move_tail(&ctx->list, &tls_device_down_list); + spin_unlock_irqrestore(&tls_device_lock, flags); + + /* Device contexts for RX and TX will be freed in on sk_destruct + * by tls_device_free_ctx. rx_conf and tx_conf stay in TLS_HW. + * Now release the ref taken above. + */ + if (refcount_dec_and_test(&ctx->refcount)) { + /* sk_destruct ran after tls_device_down took a ref, and + * it returned early. Complete the destruction here. + */ + list_del(&ctx->list); tls_device_free_ctx(ctx); + } } up_write(&device_offload_lock); diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c index 6cf832891b53edb829fc5b63a8ee049d272aef4d..319aaed05571f8dc600e1d9ed84b206374604e66 100644 --- a/net/tls/tls_device_fallback.c +++ b/net/tls/tls_device_fallback.c @@ -420,13 +420,20 @@ struct sk_buff *tls_validate_xmit_skb(struct sock *sk, struct net_device *dev, struct sk_buff *skb) { - if (dev == tls_get_ctx(sk)->netdev) + if (dev == rcu_dereference_bh(tls_get_ctx(sk)->netdev)) return skb; return tls_sw_fallback(sk, skb); } EXPORT_SYMBOL_GPL(tls_validate_xmit_skb); +struct sk_buff *tls_validate_xmit_skb_sw(struct sock *sk, + struct net_device *dev, + struct sk_buff *skb) +{ + return tls_sw_fallback(sk, skb); +} + int tls_sw_fallback_init(struct sock *sk, struct tls_offload_context_tx *offload_ctx, struct tls_crypto_info *crypto_info) diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 19646ef9f6f61eebcb61287f460a6f8596eac6f4..9182314cc99a6c74597ded9bdfa0bd3894d69355 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -550,16 +550,19 @@ static int tls_setsockopt(struct sock *sk, int level, int optname, static struct tls_context *create_ctx(struct sock *sk) { struct inet_connection_sock *icsk = inet_csk(sk); + struct tls_context_wrapper *ctx_wrapper; struct tls_context *ctx; - ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC); - if (!ctx) + ctx_wrapper = kzalloc(sizeof(*ctx_wrapper), GFP_ATOMIC); + if (!ctx_wrapper) return NULL; + ctx = &ctx_wrapper->ctx; icsk->icsk_ulp_data = ctx; ctx->setsockopt = sk->sk_prot->setsockopt; ctx->getsockopt = sk->sk_prot->getsockopt; ctx->sk_proto_close = sk->sk_prot->close; + ctx_wrapper->sk = sk; return ctx; }