diff --git a/include/linux/bpf.h b/include/linux/bpf.h index bb26c2e18092..ad4bb36d4c10 100644 --- a/include/linux/bpf.h +++ b/include/linux/bpf.h @@ -2484,6 +2484,7 @@ int bpf_dev_bound_kfunc_check(struct bpf_verifier_log *log, struct bpf_prog_aux *prog_aux); void *bpf_dev_bound_resolve_kfunc(struct bpf_prog *prog, u32 func_id); int bpf_prog_dev_bound_init(struct bpf_prog *prog, union bpf_attr *attr); +int bpf_prog_dev_bound_inherit(struct bpf_prog *new_prog, struct bpf_prog *old_prog); void bpf_dev_bound_netdev_unregister(struct net_device *dev); static inline bool bpf_prog_is_dev_bound(const struct bpf_prog_aux *aux) @@ -2496,6 +2497,8 @@ static inline bool bpf_prog_is_offloaded(const struct bpf_prog_aux *aux) return aux->offload_requested; } +bool bpf_prog_dev_bound_match(const struct bpf_prog *lhs, const struct bpf_prog *rhs); + static inline bool bpf_map_is_offloaded(struct bpf_map *map) { return unlikely(map->ops == &bpf_map_offload_ops); @@ -2535,6 +2538,12 @@ static inline int bpf_prog_dev_bound_init(struct bpf_prog *prog, return -EOPNOTSUPP; } +static inline int bpf_prog_dev_bound_inherit(struct bpf_prog *new_prog, + struct bpf_prog *old_prog) +{ + return -EOPNOTSUPP; +} + static inline void bpf_dev_bound_netdev_unregister(struct net_device *dev) { } @@ -2549,6 +2558,11 @@ static inline bool bpf_prog_is_offloaded(struct bpf_prog_aux *aux) return false; } +static inline bool bpf_prog_dev_bound_match(const struct bpf_prog *lhs, const struct bpf_prog *rhs) +{ + return false; +} + static inline bool bpf_map_is_offloaded(struct bpf_map *map) { return false; diff --git a/kernel/bpf/offload.c b/kernel/bpf/offload.c index 3e173c694bbb..e87cab2ed710 100644 --- a/kernel/bpf/offload.c +++ b/kernel/bpf/offload.c @@ -187,12 +187,49 @@ static void __bpf_offload_dev_netdev_unregister(struct bpf_offload_dev *offdev, kfree(ondev); } -int bpf_prog_dev_bound_init(struct bpf_prog *prog, union bpf_attr *attr) +static int __bpf_prog_dev_bound_init(struct bpf_prog *prog, struct net_device *netdev) { struct bpf_offload_netdev *ondev; struct bpf_prog_offload *offload; int err; + offload = kzalloc(sizeof(*offload), GFP_USER); + if (!offload) + return -ENOMEM; + + offload->prog = prog; + offload->netdev = netdev; + + ondev = bpf_offload_find_netdev(offload->netdev); + if (!ondev) { + if (bpf_prog_is_offloaded(prog->aux)) { + err = -EINVAL; + goto err_free; + } + + /* When only binding to the device, explicitly + * create an entry in the hashtable. + */ + err = __bpf_offload_dev_netdev_register(NULL, offload->netdev); + if (err) + goto err_free; + ondev = bpf_offload_find_netdev(offload->netdev); + } + offload->offdev = ondev->offdev; + prog->aux->offload = offload; + list_add_tail(&offload->offloads, &ondev->progs); + + return 0; +err_free: + kfree(offload); + return err; +} + +int bpf_prog_dev_bound_init(struct bpf_prog *prog, union bpf_attr *attr) +{ + struct net_device *netdev; + int err; + if (attr->prog_type != BPF_PROG_TYPE_SCHED_CLS && attr->prog_type != BPF_PROG_TYPE_XDP) return -EINVAL; @@ -204,49 +241,48 @@ int bpf_prog_dev_bound_init(struct bpf_prog *prog, union bpf_attr *attr) attr->prog_flags & BPF_F_XDP_DEV_BOUND_ONLY) return -EINVAL; - offload = kzalloc(sizeof(*offload), GFP_USER); - if (!offload) - return -ENOMEM; + netdev = dev_get_by_index(current->nsproxy->net_ns, attr->prog_ifindex); + if (!netdev) + return -EINVAL; - offload->prog = prog; - - offload->netdev = dev_get_by_index(current->nsproxy->net_ns, - attr->prog_ifindex); - err = bpf_dev_offload_check(offload->netdev); + err = bpf_dev_offload_check(netdev); if (err) - goto err_maybe_put; + goto out; prog->aux->offload_requested = !(attr->prog_flags & BPF_F_XDP_DEV_BOUND_ONLY); down_write(&bpf_devs_lock); - ondev = bpf_offload_find_netdev(offload->netdev); - if (!ondev) { - if (bpf_prog_is_offloaded(prog->aux)) { - err = -EINVAL; - goto err_unlock; - } + err = __bpf_prog_dev_bound_init(prog, netdev); + up_write(&bpf_devs_lock); - /* When only binding to the device, explicitly - * create an entry in the hashtable. - */ - err = __bpf_offload_dev_netdev_register(NULL, offload->netdev); - if (err) - goto err_unlock; - ondev = bpf_offload_find_netdev(offload->netdev); +out: + dev_put(netdev); + return err; +} + +int bpf_prog_dev_bound_inherit(struct bpf_prog *new_prog, struct bpf_prog *old_prog) +{ + int err; + + if (!bpf_prog_is_dev_bound(old_prog->aux)) + return 0; + + if (bpf_prog_is_offloaded(old_prog->aux)) + return -EINVAL; + + new_prog->aux->dev_bound = old_prog->aux->dev_bound; + new_prog->aux->offload_requested = old_prog->aux->offload_requested; + + down_write(&bpf_devs_lock); + if (!old_prog->aux->offload) { + err = -EINVAL; + goto out; } - offload->offdev = ondev->offdev; - prog->aux->offload = offload; - list_add_tail(&offload->offloads, &ondev->progs); - dev_put(offload->netdev); - up_write(&bpf_devs_lock); - return 0; -err_unlock: + err = __bpf_prog_dev_bound_init(new_prog, old_prog->aux->offload->netdev); + +out: up_write(&bpf_devs_lock); -err_maybe_put: - if (offload->netdev) - dev_put(offload->netdev); - kfree(offload); return err; } @@ -675,6 +711,22 @@ bool bpf_offload_dev_match(struct bpf_prog *prog, struct net_device *netdev) } EXPORT_SYMBOL_GPL(bpf_offload_dev_match); +bool bpf_prog_dev_bound_match(const struct bpf_prog *lhs, const struct bpf_prog *rhs) +{ + bool ret; + + if (bpf_prog_is_offloaded(lhs->aux) != bpf_prog_is_offloaded(rhs->aux)) + return false; + + down_read(&bpf_devs_lock); + ret = lhs->aux->offload && rhs->aux->offload && + lhs->aux->offload->netdev && + lhs->aux->offload->netdev == rhs->aux->offload->netdev; + up_read(&bpf_devs_lock); + + return ret; +} + bool bpf_offload_prog_map_match(struct bpf_prog *prog, struct bpf_map *map) { struct bpf_offloaded_map *offmap; diff --git a/kernel/bpf/syscall.c b/kernel/bpf/syscall.c index fdf4ff3d5a7f..d5ffa7a01dfb 100644 --- a/kernel/bpf/syscall.c +++ b/kernel/bpf/syscall.c @@ -2605,6 +2605,13 @@ static int bpf_prog_load(union bpf_attr *attr, bpfptr_t uattr) goto free_prog_sec; } + if (type == BPF_PROG_TYPE_EXT && dst_prog && + bpf_prog_is_dev_bound(dst_prog->aux)) { + err = bpf_prog_dev_bound_inherit(prog, dst_prog); + if (err) + goto free_prog_sec; + } + /* find program type: socket_filter vs tracing_filter */ err = find_prog_type(type, prog); if (err < 0) diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index 9009395206f8..800488289297 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -16813,8 +16813,9 @@ int bpf_check_attach_target(struct bpf_verifier_log *log, if (tgt_prog) { struct bpf_prog_aux *aux = tgt_prog->aux; - if (bpf_prog_is_dev_bound(tgt_prog->aux)) { - bpf_log(log, "Replacing device-bound programs not supported\n"); + if (bpf_prog_is_dev_bound(prog->aux) && + !bpf_prog_dev_bound_match(prog, tgt_prog)) { + bpf_log(log, "Target program bound device mismatch"); return -EINVAL; }