Technology

Chart Type 《大数据经典论文解读》 三驾马车学习 Spark 内存管理及调优 Yarn学习 从Spark部署模式开始讲源码分析 容器狂占内存资源怎么办? 多角度理解一致性 golang io使用及优化模式 Flink学习 c++学习 学习ebpf go设计哲学 ceph学习 学习mesh kvm虚拟化 学习MQ go编译器 学习go 为什么要有堆栈 汇编语言 计算机组成原理 运行时和库 Prometheus client mysql 事务 mysql 事务的隔离级别 mysql 索引 坏味道 学习分布式 学习网络 学习Linux go 内存管理 golang 系统调用与阻塞处理 Goroutine 调度过程 重新认识cpu mosn有的没的 负载均衡泛谈 单元测试的新解读 《Redis核心技术与实现》笔记 《Prometheus监控实战》笔记 Prometheus 告警学习 calico源码分析 对容器云平台的理解 Prometheus 源码分析 并发的成本 基础设施优化 hashicorp raft源码学习 docker 架构 mosn细节 与微服务框架整合 Java动态代理 编程范式 并发通信模型 《网络是怎样连接的》笔记 go channel codereview gc分析 jvm 线程实现 go打包机制 go interface及反射 如何学习Kubernetes 《编译原理之美》笔记——后端部分 《编译原理之美》笔记——前端部分 Pilot MCP协议分析 go gc 内存管理玩法汇总 软件机制 istio流量管理 Pilot源码分析 golang io 学习Spring mosn源码浅析 MOSN简介 《datacenter as a computer》笔记 学习JVM Tomcat源码分析 Linux可观测性 学习存储 学计算 Gotty源码分析 kubernetes operator kaggle泰坦尼克问题实践 kubernetes扩缩容 神经网络模型优化 直觉上理解深度学习 如何学习机器学习 TIDB源码分析 什么是云原生 Alibaba Java诊断工具Arthas TIDB存储——TIKV 《Apache Kafka源码分析》——简介 netty中的线程池 guava cache 源码分析 Springboot 启动过程分析 Spring 创建Bean的年代变迁 Linux内存管理 自定义CNI IPAM 共识算法 spring redis 源码分析 kafka实践 spring kafka 源码分析 Linux进程调度 让kafka支持优先级队列 Codis源码分析 Redis源码分析 C语言学习 《趣谈Linux操作系统》笔记 docker和k8s安全访问机制 jvm crash分析 Prometheus 学习 Kubernetes监控 容器日志采集 Kubernetes 控制器模型 容器狂占资源怎么办? Kubernetes资源调度——scheduler 时序性数据库介绍及对比 influxdb入门 maven的基本概念 《Apache Kafka源码分析》——server Kubernetes类型系统 源码分析体会 《数据结构与算法之美》——算法新解 Kubernetes源码分析——controller mananger Kubernetes源码分析——apiserver Kubernetes源码分析——kubelet Kubernetes介绍 ansible学习 Kubernetes源码分析——从kubectl开始 jib源码分析之Step实现 jib源码分析之细节 线程排队 跨主机容器通信 jib源码分析及应用 为容器选择一个合适的entrypoint kubernetes yaml配置 《持续交付36讲》笔记 mybatis学习 程序猿应该知道的 无锁数据结构和算法 CNI——容器网络是如何打通的 为什么很多业务程序猿觉得数据结构和算法没用? 串一串一致性协议 当我在说PaaS时,我在说什么 《数据结构与算法之美》——数据结构笔记 PouchContainer技术分享体会 harbor学习 用groovy 来动态化你的代码 精简代码的利器——lombok 学习 《深入剖析kubernetes》笔记 编程语言那些事儿 rxjava3——背压 rxjava2——线程切换 spring cloud 初识 《深入拆解java 虚拟机》笔记 《how tomcat works》笔记 hystrix 学习 rxjava1——概念 Redis 学习 TIDB 学习 如何分发计算 Storm 学习 AQS1——论文学习 Unsafe Spark Stream 学习 linux vfs轮廓 《自己动手写docker》笔记 java8 实践 中本聪比特币白皮书 细读 区块链泛谈 比特币 大杂烩 总纲——如何学习分布式系统 hbase 泛谈 forkjoin 泛谈 看不见摸不着的cdn是啥 《jdk8 in action》笔记 程序猿视角看网络 bgp初识 calico学习 AQS——粗略的代码分析 我们能用反射做什么 web 跨域问题 《clean code》笔记 《Elasticsearch权威指南》笔记 mockito简介及源码分析 2017软件开发小结—— 从做功能到做系统 《Apache Kafka源码分析》——clients dns隐藏的一个坑 《mysql技术内幕》笔记 log4j学习 为什么netty比较难懂? 回溯法 apollo client源码分析及看待面向对象设计 学习并发 docker运行java项目的常见问题 OpenTSDB 入门 spring事务小结 分布式事务 javascript应用在哪里 《netty in action》读书笔记 netty对http2协议的解析 ssl证书是什么东西 http那些事 苹果APNs推送框架pushy apple 推送那些事儿 编写java框架的几大利器 java内存模型 java exception Linux IO学习 netty内存管理 测试环境docker化实践 netty在框架中的使用套路 Nginx简单使用 《Linux内核设计的艺术》小结 Go并发机制及语言层工具 Linux网络源代码学习——数据包的发送与接收 《docker源码分析》小结 docker namespace和cgroup Linux网络源代码学习——整体介绍 zookeeper三重奏 数据库的一些知识 Spark 泛谈 链式处理的那些套路 netty回顾 Thrift基本原理与实践(二) Thrift基本原理与实践(一) 回调 异步执行抽象——Executor与Future Docker0.1.0源码分析 java gc Jedis源码分析 深度学习泛谈 Linux网络命令操作 JTA与TCC 换个角度看待设计模式 Scala初识 向Hadoop学习NIO的使用 以新的角度看数据结构 并发控制相关的硬件与内核支持 systemd 简介 quartz 源码分析 基于docker搭建测试环境(二) spring aop 实现原理简述 自己动手写spring(八) 支持AOP 自己动手写spring(七) 类结构设计调整 分析log日志 自己动手写spring(六) 支持FactoryBean 自己动手写spring(九) 总结 自己动手写spring(五) bean的生命周期管理 自己动手写spring(四) 整合xml与注解方式 自己动手写spring(三) 支持注解方式 自己动手写spring(二) 创建一个bean工厂 自己动手写spring(一) 使用digester varnish 简单使用 关于docker image的那点事儿 基于docker搭建测试环境 分布式配置系统 JVM执行 git maven/ant/gradle/make使用 再看tcp kv系统 java nio的多线程扩展 《Concurrency Models》笔记 回头看Spring IOC IntelliJ IDEA使用 Java泛型 vagrant 使用 Go常用的一些库 Python初学 Goroutine 调度模型 虚拟网络 《程序员的自我修养》小结 Kubernetes存储 访问Kubernetes上的Service Kubernetes副本管理 Kubernetes pod 组件 Go基础 JVM类加载 硬币和扑克牌问题 LRU实现 virtualbox 使用 ThreadLocal小结 docker快速入门

