RAT-SQL 论文复现 —— bug 总结与完整复现流程

复现 ACL'20 论文:RAT-SQL 时遇到的 bug 总结以及自己的复现流程。

首先尝试了docker之后遇到一堆坑,最后决定直接使用conda环境。需要知道root密码。环境:liunx,Ubuntu20.04,RTX3090。"/path/to/"表示该文件或目录的所在路径,比如"/path/to/rat-sql",在具体路径"/home/ps/rat-sql"中,"/path/to/"等于"/home/ps/"。

1 BUG总结

1.1 ValueError: Unsupported kind for param args: VAR_POSITIONAL

发生在preprocess时。原因在于pytorch版本过高或者python版本过高,可用如下命令安装pytorch

1
conda install pytorch==1.3.1 cudatoolkit=10.1
preprocess阶段还有类似bug,都可用这个方法解决,请使用python 3.7

1.2 no space left on device

发生在train时。微软给的代码中保存的模型检查点过多,非bert训练大约需要几十G,bert则需要几百G。 需要指定 --logdir 到足够大的硬盘中,或减少检查点数量。

1.3 找不到__LOGDIR__路径

发生在eval时。在infer.py和eval.py中的__LOGDIR__都被替换为了实际的log路径,但是在run.py中,没有被替换,可以把run.py中104行开始的如下两行代码

1
2
res_json = json.load(open(eval_output_path))
print(step, res_json['total_scores']['all']['exact'])
替换成如下代码
1
2
3
4
5
6
7
8
9
10
11
12
13
model_config = json.loads(_jsonnet.evaluate_file(
eval_config.config,
tla_codes={'args': eval_config.config_args}))
if 'model_name' in model_config:
specific_logdir = os.path.join(logdir, model_config['model_name'])
eval_output_path = eval_output_path.replace('__LOGDIR__', specific_logdir)
res_json = json.load(open(eval_output_path))
print(step, res_json['total_scores']['all']['exact'])
else:
specific_logdir = logdir
eval_output_path = eval_output_path.replace('__LOGDIR__', specific_logdir)
res_json = json.load(open(eval_output_path))
print(step, res_json['total_scores'])

1.4 assert next_choices is not None

发生在eval wikisql时,需要把experiments/wikisql-glove-run.jsonnet中第12行的

1
eval_use_heuristic: true
改为
1
eval_use_heuristic: false

1.5 AttributeError: 'RMKeyView' object has no attribute 'index'

依旧发生在eval wikisql时,是records包本身的bug。找到path/to/anaconda3/envs/ratsql/lib/python3.7/site-packages/records.py(ratsql是conda环境名;python3.7是python版本) 找到第40行keys函数中

1
return self._keys
改为
1
return list(self._keys)

1.6 把自定义的包路径加入conda环境中

遇到不能pip install或conda install的包时,比如third_party中的wikisql。用PYTHONPATH=""等方法加到当前终端(或类似方法加到linux当前用户,linux所有用户)感觉相当麻烦,我就想加到我的conda环境中,也不影响其他的项目也不影响别人。使用如下命令一行解决。

1
conda develop /path/to/rat-sql/third_party/wikisql/
具体原理是在"/path/to/anaconda3/envs/ratsql/lib/python3.7/site-packages"(ratsql是conda环境名;python3.7是python版本)目录下生成一个conda.pth文件,conda环境会把conda.pth文件中的路径加到sys.path中,因此只在该conda环境中有效。

2 复现流程

2.1 安装linux包

1
sudo su
1
2
3
4
5
6
7
8
9
10
mkdir -p /usr/share/man/man1 && \
apt-get update && apt-get install -y \
build-essential \
cifs-utils \
curl \
default-jdk \
dialog \
dos2unix \
git \
sudo
1
exit

2.2 创建conda环境,安装python包

1
conda create -n ratsql python=3.7
1
conda activate ratsql
1
pip install asdl==0.1.5
1
pip install astor==0.7.1
1
pip install attrs==18.2.0
1
pip install babel==2.7.0
1
pip install bpemb==0.2.11
1
pip install cython==0.29.1
1
pip install jsonnet==0.14.0
1
pip install networkx==2.2
1
pip install nltk==3.4
1
pip install pyrsistent==0.14.9
1
pip install pytest==5.3.2
1
pip install records==0.5.3
1
pip install stanford-corenlp==3.9.2
1
pip install tabulate==0.8.6
1
conda install pytorch==1.3.1 cudatoolkit=10.1
1
pip install torchtext==0.3.1
1
pip install tqdm==4.36.1
1
pip install transformers==2.3.0
1
pip install entmax
1
pip install scikit-learn

2.3 下载nltk_data和bert

1
python -c "import nltk; nltk.download('stopwords'); nltk.download('punkt')"
1
python -c "from transformers import BertModel; BertModel.from_pretrained('bert-large-uncased-whole-word-masking')"

2.4 下载stanford-corenlp和wikisql官方脚本

1
sudo su
1
2
3
4
5
mkdir -p third_party && \
cd third_party && \
curl https://download.cs.stanford.edu/nlp/software/stanford-corenlp-full-2018-10-05.zip | jar xv && \
cd .. && \
git clone https://github.com/salesforce/WikiSQL third_party/wikisql
1
exit

连不上github可以把https://换成git://

2.5 把下载下来并组织好的data(参照rat-sql的readme)复制到项目中

1
2
3
4
mkdir -p data && \
cd data && \
cp -r /path/to/data ./ && \
cd ..

2.6 将所有 shell 脚本转换为 Unix 行尾

1
/bin/bash -c 'if compgen -G "/path/to/rat-sql/**/*.sh" > /dev/null; then dos2unix /app/**/*.sh; fi'

2.7 把wikisql官方脚本加到conda环境

1
conda develop /path/to/rat-sql/third_party/wikisql

3 运行

3.1 运行命令

3.1.1 spider-glove

1
python run.py preprocess experiments/spider-glove-run.jsonnet
1
python run.py train experiments/spider-glove-run.jsonnet
1
python run.py eval experiments/spider-glove-run.jsonnet

3.1.2 spider-bert

1
python run.py preprocess experiments/spider-bert-run.jsonnet
1
python run.py train experiments/spider-bert-run.jsonnet
1
python run.py eval experiments/spider-bert-run.jsonnet

3.1.3 wikisql-glove

1
python run.py preprocess experiments/wikisql-glove-run.jsonnet
1
python run.py train experiments/wikisql-glove-run.jsonnet
1
python run.py eval experiments/wikisql-glove-run.jsonnet

RAT-SQL 论文复现 —— bug 总结与完整复现流程
https://ymliucs.github.io/run_ratsql_model/
作者
Yumeng Liu
发布于
2021年12月18日
许可协议