题目描述
给你一个产品数组 products 和一个字符串 searchWord ,products 数组中每个产品都是一个字符串。请你设计一个推荐系统,在依次输入单词 searchWord 的每一个字母后,推荐 products 数组中前缀与 searchWord 相同的最多三个产品。如果前缀相同的可推荐产品超过三个,请按字典序返回最小的三个。请你以二维列表的形式,返回在输入 searchWord 每个字母后相应的推荐产品的列表。
解析
在上一题实现前缀树的基础上,只需要查询前缀树中有前缀的前三个单词即可,使用递归去查找子节点。
static class Trie {
boolean isEndOfWord;
Trie[] children;
public Trie() {
this.isEndOfWord = false;
this.children = new Trie[26];
}
public Trie(String[] stringList) {
this();
for (String s : stringList) {
this.insert(s);
}
}
public void insert(String word) {
Trie cur = this;
for (int i = 0; i < word.length(); i++) {
int index = word.charAt(i) - 'a';
if (cur.children[index] == null) {
cur.children[index] = new Trie();
}
cur = cur.children[index];
}
cur.isEndOfWord = true;
}
private void DFS(List<String> res, Trie node, int num, StringBuilder str) {
if (res.size() >= num) {
return;
}
if (node.isEndOfWord) {
res.add(str.toString());
}
for (int i = 0; i < 26; i++) {
if (node.children[i] != null) {
str.append((char) (i + 'a'));
DFS(res, node.children[i], num, str);
str.deleteCharAt(str.length() - 1);
}
}
}
public List<String> getRecommendation(String prefix, int num) {
List<String> res = new ArrayList<>(num);
Trie cur = this;
for (int i = 0; i < prefix.length(); i++) {
int index = prefix.charAt(i) - 'a';
if (cur.children[index] == null) {
return res;
}
cur = cur.children[index];
}
DFS(res, cur, num, new StringBuilder(prefix));
return res;
}
}
public List<List<String>> suggestedProducts(String[] products, String searchWord) {
List<List<String>> res = new ArrayList<>(searchWord.length());
Trie trie = new Trie(products);
StringBuilder prefix = new StringBuilder();
for (char c : searchWord.toCharArray()) {
prefix.append(c);
res.add(trie.getRecommendation(prefix.toString(), 3));
}
return res;
}
实际上,并不需要使用前缀树这种复杂的结构,我们可以直接将产品字符串数组按照ASCII码的顺序排序,然后通过二分查找去查询即可。查询使用String的compareTo方法即可,如果products[mid].compareTo(prefix) < 0
,说明待搜索的元素在中值的右边,继续搜索右边即可。
public List<List<String>> suggestedProducts(String[] products, String searchWord) {
Arrays.sort(products);
List<List<String>> res = new ArrayList<>();
String prefix = "";
for (char c : searchWord.toCharArray()) {
prefix += c;
int start = binarySearch(products, prefix);
List<String> recommendations = new ArrayList<>();
for (int i = start; i < products.length && recommendations.size() < 3; i++) {
if (products[i].startsWith(prefix)) {
recommendations.add(products[i]);
} else {
break;
}
}
res.add(recommendations);
}
return res;
}
private int binarySearch(String[] products, String prefix) {
int low = 0, high = products.length;
while (low < high) {
int mid = (low + high) / 2;
if (products[mid].compareTo(prefix) < 0) {
low = mid + 1;
} else {
high = mid;
}
}
return low;
}
上面的方法使用了startsWith这个方法,对于测试用例来说,比较的字符串非常长,那么这种方式是非常耗时的,那么可以用两次二分查找去找前缀的起点和终点,得到区间就不需要使用startsWith去判断了。
private int findBorder(String[] a, int start, int end, int i, char s, boolean left) {
while (start <= end) {
int mid = (start + end) / 2;
if (a[mid].length() <= i) {
start = mid + 1;
continue;
}
char c = a[mid].charAt(i);
if (c < s) {
a[mid] = "";
start = mid + 1;
} else if (c > s) {
end = mid - 1;
} else if (left) {
end = mid - 1;
} else {
start = mid + 1;
}
}
if (left) {
return start;
} else {
return end;
}
}
private int findLeft(String[] a, int start, int end, int i, char s) {
return findBorder(a, start, end, i, s, true);
}
private int findRight(String[] a, int start, int end, int i, char s) {
return findBorder(a, start, end, i, s, false);
}
public List<List<String>> suggestedProducts(String[] products, String searchWord) {
char[] s = searchWord.toCharArray();
Arrays.sort(products);
int n = products.length;
int start = 0;
int end = n - 1;
List<List<String>> output = new ArrayList<>(s.length);
for (int i = 0; i < s.length; i++) {
output.add(new ArrayList<>());
}
for (int i = 0; i < s.length; i++) {
char c = s[i];
int tmpStart = start;
int tmpEnd = end;
start = findLeft(products, tmpStart, tmpEnd, i, c);
end = findRight(products, tmpStart, tmpEnd, i, c);
if (start > end) {
break;
}
for (int j = 0; j < Math.min(end - start + 1, 3); j++) {
output.get(i).add(products[start + j]);
}
}
return output;
}