Architecture

实时训练 分布式链路追踪 helm tensorflow原理——python层分析 如何学习tensorflow 数据并行——allreduce 数据并行——ps 机器学习中的python调用c 机器学习训练框架概述 embedding的原理及实践 tensornet源码分析 大模型训练 X的生成——特征工程 tvm tensorflow原理——core层分析 模型演变 《深度学习推荐系统实战》笔记 keras 和 Estimator tensorflow分布式训练 分布式训练的一些问题 基于Volcano的弹性训练 图神经网络 pytorch弹性分布式训练 在离线业务混部 RNN pytorch分布式训练 CNN 《动手学深度学习》笔记 pytorch与线性回归 多活 volcano特性源码分析 推理服务 kubebuilder 学习 mpi 学习pytorch client-go学习 tensorflow学习 提高gpu 利用率 GPU与容器的结合 GPU入门 AI云平台 tf-operator源码分析 k8s批处理调度 喜马拉雅容器化实践 Kubernetes 实践 学习rpc BFF 生命周期管理 openkruise学习 可观察性和监控系统 基于Kubernetes选主及应用 《许式伟的架构课》笔记 Kubernetes webhook 发布平台系统设计 k8s水平扩缩容 Scheduler如何给Node打分 Scheduler扩展 controller 组件介绍 openkruise cloneset学习 controller-runtime源码分析 pv与pvc实现 csi学习 client-go源码分析 kubelet 组件分析 调度实践 Pod是如何被创建出来的? 《软件设计之美》笔记 mecha 架构学习 Kubernetes events学习及应用 CRI 资源调度泛谈 业务系统设计原则 grpc学习 元编程 以应用为中心 istio学习 下一代微服务Service Mesh 《实现领域驱动设计》笔记 serverless 泛谈 概率论 《架构整洁之道》笔记 处理复杂性 那些年追过的并发 服务器端编程 网络通信协议 架构大杂烩 如何学习架构 《反应式设计模式》笔记 项目的演化特点 反应式架构摸索 函数式编程的设计模式 服务化 ddd反模式——CRUD的败笔 研发效能平台 重新看面向对象设计 业务系统设计的一些体会 函数式编程 《左耳听风》笔记 业务程序猿眼中的微服务管理 DDD实践——CQRS 项目隔离——案例研究 《编程的本质》笔记 系统故障排查汇总及教训 平台支持类系统的几个点 代码腾挪的艺术 abtest 系统设计汇总 《从0开始学架构》笔记 初级权限系统设计 领域驱动理念入门 现有上传协议分析 移动网络下的文件上传要注意的几个问题 推送系统的几个基本问题 用户登陆 做配置中心要想好的几个基本问题 不同层面的异步 分层那些事儿 性能问题分析 当我在说模板引擎的时候,我在说什么 用户认证问题 资源的分配与回收——池 消息/任务队列


embedding的原理及实践

2022年03月02日

简介

基本概念及原理

一种表述:Embedding 是个英文术语,如果非要找一个中文翻译对照的话,我觉得“向量化”(Vectorize)最合适。Embedding 的过程,就是把数据集合映射到向量空间,进而把数据进行向量化的过程。Embedding 的目标,就是找到一组合适的向量,来刻画现有的数据集合。

  1. 比如让国家作为模型参数,我们该如何用数字化的方式来表示它们呢?毕竟,模型只能消费数值,不能直接消费字符串。一种方法是把字符串转换为连续的整数,然后让模型去消费这些整数。。在理论上,这么做没有任何问题。但从模型的效果出发,整数的表达方式并不合理。为什么这么说呢?我们知道,连续整数之间,是存在比较关系的,比如 1 < 3,6 > 5,等等。但是原始的字符串之间,比如,国家并不存在大小关系,如果强行用 0 表示“中国”、用 1 表示“美国”,逻辑上就会出现“中国”<“美国”的悖论。仅仅是把字符串转换为数字,转换得到的数值是不能直接喂给模型做训练
  2. 我们需要把这些数字进一步向量化,才能交给模型去消费。Embedding 的方法也是日新月异、层出不穷。从最基本的热独编码到 PCA 降维,从 Word2Vec 到 Item2Vec,从矩阵分解到基于深度学习的协同过滤,可谓百花齐放、百家争鸣。

一种表述:embedding 是指将客观世界中离散的物体或对象(如单词、短语、图片)等映射到特征空间的操作,embedding向量是指映射后 的特征空间中连续且稠密的高维向量。在机器学习场景中,我们经常使用embedding向量 来描述客观世界的物体。embedding向量 不是对物体进行简单编号的结果,而是在尽量保持相似不变性的前提下 对物体进行特征抽象和编码的产物。通过不断训练,我们能够将客观世界中的物体不失真的映射到高维特征空间中,进而可以使用这些embedding向量 实现分类、回归和预测等操作。

Embedding 就是用一个数值向量“表示”一个对象(Object)的方法。“实体对象”可以是image、word等,“数值化表示”就是一个编码向量。例如对“颜色“这种实体对象用(R,G,B)这样一个三元素向量编码。embedding还可以理解成将离散目标投影到连续空间中的某个点上。数值化的embedding vector本身是没有意义的,不同vector之间的相对关系才是有实际意义的。例如:NLP中最基本的word embedding,给每一个单词一个N维编码向量(或者说将每个word投影到N维空间中),我们期望这种编码满足这样的特性:两个向量之间的”距离“越小,代表这两个单词含义越接近。比如利用 Word2vec 这个模型把单词映射到了高维空间中,从 king 到 queen 的向量和从 man 到 woman 的向量,无论从方向还是尺度来说它们都异常接近。

Embedding 技术对深度学习推荐系统的重要性

  1. Embedding 是处理稀疏特征的利器。因为推荐场景中的类别、ID 型特征非常多,大量使用 One-hot 编码会导致样本特征向量极度稀疏,而深度学习的结构特点又不利于稀疏特征向量的处理,因此几乎所有深度学习推荐模型都会由 Embedding 层负责将稀疏高维特征向量转换成稠密低维特征向量。
  2. Embedding 可以融合大量有价值信息,本身就是极其重要的特征向量 。 相比由原始信息直接处理得来的特征向量,Embedding 的表达能力更强,特别是 Graph Embedding 技术被提出后,Embedding 几乎可以引入任何信息进行编码,使其本身就包含大量有价值的信息,所以通过预训练得到的 Embedding 向量本身就是极其重要的特征向量。

Word2vec 是生成对“词”的向量表达的模型,其中,Word2vec 的训练样本是通过滑动窗口一一截取词组生成的。在训练完成后,模型输入向量矩阵的行向量,就是我们要提取的词向量。

在 Word2vec 诞生之后,Embedding 的思想迅速从自然语言处理领域扩散到几乎所有机器学习领域,既然 Word2vec 可以对词“序列”中的词进行 Embedding,那么对于用户购买“序列”中的一个商品,用户观看“序列”中的一个电影,也应该存在相应的 Embedding 方法。于是,微软于 2015 年提出了 Item2Vec 方法,它是对 Word2vec 方法的推广,使 Embedding 方法适用于几乎所有的序列数据。只要能够用序列数据的形式把我们要表达的对象表示出来,再把序列数据“喂”给 Word2vec 模型,我们就能够得到任意物品的 Embedding 了。假设我们知道 用户看过的电影的id 序列,比如296 380 344 588 593 231 595 318 480,那么此时电影id 是词,电影id 序列是句子,一个句子内的词有相互关系,那么就可以 根据 Item2vec 计算电影id 对应的 Embedding 向量。

Embedding这块,spark MLlib 和 机器学习库 都提供了处理函数。利用Tensorboard很容易将embedding进行可视化,不过既然是可视化,最高只能“可视”三维空间,所以高维向量需要被投影到三维(或二维空间)。不过不用担心细节,Tensorboard做了足够高质量的封装。

一文梳理推荐系统中Embedding应用实践

  1. 端到端的方法是将Embedding层作为神经网络的一部分,在进行BP更新每一层参数的时候同时更新Embedding,这种方法的好处是让Embedding的训练成为一个有监督的方式,可以很好的与最终的目标产生联系,使得Embedding与最终目标处于同一意义空间。但这样做的缺点同样显而易见的,由于Embedding层输入向量的维度甚大,Embedding层的加入会拖慢整个神经网络的收敛速度。大部分的训练时间和计算开销都被Embedding层所占据。正因为这个原因,「对于那些时间要求较为苛刻的场景,Embedding最好采用非端到端,也就是预训练的方式完成。」
  2. 非端到端(预训练),在一些时间要求比较高的场景下,Embedding的训练往往独立于深度学习网络进行,在得到稀疏特征的稠密表达之后,再与其他特征一起输入神经网络进行训练。在做任务时,将训练集中的词替换成事先训练好的向量表示放到网络中。Word2Vec,Doc2Vec,Item2Vec都是典型的非端到端的方法

在自然语言中,非端到端很常见,因为学到一个好的的词向量表示,就能很好地挖掘出词之间的潜在关系,那么在其他语料训练集和自然语言任务中,也能很好地表征这些词的内在联系,预训练的方式得到的Embedding并不会对最终的任务和模型造成太大影响,但却能够「提高效率节省时间」,这也是预训练的一大好处。但是在推荐场景下,根据不同目标构造出的序列不同,那么训练得到的Embedding挖掘出的关联信息也不同。所以,「在推荐中要想用预训练的方式,必须保证Embedding的预训练和最终任务目标处于同一意义空间」,否则就会造成预训练得到Embedding的意义和最终目标完全不一致。比如做召回阶段的深度模型的目标是衡量两个商品之间的相似性,但是CTR做的是预测用户点击商品的概率,初始化一个不相关的 Embedding 会给模型带来更大的负担,更慢地收敛。

在梯度下降这块对embedding weight也有针对性的优化算法,从梯度下降到FTRLFTRL是在广告/推荐领域会用到的优化方法,适用于对高维稀疏模型进行训练,获取稀疏解。

实践

《深度学习推荐系统实战》为什么深度学习的结构特点不利于稀疏特征向量的处理呢?一方面,如果我们深入到神经网络的梯度下降学习过程就会发现,特征过于稀疏会导致整个网络的收敛非常慢,因为每一个样本的学习只有极少数的权重会得到更新,这在样本数量有限的情况下会导致模型不收敛。另一个方面,One-hot 类稀疏特征的维度往往非常地大,可能会达到千万甚至亿的级别,如果直接连接进入深度学习网络,那整个模型的参数数量会非常庞大。因此,我们往往会先通过 Embedding 把原始稀疏特征稠密化,然后再输入复杂的深度学习网络进行训练,这相当于把原始特征向量跟上层复杂深度学习网络做一个隔离。

案例

从论文源码学习 之 embedding_lookup Embedding最重要的属性是:越“相似”的实体,Embedding之间的距离越小。比如用one-hot编码来表示4个梁山好汉。

李逵   [0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
刘唐   [0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
武松   [0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
鲁智深 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] 
==>
        二  出  官   武
        货  家  阶   力
李逵    [1   0   0   0.5]
刘唐    [1   0   0   0.4]
武松    [0   1   0.5 0.8]
鲁智深  [0   1   0.75 0.8] 

Embedding层把我们的稀疏矩阵,通过一些线性变换(比如用全连接层进行转换,也称为查表操作),变成了一个密集矩阵,这个密集矩阵用了N(例子中N=4)个特征来表征所有的好汉。在这个密集矩阵中,表象上代表着密集矩阵跟单个好汉的一一对应关系,实际上还蕴含了大量的好汉与好汉之间的内在关系(如:我们得出的李逵跟刘唐的关系)。它们之间的关系,用嵌入层学习来的参数进行表征。这个从稀疏矩阵到密集矩阵的过程,叫做embedding,很多人也把它叫做查表,因为它们之间也是一个一一映射的关系。这种映射关系在反向传播的过程中一直在更新。因此能在多次epoch后,使得这个关系变成相对成熟,即:正确的表达整个语义以及各个语句之间的关系。这个成熟的关系,就是embedding层的所有权重参数。Embedding最大的劣势是无法解释每个维度的含义,这也是复杂机器学习模型的通病。

Embedding除了把独立向量联系起来之外,还有两个作用:降维,升维。

  1. embedding层 降维的原理就是矩阵乘法。比如一个 1 x 4 的矩阵,乘以一个 4 x 3 的矩阵,得倒一个 1 x 3 的矩阵。4 x 3 的矩阵缩小了 1 / 4。假如我们有一个100W X 10W的矩阵,用它乘上一个10W X 20的矩阵,我们可以把它降到100W X 20,瞬间量级降了。
  2. 升维可以理解为:前面有一幅图画,你离远了看不清楚,离近了看就可以看清楚细节。当对低维的数据进行升维时,可能把一些其他特征给放大了,或者把笼统的特征给分开了。同时这个embedding是一直在学习在优化的,就使得整个拉近拉远的过程慢慢形成一个良好的观察点。

如何生成?

  1. 矩阵分解
  2. 无监督建模
  3. 有监督建模

Embedding与深度学习推荐系统的结合

NVIDIA HugeCTR,GPU版本参数服务器— (5) 嵌入式hash表 具有两个嵌入表和多个全连接层的神经网络

Embedding 权重矩阵可以是一个 [item_size, embedding_size] 的稠密矩阵,item_size是需要embedding的物品个数,embedding_size是映射的向量长度,或者说矩阵的大小是:特征数量 * 嵌入维度。Embedding 权重矩阵的每一行对应输入的一个维度特征(one-hot之后的维度)。用户可以用一个index表示选择了哪个特征。

这样就把两个 1 x 9 的高维度,离散,稀疏向量,压缩到 两个 1 x 3 的低维稠密向量。这里把 One-Hot 向量中 “1”的位置叫做sparseID,就是一个编号。这个独热向量和嵌入表的矩阵乘法就等于利用sparseID进行的一次查表过程。

TensorFlow 的 embedding_lookup(params, ids) 函数的目的是按照ids从params这个矩阵中拿向量(行),所以ids就是这个矩阵索引(行号),需要int类型。即按照ids顺序返回params中的第ids行。比如说,ids=[1,3,2],就是返回params中第1,3,2行。返回结果为由params的1,3,2行组成的tensor。

embedding_lookup是一种特殊的全连接层的实现方法,其针对 输入是超高维 one hot向量的情况。

  1. 神经网络处理不了onehot编码,Z = WX + b。由于X是One-Hot Encoding 的原因,WX 的矩阵乘法看起来就像是取了Weights矩阵中对应的一行,看起来就像是在查表,所以叫做 lookup。embedding_lookup(W,X)等于说进行了一次矩阵相乘运算,其实就是一次线性变换。
  2. 假设embedding权重矩阵是一个[vocab_size, embed_size]的稠密矩阵W,vocab_size是需要embed的所有item的个数(比如:所有词的个数,所有商品的个数),embed_size是映射后的向量长度。所谓embedding_lookup(W, id1),可以想像成一个只在id1位为1的[1, vocab_size]的one_hot向量,与[vocab_size, embed_size]的W矩阵相乘,结果是一个[1, embed_size]的向量,它就是id1对应的embedding向量,实际上就是W矩阵的第id1行。但是,以上过程只是forward,因为W一般是随机初始化的,是待优化的变量。因此,embedding_lookup除了要完成以上矩阵相乘的过程(实现成“抽取id对应的行”),还要完成自动求导,以实现对W的更新。PS: 所以embedding_lookup 的底层是一个op,在tensorflow r1.4 分支下,底层执行的是 array_ops.gather

tensorflow 实现

一般在tensorflow中都会使用一个shape=[id_index_size, embedding_size]的Variable 矩阵做embedding参数,然后根据id特征的index去Variable矩阵中查表得到相应的embedding表示。这里需要注意的是:id_index_size的大小一般都不会等于对应id table的元素个数,因为有很多id元素不在原始的id table表中,比如新上架的一些商品等。此时需要将id_index_size设置的大一些,以留一些位置给那些不在id table表的元素使用。

使用tf.Variable 作为 embedding参数

import numpy as np
import tensorflow as tf
sess = tf.InteractiveSession()
embedding = tf.Variable(np.identity(6, dtype=np.int32))    # 创建一个embedding词典
input_ids = tf.placeholder(dtype=tf.int32, shape=[None])
# 相对于 feature_column 中的EmbeddingColumn,embedding_lookup 是有点偏底层的api/op
input_embedding = tf.nn.embedding_lookup(embedding, input_ids)  # 把input_ids中给出的tensor表现成embedding中的形式

sess.run(tf.global_variables_initializer())
print("====== the embedding ====== ")
print(sess.run(embedding) )
print("====== the input_embedding ====== ")
print(sess.run(input_embedding, feed_dict={input_ids: [4, 0, 2]}))
====== the embedding ====== 
[[1 0 0 0 0 0]
 [0 1 0 0 0 0]
 [0 0 1 0 0 0]
 [0 0 0 1 0 0]
 [0 0 0 0 1 0]
 [0 0 0 0 0 1]]
====== the input_embedding ====== 
[[0 0 0 0 1 0]
 [1 0 0 0 0 0]
 [0 0 1 0 0 0]]

使用get_embedding_variable接口

var = tf.get_embedding_variable("var_0",embedding_dim=3,initializer=tf.ones_initializer(tf.float32),partitioner=tf.fixed_size_partitioner(num_shards=4))
shape = [var1.total_count() for var1 in var]
emb = tf.nn.embedding_lookup(var, tf.cast([0,1,2,5,6,7], tf.int64))
...

使用categorical_column_with_embedding接口

columns = tf.feature_column.categorical_column_with_embedding("col_emb", dtype=tf.dtypes.int64)
W = tf.feature_column.embedding_column(categorical_column=columns,dimension=3,initializer=tf.ones_initializer(tf.dtypes.float32))
ids={}
ids["col_emb"] = tf.SparseTensor(indices=[[0,0],[1,1],[2,2],[3,3],[4,4]], values=tf.cast([1,2,3,4,5], tf.dtypes.int64), dense_shape=[5, 4])
emb = tf.feature_column.input_layer(ids, [W])

从论文源码学习 之 embedding层如何自动更新input_embedding = embedding * input_ids 从效果上 可以把 input_ids 视为索引的作用,返回第4、0、2 行数据,但 embedding_lookup 函数 也可以看做是一个 矩阵乘法(底层两种都支持,是一个策略参数),也因此 embedding层可以通过 optimizer 进行更新。

原生的tf optimizer 根据 梯度/grad 的类型 来决定更新weight/ variable 的方法,当传来的梯度是普通tensor时,调用_apply_dense方法去更新参数;当传来的梯度是IndexedSlices类型时,则去调用optimizer._apply_sparse_duplicate_indices函数。 Embedding 参数的梯度中包含每个 tensor 中发生变化的数据切片 IndexedSlices。IndexedSlices类型是一种可以存储稀疏矩阵的数据结构,只需要存储对应的行号和相应的值即可。可以认为是一种类似 SparseTensor 的思想,用元素数据和元素位置表示一个较大 tensor 。将 tensor 按第一维度切片,从而将一个较大的形状为 [LARGE0, D1, .. , DN] 的 tensor 表示为多个较小的形状为 [D1, .. , DN] 的 tensor。

总结一下涉及到哪些问题: 稀疏参数的表示(开始由Variable 表示 ,各种框架提供EmbeddingVariable 表示)、存储(ps,底层是分布式hashmap)、通信(只通信部分,数据存在gpu + gpu 直接通信)、优化(稀疏参数的优化器与稠密参数的优化器不兼容) 和 稀疏参数的梯度的表示、通信(由IndexedSlices 表示)、优化

嵌入层的优化

DL 推荐模型的嵌入层是比较特殊的:它们为模型贡献了大量参数,但几乎不需要计算,而计算密集型denser layers的参数数量则要少得多。所以对于推荐系统,嵌入层的优化十分重要。

点击率预测模型Embedding层的学习和训练

原理上

TensorFlow在美团外卖推荐场景的GPU训练优化实践-参数规模的合理化

  1. 去交叉特征
  2. 精简特征
  3. 压缩Embedding向量数
  4. 压缩Embedding向量维度
  5. 量化压缩

    工程上

embedding部分的难点在于存储和检索。DNN这部分主要是稠密计算。Embedding 优化

  1. 把嵌入层分布在多个 GPU 和多个节点上
  2. Embedding 层模型并行,dense 层数据并行。